Modify the rewrite plugin to write multiple EDNS0 options (#936) (#1096)

* Add processing mode

* Add processing mode

* Update UTs

* Update README.md

* Change to use the constant Stop

* Fix README per review comments
This commit is contained in:
Thong Huynh
2017-09-20 13:06:53 -07:00
committed by John Belamaric
parent 36c7aa6437
commit ec21f83425
7 changed files with 139 additions and 28 deletions

View File

@@ -8,7 +8,7 @@ Rewrites are invisible to the client. There are simple rewrites (fast) and compl
## Syntax ## Syntax
~~~ ~~~
rewrite FIELD FROM TO rewrite [continue|stop] FIELD FROM TO
~~~ ~~~
* **FIELD** is (`type`, `class`, `name`, ...) * **FIELD** is (`type`, `class`, `name`, ...)
@@ -26,8 +26,11 @@ needs to be a full match of the name, e.g., `rewrite name miek.nl example.org`.
When the FIELD is `edns0` an EDNS0 option can be appended to the request as described below. When the FIELD is `edns0` an EDNS0 option can be appended to the request as described below.
If you specify multiple rules and an incoming query matches on multiple (simple) rules, only If you specify multiple rules and an incoming query matches on multiple rules, the rewrite
the first rewrite is applied. will behave as following
* `continue` will continue apply the next rule in the rule list.
* `stop` will consider the current rule is the last rule and will not continue. Default behaviour
for not specifying this rule processing mode is `stop`
## EDNS0 Options ## EDNS0 Options

View File

@@ -33,3 +33,8 @@ func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
} }
return RewriteIgnored return RewriteIgnored
} }
// Mode returns the processing mode
func (rule *classRule) Mode() string {
return Stop
}

View File

@@ -15,6 +15,7 @@ import (
// edns0LocalRule is a rewrite rule for EDNS0_LOCAL options // edns0LocalRule is a rewrite rule for EDNS0_LOCAL options
type edns0LocalRule struct { type edns0LocalRule struct {
mode string
action string action string
code uint16 code uint16
data []byte data []byte
@@ -22,6 +23,7 @@ type edns0LocalRule struct {
// edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable // edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable
type edns0VariableRule struct { type edns0VariableRule struct {
mode string
action string action string
code uint16 code uint16
variable string variable string
@@ -29,6 +31,7 @@ type edns0VariableRule struct {
// ends0NsidRule is a rewrite rule for EDNS0_NSID options // ends0NsidRule is a rewrite rule for EDNS0_NSID options
type edns0NsidRule struct { type edns0NsidRule struct {
mode string
action string action string
} }
@@ -70,6 +73,11 @@ Option:
return result return result
} }
// Mode returns the processing mode
func (rule *edns0NsidRule) Mode() string {
return rule.mode
}
// Rewrite will alter the request EDNS0 local options // Rewrite will alter the request EDNS0 local options
func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored result := RewriteIgnored
@@ -102,8 +110,13 @@ func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
return result return result
} }
// Mode returns the processing mode
func (rule *edns0LocalRule) Mode() string {
return rule.mode
}
// newEdns0Rule creates an EDNS0 rule of the appropriate type based on the args // newEdns0Rule creates an EDNS0 rule of the appropriate type based on the args
func newEdns0Rule(args ...string) (Rule, error) { func newEdns0Rule(mode string, args ...string) (Rule, error) {
if len(args) < 2 { if len(args) < 2 {
return nil, fmt.Errorf("too few arguments for an EDNS0 rule") return nil, fmt.Errorf("too few arguments for an EDNS0 rule")
} }
@@ -125,25 +138,25 @@ func newEdns0Rule(args ...string) (Rule, error) {
} }
//Check for variable option //Check for variable option
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") { if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
return newEdns0VariableRule(action, args[2], args[3]) return newEdns0VariableRule(mode, action, args[2], args[3])
} }
return newEdns0LocalRule(action, args[2], args[3]) return newEdns0LocalRule(mode, action, args[2], args[3])
case "nsid": case "nsid":
if len(args) != 2 { if len(args) != 2 {
return nil, fmt.Errorf("EDNS0 NSID rules do not accept args") return nil, fmt.Errorf("EDNS0 NSID rules do not accept args")
} }
return &edns0NsidRule{action: action}, nil return &edns0NsidRule{mode: mode, action: action}, nil
case "subnet": case "subnet":
if len(args) != 4 { if len(args) != 4 {
return nil, fmt.Errorf("EDNS0 subnet rules require exactly three args") return nil, fmt.Errorf("EDNS0 subnet rules require exactly three args")
} }
return newEdns0SubnetRule(action, args[2], args[3]) return newEdns0SubnetRule(mode, action, args[2], args[3])
default: default:
return nil, fmt.Errorf("invalid rule type %q", ruleType) return nil, fmt.Errorf("invalid rule type %q", ruleType)
} }
} }
func newEdns0LocalRule(action, code, data string) (*edns0LocalRule, error) { func newEdns0LocalRule(mode, action, code, data string) (*edns0LocalRule, error) {
c, err := strconv.ParseUint(code, 0, 16) c, err := strconv.ParseUint(code, 0, 16)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -156,11 +169,11 @@ func newEdns0LocalRule(action, code, data string) (*edns0LocalRule, error) {
return nil, err return nil, err
} }
} }
return &edns0LocalRule{action: action, code: uint16(c), data: decoded}, nil return &edns0LocalRule{mode: mode, action: action, code: uint16(c), data: decoded}, nil
} }
// newEdns0VariableRule creates an EDNS0 rule that handles variable substitution // newEdns0VariableRule creates an EDNS0 rule that handles variable substitution
func newEdns0VariableRule(action, code, variable string) (*edns0VariableRule, error) { func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRule, error) {
c, err := strconv.ParseUint(code, 0, 16) c, err := strconv.ParseUint(code, 0, 16)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -169,7 +182,7 @@ func newEdns0VariableRule(action, code, variable string) (*edns0VariableRule, er
if !isValidVariable(variable) { if !isValidVariable(variable) {
return nil, fmt.Errorf("unsupported variable name %q", variable) return nil, fmt.Errorf("unsupported variable name %q", variable)
} }
return &edns0VariableRule{action: action, code: uint16(c), variable: variable}, nil return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil
} }
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. // ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
@@ -294,6 +307,11 @@ func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result
return result return result
} }
// Mode returns the processing mode
func (rule *edns0VariableRule) Mode() string {
return rule.mode
}
func isValidVariable(variable string) bool { func isValidVariable(variable string) bool {
switch variable { switch variable {
case case
@@ -311,12 +329,13 @@ func isValidVariable(variable string) bool {
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options // ends0SubnetRule is a rewrite rule for EDNS0 subnet options
type edns0SubnetRule struct { type edns0SubnetRule struct {
mode string
v4BitMaskLen uint8 v4BitMaskLen uint8
v6BitMaskLen uint8 v6BitMaskLen uint8
action string action string
} }
func newEdns0SubnetRule(action, v4BitMaskLen, v6BitMaskLen string) (*edns0SubnetRule, error) { func newEdns0SubnetRule(mode, action, v4BitMaskLen, v6BitMaskLen string) (*edns0SubnetRule, error) {
v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16) v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -335,7 +354,7 @@ func newEdns0SubnetRule(action, v4BitMaskLen, v6BitMaskLen string) (*edns0Subnet
return nil, fmt.Errorf("invalid IPv6 bit mask length %d", v6Len) return nil, fmt.Errorf("invalid IPv6 bit mask length %d", v6Len)
} }
return &edns0SubnetRule{action: action, return &edns0SubnetRule{mode: mode, action: action,
v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len)}, nil v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len)}, nil
} }
@@ -400,6 +419,11 @@ func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
return result return result
} }
// Mode returns the processing mode
func (rule *edns0SubnetRule) Mode() string {
return rule.mode
}
// These are all defined actions. // These are all defined actions.
const ( const (
Replace = "replace" Replace = "replace"

View File

@@ -22,3 +22,8 @@ func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
} }
return RewriteIgnored return RewriteIgnored
} }
// Mode returns the processing mode
func (rule *nameRule) Mode() string {
return Stop
}

