mirror of
https://github.com/coredns/coredns.git
synced 2025-10-26 15:54:16 -04:00
159 lines
4.3 KiB
Go
159 lines
4.3 KiB
Go
package nomad
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
|
|
"github.com/coredns/coredns/plugin"
|
|
"github.com/coredns/coredns/plugin/metrics"
|
|
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
|
clog "github.com/coredns/coredns/plugin/pkg/log"
|
|
"github.com/coredns/coredns/request"
|
|
|
|
"github.com/hashicorp/nomad/api"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
const pluginName = "nomad"
|
|
|
|
var (
|
|
log = clog.NewWithPlugin(pluginName)
|
|
defaultTTL = 30
|
|
)
|
|
|
|
type Nomad struct {
|
|
Next plugin.Handler
|
|
|
|
ttl uint32
|
|
Zone string
|
|
clients []*api.Client
|
|
current int
|
|
}
|
|
|
|
func (n *Nomad) Name() string {
|
|
return pluginName
|
|
}
|
|
|
|
func (n Nomad) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
|
state := request.Request{W: w, Req: r}
|
|
qname, originalQName, err := processQName(state.Name(), n.Zone)
|
|
if err != nil {
|
|
return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
|
|
}
|
|
|
|
namespace, serviceName, err := extractNamespaceAndService(qname)
|
|
if err != nil {
|
|
return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
|
|
}
|
|
|
|
m, header := initializeMessage(state, n.ttl)
|
|
|
|
svcRegistrations, _, err := fetchServiceRegistrations(n, serviceName, namespace)
|
|
if err != nil {
|
|
log.Warning(err)
|
|
return handleServiceLookupError(w, m, ctx, namespace)
|
|
}
|
|
|
|
if len(svcRegistrations) == 0 {
|
|
return handleResponseError(n, w, m, originalQName, n.ttl, ctx, namespace, err)
|
|
}
|
|
|
|
if err := addServiceResponses(m, svcRegistrations, header, state.QType(), originalQName, n.ttl); err != nil {
|
|
return handleResponseError(n, w, m, originalQName, n.ttl, ctx, namespace, err)
|
|
}
|
|
|
|
err = w.WriteMsg(m)
|
|
requestSuccessCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
|
|
return dns.RcodeSuccess, err
|
|
}
|
|
|
|
func processQName(qname, zone string) (string, string, error) {
|
|
original := dns.Fqdn(qname)
|
|
base, err := dnsutil.TrimZone(original, dns.Fqdn(zone))
|
|
return base, original, err
|
|
}
|
|
|
|
func extractNamespaceAndService(qname string) (string, string, error) {
|
|
qnameSplit := dns.SplitDomainName(qname)
|
|
if len(qnameSplit) < 2 {
|
|
return "", "", fmt.Errorf("invalid query name")
|
|
}
|
|
return qnameSplit[1], qnameSplit[0], nil
|
|
}
|
|
|
|
func initializeMessage(state request.Request, ttl uint32) (*dns.Msg, dns.RR_Header) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(state.Req)
|
|
m.Authoritative, m.Compress, m.Rcode = true, true, dns.RcodeSuccess
|
|
|
|
header := dns.RR_Header{
|
|
Name: state.QName(),
|
|
Rrtype: state.QType(),
|
|
Class: dns.ClassINET,
|
|
Ttl: ttl,
|
|
}
|
|
|
|
return m, header
|
|
}
|
|
|
|
func fetchServiceRegistrations(n Nomad, serviceName, namespace string) ([]*api.ServiceRegistration, *api.QueryMeta, error) {
|
|
log.Debugf("Looking up record for svc: %s namespace: %s", serviceName, namespace)
|
|
nc, err := n.getClient()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return nc.Services().Get(serviceName, (&api.QueryOptions{Namespace: namespace}))
|
|
}
|
|
|
|
func handleServiceLookupError(w dns.ResponseWriter, m *dns.Msg, ctx context.Context, namespace string) (int, error) {
|
|
m.Rcode = dns.RcodeSuccess
|
|
err := w.WriteMsg(m)
|
|
requestFailedCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
|
|
return dns.RcodeServerFailure, err
|
|
}
|
|
|
|
func addServiceResponses(m *dns.Msg, svcRegistrations []*api.ServiceRegistration, header dns.RR_Header, qtype uint16, originalQName string, ttl uint32) error {
|
|
for _, s := range svcRegistrations {
|
|
addr := net.ParseIP(s.Address)
|
|
if addr == nil {
|
|
return fmt.Errorf("error parsing IP address")
|
|
}
|
|
|
|
switch qtype {
|
|
case dns.TypeA:
|
|
if addr.To4() == nil {
|
|
continue
|
|
}
|
|
addARecord(m, header, addr)
|
|
case dns.TypeAAAA:
|
|
if addr.To4() != nil {
|
|
continue
|
|
}
|
|
addAAAARecord(m, header, addr)
|
|
case dns.TypeSRV:
|
|
err := addSRVRecord(m, s, header, originalQName, addr, ttl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
m.Rcode = dns.RcodeNotImplemented
|
|
return fmt.Errorf("query type not implemented")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func handleResponseError(n Nomad, w dns.ResponseWriter, m *dns.Msg, originalQName string, ttl uint32, ctx context.Context, namespace string, err error) (int, error) {
|
|
m.Rcode = dns.RcodeNameError
|
|
m.Answer = append(m.Answer, createSOARecord(originalQName, ttl, n.Zone))
|
|
|
|
if writeErr := w.WriteMsg(m); writeErr != nil {
|
|
return dns.RcodeServerFailure, fmt.Errorf("write message error: %w", writeErr)
|
|
}
|
|
|
|
requestFailedCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
|
|
|
|
return dns.RcodeSuccess, err
|
|
}
|