Files
coredns/plugin/nomad/nomad.go
Olli Janatuinen 83ce0baeac plugin/nomad: Add a Nomad plugin (#7467)
Signed-off-by: Olli Janatuinen <olli.janatuinen@gmail.com>
2025-09-30 09:35:32 -07:00

159 lines
4.3 KiB
Go

package nomad
import (
"context"
"fmt"
"net"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/hashicorp/nomad/api"
"github.com/miekg/dns"
)
const pluginName = "nomad"
var (
log = clog.NewWithPlugin(pluginName)
defaultTTL = 30
)
type Nomad struct {
Next plugin.Handler
ttl uint32
Zone string
clients []*api.Client
current int
}
func (n *Nomad) Name() string {
return pluginName
}
func (n Nomad) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname, originalQName, err := processQName(state.Name(), n.Zone)
if err != nil {
return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
}
namespace, serviceName, err := extractNamespaceAndService(qname)
if err != nil {
return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
}
m, header := initializeMessage(state, n.ttl)
svcRegistrations, _, err := fetchServiceRegistrations(n, serviceName, namespace)
if err != nil {
log.Warning(err)
return handleServiceLookupError(w, m, ctx, namespace)
}
if len(svcRegistrations) == 0 {
return handleResponseError(n, w, m, originalQName, n.ttl, ctx, namespace, err)
}
if err := addServiceResponses(m, svcRegistrations, header, state.QType(), originalQName, n.ttl); err != nil {
return handleResponseError(n, w, m, originalQName, n.ttl, ctx, namespace, err)
}
err = w.WriteMsg(m)
requestSuccessCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
return dns.RcodeSuccess, err
}
func processQName(qname, zone string) (string, string, error) {
original := dns.Fqdn(qname)
base, err := dnsutil.TrimZone(original, dns.Fqdn(zone))
return base, original, err
}
func extractNamespaceAndService(qname string) (string, string, error) {
qnameSplit := dns.SplitDomainName(qname)
if len(qnameSplit) < 2 {
return "", "", fmt.Errorf("invalid query name")
}
return qnameSplit[1], qnameSplit[0], nil
}
func initializeMessage(state request.Request, ttl uint32) (*dns.Msg, dns.RR_Header) {
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative, m.Compress, m.Rcode = true, true, dns.RcodeSuccess
header := dns.RR_Header{
Name: state.QName(),
Rrtype: state.QType(),
Class: dns.ClassINET,
Ttl: ttl,
}
return m, header
}
func fetchServiceRegistrations(n Nomad, serviceName, namespace string) ([]*api.ServiceRegistration, *api.QueryMeta, error) {
log.Debugf("Looking up record for svc: %s namespace: %s", serviceName, namespace)
nc, err := n.getClient()
if err != nil {
return nil, nil, err
}
return nc.Services().Get(serviceName, (&api.QueryOptions{Namespace: namespace}))
}
func handleServiceLookupError(w dns.ResponseWriter, m *dns.Msg, ctx context.Context, namespace string) (int, error) {
m.Rcode = dns.RcodeSuccess
err := w.WriteMsg(m)
requestFailedCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
return dns.RcodeServerFailure, err
}
func addServiceResponses(m *dns.Msg, svcRegistrations []*api.ServiceRegistration, header dns.RR_Header, qtype uint16, originalQName string, ttl uint32) error {
for _, s := range svcRegistrations {
addr := net.ParseIP(s.Address)
if addr == nil {
return fmt.Errorf("error parsing IP address")
}
switch qtype {
case dns.TypeA:
if addr.To4() == nil {
continue
}
addARecord(m, header, addr)
case dns.TypeAAAA:
if addr.To4() != nil {
continue
}
addAAAARecord(m, header, addr)
case dns.TypeSRV:
err := addSRVRecord(m, s, header, originalQName, addr, ttl)
if err != nil {
return err
}
default:
m.Rcode = dns.RcodeNotImplemented
return fmt.Errorf("query type not implemented")
}
}
return nil
}
func handleResponseError(n Nomad, w dns.ResponseWriter, m *dns.Msg, originalQName string, ttl uint32, ctx context.Context, namespace string, err error) (int, error) {
m.Rcode = dns.RcodeNameError
m.Answer = append(m.Answer, createSOARecord(originalQName, ttl, n.Zone))
if writeErr := w.WriteMsg(m); writeErr != nil {
return dns.RcodeServerFailure, fmt.Errorf("write message error: %w", writeErr)
}
requestFailedCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
return dns.RcodeSuccess, err
}