Remove the word middleware (#1067)

* Rename middleware to plugin

first pass; mostly used 'sed', few spots where I manually changed
text.

This still builds a coredns binary.

* fmt error

* Rename AddMiddleware to AddPlugin

* Readd AddMiddleware to remain backwards compat
This commit is contained in:
Miek Gieben
2017-09-14 09:36:06 +01:00
committed by GitHub
parent b984aa4559
commit d8714e64e4
354 changed files with 974 additions and 969 deletions

91
plugin/rewrite/README.md Normal file
View File

@@ -0,0 +1,91 @@
# rewrite
*rewrite* performs internal message rewriting.
Rewrites are invisible to the client. There are simple rewrites (fast) and complex rewrites
(slower), but they're powerful enough to accommodate most dynamic back-end applications.
## Syntax
~~~
rewrite FIELD FROM TO
~~~
* **FIELD** is (`type`, `class`, `name`, ...)
* **FROM** is the exact name of type to match
* **TO** is the destination name or type to rewrite to
When the FIELD is `type` and FROM is (`A`, `MX`, etc.), the type of the message will be rewritten;
e.g., to rewrite ANY queries to HINFO, use `rewrite type ANY HINFO`.
When the FIELD is `class` and FROM is (`IN`, `CH`, or `HS`) the class of the message will be
rewritten; e.g., to rewrite CH queries to IN use `rewrite class CH IN`.
When the FIELD is `name` the query name in the message is rewritten; this
needs to be a full match of the name, e.g., `rewrite name miek.nl example.org`.
When the FIELD is `edns0` an EDNS0 option can be appended to the request as described below.
If you specify multiple rules and an incoming query matches on multiple (simple) rules, only
the first rewrite is applied.
## EDNS0 Options
Using FIELD edns0, you can set, append, or replace specific EDNS0 options on the request.
* `replace` will modify any matching (what that means may vary based on EDNS0 type) option with the specified option
* `append` will add the option regardless of what options already exist
* `set` will modify a matching option or add one if none is found
Currently supported are `EDNS0_LOCAL`, `EDNS0_NSID` and `EDNS0_SUBNET`.
### `EDNS0_LOCAL`
This has two fields, code and data. A match is defined as having the same code. Data may be a string or a variable.
* A string data can be treated as hex if it starts with `0x`. Example:
~~~
rewrite edns0 local set 0xffee 0x61626364
~~~
rewrites the first local option with code 0xffee, setting the data to "abcd". Equivalent:
~~~
rewrite edns0 local set 0xffee abcd
~~~
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables:
* {qname}
* {qtype}
* {client_ip}
* {client_port}
* {protocol}
* {server_ip}
* {server_port}
Example:
~~~
rewrite edns0 local set 0xffee {client_ip}
~~~
### `EDNS0_NSID`
This has no fields; it will add an NSID option with an empty string for the NSID. If the option already exists
and the action is `replace` or `set`, then the NSID in the option will be set to the empty string.
### `EDNS0_SUBNET`
This has two fields, IPv4 bitmask length and IPv6 bitmask length. The bitmask
length is used to extract the client subnet from the source IP address in the query.
Example:
~~~
rewrite edns0 subnet set 24 56
~~~
* If the query has source IP as IPv4, the first 24 bits in the IP will be the network subnet.
* If the query has source IP as IPv6, the first 56 bits in the IP will be the network subnet.

35
plugin/rewrite/class.go Normal file
View File

@@ -0,0 +1,35 @@
package rewrite
import (
"fmt"
"strings"
"github.com/miekg/dns"
)
type classRule struct {
fromClass, toClass uint16
}
func newClassRule(fromS, toS string) (Rule, error) {
var from, to uint16
var ok bool
if from, ok = dns.StringToClass[strings.ToUpper(fromS)]; !ok {
return nil, fmt.Errorf("invalid class %q", strings.ToUpper(fromS))
}
if to, ok = dns.StringToClass[strings.ToUpper(toS)]; !ok {
return nil, fmt.Errorf("invalid class %q", strings.ToUpper(toS))
}
return &classRule{fromClass: from, toClass: to}, nil
}
// Rewrite rewrites the the current request.
func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.fromClass > 0 && rule.toClass > 0 {
if r.Question[0].Qclass == rule.fromClass {
r.Question[0].Qclass = rule.toClass
return RewriteDone
}
}
return RewriteIgnored
}

132
plugin/rewrite/condition.go Normal file
View File

@@ -0,0 +1,132 @@
package rewrite
import (
"fmt"
"regexp"
"strings"
"github.com/coredns/coredns/plugin/pkg/replacer"
"github.com/miekg/dns"
)
// Operators
const (
Is = "is"
Not = "not"
Has = "has"
NotHas = "not_has"
StartsWith = "starts_with"
EndsWith = "ends_with"
Match = "match"
NotMatch = "not_match"
)
func operatorError(operator string) error {
return fmt.Errorf("invalid operator %v", operator)
}
func newReplacer(r *dns.Msg) replacer.Replacer {
return replacer.New(r, nil, "")
}
// condition is a rewrite condition.
type condition func(string, string) bool
var conditions = map[string]condition{
Is: isFunc,
Not: notFunc,
Has: hasFunc,
NotHas: notHasFunc,
StartsWith: startsWithFunc,
EndsWith: endsWithFunc,
Match: matchFunc,
NotMatch: notMatchFunc,
}
// isFunc is condition for Is operator.
// It checks for equality.
func isFunc(a, b string) bool {
return a == b
}
// notFunc is condition for Not operator.
// It checks for inequality.
func notFunc(a, b string) bool {
return a != b
}
// hasFunc is condition for Has operator.
// It checks if b is a substring of a.
func hasFunc(a, b string) bool {
return strings.Contains(a, b)
}
// notHasFunc is condition for NotHas operator.
// It checks if b is not a substring of a.
func notHasFunc(a, b string) bool {
return !strings.Contains(a, b)
}
// startsWithFunc is condition for StartsWith operator.
// It checks if b is a prefix of a.
func startsWithFunc(a, b string) bool {
return strings.HasPrefix(a, b)
}
// endsWithFunc is condition for EndsWith operator.
// It checks if b is a suffix of a.
func endsWithFunc(a, b string) bool {
// TODO(miek): IsSubDomain
return strings.HasSuffix(a, b)
}
// matchFunc is condition for Match operator.
// It does regexp matching of a against pattern in b
// and returns if they match.
func matchFunc(a, b string) bool {
matched, _ := regexp.MatchString(b, a)
return matched
}
// notMatchFunc is condition for NotMatch operator.
// It does regexp matching of a against pattern in b
// and returns if they do not match.
func notMatchFunc(a, b string) bool {
matched, _ := regexp.MatchString(b, a)
return !matched
}
// If is statement for a rewrite condition.
type If struct {
A string
Operator string
B string
}
// True returns true if the condition is true and false otherwise.
// If r is not nil, it replaces placeholders before comparison.
func (i If) True(r *dns.Msg) bool {
if c, ok := conditions[i.Operator]; ok {
a, b := i.A, i.B
if r != nil {
replacer := newReplacer(r)
a = replacer.Replace(i.A)
b = replacer.Replace(i.B)
}
return c(a, b)
}
return false
}
// NewIf creates a new If condition.
func NewIf(a, operator, b string) (If, error) {
if _, ok := conditions[operator]; !ok {
return If{}, operatorError(operator)
}
return If{
A: a,
Operator: operator,
B: b,
}, nil
}

View File

@@ -0,0 +1,102 @@
package rewrite
/*
func TestConditions(t *testing.T) {
tests := []struct {
condition string
isTrue bool
}{
{"a is b", false},
{"a is a", true},
{"a not b", true},
{"a not a", false},
{"a has a", true},
{"a has b", false},
{"ba has b", true},
{"bab has b", true},
{"bab has bb", false},
{"a not_has a", false},
{"a not_has b", true},
{"ba not_has b", false},
{"bab not_has b", false},
{"bab not_has bb", true},
{"bab starts_with bb", false},
{"bab starts_with ba", true},
{"bab starts_with bab", true},
{"bab ends_with bb", false},
{"bab ends_with bab", true},
{"bab ends_with ab", true},
{"a match *", false},
{"a match a", true},
{"a match .*", true},
{"a match a.*", true},
{"a match b.*", false},
{"ba match b.*", true},
{"ba match b[a-z]", true},
{"b0 match b[a-z]", false},
{"b0a match b[a-z]", false},
{"b0a match b[a-z]+", false},
{"b0a match b[a-z0-9]+", true},
{"a not_match *", true},
{"a not_match a", false},
{"a not_match .*", false},
{"a not_match a.*", false},
{"a not_match b.*", true},
{"ba not_match b.*", false},
{"ba not_match b[a-z]", false},
{"b0 not_match b[a-z]", true},
{"b0a not_match b[a-z]", true},
{"b0a not_match b[a-z]+", true},
{"b0a not_match b[a-z0-9]+", false},
}
for i, test := range tests {
str := strings.Fields(test.condition)
ifCond, err := NewIf(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
isTrue := ifCond.True(nil)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
}
}
invalidOperators := []string{"ss", "and", "if"}
for _, op := range invalidOperators {
_, err := NewIf("a", op, "b")
if err == nil {
t.Errorf("Invalid operator %v used, expected error.", op)
}
}
replaceTests := []struct {
url string
condition string
isTrue bool
}{
{"/home", "{uri} match /home", true},
{"/hom", "{uri} match /home", false},
{"/hom", "{uri} starts_with /home", false},
{"/hom", "{uri} starts_with /h", true},
{"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
{"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
}
for i, test := range replaceTests {
r, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Error(err)
}
str := strings.Fields(test.condition)
ifCond, err := NewIf(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
isTrue := ifCond.True(r)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
}
}
}
*/

