From 63ef6d3d55663d3500d96a9e414aa9588f196f0f Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Sun, 19 Jan 2020 08:30:13 +0100 Subject: [PATCH] Return all records for SRV queries Return all SRV records and assume the client is smart enough to make the call. Signed-off-by: Miek Gieben --- plugin/traffic/README.md | 14 ++--- plugin/traffic/traffic.go | 93 +++++++++++++++++++++++++++----- plugin/traffic/traffic_test.go | 71 ++++++++++++++++++++++-- plugin/traffic/xds/assignment.go | 46 +++++++++++----- plugin/traffic/xds/client.go | 13 +++-- 5 files changed, 198 insertions(+), 39 deletions(-) diff --git a/plugin/traffic/README.md b/plugin/traffic/README.md index ba2320773..c08d83478 100644 --- a/plugin/traffic/README.md +++ b/plugin/traffic/README.md @@ -24,13 +24,15 @@ endpoints need to be drained from it. discovered every 10 seconds. The plugin hands out responses that adhere to these assignments. Only endpoints that are *healthy* are handed out. -Each DNS response contains a single IP address (or SRV record) that's considered the best one. -*Traffic* will load balance A, AAAA and SRV queries. The TTL on these answer is set to 5s. It will -only return successful responses either with an answer or otherwise a NODATA response. Queries for -non-existent clusters get a NXDOMAIN, where the minimal TTL is also set to 5s. +For A and AAAA queries each DNS response contains a single IP address that's considered the best +one. The TTL on these answer is set to 5s. It will only return successful responses either with an +answer or otherwise a NODATA response. Queries for non-existent clusters get a NXDOMAIN, where the +minimal TTL is also set to 5s. -When an SRV record is returned an endpoint DNS name is synthesized `endpoint-0..` that -carries the IP address. Querying for these synthesized names works as well. +For SRV queries all healthy backends will be returned - assuming the client doing the query is smart +enough to select the best one. When SRV records are returned, the endpoint DNS names are synthesized +`endpoint-..` that carries the IP address. Querying for these synthesized names +works as well. The *traffic* plugin has no notion of draining, drop overload and anything that advanced, *it just acts upon assignments*. This is means that if a endpoint goes down and *traffic* has not seen a new diff --git a/plugin/traffic/traffic.go b/plugin/traffic/traffic.go index 2377e4c56..f9d7c8f4a 100644 --- a/plugin/traffic/traffic.go +++ b/plugin/traffic/traffic.go @@ -2,6 +2,8 @@ package traffic import ( "context" + "fmt" + "strconv" "strings" "time" @@ -34,11 +36,12 @@ func (t *Traffic) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg break } } + m := new(dns.Msg) m.SetReply(r) m.Authoritative = true - addr, port, ok := t.c.Select(cluster) + sockaddr, ok := t.c.Select(cluster) if !ok { // ok the cluster (which has potentially extra labels), doesn't exist, but we may have a query for endpoint-0.. // check if we have 2 labels and that the first equals endpoint-0. @@ -49,25 +52,96 @@ func (t *Traffic) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg return 0, nil } labels := dns.SplitDomainName(cluster) - if strings.Compare(labels[0], "endpoint-0") == 0 { + if strings.HasPrefix(labels[0], "endpoint-") { // recheck if the cluster exist. - addr, port, ok = t.c.Select(labels[1]) + cluster = labels[1] + sockaddr, ok = t.c.Select(cluster) if !ok { m.Ns = soa(state.Zone) m.Rcode = dns.RcodeNameError w.WriteMsg(m) return 0, nil } + return t.ServeEndpoint(ctx, state, labels[0], cluster) } } - if addr == nil { + if sockaddr == nil { log.Debugf("No (healthy) endpoints found for %q", cluster) m.Ns = soa(state.Zone) w.WriteMsg(m) return 0, nil } + switch state.QType() { + case dns.TypeA: + if sockaddr.Address().To4() == nil { // it's an IPv6 address, return nodata in that case. + m.Ns = soa(state.Zone) + break + } + m.Answer = []dns.RR{&dns.A{Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5}, A: sockaddr.Address()}} + + case dns.TypeAAAA: + if sockaddr.Address().To4() != nil { // it's an IPv4 address, return nodata in that case. + m.Ns = soa(state.Zone) + break + } + m.Answer = []dns.RR{&dns.AAAA{Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 5}, AAAA: sockaddr.Address()}} + case dns.TypeSRV: + sockaddrs, _ := t.c.All(cluster) + for i, sa := range sockaddrs { + target := fmt.Sprintf("endpoint-%d.%s.%s", i, cluster, state.Zone) + + m.Answer = append(m.Answer, &dns.SRV{ + Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5}, + Priority: 100, Weight: 100, Port: sa.Port(), Target: target}) + + if sa.Address().To4() == nil { + m.Extra = []dns.RR{&dns.AAAA{Hdr: dns.RR_Header{Name: target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5}, AAAA: sa.Address()}} + } else { + m.Extra = []dns.RR{&dns.A{Hdr: dns.RR_Header{Name: target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5}, A: sa.Address()}} + } + } + default: + m.Ns = soa(state.Zone) + } + + w.WriteMsg(m) + return 0, nil +} + +func (t *Traffic) ServeEndpoint(ctx context.Context, state request.Request, endpoint, cluster string) (int, error) { + m := new(dns.Msg) + m.SetReply(state.Req) + m.Authoritative = true + + // get endpoint number + i := strings.Index(endpoint, "-") + if i == -1 || i == len(endpoint) { + m.Ns = soa(state.Zone) + m.Rcode = dns.RcodeNameError + state.W.WriteMsg(m) + return 0, nil + } + + end := endpoint[i+1:] // +1 to remove '-' + nr, err := strconv.Atoi(end) + if err != nil { + m.Ns = soa(state.Zone) + m.Rcode = dns.RcodeNameError + state.W.WriteMsg(m) + return 0, nil + } + + sockaddrs, _ := t.c.All(cluster) + if len(sockaddrs) < nr { + m.Ns = soa(state.Zone) + m.Rcode = dns.RcodeNameError + state.W.WriteMsg(m) + return 0, nil + } + + addr := sockaddrs[nr].Address() switch state.QType() { case dns.TypeA: if addr.To4() == nil { // it's an IPv6 address, return nodata in that case. @@ -82,20 +156,11 @@ func (t *Traffic) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg break } m.Answer = []dns.RR{&dns.AAAA{Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 5}, AAAA: addr}} - case dns.TypeSRV: - target := dnsutil.Join("endpoint-0", cluster) + state.Zone - m.Answer = []dns.RR{&dns.SRV{Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5}, - Priority: 100, Weight: 100, Port: port, Target: target}} - if addr.To4() == nil { - m.Extra = []dns.RR{&dns.AAAA{Hdr: dns.RR_Header{Name: target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5}, AAAA: addr}} - } else { - m.Extra = []dns.RR{&dns.A{Hdr: dns.RR_Header{Name: target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5}, A: addr}} - } default: m.Ns = soa(state.Zone) } - w.WriteMsg(m) + state.W.WriteMsg(m) return 0, nil } diff --git a/plugin/traffic/traffic_test.go b/plugin/traffic/traffic_test.go index 8a2ec73c9..43b5ad605 100644 --- a/plugin/traffic/traffic_test.go +++ b/plugin/traffic/traffic_test.go @@ -43,7 +43,7 @@ func TestTraffic(t *testing.T) { cla: &xdspb.ClusterLoadAssignment{}, cluster: "does-not-exist", qtype: dns.TypeA, rcode: dns.RcodeNameError, ns: true, }, - // healthy backend + // healthy endpoint { cla: &xdspb.ClusterLoadAssignment{ ClusterName: "web", @@ -58,7 +58,7 @@ func TestTraffic(t *testing.T) { }, cluster: "web", qtype: dns.TypeAAAA, rcode: dns.RcodeSuccess, answer: "::1", }, - // unknown backend + // unknown endpoint { cla: &xdspb.ClusterLoadAssignment{ ClusterName: "web", @@ -66,7 +66,7 @@ func TestTraffic(t *testing.T) { }, cluster: "web", qtype: dns.TypeA, rcode: dns.RcodeSuccess, ns: true, }, - // unknown backend and healthy backend + // unknown endpoint and healthy endpoint { cla: &xdspb.ClusterLoadAssignment{ ClusterName: "web", @@ -77,7 +77,7 @@ func TestTraffic(t *testing.T) { }, cluster: "web", qtype: dns.TypeA, rcode: dns.RcodeSuccess, answer: "127.0.0.2", }, - // SRV query healthy backend + // SRV query healthy endpoint { cla: &xdspb.ClusterLoadAssignment{ ClusterName: "web", @@ -97,6 +97,17 @@ func TestTraffic(t *testing.T) { }, cluster: "endpoint-0.web", qtype: dns.TypeA, rcode: dns.RcodeSuccess, answer: "127.0.0.2", }, + // A query for endpoint-1. + { + cla: &xdspb.ClusterLoadAssignment{ + ClusterName: "web", + Endpoints: endpoints([]EndpointHealth{ + {"127.0.0.2", 18008, corepb.HealthStatus_HEALTHY}, + {"127.0.0.3", 18008, corepb.HealthStatus_HEALTHY}, + }), + }, + cluster: "endpoint-1.web", qtype: dns.TypeA, rcode: dns.RcodeSuccess, answer: "127.0.0.3", + }, } ctx := context.TODO() @@ -142,6 +153,58 @@ func TestTraffic(t *testing.T) { } } +func TestTrafficSRV(t *testing.T) { + c, err := xds.New("127.0.0.1:0", "test-id", grpc.WithInsecure()) + if err != nil { + t.Fatal(err) + } + tr := &Traffic{c: c, origins: []string{"lb.example.org."}} + + tests := []struct { + cla *xdspb.ClusterLoadAssignment + cluster string + qtype uint16 + rcode int + answer int // number of records in answer section + }{ + // SRV query healthy endpoint + { + cla: &xdspb.ClusterLoadAssignment{ + ClusterName: "web", + Endpoints: endpoints([]EndpointHealth{ + {"127.0.0.2", 18008, corepb.HealthStatus_HEALTHY}, + {"127.0.0.3", 18008, corepb.HealthStatus_HEALTHY}, + }), + }, + cluster: "web", qtype: dns.TypeSRV, rcode: dns.RcodeSuccess, answer: 2, + }, + } + + ctx := context.TODO() + + for i, tc := range tests { + a := xds.NewAssignment() + a.SetClusterLoadAssignment("web", tc.cla) // web is our cluster + c.SetAssignments(a) + + m := new(dns.Msg) + cl := dnsutil.Join(tc.cluster, tr.origins[0]) + m.SetQuestion(cl, tc.qtype) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := tr.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %q", i, err) + } + if rec.Msg.Rcode != tc.rcode { + t.Errorf("Test %d: Expected no rcode %d, but got %d", i, tc.rcode, rec.Msg.Rcode) + } + if tc.answer != len(rec.Msg.Answer) { + t.Fatalf("Test %d: Expected %d answers, but got %d", i, tc.answer, len(rec.Msg.Answer)) + } + } +} + type EndpointHealth struct { Address string Port uint16 diff --git a/plugin/traffic/xds/assignment.go b/plugin/traffic/xds/assignment.go index 4bea5cb0f..299b4a5e0 100644 --- a/plugin/traffic/xds/assignment.go +++ b/plugin/traffic/xds/assignment.go @@ -9,6 +9,13 @@ import ( corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" ) +type SocketAddress struct { + *corepb.SocketAddress +} + +func (s *SocketAddress) Address() net.IP { return net.ParseIP(s.GetAddress()) } +func (s *SocketAddress) Port() uint16 { return uint16(s.GetPortValue()) } + type assignment struct { mu sync.RWMutex cla map[string]*xdspb.ClusterLoadAssignment @@ -59,11 +66,11 @@ func (a *assignment) clusters() []string { return clusters } -// Select selects a backend from cluster load assignments, using weighted random selection. It only selects backends that are reporting healthy. -func (a *assignment) Select(cluster string) (ip net.IP, port uint16, exists bool) { +// Select selects a endpoint from cluster load assignments, using weighted random selection. It only selects endpoints that are reporting healthy. +func (a *assignment) Select(cluster string) (*SocketAddress, bool) { cla := a.ClusterLoadAssignment(cluster) if cla == nil { - return nil, 0, false + return nil, false } total := 0 @@ -78,7 +85,7 @@ func (a *assignment) Select(cluster string) (ip net.IP, port uint16, exists bool } } if healthy == 0 { - return nil, 0, true + return nil, true } if total == 0 { @@ -91,14 +98,12 @@ func (a *assignment) Select(cluster string) (ip net.IP, port uint16, exists bool continue } if r == i { - addr := net.ParseIP(lb.GetEndpoint().GetAddress().GetSocketAddress().GetAddress()) - port := uint16(lb.GetEndpoint().GetAddress().GetSocketAddress().GetPortValue()) - return addr, port, true + return &SocketAddress{lb.GetEndpoint().GetAddress().GetSocketAddress()}, true } i++ } } - return nil, 0, true + return nil, true } r := rand.Intn(total) + 1 @@ -109,11 +114,28 @@ func (a *assignment) Select(cluster string) (ip net.IP, port uint16, exists bool } r -= int(lb.GetLoadBalancingWeight().GetValue()) if r <= 0 { - addr := net.ParseIP(lb.GetEndpoint().GetAddress().GetSocketAddress().GetAddress()) - port := uint16(lb.GetEndpoint().GetAddress().GetSocketAddress().GetPortValue()) - return addr, port, true + return &SocketAddress{lb.GetEndpoint().GetAddress().GetSocketAddress()}, true } } } - return nil, 0, true + return nil, true +} + +// All returns all healthy endpoints. +func (a *assignment) All(cluster string) ([]*SocketAddress, bool) { + cla := a.ClusterLoadAssignment(cluster) + if cla == nil { + return nil, false + } + + sa := []*SocketAddress{} + for _, ep := range cla.Endpoints { + for _, lb := range ep.GetLbEndpoints() { + if lb.GetHealthStatus() != corepb.HealthStatus_HEALTHY { + continue + } + sa = append(sa, &SocketAddress{lb.GetEndpoint().GetAddress().GetSocketAddress()}) + } + } + return sa, true } diff --git a/plugin/traffic/xds/client.go b/plugin/traffic/xds/client.go index 9e08631e7..5302f1ffc 100644 --- a/plugin/traffic/xds/client.go +++ b/plugin/traffic/xds/client.go @@ -23,7 +23,6 @@ package xds import ( "context" "fmt" - "net" "os" "sync" "time" @@ -228,9 +227,17 @@ func (c *Client) receive(stream adsStream) error { // Select returns an address that is deemed to be the correct one for this cluster. The returned // boolean indicates if the cluster exists. -func (c *Client) Select(cluster string) (net.IP, uint16, bool) { +func (c *Client) Select(cluster string) (*SocketAddress, bool) { if cluster == "" { - return nil, 0, false + return nil, false } return c.assignments.Select(cluster) } + +// All returns all endpoints. +func (c *Client) All(cluster string) ([]*SocketAddress, bool) { + if cluster == "" { + return nil, false + } + return c.assignments.All(cluster) +}