diff --git a/plugin/rewrite/README.md b/plugin/rewrite/README.md index 826c3ec1f..9740f9d5a 100644 --- a/plugin/rewrite/README.md +++ b/plugin/rewrite/README.md @@ -463,8 +463,30 @@ rewrite edns0 subnet set 24 56 * If the query's source IP address is an IPv4 address, the first 24 bits in the IP will be the network subnet. * If the query's source IP address is an IPv6 address, the first 56 bits in the IP will be the network subnet. +### EDNS0 Revert -### CNAME Field Rewrites +Using the `revert` flag, you can revert the changes made by this rewrite call, so the response will not contain this option. + +This example sets option, but response will not contain it +~~~ corefile +. { + rewrite edns0 local set 0xffee abcd revert +} +~~~ + +If only some calls contain the `revert` flag, then the value in the response will be changed to the previous one. So, in this example, the response will contain `abcd` data at `0xffee` +~~~ corefile +. { + rewrite continue { + edns0 local set 0xffee abcd + } + + rewrite edns0 local replace 0xffee bcde revert +} +~~~ + + +## CNAME Field Rewrites There might be a scenario where you want the `CNAME` target of the response to be rewritten. You can do this by using the `CNAME` field rewrite. This will generate new answer records according to the new `CNAME` target. diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go index 85146c7ec..44e457919 100644 --- a/plugin/rewrite/edns0.go +++ b/plugin/rewrite/edns0.go @@ -22,6 +22,7 @@ type edns0LocalRule struct { action string code uint16 data []byte + revert bool } // edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable. @@ -30,12 +31,43 @@ type edns0VariableRule struct { action string code uint16 variable string + revert bool } // ends0NsidRule is a rewrite rule for EDNS0_NSID options. type edns0NsidRule struct { mode string action string + revert bool +} + +type edns0SetResponseRule struct { + code uint16 +} + +func (r *edns0SetResponseRule) RewriteResponse(res *dns.Msg, _ dns.RR) { + ednsOpt := res.IsEdns0() + for idx, opt := range ednsOpt.Option { + if opt.Option() == r.code { + ednsOpt.Option = append(ednsOpt.Option[:idx], ednsOpt.Option[idx+1:]...) + return + } + } +} + +type edns0ReplaceResponseRule[T dns.EDNS0] struct { + code uint16 + source T +} + +func (r *edns0ReplaceResponseRule[T]) RewriteResponse(res *dns.Msg, _ dns.RR) { + ednsOpt := res.IsEdns0() + for idx, opt := range ednsOpt.Option { + if opt.Option() == r.code { + ednsOpt.Option[idx] = r.source + return + } + } } // setupEdns0Opt will retrieve the EDNS0 OPT or create it if it does not exist. @@ -52,11 +84,17 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT { func (rule *edns0NsidRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { o := setupEdns0Opt(state.Req) + var resp ResponseRules + for _, s := range o.Option { if e, ok := s.(*dns.EDNS0_NSID); ok { if rule.action == Replace || rule.action == Set { + if rule.revert { + old := *e + resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_NSID]{code: e.Code, source: &old}) + } e.Nsid = "" // make sure it is empty for request - return nil, RewriteDone + return resp, RewriteDone } } } @@ -64,7 +102,10 @@ func (rule *edns0NsidRule) Rewrite(ctx context.Context, state request.Request) ( // add option if not found if rule.action == Append || rule.action == Set { o.Option = append(o.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}) - return nil, RewriteDone + if rule.revert { + resp = append(resp, &edns0SetResponseRule{code: dns.EDNS0NSID}) + } + return resp, RewriteDone } return nil, RewriteIgnored @@ -77,12 +118,18 @@ func (rule *edns0NsidRule) Mode() string { return rule.mode } func (rule *edns0LocalRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { o := setupEdns0Opt(state.Req) + var resp ResponseRules + for _, s := range o.Option { if e, ok := s.(*dns.EDNS0_LOCAL); ok { if rule.code == e.Code { if rule.action == Replace || rule.action == Set { + if rule.revert { + old := *e + resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_LOCAL]{code: rule.code, source: &old}) + } e.Data = rule.data - return nil, RewriteDone + return resp, RewriteDone } } } @@ -91,7 +138,10 @@ func (rule *edns0LocalRule) Rewrite(ctx context.Context, state request.Request) // add option if not found if rule.action == Append || rule.action == Set { o.Option = append(o.Option, &dns.EDNS0_LOCAL{Code: rule.code, Data: rule.data}) - return nil, RewriteDone + if rule.revert { + resp = append(resp, &edns0SetResponseRule{code: rule.code}) + } + return resp, RewriteDone } return nil, RewriteIgnored @@ -116,32 +166,39 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) { return nil, fmt.Errorf("invalid action: %q", action) } + // Extract "revert" parameter. + var revert bool + if args[len(args)-1] == "revert" { + revert = true + args = args[:len(args)-1] + } + switch ruleType { case "local": if len(args) != 4 { - return nil, fmt.Errorf("EDNS0 local rules require exactly three args") + return nil, fmt.Errorf("EDNS0 local rules require three or four args") } // Check for variable option. if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") { - return newEdns0VariableRule(mode, action, args[2], args[3]) + return newEdns0VariableRule(mode, action, args[2], args[3], revert) } - return newEdns0LocalRule(mode, action, args[2], args[3]) + return newEdns0LocalRule(mode, action, args[2], args[3], revert) case "nsid": if len(args) != 2 { - return nil, fmt.Errorf("EDNS0 NSID rules do not accept args") + return nil, fmt.Errorf("EDNS0 NSID rules can accept no more than one arg") } - return &edns0NsidRule{mode: mode, action: action}, nil + return &edns0NsidRule{mode: mode, action: action, revert: revert}, nil case "subnet": if len(args) != 4 { - return nil, fmt.Errorf("EDNS0 subnet rules require exactly three args") + return nil, fmt.Errorf("EDNS0 subnet rules require three or four args") } - return newEdns0SubnetRule(mode, action, args[2], args[3]) + return newEdns0SubnetRule(mode, action, args[2], args[3], revert) default: return nil, fmt.Errorf("invalid rule type %q", ruleType) } } -func newEdns0LocalRule(mode, action, code, data string) (*edns0LocalRule, error) { +func newEdns0LocalRule(mode, action, code, data string, revert bool) (*edns0LocalRule, error) { c, err := strconv.ParseUint(code, 0, 16) if err != nil { return nil, err @@ -158,11 +215,11 @@ func newEdns0LocalRule(mode, action, code, data string) (*edns0LocalRule, error) // Add this code to the ones the server supports. edns.SetSupportedOption(uint16(c)) - return &edns0LocalRule{mode: mode, action: action, code: uint16(c), data: decoded}, nil + return &edns0LocalRule{mode: mode, action: action, code: uint16(c), data: decoded, revert: revert}, nil } // newEdns0VariableRule creates an EDNS0 rule that handles variable substitution -func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRule, error) { +func newEdns0VariableRule(mode, action, code, variable string, revert bool) (*edns0VariableRule, error) { c, err := strconv.ParseUint(code, 0, 16) if err != nil { return nil, err @@ -175,7 +232,7 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu // Add this code to the ones the server supports. edns.SetSupportedOption(uint16(c)) - return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil + return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable, revert: revert}, nil } // ruleData returns the data specified by the variable. @@ -221,13 +278,19 @@ func (rule *edns0VariableRule) Rewrite(ctx context.Context, state request.Reques return nil, RewriteIgnored } + var resp ResponseRules + o := setupEdns0Opt(state.Req) for _, s := range o.Option { if e, ok := s.(*dns.EDNS0_LOCAL); ok { if rule.code == e.Code { if rule.action == Replace || rule.action == Set { + if rule.revert { + old := *e + resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_LOCAL]{code: rule.code, source: &old}) + } e.Data = data - return nil, RewriteDone + return resp, RewriteDone } return nil, RewriteIgnored } @@ -237,7 +300,10 @@ func (rule *edns0VariableRule) Rewrite(ctx context.Context, state request.Reques // add option if not found if rule.action == Append || rule.action == Set { o.Option = append(o.Option, &dns.EDNS0_LOCAL{Code: rule.code, Data: data}) - return nil, RewriteDone + if rule.revert { + resp = append(resp, &edns0SetResponseRule{code: rule.code}) + } + return resp, RewriteDone } return nil, RewriteIgnored @@ -271,9 +337,10 @@ type edns0SubnetRule struct { v4BitMaskLen uint8 v6BitMaskLen uint8 action string + revert bool } -func newEdns0SubnetRule(mode, action, v4BitMaskLen, v6BitMaskLen string) (*edns0SubnetRule, error) { +func newEdns0SubnetRule(mode, action, v4BitMaskLen, v6BitMaskLen string, revert bool) (*edns0SubnetRule, error) { v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16) if err != nil { return nil, err @@ -293,7 +360,7 @@ func newEdns0SubnetRule(mode, action, v4BitMaskLen, v6BitMaskLen string) (*edns0 } return &edns0SubnetRule{mode: mode, action: action, - v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len)}, nil + v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len), revert: revert}, nil } // fillEcsData sets the subnet data into the ecs option @@ -326,11 +393,17 @@ func (rule *edns0SubnetRule) fillEcsData(state request.Request, ecs *dns.EDNS0_S func (rule *edns0SubnetRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { o := setupEdns0Opt(state.Req) + var resp ResponseRules + for _, s := range o.Option { if e, ok := s.(*dns.EDNS0_SUBNET); ok { if rule.action == Replace || rule.action == Set { + if rule.revert { + old := *e + resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_SUBNET]{code: e.Code, source: &old}) + } if rule.fillEcsData(state, e) == nil { - return nil, RewriteDone + return resp, RewriteDone } } return nil, RewriteIgnored @@ -342,7 +415,10 @@ func (rule *edns0SubnetRule) Rewrite(ctx context.Context, state request.Request) opt := &dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET} if rule.fillEcsData(state, opt) == nil { o.Option = append(o.Option, opt) - return nil, RewriteDone + if rule.revert { + resp = append(resp, &edns0SetResponseRule{code: dns.EDNS0SUBNET}) + } + return resp, RewriteDone } } diff --git a/plugin/rewrite/rewrite_test.go b/plugin/rewrite/rewrite_test.go index 03d4fff1d..74f596acf 100644 --- a/plugin/rewrite/rewrite_test.go +++ b/plugin/rewrite/rewrite_test.go @@ -70,14 +70,20 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "set", "0xffee"}, true, nil}, {[]string{"edns0", "local", "set", "65518", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, {[]string{"edns0", "local", "set", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "abcdefg", "revert"}, false, reflect.TypeOf(&edns0LocalRule{})}, {[]string{"edns0", "local", "append", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "abcdefg", "revert"}, false, reflect.TypeOf(&edns0LocalRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "abcdefg", "revert"}, false, reflect.TypeOf(&edns0LocalRule{})}, {[]string{"edns0", "local", "foo", "0xffee", "abcdefg"}, true, nil}, {[]string{"edns0", "local", "set", "0xffee", "0xabcdefg"}, true, nil}, {[]string{"edns0", "nsid", "set", "junk"}, true, nil}, {[]string{"edns0", "nsid", "set"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "set", "revert"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "append", "revert"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "replace", "revert"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "foo"}, true, nil}, {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil}, {[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, @@ -87,6 +93,7 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{server_port}", "revert"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil}, {[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, @@ -95,6 +102,7 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{server_port}", "revert"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil}, {[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, @@ -103,13 +111,18 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "replace", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{server_port}", "revert"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "subnet", "set", "-1", "56"}, true, nil}, {[]string{"edns0", "subnet", "set", "24", "-56"}, true, nil}, {[]string{"edns0", "subnet", "set", "33", "56"}, true, nil}, {[]string{"edns0", "subnet", "set", "24", "129"}, true, nil}, {[]string{"edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "set", "24", "56", "revert"}, false, reflect.TypeOf(&edns0SubnetRule{})}, {[]string{"edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "append", "24", "56", "72"}, true, nil}, + {[]string{"edns0", "subnet", "append", "24", "56", "revert"}, false, reflect.TypeOf(&edns0SubnetRule{})}, {[]string{"edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "replace", "24", "56", "revert"}, false, reflect.TypeOf(&edns0SubnetRule{})}, {[]string{"unknown-action", "name", "a.com", "b.com"}, true, nil}, {[]string{"stop", "name", "a.com", "b.com"}, false, reflect.TypeOf(&exactNameRule{})}, {[]string{"continue", "name", "a.com", "b.com"}, false, reflect.TypeOf(&exactNameRule{})}, @@ -387,30 +400,207 @@ func TestRewriteEDNS0Local(t *testing.T) { } } -func TestEdns0LocalMultiRule(t *testing.T) { - rules := []Rule{} - r, _ := newEdns0Rule("stop", "local", "replace", "0xffee", "abcdef") - rules = append(rules, r) - r, _ = newEdns0Rule("stop", "local", "set", "0xffee", "fedcba") - rules = append(rules, r) - - rw := Rewrite{ - Next: plugin.HandlerFunc(msgPrinter), - Rules: rules, - RevertPolicy: NoRevertPolicy(), - } - +func TestEdns0MultiRule(t *testing.T) { tests := []struct { - fromOpts []dns.EDNS0 - toOpts []dns.EDNS0 + rules [][]string + fromOpts []dns.EDNS0 + toOpts []dns.EDNS0 + revertPolicy RevertPolicy }{ + // Local. { + [][]string{ + {"stop", "local", "replace", "0xffee", "abcdef"}, + {"stop", "local", "set", "0xffee", "fedcba"}, + }, nil, []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("fedcba")}}, + NoRevertPolicy(), }, { + [][]string{ + {"stop", "local", "replace", "0xffee", "abcdef"}, + {"stop", "local", "set", "0xffee", "fedcba"}, + }, []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("abcdef")}}, + NoRevertPolicy(), + }, + // Local with "revert". + { + [][]string{ + {"stop", "local", "replace", "0xffee", "abcdef", "revert"}, + {"stop", "local", "set", "0xffee", "fedcba", "revert"}, + }, + nil, + []dns.EDNS0{}, + NewRevertPolicy(false, false), + }, + { + [][]string{ + {"stop", "local", "replace", "0xffee", "abcdef", "revert"}, + {"stop", "local", "set", "0xffee", "fedcba", "revert"}, + }, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + NewRevertPolicy(false, false), + }, + // Local variable. + { + [][]string{ + {"stop", "local", "replace", "0xffee", "{qname}"}, + {"stop", "local", "set", "0xffee", "{qtype}"}, + }, + nil, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x01}}}, + NoRevertPolicy(), + }, + { + [][]string{ + {"stop", "local", "replace", "0xffee", "{qname}"}, + {"stop", "local", "set", "0xffee", "{qtype}"}, + }, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("example.com.")}}, + NoRevertPolicy(), + }, + // Local variable with "revert". + { + [][]string{ + {"stop", "local", "replace", "0xffee", "{qname}", "revert"}, + {"stop", "local", "set", "0xffee", "{qtype}", "revert"}, + }, + nil, + []dns.EDNS0{}, + NewRevertPolicy(false, false), + }, + { + [][]string{ + {"stop", "local", "replace", "0xffee", "{qname}", "revert"}, + {"stop", "local", "set", "0xffee", "{qtype}", "revert"}, + }, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + NewRevertPolicy(false, false), + }, + // Nsid. + { + [][]string{ + {"stop", "nsid", "replace"}, + {"stop", "nsid", "set"}, + }, + nil, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + NoRevertPolicy(), + }, + { + [][]string{ + {"stop", "nsid", "replace"}, + {"stop", "nsid", "set"}, + }, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + NoRevertPolicy(), + }, + { + [][]string{ + {"stop", "nsid", "replace"}, + {"stop", "nsid", "set"}, + }, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + NoRevertPolicy(), + }, + // Nsid with "revert". + { + [][]string{ + {"stop", "nsid", "replace", "revert"}, + {"stop", "nsid", "set", "revert"}, + }, + nil, + []dns.EDNS0{}, + NewRevertPolicy(false, false), + }, + { + [][]string{ + {"stop", "nsid", "replace", "revert"}, + {"stop", "nsid", "set", "revert"}, + }, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + NewRevertPolicy(false, false), + }, + { + [][]string{ + {"stop", "nsid", "replace", "revert"}, + {"stop", "nsid", "set", "revert"}, + }, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + NewRevertPolicy(false, false), + }, + // Subnet. + { + [][]string{ + {"stop", "subnet", "replace", "32", "56"}, + {"stop", "subnet", "set", "0", "56"}, + }, + nil, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00}, + }}, + NoRevertPolicy(), + }, + { + [][]string{ + {"stop", "subnet", "replace", "32", "56"}, + {"stop", "subnet", "set", "0", "56"}, + }, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00}, + }}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x20, + SourceScope: 0x0, + Address: []byte{0x0A, 0xF0, 0x00, 0x01}, + }}, + NoRevertPolicy(), + }, + // Subnet with "revert". + { + [][]string{ + {"stop", "subnet", "replace", "32", "56", "revert"}, + {"stop", "subnet", "set", "0", "56", "revert"}, + }, + nil, + []dns.EDNS0{}, + NewRevertPolicy(false, false), + }, + { + [][]string{ + {"stop", "subnet", "replace", "32", "56", "revert"}, + {"stop", "subnet", "set", "0", "56", "revert"}, + }, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00}, + }}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00}, + }}, + NewRevertPolicy(false, false), }, } @@ -428,6 +618,19 @@ func TestEdns0LocalMultiRule(t *testing.T) { o.Option = append(o.Option, tc.fromOpts...) } rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + rules := make([]Rule, 0, len(tc.rules)) + for _, rule := range tc.rules { + r, _ := newEdns0Rule(rule[0], rule[1:]...) + rules = append(rules, r) + } + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + RevertPolicy: tc.revertPolicy, + } + rw.ServeDNS(ctx, rec, m) resp := rec.Msg @@ -745,3 +948,101 @@ func TestRewriteEDNS0Subnet(t *testing.T) { } } } + +func TestRewriteEDNS0Revert(t *testing.T) { + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + RevertPolicy: NewRevertPolicy(false, false), + } + + tests := []struct { + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + doBool bool + }{ + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "0xabcdef", "revert"}, + []dns.EDNS0{}, + false, + }, + { + []dns.EDNS0{}, + []string{"local", "append", "0xffee", "abcdefghijklmnop", "revert"}, + []dns.EDNS0{}, + false, + }, + { + []dns.EDNS0{}, + []string{"local", "replace", "0xffee", "abcdefghijklmnop", "revert"}, + []dns.EDNS0{}, + true, + }, + { + []dns.EDNS0{}, + []string{"nsid", "set", "revert"}, + []dns.EDNS0{}, + false, + }, + { + []dns.EDNS0{}, + []string{"nsid", "append", "revert"}, + []dns.EDNS0{}, + true, + }, + { + []dns.EDNS0{}, + []string{"nsid", "replace"}, + []dns.EDNS0{}, + true, + }, + + { + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffed, Data: []byte{0xab, 0xcd, 0xef}}}, + []string{"local", "set", "0xffee", "0xabcd", "revert"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffed, Data: []byte{0xab, 0xcd, 0xef}}}, + false, + }, + { + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffef, Data: []byte{0xab, 0xcd, 0xef}}}, + []string{"local", "replace", "0xffee", "abcdefghijklmnop"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffef, Data: []byte{0xab, 0xcd, 0xef}}}, + true, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + + r, err := newEdns0Rule("stop", tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + o.SetDo(tc.doBool) + if tc.fromOpts != nil { + o.Option = append(o.Option, tc.fromOpts...) + } + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if o.Do() != tc.doBool { + t.Errorf("Test %d: Expected %v but got %v", i, tc.doBool, o.Do()) + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +}