mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-31 10:13:14 -04:00 
			
		
		
		
	
		
			
	
	
		
			159 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			159 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package nomad
 | ||
|  | 
 | ||
|  | import (
 | ||
|  | 	"context"
 | ||
|  | 	"fmt"
 | ||
|  | 	"net"
 | ||
|  | 
 | ||
|  | 	"github.com/coredns/coredns/plugin"
 | ||
|  | 	"github.com/coredns/coredns/plugin/metrics"
 | ||
|  | 	"github.com/coredns/coredns/plugin/pkg/dnsutil"
 | ||
|  | 	clog "github.com/coredns/coredns/plugin/pkg/log"
 | ||
|  | 	"github.com/coredns/coredns/request"
 | ||
|  | 
 | ||
|  | 	"github.com/hashicorp/nomad/api"
 | ||
|  | 	"github.com/miekg/dns"
 | ||
|  | )
 | ||
|  | 
 | ||
|  | const pluginName = "nomad"
 | ||
|  | 
 | ||
|  | var (
 | ||
|  | 	log        = clog.NewWithPlugin(pluginName)
 | ||
|  | 	defaultTTL = 30
 | ||
|  | )
 | ||
|  | 
 | ||
|  | type Nomad struct {
 | ||
|  | 	Next plugin.Handler
 | ||
|  | 
 | ||
|  | 	ttl     uint32
 | ||
|  | 	Zone    string
 | ||
|  | 	clients []*api.Client
 | ||
|  | 	current int
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func (n *Nomad) Name() string {
 | ||
|  | 	return pluginName
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func (n Nomad) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
 | ||
|  | 	state := request.Request{W: w, Req: r}
 | ||
|  | 	qname, originalQName, err := processQName(state.Name(), n.Zone)
 | ||
|  | 	if err != nil {
 | ||
|  | 		return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	namespace, serviceName, err := extractNamespaceAndService(qname)
 | ||
|  | 	if err != nil {
 | ||
|  | 		return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	m, header := initializeMessage(state, n.ttl)
 | ||
|  | 
 | ||
|  | 	svcRegistrations, _, err := fetchServiceRegistrations(n, serviceName, namespace)
 | ||
|  | 	if err != nil {
 | ||
|  | 		log.Warning(err)
 | ||
|  | 		return handleServiceLookupError(w, m, ctx, namespace)
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	if len(svcRegistrations) == 0 {
 | ||
|  | 		return handleResponseError(n, w, m, originalQName, n.ttl, ctx, namespace, err)
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	if err := addServiceResponses(m, svcRegistrations, header, state.QType(), originalQName, n.ttl); err != nil {
 | ||
|  | 		return handleResponseError(n, w, m, originalQName, n.ttl, ctx, namespace, err)
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	err = w.WriteMsg(m)
 | ||
|  | 	requestSuccessCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
 | ||
|  | 	return dns.RcodeSuccess, err
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func processQName(qname, zone string) (string, string, error) {
 | ||
|  | 	original := dns.Fqdn(qname)
 | ||
|  | 	base, err := dnsutil.TrimZone(original, dns.Fqdn(zone))
 | ||
|  | 	return base, original, err
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func extractNamespaceAndService(qname string) (string, string, error) {
 | ||
|  | 	qnameSplit := dns.SplitDomainName(qname)
 | ||
|  | 	if len(qnameSplit) < 2 {
 | ||
|  | 		return "", "", fmt.Errorf("invalid query name")
 | ||
|  | 	}
 | ||
|  | 	return qnameSplit[1], qnameSplit[0], nil
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func initializeMessage(state request.Request, ttl uint32) (*dns.Msg, dns.RR_Header) {
 | ||
|  | 	m := new(dns.Msg)
 | ||
|  | 	m.SetReply(state.Req)
 | ||
|  | 	m.Authoritative, m.Compress, m.Rcode = true, true, dns.RcodeSuccess
 | ||
|  | 
 | ||
|  | 	header := dns.RR_Header{
 | ||
|  | 		Name:   state.QName(),
 | ||
|  | 		Rrtype: state.QType(),
 | ||
|  | 		Class:  dns.ClassINET,
 | ||
|  | 		Ttl:    ttl,
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	return m, header
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func fetchServiceRegistrations(n Nomad, serviceName, namespace string) ([]*api.ServiceRegistration, *api.QueryMeta, error) {
 | ||
|  | 	log.Debugf("Looking up record for svc: %s namespace: %s", serviceName, namespace)
 | ||
|  | 	nc, err := n.getClient()
 | ||
|  | 	if err != nil {
 | ||
|  | 		return nil, nil, err
 | ||
|  | 	}
 | ||
|  | 	return nc.Services().Get(serviceName, (&api.QueryOptions{Namespace: namespace}))
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func handleServiceLookupError(w dns.ResponseWriter, m *dns.Msg, ctx context.Context, namespace string) (int, error) {
 | ||
|  | 	m.Rcode = dns.RcodeSuccess
 | ||
|  | 	err := w.WriteMsg(m)
 | ||
|  | 	requestFailedCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
 | ||
|  | 	return dns.RcodeServerFailure, err
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func addServiceResponses(m *dns.Msg, svcRegistrations []*api.ServiceRegistration, header dns.RR_Header, qtype uint16, originalQName string, ttl uint32) error {
 | ||
|  | 	for _, s := range svcRegistrations {
 | ||
|  | 		addr := net.ParseIP(s.Address)
 | ||
|  | 		if addr == nil {
 | ||
|  | 			return fmt.Errorf("error parsing IP address")
 | ||
|  | 		}
 | ||
|  | 
 | ||
|  | 		switch qtype {
 | ||
|  | 		case dns.TypeA:
 | ||
|  | 			if addr.To4() == nil {
 | ||
|  | 				continue
 | ||
|  | 			}
 | ||
|  | 			addARecord(m, header, addr)
 | ||
|  | 		case dns.TypeAAAA:
 | ||
|  | 			if addr.To4() != nil {
 | ||
|  | 				continue
 | ||
|  | 			}
 | ||
|  | 			addAAAARecord(m, header, addr)
 | ||
|  | 		case dns.TypeSRV:
 | ||
|  | 			err := addSRVRecord(m, s, header, originalQName, addr, ttl)
 | ||
|  | 			if err != nil {
 | ||
|  | 				return err
 | ||
|  | 			}
 | ||
|  | 		default:
 | ||
|  | 			m.Rcode = dns.RcodeNotImplemented
 | ||
|  | 			return fmt.Errorf("query type not implemented")
 | ||
|  | 		}
 | ||
|  | 	}
 | ||
|  | 	return nil
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func handleResponseError(n Nomad, w dns.ResponseWriter, m *dns.Msg, originalQName string, ttl uint32, ctx context.Context, namespace string, err error) (int, error) {
 | ||
|  | 	m.Rcode = dns.RcodeNameError
 | ||
|  | 	m.Answer = append(m.Answer, createSOARecord(originalQName, ttl, n.Zone))
 | ||
|  | 
 | ||
|  | 	if writeErr := w.WriteMsg(m); writeErr != nil {
 | ||
|  | 		return dns.RcodeServerFailure, fmt.Errorf("write message error: %w", writeErr)
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	requestFailedCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
 | ||
|  | 
 | ||
|  | 	return dns.RcodeSuccess, err
 | ||
|  | }
 |