View File

@@ -24,6 +24,14 @@ const (
RewriteStatus RewriteStatus
) )
// These are defined processing mode.
const (
// Processing should stop after completing this rule
Stop = "stop"
// Processing should continue to next rule
Continue = "continue"
)
// Rewrite is plugin to rewrite requests internally before being handled. // Rewrite is plugin to rewrite requests internally before being handled.
type Rewrite struct { type Rewrite struct {
Next plugin.Handler Next plugin.Handler
@@ -37,10 +45,12 @@ func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
for _, rule := range rw.Rules { for _, rule := range rw.Rules {
switch result := rule.Rewrite(w, r); result { switch result := rule.Rewrite(w, r); result {
case RewriteDone: case RewriteDone:
if rw.noRevert { if rule.Mode() == Stop {
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r) if rw.noRevert {
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
}
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
} }
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
case RewriteIgnored: case RewriteIgnored:
break break
case RewriteStatus: case RewriteStatus:
@@ -60,6 +70,8 @@ func (rw Rewrite) Name() string { return "rewrite" }
type Rule interface { type Rule interface {
// Rewrite rewrites the current request. // Rewrite rewrites the current request.
Rewrite(dns.ResponseWriter, *dns.Msg) Result Rewrite(dns.ResponseWriter, *dns.Msg) Result
// Mode returns the processing mode stop or continue
Mode() string
} }
func newRule(args ...string) (Rule, error) { func newRule(args ...string) (Rule, error) {
@@ -67,19 +79,39 @@ func newRule(args ...string) (Rule, error) {
return nil, fmt.Errorf("no rule type specified for rewrite") return nil, fmt.Errorf("no rule type specified for rewrite")
} }
ruleType := strings.ToLower(args[0]) arg0 := strings.ToLower(args[0])
if ruleType != "edns0" && len(args) != 3 { var ruleType string
var expectNumArgs, startArg int
mode := Stop
switch arg0 {
case Continue:
mode = arg0
ruleType = strings.ToLower(args[1])
expectNumArgs = len(args) - 1
startArg = 2
case Stop:
ruleType = strings.ToLower(args[1])
expectNumArgs = len(args) - 1
startArg = 2
default:
// for backward compability
ruleType = arg0
expectNumArgs = len(args)
startArg = 1
}
if ruleType != "edns0" && expectNumArgs != 3 {
return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType) return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
} }
switch ruleType { switch ruleType {
case "name": case "name":
return newNameRule(args[1], args[2]) return newNameRule(args[startArg], args[startArg+1])
case "class": case "class":
return newClassRule(args[1], args[2]) return newClassRule(args[startArg], args[startArg+1])
case "type": case "type":
return newTypeRule(args[1], args[2]) return newTypeRule(args[startArg], args[startArg+1])
case "edns0": case "edns0":
return newEdns0Rule(args[1:]...) return newEdns0Rule(mode, args[startArg:]...)
default: default:
return nil, fmt.Errorf("invalid rule type %q", args[0]) return nil, fmt.Errorf("invalid rule type %q", args[0])
} }

View File

@@ -88,6 +88,43 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, {[]string{"edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, {[]string{"edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, {[]string{"edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"unknown-action", "name", "a.com", "b.com"}, true, nil},
{[]string{"stop", "name", "a.com", "b.com"}, false, reflect.TypeOf(&nameRule{})},
{[]string{"continue", "name", "a.com", "b.com"}, false, reflect.TypeOf(&nameRule{})},
{[]string{"unknown-action", "type", "any", "a"}, true, nil},
{[]string{"stop", "type", "any", "a"}, false, reflect.TypeOf(&typeRule{})},
{[]string{"continue", "type", "any", "a"}, false, reflect.TypeOf(&typeRule{})},
{[]string{"unknown-action", "class", "ch", "in"}, true, nil},
{[]string{"stop", "class", "ch", "in"}, false, reflect.TypeOf(&classRule{})},
{[]string{"continue", "class", "ch", "in"}, false, reflect.TypeOf(&classRule{})},
{[]string{"unknown-action", "edns0", "local", "set", "0xffee", "abcedef"}, true, nil},
{[]string{"stop", "edns0", "local", "set", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})},
{[]string{"continue", "edns0", "local", "set", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})},
{[]string{"unknown-action", "edns0", "nsid", "set"}, true, nil},
{[]string{"stop", "edns0", "nsid", "set"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"continue", "edns0", "nsid", "set"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"unknown-action", "edns0", "local", "set", "0xffee", "{qname}"}, true, nil},
{[]string{"stop", "edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"stop", "edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"stop", "edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"stop", "edns0", "local", "set", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"stop", "edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"stop", "edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"stop", "edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"continue", "edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"continue", "edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"continue", "edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"continue", "edns0", "local", "set", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"continue", "edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"continue", "edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"continue", "edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"unknown-action", "edns0", "subnet", "set", "24", "64"}, true, nil},
{[]string{"stop", "edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"stop", "edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"stop", "edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"continue", "edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"continue", "edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"continue", "edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
} }
for i, tc := range tests { for i, tc := range tests {
@@ -208,7 +245,7 @@ func TestRewriteEDNS0Local(t *testing.T) {
m.SetQuestion("example.com.", dns.TypeA) m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET m.Question[0].Qclass = dns.ClassINET
r, err := newEdns0Rule(tc.args...) r, err := newEdns0Rule("stop", tc.args...)
if err != nil { if err != nil {
t.Errorf("Error creating test rule: %s", err) t.Errorf("Error creating test rule: %s", err)
continue continue
@@ -232,9 +269,9 @@ func TestRewriteEDNS0Local(t *testing.T) {
func TestEdns0LocalMultiRule(t *testing.T) { func TestEdns0LocalMultiRule(t *testing.T) {
rules := []Rule{} rules := []Rule{}
r, _ := newEdns0Rule("local", "replace", "0xffee", "abcdef") r, _ := newEdns0Rule("stop", "local", "replace", "0xffee", "abcdef")
rules = append(rules, r) rules = append(rules, r)
r, _ = newEdns0Rule("local", "set", "0xffee", "fedcba") r, _ = newEdns0Rule("stop", "local", "set", "0xffee", "fedcba")
rules = append(rules, r) rules = append(rules, r)
rw := Rewrite{ rw := Rewrite{
@@ -399,7 +436,7 @@ func TestRewriteEDNS0LocalVariable(t *testing.T) {
m.SetQuestion("example.com.", dns.TypeA) m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET m.Question[0].Qclass = dns.ClassINET
r, err := newEdns0Rule(tc.args...) r, err := newEdns0Rule("stop", tc.args...)
if err != nil { if err != nil {
t.Errorf("Error creating test rule: %s", err) t.Errorf("Error creating test rule: %s", err)
continue continue
@@ -510,7 +547,7 @@ func TestRewriteEDNS0Subnet(t *testing.T) {
m.SetQuestion("example.com.", dns.TypeA) m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET m.Question[0].Qclass = dns.ClassINET
r, err := newEdns0Rule(tc.args...) r, err := newEdns0Rule("stop", tc.args...)
if err != nil { if err != nil {
t.Errorf("Error creating test rule: %s", err) t.Errorf("Error creating test rule: %s", err)
continue continue

View File

@@ -35,3 +35,8 @@ func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
} }
return RewriteIgnored return RewriteIgnored
} }
// Mode returns the processing mode
func (rule *typeRule) Mode() string {
return Stop
}