diff --git a/plugin/traffic/traffic_test.go b/plugin/traffic/traffic_test.go index 75e8c1e53..aa98e452c 100644 --- a/plugin/traffic/traffic_test.go +++ b/plugin/traffic/traffic_test.go @@ -153,6 +153,63 @@ func TestTraffic(t *testing.T) { } } +func TestTrafficLocality(t *testing.T) { + c, err := xds.New("127.0.0.1:0", "test-id", grpc.WithInsecure()) + if err != nil { + t.Fatal(err) + } + tr := &Traffic{c: c, origins: []string{"lb.example.org."}} + + tests := []struct { + cla *xdspb.ClusterLoadAssignment + cluster string + loc xds.Locality // where we run + qtype uint16 + rcode int + answer int // number of records in answer section + }{ + { + cla: &xdspb.ClusterLoadAssignment{ + ClusterName: "web", + Endpoints: append(endpointsWithLocality([]EndpointHealth{ + {"127.0.0.1", 18008, corepb.HealthStatus_HEALTHY}, + {"127.0.0.2", 18008, corepb.HealthStatus_HEALTHY}}, + xds.Locality{Region: "us"}), + endpointsWithLocality([]EndpointHealth{ + {"127.0.1.1", 18008, corepb.HealthStatus_HEALTHY}, + {"127.0.1.2", 18008, corepb.HealthStatus_HEALTHY}}, + xds.Locality{Region: "eu"})...), + }, + cluster: "web", qtype: dns.TypeA, rcode: dns.RcodeSuccess, answer: 2, + loc: xds.Locality{Region: "eu"}, + }, + } + + ctx := context.TODO() + + for i, tc := range tests { + a := xds.NewAssignment() + a.SetClusterLoadAssignment("web", tc.cla) // web is our cluster + c.SetAssignments(a) + + m := new(dns.Msg) + cl := dnsutil.Join(tc.cluster, tr.origins[0]) + m.SetQuestion(cl, tc.qtype) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := tr.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %q", i, err) + } + if rec.Msg.Rcode != tc.rcode { + t.Errorf("Test %d: Expected no rcode %d, but got %d", i, tc.rcode, rec.Msg.Rcode) + } + if tc.answer != len(rec.Msg.Answer) { + t.Fatalf("Test %d: Expected %d answers, but got %d", i, tc.answer, len(rec.Msg.Answer)) + } + } +} + func TestTrafficSRV(t *testing.T) { c, err := xds.New("127.0.0.1:0", "test-id", grpc.WithInsecure()) if err != nil {