mirror of
https://github.com/coredns/coredns.git
synced 2025-10-27 16:24:19 -04:00
plugin/metadata: metadata is just label=value (#1914)
This revert 17d807f0 and re-adds the metadata plugin as a plugin that
just sets a label to a value function.
Add package documentation on how to use the metadata package. Make it
clear that any caching is up to the Func implemented.
There are now - no in tree users. We could add the request metadata by
default under names that copy request.Request, i.e
request/ip - remote IP
request/port - remote port
Variables.go has been deleted.
Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
@@ -30,7 +30,6 @@ var Directives = []string{
|
||||
"rewrite",
|
||||
"dnssec",
|
||||
"autopath",
|
||||
"reverse",
|
||||
"template",
|
||||
"hosts",
|
||||
"route53",
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
_ "github.com/coredns/coredns/plugin/cache"
|
||||
_ "github.com/coredns/coredns/plugin/chaos"
|
||||
_ "github.com/coredns/coredns/plugin/debug"
|
||||
_ "github.com/coredns/coredns/plugin/deprecated"
|
||||
_ "github.com/coredns/coredns/plugin/dnssec"
|
||||
_ "github.com/coredns/coredns/plugin/dnstap"
|
||||
_ "github.com/coredns/coredns/plugin/erratic"
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/pkg/variables"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -24,18 +23,13 @@ func (m *Metadata) Name() string { return "metadata" }
|
||||
// ServeDNS implements the plugin.Handler interface.
|
||||
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
|
||||
ctx = context.WithValue(ctx, metadataKey{}, M{})
|
||||
md, _ := FromContext(ctx)
|
||||
ctx = context.WithValue(ctx, key{}, md{})
|
||||
|
||||
state := request.Request{W: w, Req: r}
|
||||
if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
|
||||
// Go through all Providers and collect metadata.
|
||||
for _, provider := range m.Providers {
|
||||
for _, varName := range provider.MetadataVarNames() {
|
||||
if val, ok := provider.Metadata(ctx, state, varName); ok {
|
||||
md.SetValue(varName, val)
|
||||
}
|
||||
}
|
||||
for _, p := range m.Providers {
|
||||
ctx = p.Metadata(ctx, state)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,14 +37,3 @@ func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||
|
||||
return rcode, err
|
||||
}
|
||||
|
||||
// MetadataVarNames implements the plugin.Provider interface.
|
||||
func (m *Metadata) MetadataVarNames() []string { return variables.All }
|
||||
|
||||
// Metadata implements the plugin.Provider interface.
|
||||
func (m *Metadata) Metadata(ctx context.Context, state request.Request, varName string) (interface{}, bool) {
|
||||
if val, err := variables.GetValue(state, varName); err == nil {
|
||||
return val, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -10,26 +10,18 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// testProvider implements fake Providers. Plugins which inmplement Provider interface
|
||||
type testProvider map[string]interface{}
|
||||
type testProvider map[string]Func
|
||||
|
||||
func (m testProvider) MetadataVarNames() []string {
|
||||
keys := []string{}
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
func (tp testProvider) Metadata(ctx context.Context, state request.Request) context.Context {
|
||||
for k, v := range tp {
|
||||
SetValueFunc(ctx, k, v)
|
||||
}
|
||||
return keys
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (m testProvider) Metadata(ctx context.Context, state request.Request, key string) (val interface{}, ok bool) {
|
||||
value, ok := m[key]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// testHandler implements plugin.Handler.
|
||||
type testHandler struct{ ctx context.Context }
|
||||
|
||||
func (m *testHandler) Name() string { return "testHandler" }
|
||||
func (m *testHandler) Name() string { return "test" }
|
||||
|
||||
func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
m.ctx = ctx
|
||||
@@ -38,8 +30,8 @@ func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
|
||||
|
||||
func TestMetadataServeDNS(t *testing.T) {
|
||||
expectedMetadata := []testProvider{
|
||||
testProvider{"testkey1": "testvalue1"},
|
||||
testProvider{"testkey2": 2, "testkey3": "testvalue3"},
|
||||
testProvider{"test/key1": func() string { return "testvalue1" }},
|
||||
testProvider{"test/key2": func() string { return "two" }, "test/key3": func() string { return "testvalue3" }},
|
||||
}
|
||||
// Create fake Providers based on expectedMetadata
|
||||
providers := []Provider{}
|
||||
@@ -48,32 +40,22 @@ func TestMetadataServeDNS(t *testing.T) {
|
||||
}
|
||||
|
||||
next := &testHandler{} // fake handler which stores the resulting context
|
||||
metadata := Metadata{
|
||||
m := Metadata{
|
||||
Zones: []string{"."},
|
||||
Providers: providers,
|
||||
Next: next,
|
||||
}
|
||||
metadata.ServeDNS(context.TODO(), &test.ResponseWriter{}, new(dns.Msg))
|
||||
|
||||
// Verify that next plugin can find metadata in context from all Providers
|
||||
ctx := context.TODO()
|
||||
m.ServeDNS(ctx, &test.ResponseWriter{}, new(dns.Msg))
|
||||
nctx := next.ctx
|
||||
|
||||
for _, expected := range expectedMetadata {
|
||||
md, ok := FromContext(next.ctx)
|
||||
if !ok {
|
||||
t.Fatalf("Metadata is expected but not present inside the context")
|
||||
}
|
||||
for expKey, expVal := range expected {
|
||||
metadataVal, valOk := md.Value(expKey)
|
||||
if !valOk {
|
||||
t.Fatalf("Value by key %v can't be retrieved", expKey)
|
||||
for label, expVal := range expected {
|
||||
val := ValueFunc(nctx, label)
|
||||
if val() != expVal() {
|
||||
t.Errorf("Expected value %s for %s, but got %s", expVal(), label, val())
|
||||
}
|
||||
if metadataVal != expVal {
|
||||
t.Errorf("Expected value %v, but got %v", expVal, metadataVal)
|
||||
}
|
||||
}
|
||||
wrongKey := "wrong_key"
|
||||
metadataVal, ok := md.Value(wrongKey)
|
||||
if ok {
|
||||
t.Fatalf("Value by key %v is not expected to be recieved, but got: %v", wrongKey, metadataVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,33 @@
|
||||
// Package metadata provides an API that allows plugins to add metadata to the context.
|
||||
// Each metadata is stored under a label that has the form <plugin>/<name>. Each metadata
|
||||
// is returned as a Func. When Func is called the metadata is returned. If Func is expensive to
|
||||
// execute it is its responsibility to provide some form of caching. During the handling of a
|
||||
// query it is expected the metadata stays constant.
|
||||
//
|
||||
// Basic example:
|
||||
//
|
||||
// Implement the Provder interface for a plugin:
|
||||
//
|
||||
// func (p P) Metadata(ctx context.Context, state request.Request) context.Context {
|
||||
// cached := ""
|
||||
// f := func() string {
|
||||
// if cached != "" {
|
||||
// return cached
|
||||
// }
|
||||
// cached = expensiveFunc()
|
||||
// return cached
|
||||
// }
|
||||
// metadata.SetValueFunc(ctx, "test/something", f)
|
||||
// return ctx
|
||||
// }
|
||||
//
|
||||
// Check the metadata from another plugin:
|
||||
//
|
||||
// // ...
|
||||
// valueFunc := metadata.ValueFunc(ctx, "test/something")
|
||||
// value := valueFunc()
|
||||
// // use 'value'
|
||||
//
|
||||
package metadata
|
||||
|
||||
import (
|
||||
@@ -8,40 +38,62 @@ import (
|
||||
|
||||
// Provider interface needs to be implemented by each plugin willing to provide
|
||||
// metadata information for other plugins.
|
||||
// Note: this method should work quickly, because it is called for every request
|
||||
// from the metadata plugin.
|
||||
type Provider interface {
|
||||
// List of variables which are provided by current Provider. Must remain constant.
|
||||
MetadataVarNames() []string
|
||||
// Metadata is expected to return a value with metadata information by the key
|
||||
// from 4th argument. Value can be later retrieved from context by any other plugin.
|
||||
// If value is not available by some reason returned boolean value should be false.
|
||||
Metadata(ctx context.Context, state request.Request, variable string) (interface{}, bool)
|
||||
// Metadata adds metadata to the context and returns a (potentially) new context.
|
||||
// Note: this method should work quickly, because it is called for every request
|
||||
// from the metadata plugin.
|
||||
Metadata(ctx context.Context, state request.Request) context.Context
|
||||
}
|
||||
|
||||
// M is metadata information storage.
|
||||
type M map[string]interface{}
|
||||
// Func is the type of function in the metadata, when called they return the value of the label.
|
||||
type Func func() string
|
||||
|
||||
// FromContext retrieves the metadata from the context.
|
||||
func FromContext(ctx context.Context) (M, bool) {
|
||||
if metadata := ctx.Value(metadataKey{}); metadata != nil {
|
||||
if m, ok := metadata.(M); ok {
|
||||
return m, true
|
||||
// Labels returns all metadata keys stored in the context. These label names should be named
|
||||
// as: plugin/NAME, where NAME is something descriptive.
|
||||
func Labels(ctx context.Context) []string {
|
||||
if metadata := ctx.Value(key{}); metadata != nil {
|
||||
if m, ok := metadata.(md); ok {
|
||||
return keys(m)
|
||||
}
|
||||
}
|
||||
return M{}, false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns metadata value by key.
|
||||
func (m M) Value(key string) (value interface{}, ok bool) {
|
||||
value, ok = m[key]
|
||||
return value, ok
|
||||
// ValueFunc returns the value function of label. If none can be found nil is returned. Calling the
|
||||
// function returns the value of the label.
|
||||
func ValueFunc(ctx context.Context, label string) Func {
|
||||
if metadata := ctx.Value(key{}); metadata != nil {
|
||||
if m, ok := metadata.(md); ok {
|
||||
return m[label]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetValue sets the metadata value under key.
|
||||
func (m M) SetValue(key string, val interface{}) {
|
||||
m[key] = val
|
||||
// SetValueFunc set the metadata label to the value function. If no metadata can be found this is a noop and
|
||||
// false is returned. Any existing value is overwritten.
|
||||
func SetValueFunc(ctx context.Context, label string, f Func) bool {
|
||||
if metadata := ctx.Value(key{}); metadata != nil {
|
||||
if m, ok := metadata.(md); ok {
|
||||
m[label] = f
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// metadataKey defines the type of key that is used to save metadata into the context.
|
||||
type metadataKey struct{}
|
||||
// md is metadata information storage.
|
||||
type md map[string]Func
|
||||
|
||||
// key defines the type of key that is used to save metadata into the context.
|
||||
type key struct{}
|
||||
|
||||
func keys(m map[string]Func) []string {
|
||||
s := make([]string, len(m))
|
||||
i := 0
|
||||
for k := range m {
|
||||
s[i] = k
|
||||
i++
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMD(t *testing.T) {
|
||||
tests := []struct {
|
||||
addValues map[string]interface{}
|
||||
expectedValues map[string]interface{}
|
||||
}{
|
||||
{
|
||||
// Add initial metadata key/vals
|
||||
map[string]interface{}{"key1": "val1", "key2": 2},
|
||||
map[string]interface{}{"key1": "val1", "key2": 2},
|
||||
},
|
||||
{
|
||||
// Add additional key/vals.
|
||||
map[string]interface{}{"key3": 3, "key4": 4.5},
|
||||
map[string]interface{}{"key1": "val1", "key2": 2, "key3": 3, "key4": 4.5},
|
||||
},
|
||||
}
|
||||
|
||||
// Using one same md and ctx for all test cases
|
||||
ctx := context.TODO()
|
||||
ctx = context.WithValue(ctx, metadataKey{}, M{})
|
||||
m, _ := FromContext(ctx)
|
||||
|
||||
for i, tc := range tests {
|
||||
for k, v := range tc.addValues {
|
||||
m.SetValue(k, v)
|
||||
}
|
||||
if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(m)) {
|
||||
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, m)
|
||||
}
|
||||
|
||||
// Make sure that md is recieved from context successfullly
|
||||
mFromContext, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
t.Errorf("Test %d: md is not recieved from the context", i)
|
||||
}
|
||||
if !reflect.DeepEqual(m, mFromContext) {
|
||||
t.Errorf("Test %d: md recieved from context differs from initial. Initial: %v, from context: %v", i, m, mFromContext)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
|
||||
@@ -28,16 +26,8 @@ func setup(c *caddy.Controller) error {
|
||||
|
||||
c.OnStartup(func() error {
|
||||
plugins := dnsserver.GetConfig(c).Handlers()
|
||||
// Collect all plugins which implement Provider interface
|
||||
metadataVariables := map[string]bool{}
|
||||
for _, p := range plugins {
|
||||
if met, ok := p.(Provider); ok {
|
||||
for _, varName := range met.MetadataVarNames() {
|
||||
if _, ok := metadataVariables[varName]; ok {
|
||||
return fmt.Errorf("Metadata variable '%v' has duplicates", varName)
|
||||
}
|
||||
metadataVariables[varName] = true
|
||||
}
|
||||
m.Providers = append(m.Providers, met)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
package variables
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/coredns/coredns/request"
|
||||
)
|
||||
|
||||
const (
|
||||
queryName = "qname"
|
||||
queryType = "qtype"
|
||||
clientIP = "client_ip"
|
||||
clientPort = "client_port"
|
||||
protocol = "protocol"
|
||||
serverIP = "server_ip"
|
||||
serverPort = "server_port"
|
||||
)
|
||||
|
||||
// All is a list of available variables provided by GetMetadataValue
|
||||
var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverIP, serverPort}
|
||||
|
||||
// GetValue calculates and returns the data specified by the variable name.
|
||||
// Supported varNames are listed in allProvidedVars.
|
||||
func GetValue(state request.Request, varName string) ([]byte, error) {
|
||||
switch varName {
|
||||
case queryName:
|
||||
return []byte(state.QName()), nil
|
||||
|
||||
case queryType:
|
||||
return uint16ToWire(state.QType()), nil
|
||||
|
||||
case clientIP:
|
||||
return ipToWire(state.Family(), state.IP())
|
||||
|
||||
case clientPort:
|
||||
return portToWire(state.Port())
|
||||
|
||||
case protocol:
|
||||
return []byte(state.Proto()), nil
|
||||
|
||||
case serverIP:
|
||||
ip, _, err := net.SplitHostPort(state.W.LocalAddr().String())
|
||||
if err != nil {
|
||||
ip = state.W.RemoteAddr().String()
|
||||
}
|
||||
return ipToWire(state.Family(), ip)
|
||||
|
||||
case serverPort:
|
||||
_, port, err := net.SplitHostPort(state.W.LocalAddr().String())
|
||||
if err != nil {
|
||||
port = "0"
|
||||
}
|
||||
return portToWire(port)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to extract data for variable %s", varName)
|
||||
}
|
||||
|
||||
// uint16ToWire writes unit16 to wire/binary format
|
||||
func uint16ToWire(data uint16) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, uint16(data))
|
||||
return buf
|
||||
}
|
||||
|
||||
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
|
||||
func ipToWire(family int, ipAddr string) ([]byte, error) {
|
||||
|
||||
switch family {
|
||||
case 1:
|
||||
return net.ParseIP(ipAddr).To4(), nil
|
||||
case 2:
|
||||
return net.ParseIP(ipAddr).To16(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
|
||||
}
|
||||
|
||||
// portToWire writes port to wire/binary format, 2 bytes
|
||||
func portToWire(portStr string) ([]byte, error) {
|
||||
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return uint16ToWire(uint16(port)), nil
|
||||
}
|
||||
|
||||
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
|
||||
func family(ip net.Addr) int {
|
||||
var a net.IP
|
||||
if i, ok := ip.(*net.UDPAddr); ok {
|
||||
a = i.IP
|
||||
}
|
||||
if i, ok := ip.(*net.TCPAddr); ok {
|
||||
a = i.IP
|
||||
}
|
||||
if a.To4() != nil {
|
||||
return 1
|
||||
}
|
||||
return 2
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package variables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestGetValue(t *testing.T) {
|
||||
// test.ResponseWriter has the following values:
|
||||
// The remote will always be 10.240.0.1 and port 40212.
|
||||
// The local address is always 127.0.0.1 and port 53.
|
||||
tests := []struct {
|
||||
varName string
|
||||
expectedValue []byte
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
queryName,
|
||||
[]byte("example.com."),
|
||||
false,
|
||||
},
|
||||
{
|
||||
queryType,
|
||||
[]byte{0x00, 0x01},
|
||||
false,
|
||||
},
|
||||
{
|
||||
clientIP,
|
||||
[]byte{10, 240, 0, 1},
|
||||
false,
|
||||
},
|
||||
{
|
||||
clientPort,
|
||||
[]byte{0x9D, 0x14},
|
||||
false,
|
||||
},
|
||||
{
|
||||
protocol,
|
||||
[]byte("udp"),
|
||||
false,
|
||||
},
|
||||
{
|
||||
serverIP,
|
||||
[]byte{127, 0, 0, 1},
|
||||
false,
|
||||
},
|
||||
{
|
||||
serverPort,
|
||||
[]byte{0, 53},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wrong_var",
|
||||
[]byte{},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
m.Question[0].Qclass = dns.ClassINET
|
||||
state := request.Request{W: &test.ResponseWriter{}, Req: m}
|
||||
|
||||
value, err := GetValue(state, tc.varName)
|
||||
|
||||
if tc.shouldErr && err == nil {
|
||||
t.Errorf("Test %d: Expected error, but didn't recieve", i)
|
||||
}
|
||||
if !tc.shouldErr && err != nil {
|
||||
t.Errorf("Test %d: Expected no error, but got error: %v", i, err.Error())
|
||||
}
|
||||
|
||||
if !bytes.Equal(tc.expectedValue, value) {
|
||||
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValue, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -206,17 +206,13 @@ rewrites the first local option with code 0xffee, setting the data to "abcd". Eq
|
||||
}
|
||||
~~~
|
||||
|
||||
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables by default:
|
||||
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables:
|
||||
{qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}.
|
||||
Any plugin that can provide it's own additional variables by implementing metadata.Provider interface. If you are going to use metadata variables then metadata plugin must be enabled.
|
||||
|
||||
Example:
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
metadata
|
||||
rewrite edns0 local set 0xffee {client_ip}
|
||||
}
|
||||
~~~
|
||||
rewrite edns0 local set 0xffee {client_ip}
|
||||
~~~
|
||||
|
||||
### EDNS0_NSID
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -28,7 +27,7 @@ func newClassRule(nextAction string, args ...string) (Rule, error) {
|
||||
}
|
||||
|
||||
// Rewrite rewrites the the current request.
|
||||
func (rule *classRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
if rule.fromClass > 0 && rule.toClass > 0 {
|
||||
if r.Question[0].Qclass == rule.fromClass {
|
||||
r.Question[0].Qclass = rule.toClass
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coredns/coredns/plugin/metadata"
|
||||
"github.com/coredns/coredns/plugin/pkg/variables"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -49,7 +47,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT {
|
||||
}
|
||||
|
||||
// Rewrite will alter the request EDNS0 NSID option
|
||||
func (rule *edns0NsidRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
result := RewriteIgnored
|
||||
o := setupEdns0Opt(r)
|
||||
found := false
|
||||
@@ -86,7 +84,7 @@ func (rule *edns0NsidRule) GetResponseRule() ResponseRule {
|
||||
}
|
||||
|
||||
// Rewrite will alter the request EDNS0 local options
|
||||
func (rule *edns0LocalRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
result := RewriteIgnored
|
||||
o := setupEdns0Opt(r)
|
||||
found := false
|
||||
@@ -149,9 +147,7 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) {
|
||||
}
|
||||
//Check for variable option
|
||||
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
|
||||
// Remove first and last runes
|
||||
variable := args[3][1 : len(args[3])-1]
|
||||
return newEdns0VariableRule(mode, action, args[2], variable)
|
||||
return newEdns0VariableRule(mode, action, args[2], args[3])
|
||||
}
|
||||
return newEdns0LocalRule(mode, action, args[2], args[3])
|
||||
case "nsid":
|
||||
@@ -191,29 +187,102 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//Validate
|
||||
if !isValidVariable(variable) {
|
||||
return nil, fmt.Errorf("unsupported variable name %q", variable)
|
||||
}
|
||||
return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil
|
||||
}
|
||||
|
||||
// ruleData returns the data specified by the variable
|
||||
func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
|
||||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
if value, ok := md.Value(rule.variable); ok {
|
||||
if v, ok := value.([]byte); ok {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
} else { // No metadata available means metadata plugin is disabled. Try to get the value directly.
|
||||
state := request.Request{W: w, Req: r} // TODO(miek): every rule needs to take a request.Request.
|
||||
return variables.GetValue(state, rule.variable)
|
||||
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
|
||||
func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) {
|
||||
|
||||
switch family {
|
||||
case 1:
|
||||
return net.ParseIP(ipAddr).To4(), nil
|
||||
case 2:
|
||||
return net.ParseIP(ipAddr).To16(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
|
||||
}
|
||||
|
||||
// uint16ToWire writes unit16 to wire/binary format
|
||||
func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, uint16(data))
|
||||
return buf
|
||||
}
|
||||
|
||||
// portToWire writes port to wire/binary format, 2 bytes
|
||||
func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) {
|
||||
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rule.uint16ToWire(uint16(port)), nil
|
||||
}
|
||||
|
||||
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
|
||||
func (rule *edns0VariableRule) family(ip net.Addr) int {
|
||||
var a net.IP
|
||||
if i, ok := ip.(*net.UDPAddr); ok {
|
||||
a = i.IP
|
||||
}
|
||||
if i, ok := ip.(*net.TCPAddr); ok {
|
||||
a = i.IP
|
||||
}
|
||||
if a.To4() != nil {
|
||||
return 1
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
// ruleData returns the data specified by the variable
|
||||
func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
|
||||
|
||||
req := request.Request{W: w, Req: r}
|
||||
switch rule.variable {
|
||||
case queryName:
|
||||
//Query name is written as ascii string
|
||||
return []byte(req.QName()), nil
|
||||
|
||||
case queryType:
|
||||
return rule.uint16ToWire(req.QType()), nil
|
||||
|
||||
case clientIP:
|
||||
return rule.ipToWire(req.Family(), req.IP())
|
||||
|
||||
case clientPort:
|
||||
return rule.portToWire(req.Port())
|
||||
|
||||
case protocol:
|
||||
// Proto is written as ascii string
|
||||
return []byte(req.Proto()), nil
|
||||
|
||||
case serverIP:
|
||||
ip, _, err := net.SplitHostPort(w.LocalAddr().String())
|
||||
if err != nil {
|
||||
ip = w.RemoteAddr().String()
|
||||
}
|
||||
return rule.ipToWire(rule.family(w.RemoteAddr()), ip)
|
||||
|
||||
case serverPort:
|
||||
_, port, err := net.SplitHostPort(w.LocalAddr().String())
|
||||
if err != nil {
|
||||
port = "0"
|
||||
}
|
||||
return rule.portToWire(port)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
|
||||
}
|
||||
|
||||
// Rewrite will alter the request EDNS0 local options with specified variables
|
||||
func (rule *edns0VariableRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
result := RewriteIgnored
|
||||
|
||||
data, err := rule.ruleData(ctx, w, r)
|
||||
data, err := rule.ruleData(w, r)
|
||||
if err != nil || data == nil {
|
||||
return result
|
||||
}
|
||||
@@ -256,6 +325,21 @@ func (rule *edns0VariableRule) GetResponseRule() ResponseRule {
|
||||
return ResponseRule{}
|
||||
}
|
||||
|
||||
func isValidVariable(variable string) bool {
|
||||
switch variable {
|
||||
case
|
||||
queryName,
|
||||
queryType,
|
||||
clientIP,
|
||||
clientPort,
|
||||
protocol,
|
||||
serverIP,
|
||||
serverPort:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
|
||||
type edns0SubnetRule struct {
|
||||
mode string
|
||||
@@ -316,7 +400,7 @@ func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg, ecs *
|
||||
}
|
||||
|
||||
// Rewrite will alter the request EDNS0 subnet option
|
||||
func (rule *edns0SubnetRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
result := RewriteIgnored
|
||||
o := setupEdns0Opt(r)
|
||||
found := false
|
||||
@@ -362,6 +446,17 @@ const (
|
||||
Append = "append"
|
||||
)
|
||||
|
||||
// Supported local EDNS0 variables
|
||||
const (
|
||||
queryName = "{qname}"
|
||||
queryType = "{qtype}"
|
||||
clientIP = "{client_ip}"
|
||||
clientPort = "{client_port}"
|
||||
protocol = "{protocol}"
|
||||
serverIP = "{server_ip}"
|
||||
serverPort = "{server_port}"
|
||||
)
|
||||
|
||||
// Subnet maximum bit mask length
|
||||
const (
|
||||
maxV4BitMaskLen = 32
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -58,7 +57,7 @@ const (
|
||||
|
||||
// Rewrite rewrites the current request based upon exact match of the name
|
||||
// in the question section of the request
|
||||
func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
if rule.From == r.Question[0].Name {
|
||||
r.Question[0].Name = rule.To
|
||||
return RewriteDone
|
||||
@@ -67,7 +66,7 @@ func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.
|
||||
}
|
||||
|
||||
// Rewrite rewrites the current request when the name begins with the matching string
|
||||
func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
if strings.HasPrefix(r.Question[0].Name, rule.Prefix) {
|
||||
r.Question[0].Name = rule.Replacement + strings.TrimLeft(r.Question[0].Name, rule.Prefix)
|
||||
return RewriteDone
|
||||
@@ -76,7 +75,7 @@ func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r
|
||||
}
|
||||
|
||||
// Rewrite rewrites the current request when the name ends with the matching string
|
||||
func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
if strings.HasSuffix(r.Question[0].Name, rule.Suffix) {
|
||||
r.Question[0].Name = strings.TrimRight(r.Question[0].Name, rule.Suffix) + rule.Replacement
|
||||
return RewriteDone
|
||||
@@ -86,7 +85,7 @@ func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r
|
||||
|
||||
// Rewrite rewrites the current request based upon partial match of the
|
||||
// name in the question section of the request
|
||||
func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
if strings.Contains(r.Question[0].Name, rule.Substring) {
|
||||
r.Question[0].Name = strings.Replace(r.Question[0].Name, rule.Substring, rule.Replacement, -1)
|
||||
return RewriteDone
|
||||
@@ -96,7 +95,7 @@ func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter
|
||||
|
||||
// Rewrite rewrites the current request when the name in the question
|
||||
// section of the request matches a regular expression
|
||||
func (rule *regexNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *regexNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
regexGroups := rule.Pattern.FindStringSubmatch(r.Question[0].Name)
|
||||
if len(regexGroups) == 0 {
|
||||
return RewriteIgnored
|
||||
|
||||
@@ -39,7 +39,7 @@ type Rewrite struct {
|
||||
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
wr := NewResponseReverter(w, r)
|
||||
for _, rule := range rw.Rules {
|
||||
switch result := rule.Rewrite(ctx, w, r); result {
|
||||
switch result := rule.Rewrite(w, r); result {
|
||||
case RewriteDone:
|
||||
respRule := rule.GetResponseRule()
|
||||
if respRule.Active == true {
|
||||
@@ -68,7 +68,7 @@ func (rw Rewrite) Name() string { return "rewrite" }
|
||||
// Rule describes a rewrite rule.
|
||||
type Rule interface {
|
||||
// Rewrite rewrites the current request.
|
||||
Rewrite(context.Context, dns.ResponseWriter, *dns.Msg) Result
|
||||
Rewrite(dns.ResponseWriter, *dns.Msg) Result
|
||||
// Mode returns the processing mode stop or continue.
|
||||
Mode() string
|
||||
// GetResponseRule returns the rule to rewrite response with, if any.
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestNewRule(t *testing.T) {
|
||||
{[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})},
|
||||
{[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})},
|
||||
{[]string{"edns0", "nsid", "foo"}, true, nil},
|
||||
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil},
|
||||
{[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
@@ -79,7 +79,7 @@ func TestNewRule(t *testing.T) {
|
||||
{[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil},
|
||||
{[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
@@ -87,7 +87,7 @@ func TestNewRule(t *testing.T) {
|
||||
{[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil},
|
||||
{[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
{[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -29,7 +28,7 @@ func newTypeRule(nextAction string, args ...string) (Rule, error) {
|
||||
}
|
||||
|
||||
// Rewrite rewrites the the current request.
|
||||
func (rule *typeRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||
if rule.fromType > 0 && rule.toType > 0 {
|
||||
if r.Question[0].Qtype == rule.fromType {
|
||||
r.Question[0].Qtype = rule.toType
|
||||
|
||||
Reference in New Issue
Block a user