diff --git a/handler.go b/handler.go index 10e7da4..c2a688b 100644 --- a/handler.go +++ b/handler.go @@ -3,6 +3,7 @@ package pancheri import ( "github.com/miekg/dns" "github.com/sirupsen/logrus" + "strings" ) type Handler struct { @@ -28,20 +29,274 @@ func (h *Handler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - question := r.Question[0] + q := r.Question[0] - // okay, do we have upstream resolution enabled? + // is it ours? + for authority, zone := range h.A.Zones { + if strings.HasSuffix(q.Name, authority) { + // ok, its an owned domain - respond to it + logrus.WithFields(logrus.Fields{ + "name": q.Name, + "qtype": q.Qtype, + "zone": zone.Zonefile, + "raw": q.String(), + }).Trace("responding to query for authoritative zone") + + if q.Qtype == dns.TypeA { + record, ok := zone.ARecords[q.Name] + if !ok { + // SPECIAL CASE: for A and AAAA records, return with a CNAME if and only if that cname exists + cname, cok := zone.CNAMERecords[q.Name] + if cok { + // return with the CNAME record instead and resolve the CNAME + rendered := cname.Render() + msg.Rcode = dns.RcodeSuccess + msg.Answer = []dns.RR{ + rendered, + } + msg.RecursionAvailable = true + msg.Authoritative = true + + for authority, zone := range h.A.Zones { + if strings.HasSuffix(cname.Target, authority) { + // its authoritative + if rec, ok := zone.CNAMERecords[cname.Target]; ok { + // double cname isn't allowed right now + logrus.WithFields(logrus.Fields{ + "in": q.Name, + "target": cname.Target, + "target2": rec.Target, + }).Error("double cname") + msg.Answer = []dns.RR{} + msg.Rcode = dns.RcodeNameError + break + } + if rec, ok := zone.ARecords[cname.Target]; ok { + msg.Answer = append(msg.Answer, rec.Render()) + } + } + } + if len(msg.Answer) == 1 && msg.Rcode == dns.RcodeSuccess { + // ok, it's not ours + // this is also not allowed right now (TODO) + logrus.WithFields(logrus.Fields{ + "in": q.Name, + "target": cname.Target, + }).Error("external cname (TODO)") + msg.Answer = []dns.RR{} + msg.Rcode = dns.RcodeNameError + break + } + + err := w.WriteMsg(msg) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } else { + // send an nxdomain + if r.RecursionDesired { + msg.RecursionAvailable = true + } + msg.Rcode = dns.RcodeNameError + err := w.WriteMsg(msg) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } + } + // return with the record + rendered := record.Render() + msg.Rcode = dns.RcodeSuccess + msg.Answer = []dns.RR{ + rendered, + } + msg.RecursionAvailable = true + msg.Authoritative = true + err := w.WriteMsg(msg) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } else if q.Qtype == dns.TypeAAAA { + record, ok := zone.AAAARecords[q.Name] + if !ok { + // SPECIAL CASE: for A and AAAA records, return with a CNAME if and only if that cname exists + cname, cok := zone.CNAMERecords[q.Name] + if cok { + // return with the CNAME record instead and resolve the CNAME + rendered := cname.Render() + msg.Rcode = dns.RcodeSuccess + msg.Answer = []dns.RR{ + rendered, + } + msg.RecursionAvailable = true + msg.Authoritative = true + + for authority, zone := range h.A.Zones { + if strings.HasSuffix(cname.Target, authority) { + // its authoritative + if rec, ok := zone.CNAMERecords[cname.Target]; ok { + // double cname isn't allowed right now + logrus.WithFields(logrus.Fields{ + "in": q.Name, + "target": cname.Target, + "target2": rec.Target, + }).Error("double cname") + msg.Answer = []dns.RR{} + msg.Rcode = dns.RcodeNameError + break + } + if rec, ok := zone.AAAARecords[cname.Target]; ok { + msg.Answer = append(msg.Answer, rec.Render()) + } + } + } + if len(msg.Answer) == 1 && msg.Rcode == dns.RcodeSuccess { + // ok, it's not ours + // this is also not allowed right now (TODO) + logrus.WithFields(logrus.Fields{ + "in": q.Name, + "target": cname.Target, + }).Error("external cname (TODO)") + msg.Answer = []dns.RR{} + msg.Rcode = dns.RcodeNameError + break + } + + err := w.WriteMsg(msg) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } else { + // send an nxdomain + if r.RecursionDesired { + msg.RecursionAvailable = true + } + msg.Rcode = dns.RcodeNameError + err := w.WriteMsg(msg) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } + } + // return with the record + rendered := record.Render() + msg.Rcode = dns.RcodeSuccess + msg.Answer = []dns.RR{ + rendered, + } + msg.RecursionAvailable = true + msg.Authoritative = true + err := w.WriteMsg(msg) + logrus.Trace(msg.String()) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } else if q.Qtype == dns.TypeCNAME { + record, ok := zone.CNAMERecords[q.Name] + if !ok { + // send an nxdomain + if r.RecursionDesired { + msg.RecursionAvailable = true + } + msg.Rcode = dns.RcodeNameError + err := w.WriteMsg(msg) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } + // return with the record + rendered := record.Render() + msg.Rcode = dns.RcodeSuccess + msg.Answer = []dns.RR{ + rendered, + } + msg.RecursionAvailable = true + msg.Authoritative = true + err := w.WriteMsg(msg) + logrus.Trace(msg.String()) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } else if q.Qtype == dns.TypeTXT { + record, ok := zone.TXTRecords[q.Name] + if !ok { + // send an nxdomain + if r.RecursionDesired { + msg.RecursionAvailable = true + } + msg.Rcode = dns.RcodeNameError + err := w.WriteMsg(msg) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } + // return with the record + rendered := record.Render() + msg.Rcode = dns.RcodeSuccess + msg.Answer = []dns.RR{ + rendered, + } + msg.RecursionAvailable = true + msg.Authoritative = true + err := w.WriteMsg(msg) + logrus.Trace(msg.String()) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } else { + // not supported + logrus.WithFields(logrus.Fields{ + "name": q.Name, + "qtype": q.Qtype, + }).Error("received unsupported question type") + // send an nxdomain + if r.RecursionDesired { + msg.RecursionAvailable = true + } + msg.Rcode = dns.RcodeNameError + err := w.WriteMsg(msg) + if err != nil { + logrus.Errorf("error responding: %s", err) + return + } + return + } + + return + } + } + // no. do we have upstream resolution enabled? if h.C.Resolver.Enable { // alright, resolve it with the resolver - answers, rcode, err := h.R.Resolve(question.Name, question.Qtype) + resp, err := h.R.Resolve(q.Name, q.Qtype) + resp.SetReply(r) + if err != nil { logrus.Errorf("error resolving: %s", err) return } - msg.Rcode = rcode - msg.Answer = append(msg.Answer, answers...) - err = w.WriteMsg(msg) + err = w.WriteMsg(resp) if err != nil { logrus.Errorf("error responding: %s", err) return diff --git a/resolver.go b/resolver.go index 76df16e..58e4e0e 100644 --- a/resolver.go +++ b/resolver.go @@ -17,7 +17,7 @@ func NewResolver(upstream string) *Resolver { } } -func (r *Resolver) Resolve(domain string, qtype uint16) ([]dns.RR, int, error) { +func (r *Resolver) Resolve(domain string, qtype uint16) (*dns.Msg, error) { logrus.WithFields(logrus.Fields{ "domain": domain, "qtype": qtype, @@ -33,5 +33,5 @@ func (r *Resolver) Resolve(domain string, qtype uint16) ([]dns.RR, int, error) { return nil, err } - return in.Answer, in.Rcode, nil + return in, nil } diff --git a/zone.go b/zone.go index d73bbc2..4c6088c 100644 --- a/zone.go +++ b/zone.go @@ -6,23 +6,20 @@ import ( ) type Zone struct { - Root string `yaml:"root"` - ReducedHash string `yaml:"rsha"` - Zonefile string `yaml:"zf"` - ARecords []RecordA `yaml:"ra"` - AAAARecords []RecordAAAA `yaml:"rav6"` - CNAMERecords []RecordCNAME `yaml:"rcn"` - TXTRecords []RecordTXT `yaml:"rtx"` + Root string `yaml:"root"` + ReducedHash string `yaml:"rsha"` + Zonefile string `yaml:"zf"` + ARecords map[string]RecordA `yaml:"ra"` + AAAARecords map[string]RecordAAAA `yaml:"rav6"` + CNAMERecords map[string]RecordCNAME `yaml:"rcn"` + TXTRecords map[string]RecordTXT `yaml:"rtx"` } func (z *Zone) RenderZone() string { outString := "" outString += fmt.Sprintf(";; Rendered zonefile for %s (rsha %s) at %s\n", z.Zonefile, z.ReducedHash, time.Now().Format(time.RFC3339)) outString += ";; Generated by pancheri-render. Note: this will NOT work out of the box!\n" - outString += ";; At the very least, you'll need to change the SOA and NS values.\n" - outString += "\n" - outString += ";; SOA & NS records\n" - outString += ";; TODO\n" + outString += ";; At the very least, you'll need add SOA and NS values.\n" outString += "\n" outString += ";; A Records\n" diff --git a/zone_config.go b/zone_config.go index 0605d07..9e4cef2 100644 --- a/zone_config.go +++ b/zone_config.go @@ -61,10 +61,10 @@ func LoadZone(path string) (*Zone, error) { Root: cfg.Zone.Root, ReducedHash: reducedHash, Zonefile: path, - ARecords: nil, - AAAARecords: nil, - CNAMERecords: nil, - TXTRecords: nil, + ARecords: make(map[string]RecordA), + AAAARecords: make(map[string]RecordAAAA), + CNAMERecords: make(map[string]RecordCNAME), + TXTRecords: make(map[string]RecordTXT), } for _, record := range cfg.Zone.Records { @@ -83,11 +83,14 @@ func LoadZone(path string) (*Zone, error) { if !strings.HasSuffix(domain, ".") { domain = domain + "." + cfg.Zone.Root + "." } - zone.ARecords = append(zone.ARecords, RecordA{ + if _, ok := zone.ARecords[domain]; ok { + return nil, errors.New("duplicate A record " + domain) + } + zone.ARecords[domain] = RecordA{ In: domain, Ip: record.Ipv4.To4(), TTL: record.TTL, - }) + } } } else if record.RecordType == RuleTypeAAAA { // req.d fields: in, ip, ttl @@ -104,11 +107,14 @@ func LoadZone(path string) (*Zone, error) { if !strings.HasSuffix(domain, ".") { domain = domain + "." + cfg.Zone.Root + "." } - zone.AAAARecords = append(zone.AAAARecords, RecordAAAA{ + if _, ok := zone.AAAARecords[domain]; ok { + return nil, errors.New("duplicate AAAA record " + domain) + } + zone.AAAARecords[domain] = RecordAAAA{ In: domain, Ip: record.Ipv6, TTL: record.TTL, - }) + } } } else if record.RecordType == RuleTypeCNAME { // req.d fields: in, ip, ttl @@ -121,15 +127,21 @@ func LoadZone(path string) (*Zone, error) { if record.TTL == 0 { return nil, errors.New("CNAME record TTL cannot be 0 or empty") } + if !strings.HasSuffix(record.Target, ".") { + record.Target = record.Target + "." + cfg.Zone.Root + "." + } for _, domain := range record.Domains { if !strings.HasSuffix(domain, ".") { domain = domain + "." + cfg.Zone.Root + "." } - zone.CNAMERecords = append(zone.CNAMERecords, RecordCNAME{ + if _, ok := zone.CNAMERecords[domain]; ok { + return nil, errors.New("duplicate CNAME record " + domain) + } + zone.CNAMERecords[domain] = RecordCNAME{ In: domain, Target: record.Target, TTL: record.TTL, - }) + } } } else if record.RecordType == RuleTypeTXT { // req.d fields: in, content, ttl @@ -146,11 +158,14 @@ func LoadZone(path string) (*Zone, error) { if !strings.HasSuffix(domain, ".") { domain = domain + "." + cfg.Zone.Root + "." } - zone.TXTRecords = append(zone.TXTRecords, RecordTXT{ + if _, ok := zone.TXTRecords[domain]; ok { + return nil, errors.New("duplicate TXT record " + domain) + } + zone.TXTRecords[domain] = RecordTXT{ In: domain, Content: record.Content, TTL: record.TTL, - }) + } } } }