425
plugin/rewrite/edns0.go Normal file
View File

@@ -0,0 +1,425 @@
// Package rewrite is plugin for rewriting requests internally to something different.
package rewrite
import (
"encoding/binary"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// edns0LocalRule is a rewrite rule for EDNS0_LOCAL options
type edns0LocalRule struct {
action string
code uint16
data []byte
}
// edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable
type edns0VariableRule struct {
action string
code uint16
variable string
}
// ends0NsidRule is a rewrite rule for EDNS0_NSID options
type edns0NsidRule struct {
action string
}
// setupEdns0Opt will retrieve the EDNS0 OPT or create it if it does not exist
func setupEdns0Opt(r *dns.Msg) *dns.OPT {
o := r.IsEdns0()
if o == nil {
r.SetEdns0(4096, true)
o = r.IsEdns0()
}
return o
}
// Rewrite will alter the request EDNS0 NSID option
func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
Option:
for _, s := range o.Option {
switch e := s.(type) {
case *dns.EDNS0_NSID:
if rule.action == Replace || rule.action == Set {
e.Nsid = "" // make sure it is empty for request
result = RewriteDone
}
found = true
break Option
}
}
// add option if not found
if !found && (rule.action == Append || rule.action == Set) {
o.SetDo()
o.Option = append(o.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""})
result = RewriteDone
}
return result
}
// Rewrite will alter the request EDNS0 local options
func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
for _, s := range o.Option {
switch e := s.(type) {
case *dns.EDNS0_LOCAL:
if rule.code == e.Code {
if rule.action == Replace || rule.action == Set {
e.Data = rule.data
result = RewriteDone
}
found = true
break
}
}
}
// add option if not found
if !found && (rule.action == Append || rule.action == Set) {
o.SetDo()
var opt dns.EDNS0_LOCAL
opt.Code = rule.code
opt.Data = rule.data
o.Option = append(o.Option, &opt)
result = RewriteDone
}
return result
}
// newEdns0Rule creates an EDNS0 rule of the appropriate type based on the args
func newEdns0Rule(args ...string) (Rule, error) {
if len(args) < 2 {
return nil, fmt.Errorf("too few arguments for an EDNS0 rule")
}
ruleType := strings.ToLower(args[0])
action := strings.ToLower(args[1])
switch action {
case Append:
case Replace:
case Set:
default:
return nil, fmt.Errorf("invalid action: %q", action)
}
switch ruleType {
case "local":
if len(args) != 4 {
return nil, fmt.Errorf("EDNS0 local rules require exactly three args")
}
//Check for variable option
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
return newEdns0VariableRule(action, args[2], args[3])
}
return newEdns0LocalRule(action, args[2], args[3])
case "nsid":
if len(args) != 2 {
return nil, fmt.Errorf("EDNS0 NSID rules do not accept args")
}
return &edns0NsidRule{action: action}, nil
case "subnet":
if len(args) != 4 {
return nil, fmt.Errorf("EDNS0 subnet rules require exactly three args")
}
return newEdns0SubnetRule(action, args[2], args[3])
default:
return nil, fmt.Errorf("invalid rule type %q", ruleType)
}
}
func newEdns0LocalRule(action, code, data string) (*edns0LocalRule, error) {
c, err := strconv.ParseUint(code, 0, 16)
if err != nil {
return nil, err
}
decoded := []byte(data)
if strings.HasPrefix(data, "0x") {
decoded, err = hex.DecodeString(data[2:])
if err != nil {
return nil, err
}
}
return &edns0LocalRule{action: action, code: uint16(c), data: decoded}, nil
}
// newEdns0VariableRule creates an EDNS0 rule that handles variable substitution
func newEdns0VariableRule(action, code, variable string) (*edns0VariableRule, error) {
c, err := strconv.ParseUint(code, 0, 16)
if err != nil {
return nil, err
}
//Validate
if !isValidVariable(variable) {
return nil, fmt.Errorf("unsupported variable name %q", variable)
}
return &edns0VariableRule{action: action, code: uint16(c), variable: variable}, nil
}
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) {
switch family {
case 1:
return net.ParseIP(ipAddr).To4(), nil
case 2:
return net.ParseIP(ipAddr).To16(), nil
}
return nil, fmt.Errorf("Invalid IP address family (i.e. version) %d", family)
}
// uint16ToWire writes unit16 to wire/binary format
func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(data))
return buf
}
// portToWire writes port to wire/binary format, 2 bytes
func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
return rule.uint16ToWire(uint16(port)), nil
}
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
func (rule *edns0VariableRule) family(ip net.Addr) int {
var a net.IP
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
if i, ok := ip.(*net.TCPAddr); ok {
a = i.IP
}
if a.To4() != nil {
return 1
}
return 2
}
// ruleData returns the data specified by the variable
func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
req := request.Request{W: w, Req: r}
switch rule.variable {
case queryName:
//Query name is written as ascii string
return []byte(req.QName()), nil
case queryType:
return rule.uint16ToWire(req.QType()), nil
case clientIP:
return rule.ipToWire(req.Family(), req.IP())
case clientPort:
return rule.portToWire(req.Port())
case protocol:
// Proto is written as ascii string
return []byte(req.Proto()), nil
case serverIP:
ip, _, err := net.SplitHostPort(w.LocalAddr().String())
if err != nil {
ip = w.RemoteAddr().String()
}
return rule.ipToWire(rule.family(w.RemoteAddr()), ip)
case serverPort:
_, port, err := net.SplitHostPort(w.LocalAddr().String())
if err != nil {
port = "0"
}
return rule.portToWire(port)
}
return nil, fmt.Errorf("Unable to extract data for variable %s", rule.variable)
}
// Rewrite will alter the request EDNS0 local options with specified variables
func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
data, err := rule.ruleData(w, r)
if err != nil || data == nil {
return result
}
o := setupEdns0Opt(r)
found := false
for _, s := range o.Option {
switch e := s.(type) {
case *dns.EDNS0_LOCAL:
if rule.code == e.Code {
if rule.action == Replace || rule.action == Set {
e.Data = data
result = RewriteDone
}
found = true
break
}
}
}
// add option if not found
if !found && (rule.action == Append || rule.action == Set) {
o.SetDo()
var opt dns.EDNS0_LOCAL
opt.Code = rule.code
opt.Data = data
o.Option = append(o.Option, &opt)
result = RewriteDone
}
return result
}
func isValidVariable(variable string) bool {
switch variable {
case
queryName,
queryType,
clientIP,
clientPort,
protocol,
serverIP,
serverPort:
return true
}
return false
}
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
type edns0SubnetRule struct {
v4BitMaskLen uint8
v6BitMaskLen uint8
action string
}
func newEdns0SubnetRule(action, v4BitMaskLen, v6BitMaskLen string) (*edns0SubnetRule, error) {
v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16)
if err != nil {
return nil, err
}
// Validate V4 length
if v4Len > maxV4BitMaskLen {
return nil, fmt.Errorf("invalid IPv4 bit mask length %d", v4Len)
}
v6Len, err := strconv.ParseUint(v6BitMaskLen, 0, 16)
if err != nil {
return nil, err
}
//Validate V6 length
if v6Len > maxV6BitMaskLen {
return nil, fmt.Errorf("invalid IPv6 bit mask length %d", v6Len)
}
return &edns0SubnetRule{action: action,
v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len)}, nil
}
// fillEcsData sets the subnet data into the ecs option
func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg,
ecs *dns.EDNS0_SUBNET) error {
req := request.Request{W: w, Req: r}
family := req.Family()
if (family != 1) && (family != 2) {
return fmt.Errorf("unable to fill data for EDNS0 subnet due to invalid IP family")
}
ecs.DraftOption = false
ecs.Family = uint16(family)
ecs.SourceScope = 0
ipAddr := req.IP()
switch family {
case 1:
ipv4Mask := net.CIDRMask(int(rule.v4BitMaskLen), 32)
ipv4Addr := net.ParseIP(ipAddr)
ecs.SourceNetmask = rule.v4BitMaskLen
ecs.Address = ipv4Addr.Mask(ipv4Mask).To4()
case 2:
ipv6Mask := net.CIDRMask(int(rule.v6BitMaskLen), 128)
ipv6Addr := net.ParseIP(ipAddr)
ecs.SourceNetmask = rule.v6BitMaskLen
ecs.Address = ipv6Addr.Mask(ipv6Mask).To16()
}
return nil
}
// Rewrite will alter the request EDNS0 subnet option
func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
for _, s := range o.Option {
switch e := s.(type) {
case *dns.EDNS0_SUBNET:
if rule.action == Replace || rule.action == Set {
if rule.fillEcsData(w, r, e) == nil {
result = RewriteDone
}
}
found = true
break
}
}
// add option if not found
if !found && (rule.action == Append || rule.action == Set) {
o.SetDo()
opt := dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET}
if rule.fillEcsData(w, r, &opt) == nil {
o.Option = append(o.Option, &opt)
result = RewriteDone
}
}
return result
}
// These are all defined actions.
const (
Replace = "replace"
Set = "set"
Append = "append"
)
// Supported local EDNS0 variables
const (
queryName = "{qname}"
queryType = "{qtype}"
clientIP = "{client_ip}"
clientPort = "{client_port}"
protocol = "{protocol}"
serverIP = "{server_ip}"
serverPort = "{server_port}"
)
// Subnet maximum bit mask length
const (
maxV4BitMaskLen = 32
maxV6BitMaskLen = 128
)

24
plugin/rewrite/name.go Normal file
View File

@@ -0,0 +1,24 @@
package rewrite
import (
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
type nameRule struct {
From, To string
}
func newNameRule(from, to string) (Rule, error) {
return &nameRule{plugin.Name(from).Normalize(), plugin.Name(to).Normalize()}, nil
}
// Rewrite rewrites the the current request.
func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.From == r.Question[0].Name {
r.Question[0].Name = rule.To
return RewriteDone
}
return RewriteIgnored
}

View File

@@ -0,0 +1,39 @@
package rewrite
import "github.com/miekg/dns"
// ResponseReverter reverses the operations done on the question section of a packet.
// This is need because the client will otherwise disregards the response, i.e.
// dig will complain with ';; Question section mismatch: got miek.nl/HINFO/IN'
type ResponseReverter struct {
dns.ResponseWriter
original dns.Question
}
// NewResponseReverter returns a pointer to a new ResponseReverter.
func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter {
return &ResponseReverter{
ResponseWriter: w,
original: r.Question[0],
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *ResponseReverter) WriteMsg(res *dns.Msg) error {
res.Question[0] = r.original
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseReverter) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
return n, err
}
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *ResponseReverter) Hijack() {
r.ResponseWriter.Hijack()
return
}

86
plugin/rewrite/rewrite.go Normal file
View File

