plugin/metadata: metadata is just label=value (#1914)

This revert 17d807f0 and re-adds the metadata plugin as a plugin that
just sets a label to a value function.

Add package documentation on how to use the metadata package. Make it
clear that any caching is up to the Func implemented.

There are now - no in tree users. We could add the request metadata by
default under names that copy request.Request, i.e

request/ip - remote IP
request/port - remote port

Variables.go has been deleted.

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben
2018-07-01 20:01:17 +01:00
committed by GitHub
parent 0b326e2686
commit 99800a687c
16 changed files with 229 additions and 371 deletions

View File

@@ -30,7 +30,6 @@ var Directives = []string{
"rewrite",
"dnssec",
"autopath",
"reverse",
"template",
"hosts",
"route53",

View File

@@ -10,7 +10,6 @@ import (
_ "github.com/coredns/coredns/plugin/cache"
_ "github.com/coredns/coredns/plugin/chaos"
_ "github.com/coredns/coredns/plugin/debug"
_ "github.com/coredns/coredns/plugin/deprecated"
_ "github.com/coredns/coredns/plugin/dnssec"
_ "github.com/coredns/coredns/plugin/dnstap"
_ "github.com/coredns/coredns/plugin/erratic"

View File

@@ -4,7 +4,6 @@ import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/variables"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@@ -24,18 +23,13 @@ func (m *Metadata) Name() string { return "metadata" }
// ServeDNS implements the plugin.Handler interface.
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
ctx = context.WithValue(ctx, metadataKey{}, M{})
md, _ := FromContext(ctx)
ctx = context.WithValue(ctx, key{}, md{})
state := request.Request{W: w, Req: r}
if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
// Go through all Providers and collect metadata.
for _, provider := range m.Providers {
for _, varName := range provider.MetadataVarNames() {
if val, ok := provider.Metadata(ctx, state, varName); ok {
md.SetValue(varName, val)
}
}
for _, p := range m.Providers {
ctx = p.Metadata(ctx, state)
}
}
@@ -43,14 +37,3 @@ func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
return rcode, err
}
// MetadataVarNames implements the plugin.Provider interface.
func (m *Metadata) MetadataVarNames() []string { return variables.All }
// Metadata implements the plugin.Provider interface.
func (m *Metadata) Metadata(ctx context.Context, state request.Request, varName string) (interface{}, bool) {
if val, err := variables.GetValue(state, varName); err == nil {
return val, true
}
return nil, false
}

View File

@@ -10,26 +10,18 @@ import (
"github.com/miekg/dns"
)
// testProvider implements fake Providers. Plugins which inmplement Provider interface
type testProvider map[string]interface{}
type testProvider map[string]Func
func (m testProvider) MetadataVarNames() []string {
keys := []string{}
for k := range m {
keys = append(keys, k)
func (tp testProvider) Metadata(ctx context.Context, state request.Request) context.Context {
for k, v := range tp {
SetValueFunc(ctx, k, v)
}
return keys
return ctx
}
func (m testProvider) Metadata(ctx context.Context, state request.Request, key string) (val interface{}, ok bool) {
value, ok := m[key]
return value, ok
}
// testHandler implements plugin.Handler.
type testHandler struct{ ctx context.Context }
func (m *testHandler) Name() string { return "testHandler" }
func (m *testHandler) Name() string { return "test" }
func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m.ctx = ctx
@@ -38,8 +30,8 @@ func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
func TestMetadataServeDNS(t *testing.T) {
expectedMetadata := []testProvider{
testProvider{"testkey1": "testvalue1"},
testProvider{"testkey2": 2, "testkey3": "testvalue3"},
testProvider{"test/key1": func() string { return "testvalue1" }},
testProvider{"test/key2": func() string { return "two" }, "test/key3": func() string { return "testvalue3" }},
}
// Create fake Providers based on expectedMetadata
providers := []Provider{}
@@ -48,32 +40,22 @@ func TestMetadataServeDNS(t *testing.T) {
}
next := &testHandler{} // fake handler which stores the resulting context
metadata := Metadata{
m := Metadata{
Zones: []string{"."},
Providers: providers,
Next: next,
}
metadata.ServeDNS(context.TODO(), &test.ResponseWriter{}, new(dns.Msg))
// Verify that next plugin can find metadata in context from all Providers
ctx := context.TODO()
m.ServeDNS(ctx, &test.ResponseWriter{}, new(dns.Msg))
nctx := next.ctx
for _, expected := range expectedMetadata {
md, ok := FromContext(next.ctx)
if !ok {
t.Fatalf("Metadata is expected but not present inside the context")
}
for expKey, expVal := range expected {
metadataVal, valOk := md.Value(expKey)
if !valOk {
t.Fatalf("Value by key %v can't be retrieved", expKey)
for label, expVal := range expected {
val := ValueFunc(nctx, label)
if val() != expVal() {
t.Errorf("Expected value %s for %s, but got %s", expVal(), label, val())
}
if metadataVal != expVal {
t.Errorf("Expected value %v, but got %v", expVal, metadataVal)
}
}
wrongKey := "wrong_key"
metadataVal, ok := md.Value(wrongKey)
if ok {
t.Fatalf("Value by key %v is not expected to be recieved, but got: %v", wrongKey, metadataVal)
}
}
}

View File

@@ -1,3 +1,33 @@
// Package metadata provides an API that allows plugins to add metadata to the context.
// Each metadata is stored under a label that has the form <plugin>/<name>. Each metadata
// is returned as a Func. When Func is called the metadata is returned. If Func is expensive to
// execute it is its responsibility to provide some form of caching. During the handling of a
// query it is expected the metadata stays constant.
//
// Basic example:
//
// Implement the Provder interface for a plugin:
//
// func (p P) Metadata(ctx context.Context, state request.Request) context.Context {
// cached := ""
// f := func() string {
// if cached != "" {
// return cached
// }
// cached = expensiveFunc()
// return cached
// }
// metadata.SetValueFunc(ctx, "test/something", f)
// return ctx
// }
//
// Check the metadata from another plugin:
//
// // ...
// valueFunc := metadata.ValueFunc(ctx, "test/something")
// value := valueFunc()
// // use 'value'
//
package metadata
import (
@@ -8,40 +38,62 @@ import (
// Provider interface needs to be implemented by each plugin willing to provide
// metadata information for other plugins.
// Note: this method should work quickly, because it is called for every request
// from the metadata plugin.
type Provider interface {
// List of variables which are provided by current Provider. Must remain constant.
MetadataVarNames() []string
// Metadata is expected to return a value with metadata information by the key
// from 4th argument. Value can be later retrieved from context by any other plugin.
// If value is not available by some reason returned boolean value should be false.
Metadata(ctx context.Context, state request.Request, variable string) (interface{}, bool)
// Metadata adds metadata to the context and returns a (potentially) new context.
// Note: this method should work quickly, because it is called for every request
// from the metadata plugin.
Metadata(ctx context.Context, state request.Request) context.Context
}
// M is metadata information storage.
type M map[string]interface{}
// Func is the type of function in the metadata, when called they return the value of the label.
type Func func() string
// FromContext retrieves the metadata from the context.
func FromContext(ctx context.Context) (M, bool) {
if metadata := ctx.Value(metadataKey{}); metadata != nil {
if m, ok := metadata.(M); ok {
return m, true
// Labels returns all metadata keys stored in the context. These label names should be named
// as: plugin/NAME, where NAME is something descriptive.
func Labels(ctx context.Context) []string {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
return keys(m)
}
}
return M{}, false
return nil
}
// Value returns metadata value by key.
func (m M) Value(key string) (value interface{}, ok bool) {
value, ok = m[key]
return value, ok
// ValueFunc returns the value function of label. If none can be found nil is returned. Calling the
// function returns the value of the label.
func ValueFunc(ctx context.Context, label string) Func {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
return m[label]
}
}
return nil
}
// SetValue sets the metadata value under key.
func (m M) SetValue(key string, val interface{}) {
m[key] = val
// SetValueFunc set the metadata label to the value function. If no metadata can be found this is a noop and
// false is returned. Any existing value is overwritten.
func SetValueFunc(ctx context.Context, label string, f Func) bool {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
m[label] = f
return true
}
}
return false
}
// metadataKey defines the type of key that is used to save metadata into the context.
type metadataKey struct{}
// md is metadata information storage.
type md map[string]Func
// key defines the type of key that is used to save metadata into the context.
type key struct{}
func keys(m map[string]Func) []string {
s := make([]string, len(m))
i := 0
for k := range m {
s[i] = k
i++
}
return s
}

View File

@@ -1,48 +0,0 @@
package metadata
import (
"context"
"reflect"
"testing"
)
func TestMD(t *testing.T) {
tests := []struct {
addValues map[string]interface{}
expectedValues map[string]interface{}
}{
{
// Add initial metadata key/vals
map[string]interface{}{"key1": "val1", "key2": 2},
map[string]interface{}{"key1": "val1", "key2": 2},
},
{
// Add additional key/vals.
map[string]interface{}{"key3": 3, "key4": 4.5},
map[string]interface{}{"key1": "val1", "key2": 2, "key3": 3, "key4": 4.5},
},
}
// Using one same md and ctx for all test cases
ctx := context.TODO()
ctx = context.WithValue(ctx, metadataKey{}, M{})
m, _ := FromContext(ctx)
for i, tc := range tests {
for k, v := range tc.addValues {
m.SetValue(k, v)
}
if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(m)) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, m)
}
// Make sure that md is recieved from context successfullly
mFromContext, ok := FromContext(ctx)
if !ok {
t.Errorf("Test %d: md is not recieved from the context", i)
}
if !reflect.DeepEqual(m, mFromContext) {
t.Errorf("Test %d: md recieved from context differs from initial. Initial: %v, from context: %v", i, m, mFromContext)
}
}
}

View File

@@ -1,8 +1,6 @@
package metadata
import (
"fmt"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
@@ -28,16 +26,8 @@ func setup(c *caddy.Controller) error {
c.OnStartup(func() error {
plugins := dnsserver.GetConfig(c).Handlers()
// Collect all plugins which implement Provider interface
metadataVariables := map[string]bool{}
for _, p := range plugins {
if met, ok := p.(Provider); ok {
for _, varName := range met.MetadataVarNames() {
if _, ok := metadataVariables[varName]; ok {
return fmt.Errorf("Metadata variable '%v' has duplicates", varName)
}
metadataVariables[varName] = true
}
m.Providers = append(m.Providers, met)
}
}

View File

@@ -1,104 +0,0 @@
package variables
import (
"encoding/binary"
"fmt"
"net"
"strconv"
"github.com/coredns/coredns/request"
)
const (
queryName = "qname"
queryType = "qtype"
clientIP = "client_ip"
clientPort = "client_port"
protocol = "protocol"
serverIP = "server_ip"
serverPort = "server_port"
)
// All is a list of available variables provided by GetMetadataValue
var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverIP, serverPort}
// GetValue calculates and returns the data specified by the variable name.
// Supported varNames are listed in allProvidedVars.
func GetValue(state request.Request, varName string) ([]byte, error) {
switch varName {
case queryName:
return []byte(state.QName()), nil
case queryType:
return uint16ToWire(state.QType()), nil
case clientIP:
return ipToWire(state.Family(), state.IP())
case clientPort:
return portToWire(state.Port())
case protocol:
return []byte(state.Proto()), nil
case serverIP:
ip, _, err := net.SplitHostPort(state.W.LocalAddr().String())
if err != nil {
ip = state.W.RemoteAddr().String()
}
return ipToWire(state.Family(), ip)
case serverPort:
_, port, err := net.SplitHostPort(state.W.LocalAddr().String())
if err != nil {
port = "0"
}
return portToWire(port)
}
return nil, fmt.Errorf("unable to extract data for variable %s", varName)
}
// uint16ToWire writes unit16 to wire/binary format
func uint16ToWire(data uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(data))
return buf
}
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
func ipToWire(family int, ipAddr string) ([]byte, error) {
switch family {
case 1:
return net.ParseIP(ipAddr).To4(), nil
case 2:
return net.ParseIP(ipAddr).To16(), nil
}
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
}
// portToWire writes port to wire/binary format, 2 bytes
func portToWire(portStr string) ([]byte, error) {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
return uint16ToWire(uint16(port)), nil
}
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
func family(ip net.Addr) int {
var a net.IP
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
if i, ok := ip.(*net.TCPAddr); ok {
a = i.IP
}
if a.To4() != nil {
return 1
}
return 2
}

View File

@@ -1,83 +0,0 @@
package variables
import (
"bytes"
"testing"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
func TestGetValue(t *testing.T) {
// test.ResponseWriter has the following values:
// The remote will always be 10.240.0.1 and port 40212.
// The local address is always 127.0.0.1 and port 53.
tests := []struct {
varName string
expectedValue []byte
shouldErr bool
}{
{
queryName,
[]byte("example.com."),
false,
},
{
queryType,
[]byte{0x00, 0x01},
false,
},
{
clientIP,
[]byte{10, 240, 0, 1},
false,
},
{
clientPort,
[]byte{0x9D, 0x14},
false,
},
{
protocol,
[]byte("udp"),
false,
},
{
serverIP,
[]byte{127, 0, 0, 1},
false,
},
{
serverPort,
[]byte{0, 53},
false,
},
{
"wrong_var",
[]byte{},
true,
},
}
for i, tc := range tests {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET
state := request.Request{W: &test.ResponseWriter{}, Req: m}
value, err := GetValue(state, tc.varName)
if tc.shouldErr && err == nil {
t.Errorf("Test %d: Expected error, but didn't recieve", i)
}
if !tc.shouldErr && err != nil {
t.Errorf("Test %d: Expected no error, but got error: %v", i, err.Error())
}
if !bytes.Equal(tc.expectedValue, value) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValue, value)
}
}
}

View File

@@ -206,17 +206,13 @@ rewrites the first local option with code 0xffee, setting the data to "abcd". Eq
}
~~~
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables by default:
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables:
{qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}.
Any plugin that can provide it's own additional variables by implementing metadata.Provider interface. If you are going to use metadata variables then metadata plugin must be enabled.
Example:
~~~ corefile
. {
metadata
rewrite edns0 local set 0xffee {client_ip}
}
~~~
rewrite edns0 local set 0xffee {client_ip}
~~~
### EDNS0_NSID

View File

@@ -1,7 +1,6 @@
package rewrite
import (
"context"
"fmt"
"strings"
@@ -28,7 +27,7 @@ func newClassRule(nextAction string, args ...string) (Rule, error) {
}
// Rewrite rewrites the the current request.
func (rule *classRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.fromClass > 0 && rule.toClass > 0 {
if r.Question[0].Qclass == rule.fromClass {
r.Question[0].Qclass = rule.toClass

View File

@@ -2,15 +2,13 @@
package rewrite
import (
"context"
"encoding/binary"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/pkg/variables"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@@ -49,7 +47,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT {
}
// Rewrite will alter the request EDNS0 NSID option
func (rule *edns0NsidRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@@ -86,7 +84,7 @@ func (rule *edns0NsidRule) GetResponseRule() ResponseRule {
}
// Rewrite will alter the request EDNS0 local options
func (rule *edns0LocalRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@@ -149,9 +147,7 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) {
}
//Check for variable option
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
// Remove first and last runes
variable := args[3][1 : len(args[3])-1]
return newEdns0VariableRule(mode, action, args[2], variable)
return newEdns0VariableRule(mode, action, args[2], args[3])
}
return newEdns0LocalRule(mode, action, args[2], args[3])
case "nsid":
@@ -191,29 +187,102 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu
if err != nil {
return nil, err
}
//Validate
if !isValidVariable(variable) {
return nil, fmt.Errorf("unsupported variable name %q", variable)
}
return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil
}
// ruleData returns the data specified by the variable
func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
if md, ok := metadata.FromContext(ctx); ok {
if value, ok := md.Value(rule.variable); ok {
if v, ok := value.([]byte); ok {
return v, nil
}
}
} else { // No metadata available means metadata plugin is disabled. Try to get the value directly.
state := request.Request{W: w, Req: r} // TODO(miek): every rule needs to take a request.Request.
return variables.GetValue(state, rule.variable)
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) {
switch family {
case 1:
return net.ParseIP(ipAddr).To4(), nil
case 2:
return net.ParseIP(ipAddr).To16(), nil
}
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
}
// uint16ToWire writes unit16 to wire/binary format
func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(data))
return buf
}
// portToWire writes port to wire/binary format, 2 bytes
func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
return rule.uint16ToWire(uint16(port)), nil
}
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
func (rule *edns0VariableRule) family(ip net.Addr) int {
var a net.IP
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
if i, ok := ip.(*net.TCPAddr); ok {
a = i.IP
}
if a.To4() != nil {
return 1
}
return 2
}
// ruleData returns the data specified by the variable
func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
req := request.Request{W: w, Req: r}
switch rule.variable {
case queryName:
//Query name is written as ascii string
return []byte(req.QName()), nil
case queryType:
return rule.uint16ToWire(req.QType()), nil
case clientIP:
return rule.ipToWire(req.Family(), req.IP())
case clientPort:
return rule.portToWire(req.Port())
case protocol:
// Proto is written as ascii string
return []byte(req.Proto()), nil
case serverIP:
ip, _, err := net.SplitHostPort(w.LocalAddr().String())
if err != nil {
ip = w.RemoteAddr().String()
}
return rule.ipToWire(rule.family(w.RemoteAddr()), ip)
case serverPort:
_, port, err := net.SplitHostPort(w.LocalAddr().String())
if err != nil {
port = "0"
}
return rule.portToWire(port)
}
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
}
// Rewrite will alter the request EDNS0 local options with specified variables
func (rule *edns0VariableRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
data, err := rule.ruleData(ctx, w, r)
data, err := rule.ruleData(w, r)
if err != nil || data == nil {
return result
}
@@ -256,6 +325,21 @@ func (rule *edns0VariableRule) GetResponseRule() ResponseRule {
return ResponseRule{}
}
func isValidVariable(variable string) bool {
switch variable {
case
queryName,
queryType,
clientIP,
clientPort,
protocol,
serverIP,
serverPort:
return true
}
return false
}
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
type edns0SubnetRule struct {
mode string
@@ -316,7 +400,7 @@ func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg, ecs *
}
// Rewrite will alter the request EDNS0 subnet option
func (rule *edns0SubnetRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@@ -362,6 +446,17 @@ const (
Append = "append"
)
// Supported local EDNS0 variables
const (
queryName = "{qname}"
queryType = "{qtype}"
clientIP = "{client_ip}"
clientPort = "{client_port}"
protocol = "{protocol}"
serverIP = "{server_ip}"
serverPort = "{server_port}"
)
// Subnet maximum bit mask length
const (
maxV4BitMaskLen = 32

View File

@@ -1,7 +1,6 @@
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
@@ -58,7 +57,7 @@ const (
// Rewrite rewrites the current request based upon exact match of the name
// in the question section of the request
func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.From == r.Question[0].Name {
r.Question[0].Name = rule.To
return RewriteDone
@@ -67,7 +66,7 @@ func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.
}
// Rewrite rewrites the current request when the name begins with the matching string
func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if strings.HasPrefix(r.Question[0].Name, rule.Prefix) {
r.Question[0].Name = rule.Replacement + strings.TrimLeft(r.Question[0].Name, rule.Prefix)
return RewriteDone
@@ -76,7 +75,7 @@ func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r
}
// Rewrite rewrites the current request when the name ends with the matching string
func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if strings.HasSuffix(r.Question[0].Name, rule.Suffix) {
r.Question[0].Name = strings.TrimRight(r.Question[0].Name, rule.Suffix) + rule.Replacement
return RewriteDone
@@ -86,7 +85,7 @@ func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r
// Rewrite rewrites the current request based upon partial match of the
// name in the question section of the request
func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if strings.Contains(r.Question[0].Name, rule.Substring) {
r.Question[0].Name = strings.Replace(r.Question[0].Name, rule.Substring, rule.Replacement, -1)
return RewriteDone
@@ -96,7 +95,7 @@ func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter
// Rewrite rewrites the current request when the name in the question
// section of the request matches a regular expression
func (rule *regexNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *regexNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
regexGroups := rule.Pattern.FindStringSubmatch(r.Question[0].Name)
if len(regexGroups) == 0 {
return RewriteIgnored

View File

@@ -39,7 +39,7 @@ type Rewrite struct {
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
wr := NewResponseReverter(w, r)
for _, rule := range rw.Rules {
switch result := rule.Rewrite(ctx, w, r); result {
switch result := rule.Rewrite(w, r); result {
case RewriteDone:
respRule := rule.GetResponseRule()
if respRule.Active == true {
@@ -68,7 +68,7 @@ func (rw Rewrite) Name() string { return "rewrite" }
// Rule describes a rewrite rule.
type Rule interface {
// Rewrite rewrites the current request.
Rewrite(context.Context, dns.ResponseWriter, *dns.Msg) Result
Rewrite(dns.ResponseWriter, *dns.Msg) Result
// Mode returns the processing mode stop or continue.
Mode() string
// GetResponseRule returns the rule to rewrite response with, if any.

View File

@@ -71,7 +71,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "foo"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
@@ -79,7 +79,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
@@ -87,7 +87,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},

View File

@@ -2,7 +2,6 @@
package rewrite
import (
"context"
"fmt"
"strings"
@@ -29,7 +28,7 @@ func newTypeRule(nextAction string, args ...string) (Rule, error) {
}
// Rewrite rewrites the the current request.
func (rule *typeRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
if rule.fromType > 0 && rule.toType > 0 {
if r.Question[0].Qtype == rule.fromType {
r.Question[0].Qtype = rule.toType