@@ -0,0 +1,86 @@
package rewrite
import (
"fmt"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// Result is the result of a rewrite
type Result int
const (
// RewriteIgnored is returned when rewrite is not done on request.
RewriteIgnored Result = iota
// RewriteDone is returned when rewrite is done on request.
RewriteDone
// RewriteStatus is returned when rewrite is not needed and status code should be set
// for the request.
RewriteStatus
)
// Rewrite is plugin to rewrite requests internally before being handled.
type Rewrite struct {
Next plugin.Handler
Rules []Rule
noRevert bool
}
// ServeDNS implements the plugin.Handler interface.
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
wr := NewResponseReverter(w, r)
for _, rule := range rw.Rules {
switch result := rule.Rewrite(w, r); result {
case RewriteDone:
if rw.noRevert {
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
}
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
case RewriteIgnored:
break
case RewriteStatus:
// only valid for complex rules.
// if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 {
// return cRule.Status, nil
// }
}
}
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
}
// Name implements the Handler interface.
func (rw Rewrite) Name() string { return "rewrite" }
// Rule describes a rewrite rule.
type Rule interface {
// Rewrite rewrites the current request.
Rewrite(dns.ResponseWriter, *dns.Msg) Result
}
func newRule(args ...string) (Rule, error) {
if len(args) == 0 {
return nil, fmt.Errorf("no rule type specified for rewrite")
}
ruleType := strings.ToLower(args[0])
if ruleType != "edns0" && len(args) != 3 {
return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
}
switch ruleType {
case "name":
return newNameRule(args[1], args[2])
case "class":
return newClassRule(args[1], args[2])
case "type":
return newTypeRule(args[1], args[2])
case "edns0":
return newEdns0Rule(args[1:]...)
default:
return nil, fmt.Errorf("invalid rule type %q", args[0])
}
}

View File

@@ -0,0 +1,532 @@
package rewrite
import (
"bytes"
"reflect"
"testing"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsrecorder"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
func msgPrinter(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
w.WriteMsg(r)
return 0, nil
}
func TestNewRule(t *testing.T) {
tests := []struct {
args []string
shouldError bool
expType reflect.Type
}{
{[]string{}, true, nil},
{[]string{"foo"}, true, nil},
{[]string{"name"}, true, nil},
{[]string{"name", "a.com"}, true, nil},
{[]string{"name", "a.com", "b.com", "c.com"}, true, nil},
{[]string{"name", "a.com", "b.com"}, false, reflect.TypeOf(&nameRule{})},
{[]string{"type"}, true, nil},
{[]string{"type", "a"}, true, nil},
{[]string{"type", "any", "a", "a"}, true, nil},
{[]string{"type", "any", "a"}, false, reflect.TypeOf(&typeRule{})},
{[]string{"type", "XY", "WV"}, true, nil},
{[]string{"type", "ANY", "WV"}, true, nil},
{[]string{"class"}, true, nil},
{[]string{"class", "IN"}, true, nil},
{[]string{"class", "ch", "in", "in"}, true, nil},
{[]string{"class", "ch", "in"}, false, reflect.TypeOf(&classRule{})},
{[]string{"class", "XY", "WV"}, true, nil},
{[]string{"class", "IN", "WV"}, true, nil},
{[]string{"edns0"}, true, nil},
{[]string{"edns0", "local"}, true, nil},
{[]string{"edns0", "local", "set"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee"}, true, nil},
{[]string{"edns0", "local", "set", "65518", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})},
{[]string{"edns0", "local", "set", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})},
{[]string{"edns0", "local", "append", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})},
{[]string{"edns0", "local", "foo", "0xffee", "abcdefg"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee", "0xabcdefg"}, true, nil},
{[]string{"edns0", "nsid", "set", "junk"}, true, nil},
{[]string{"edns0", "nsid", "set"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "foo"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "subnet", "set", "-1", "56"}, true, nil},
{[]string{"edns0", "subnet", "set", "24", "-56"}, true, nil},
{[]string{"edns0", "subnet", "set", "33", "56"}, true, nil},
{[]string{"edns0", "subnet", "set", "24", "129"}, true, nil},
{[]string{"edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
{[]string{"edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})},
}
for i, tc := range tests {
r, err := newRule(tc.args...)
if err == nil && tc.shouldError {
t.Errorf("Test %d: expected error but got success", i)
} else if err != nil && !tc.shouldError {
t.Errorf("Test %d: expected success but got error: %s", i, err)
}
if !tc.shouldError && reflect.TypeOf(r) != tc.expType {
t.Errorf("Test %d: expected %q but got %q", i, tc.expType, r)
}
}
}
func TestRewrite(t *testing.T) {
rules := []Rule{}
r, _ := newNameRule("from.nl.", "to.nl.")
rules = append(rules, r)
r, _ = newClassRule("CH", "IN")
rules = append(rules, r)
r, _ = newTypeRule("ANY", "HINFO")
rules = append(rules, r)
rw := Rewrite{
Next: plugin.HandlerFunc(msgPrinter),
Rules: rules,
noRevert: true,
}
tests := []struct {
from string
fromT uint16
fromC uint16
to string
toT uint16
toC uint16
}{
{"from.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET},
{"a.nl.", dns.TypeA, dns.ClassINET, "a.nl.", dns.TypeA, dns.ClassINET},
{"a.nl.", dns.TypeA, dns.ClassCHAOS, "a.nl.", dns.TypeA, dns.ClassINET},
{"a.nl.", dns.TypeANY, dns.ClassINET, "a.nl.", dns.TypeHINFO, dns.ClassINET},
// name is rewritten, type is not.
{"from.nl.", dns.TypeANY, dns.ClassINET, "to.nl.", dns.TypeANY, dns.ClassINET},
// name is not, type is, but class is, because class is the 2nd rule.
{"a.nl.", dns.TypeANY, dns.ClassCHAOS, "a.nl.", dns.TypeANY, dns.ClassINET},
}
ctx := context.TODO()
for i, tc := range tests {
m := new(dns.Msg)
m.SetQuestion(tc.from, tc.fromT)
m.Question[0].Qclass = tc.fromC
rec := dnsrecorder.New(&test.ResponseWriter{})
rw.ServeDNS(ctx, rec, m)
resp := rec.Msg
if resp.Question[0].Name != tc.to {
t.Errorf("Test %d: Expected Name to be %q but was %q", i, tc.to, resp.Question[0].Name)
}
if resp.Question[0].Qtype != tc.toT {
t.Errorf("Test %d: Expected Type to be '%d' but was '%d'", i, tc.toT, resp.Question[0].Qtype)
}
if resp.Question[0].Qclass != tc.toC {
t.Errorf("Test %d: Expected Class to be '%d' but was '%d'", i, tc.toC, resp.Question[0].Qclass)
}
}
}
func TestRewriteEDNS0Local(t *testing.T) {
rw := Rewrite{
Next: plugin.HandlerFunc(msgPrinter),
noRevert: true,
}
tests := []struct {
fromOpts []dns.EDNS0
args []string
toOpts []dns.EDNS0
}{
{
[]dns.EDNS0{},
[]string{"local", "set", "0xffee", "0xabcdef"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0xab, 0xcd, 0xef}}},
},
{
[]dns.EDNS0{},
[]string{"local", "append", "0xffee", "abcdefghijklmnop"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("abcdefghijklmnop")}},
},
{
[]dns.EDNS0{},
[]string{"local", "replace", "0xffee", "abcdefghijklmnop"},
[]dns.EDNS0{},
},
{
[]dns.EDNS0{},
[]string{"nsid", "set"},
[]dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}},
},
{
[]dns.EDNS0{},
[]string{"nsid", "append"},
[]dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}},
},
{
[]dns.EDNS0{},
[]string{"nsid", "replace"},
[]dns.EDNS0{},
},
}
ctx := context.TODO()
for i, tc := range tests {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET
r, err := newEdns0Rule(tc.args...)
if err != nil {
t.Errorf("Error creating test rule: %s", err)
continue
}
rw.Rules = []Rule{r}
rec := dnsrecorder.New(&test.ResponseWriter{})
rw.ServeDNS(ctx, rec, m)
resp := rec.Msg
o := resp.IsEdns0()
if o == nil {
t.Errorf("Test %d: EDNS0 options not set", i)
continue
}
if !optsEqual(o.Option, tc.toOpts) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o)
}
}
}
func TestEdns0LocalMultiRule(t *testing.T) {
rules := []Rule{}
r, _ := newEdns0Rule("local", "replace", "0xffee", "abcdef")
rules = append(rules, r)
r, _ = newEdns0Rule("local", "set", "0xffee", "fedcba")
rules = append(rules, r)
rw := Rewrite{
Next: plugin.HandlerFunc(msgPrinter),
Rules: rules,
noRevert: true,
}
tests := []struct {
fromOpts []dns.EDNS0
toOpts []dns.EDNS0
}{
{
nil,
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("fedcba")}},
},
{
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("abcdef")}},
},
}
ctx := context.TODO()
for i, tc := range tests {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET
if tc.fromOpts != nil {
o := m.IsEdns0()
if o == nil {
m.SetEdns0(4096, true)
o = m.IsEdns0()
}
o.Option = append(o.Option, tc.fromOpts...)
}
rec := dnsrecorder.New(&test.ResponseWriter{})
rw.ServeDNS(ctx, rec, m)
resp := rec.Msg
o := resp.IsEdns0()
if o == nil {
t.Errorf("Test %d: EDNS0 options not set", i)
continue
}
if !optsEqual(o.Option, tc.toOpts) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o)
}
}
}
func optsEqual(a, b []dns.EDNS0) bool {
if len(a) != len(b) {
return false
}
for i := range a {
switch aa := a[i].(type) {
case *dns.EDNS0_LOCAL:
if bb, ok := b[i].(*dns.EDNS0_LOCAL); ok {
if aa.Code != bb.Code {
return false
}
if !bytes.Equal(aa.Data, bb.Data) {
return false
}
} else {
return false
}
case *dns.EDNS0_NSID:
if bb, ok := b[i].(*dns.EDNS0_NSID); ok {
if aa.Nsid != bb.Nsid {
return false
}
} else {
return false
}
case *dns.EDNS0_SUBNET:
if bb, ok := b[i].(*dns.EDNS0_SUBNET); ok {
if aa.Code != bb.Code {
return false
}
if aa.Family != bb.Family {
return false
}
if aa.SourceNetmask != bb.SourceNetmask {
return false
}
if aa.SourceScope != bb.SourceScope {
return false
}
if !bytes.Equal(aa.Address, bb.Address) {
return false
}
if aa.DraftOption != bb.DraftOption {
return false
}
} else {
return false
}
default:
return false
}
}
return true
}
func TestRewriteEDNS0LocalVariable(t *testing.T) {
rw := Rewrite{
Next: plugin.HandlerFunc(msgPrinter),
noRevert: true,
}
// test.ResponseWriter has the following values:
// The remote will always be 10.240.0.1 and port 40212.
// The local address is always 127.0.0.1 and port 53.
tests := []struct {
fromOpts []dns.EDNS0
args []string
toOpts []dns.EDNS0
}{
{
[]dns.EDNS0{},
[]string{"local", "set", "0xffee", "{qname}"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("example.com.")}},
},
{
[]dns.EDNS0{},
[]string{"local", "set", "0xffee", "{qtype}"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x01}}},
},
{
[]dns.EDNS0{},
[]string{"local", "set", "0xffee", "{client_ip}"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x0A, 0xF0, 0x00, 0x01}}},
},
{
[]dns.EDNS0{},
[]string{"local", "set", "0xffee", "{client_port}"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x9D, 0x14}}},
},
{
[]dns.EDNS0{},
[]string{"local", "set", "0xffee", "{protocol}"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("udp")}},
},
{
[]dns.EDNS0{},
[]string{"local", "set", "0xffee", "{server_ip}"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x7F, 0x00, 0x00, 0x01}}},
},
{
[]dns.EDNS0{},
[]string{"local", "set", "0xffee", "{server_port}"},
[]dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x35}}},
},
}
ctx := context.TODO()
for i, tc := range tests {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET
r, err := newEdns0Rule(tc.args...)
if err != nil {
t.Errorf("Error creating test rule: %s", err)
continue
}
rw.Rules = []Rule{r}
rec := dnsrecorder.New(&test.ResponseWriter{})
rw.ServeDNS(ctx, rec, m)
resp := rec.Msg
o := resp.IsEdns0()
if o == nil {
t.Errorf("Test %d: EDNS0 options not set", i)
continue
}
if !optsEqual(o.Option, tc.toOpts) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o)
}
}
}
func TestRewriteEDNS0Subnet(t *testing.T) {
rw := Rewrite{
Next: plugin.HandlerFunc(msgPrinter),
noRevert: true,
}
tests := []struct {
writer dns.ResponseWriter
fromOpts []dns.EDNS0
args []string
toOpts []dns.EDNS0
}{
{
&test.ResponseWriter{},
[]dns.EDNS0{},
[]string{"subnet", "set", "24", "56"},
[]dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8,
Family: 0x1,
SourceNetmask: 0x18,
SourceScope: 0x0,
Address: []byte{0x0A, 0xF0, 0x00, 0x00},
DraftOption: false}},
},
{
&test.ResponseWriter{},
[]dns.EDNS0{},
[]string{"subnet", "set", "32", "56"},
[]dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8,
Family: 0x1,
SourceNetmask: 0x20,
SourceScope: 0x0,
Address: []byte{0x0A, 0xF0, 0x00, 0x01},
DraftOption: false}},
},
{
&test.ResponseWriter{},
[]dns.EDNS0{},
[]string{"subnet", "set", "0", "56"},
[]dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8,
Family: 0x1,
SourceNetmask: 0x0,
SourceScope: 0x0,
Address: []byte{0x00, 0x00, 0x00, 0x00},
DraftOption: false}},
},
{
&test.ResponseWriter6{},
[]dns.EDNS0{},
[]string{"subnet", "set", "24", "56"},
[]dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8,
Family: 0x2,
SourceNetmask: 0x38,
SourceScope: 0x0,
Address: []byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
DraftOption: false}},
},
{
&test.ResponseWriter6{},
[]dns.EDNS0{},
[]string{"subnet", "set", "24", "128"},
[]dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8,
Family: 0x2,
SourceNetmask: 0x80,
SourceScope: 0x0,
Address: []byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x42, 0x00, 0xff, 0xfe, 0xca, 0x4c, 0x65},
DraftOption: false}},
},
{
&test.ResponseWriter6{},
[]dns.EDNS0{},
[]string{"subnet", "set", "24", "0"},
[]dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8,
Family: 0x2,
SourceNetmask: 0x0,
SourceScope: 0x0,
Address: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
DraftOption: false}},
},
}
ctx := context.TODO()
for i, tc := range tests {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET
r, err := newEdns0Rule(tc.args...)
if err != nil {
t.Errorf("Error creating test rule: %s", err)
continue
}
rw.Rules = []Rule{r}
rec := dnsrecorder.New(tc.writer)
rw.ServeDNS(ctx, rec, m)
resp := rec.Msg
o := resp.IsEdns0()
if o == nil {
t.Errorf("Test %d: EDNS0 options not set", i)
continue
}
if !optsEqual(o.Option, tc.toOpts) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o)
}
}
}

42
plugin/rewrite/setup.go Normal file
View File

@@ -0,0 +1,42 @@
package rewrite
import (
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/mholt/caddy"
)
func init() {
caddy.RegisterPlugin("rewrite", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
rewrites, err := rewriteParse(c)
if err != nil {
return plugin.Error("rewrite", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Rewrite{Next: next, Rules: rewrites}
})
return nil
}
func rewriteParse(c *caddy.Controller) ([]Rule, error) {
var rules []Rule
for c.Next() {
args := c.RemainingArgs()
rule, err := newRule(args...)
if err != nil {
return nil, err
}
rules = append(rules, rule)
}
return rules, nil
}

View File

@@ -0,0 +1,25 @@
package rewrite
import (
"testing"
"github.com/mholt/caddy"
)
func TestParse(t *testing.T) {
c := caddy.NewTestController("dns", `rewrite`)
_, err := rewriteParse(c)
if err == nil {
t.Errorf("Expected error but found nil for `rewrite`")
}
c = caddy.NewTestController("dns", `rewrite name`)
_, err = rewriteParse(c)
if err == nil {
t.Errorf("Expected error but found nil for `rewrite name`")
}
c = caddy.NewTestController("dns", `rewrite name a.com b.com`)
_, err = rewriteParse(c)
if err != nil {
t.Errorf("Expected success but found %s for `rewrite name a.com b.com`", err)
}
}

0
plugin/rewrite/testdata/testdir/empty vendored Normal file
View File

1
plugin/rewrite/testdata/testfile vendored Normal file
View File

@@ -0,0 +1 @@
empty

37
plugin/rewrite/type.go Normal file
View File

@@ -0,0 +1,37 @@
// Package rewrite is plugin for rewriting requests internally to something different.
package rewrite
import (
"fmt"
"strings"
"github.com/miekg/dns"
)
// typeRule is a type rewrite rule.
type typeRule struct {
fromType, toType uint16
}
func newTypeRule(fromS, toS string) (Rule, error) {
var from, to uint16
var ok bool
if from, ok = dns.StringToType[strings.ToUpper(fromS)]; !ok {
return nil, fmt.Errorf("invalid type %q", strings.ToUpper(fromS))
}
if to, ok = dns.StringToType[strings.ToUpper(toS)]; !ok {
return nil, fmt.Errorf("invalid type %q", strings.ToUpper(toS))
}
return &typeRule{fromType: from, toType: to}, nil
}
// Rewrite rewrites the the current request.
func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.fromType > 0 && rule.toType > 0 {
if r.Question[0].Qtype == rule.fromType {
r.Question[0].Qtype = rule.toType
return RewriteDone
}
}
return RewriteIgnored
}