mirror of
https://github.com/coredns/coredns.git
synced 2025-11-19 10:22:17 -05:00
Remove the word middleware (#1067)
* Rename middleware to plugin first pass; mostly used 'sed', few spots where I manually changed text. This still builds a coredns binary. * fmt error * Rename AddMiddleware to AddPlugin * Readd AddMiddleware to remain backwards compat
This commit is contained in:
129
plugin/pkg/cache/cache.go
vendored
Normal file
129
plugin/pkg/cache/cache.go
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
// Package cache implements a cache. The cache hold 256 shards, each shard
|
||||
// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it
|
||||
// just randomly evicts elements when it gets full.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Hash returns the FNV hash of what.
|
||||
func Hash(what []byte) uint32 {
|
||||
h := fnv.New32()
|
||||
h.Write(what)
|
||||
return h.Sum32()
|
||||
}
|
||||
|
||||
// Cache is cache.
|
||||
type Cache struct {
|
||||
shards [shardSize]*shard
|
||||
}
|
||||
|
||||
// shard is a cache with random eviction.
|
||||
type shard struct {
|
||||
items map[uint32]interface{}
|
||||
size int
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// New returns a new cache.
|
||||
func New(size int) *Cache {
|
||||
ssize := size / shardSize
|
||||
if ssize < 512 {
|
||||
ssize = 512
|
||||
}
|
||||
|
||||
c := &Cache{}
|
||||
|
||||
// Initialize all the shards
|
||||
for i := 0; i < shardSize; i++ {
|
||||
c.shards[i] = newShard(ssize)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Add adds a new element to the cache. If the element already exists it is overwritten.
|
||||
func (c *Cache) Add(key uint32, el interface{}) {
|
||||
shard := key & (shardSize - 1)
|
||||
c.shards[shard].Add(key, el)
|
||||
}
|
||||
|
||||
// Get looks up element index under key.
|
||||
func (c *Cache) Get(key uint32) (interface{}, bool) {
|
||||
shard := key & (shardSize - 1)
|
||||
return c.shards[shard].Get(key)
|
||||
}
|
||||
|
||||
// Remove removes the element indexed with key.
|
||||
func (c *Cache) Remove(key uint32) {
|
||||
shard := key & (shardSize - 1)
|
||||
c.shards[shard].Remove(key)
|
||||
}
|
||||
|
||||
// Len returns the number of elements in the cache.
|
||||
func (c *Cache) Len() int {
|
||||
l := 0
|
||||
for _, s := range c.shards {
|
||||
l += s.Len()
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// newShard returns a new shard with size.
|
||||
func newShard(size int) *shard { return &shard{items: make(map[uint32]interface{}), size: size} }
|
||||
|
||||
// Add adds element indexed by key into the cache. Any existing element is overwritten
|
||||
func (s *shard) Add(key uint32, el interface{}) {
|
||||
l := s.Len()
|
||||
if l+1 > s.size {
|
||||
s.Evict()
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
s.items[key] = el
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Remove removes the element indexed by key from the cache.
|
||||
func (s *shard) Remove(key uint32) {
|
||||
s.Lock()
|
||||
delete(s.items, key)
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Evict removes a random element from the cache.
|
||||
func (s *shard) Evict() {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
key := -1
|
||||
for k := range s.items {
|
||||
key = int(k)
|
||||
break
|
||||
}
|
||||
if key == -1 {
|
||||
// empty cache
|
||||
return
|
||||
}
|
||||
delete(s.items, uint32(key))
|
||||
}
|
||||
|
||||
// Get looks up the element indexed under key.
|
||||
func (s *shard) Get(key uint32) (interface{}, bool) {
|
||||
s.RLock()
|
||||
el, found := s.items[key]
|
||||
s.RUnlock()
|
||||
return el, found
|
||||
}
|
||||
|
||||
// Len returns the current length of the cache.
|
||||
func (s *shard) Len() int {
|
||||
s.RLock()
|
||||
l := len(s.items)
|
||||
s.RUnlock()
|
||||
return l
|
||||
}
|
||||
|
||||
const shardSize = 256
|
||||
31
plugin/pkg/cache/cache_test.go
vendored
Normal file
31
plugin/pkg/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
package cache
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCacheAddAndGet(t *testing.T) {
|
||||
c := New(4)
|
||||
c.Add(1, 1)
|
||||
|
||||
if _, found := c.Get(1); !found {
|
||||
t.Fatal("Failed to find inserted record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheLen(t *testing.T) {
|
||||
c := New(4)
|
||||
|
||||
c.Add(1, 1)
|
||||
if l := c.Len(); l != 1 {
|
||||
t.Fatalf("Cache size should %d, got %d", 1, l)
|
||||
}
|
||||
|
||||
c.Add(1, 1)
|
||||
if l := c.Len(); l != 1 {
|
||||
t.Fatalf("Cache size should %d, got %d", 1, l)
|
||||
}
|
||||
|
||||
c.Add(2, 2)
|
||||
if l := c.Len(); l != 2 {
|
||||
t.Fatalf("Cache size should %d, got %d", 2, l)
|
||||
}
|
||||
}
|
||||
60
plugin/pkg/cache/shard_test.go
vendored
Normal file
60
plugin/pkg/cache/shard_test.go
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
package cache
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShardAddAndGet(t *testing.T) {
|
||||
s := newShard(4)
|
||||
s.Add(1, 1)
|
||||
|
||||
if _, found := s.Get(1); !found {
|
||||
t.Fatal("Failed to find inserted record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShardLen(t *testing.T) {
|
||||
s := newShard(4)
|
||||
|
||||
s.Add(1, 1)
|
||||
if l := s.Len(); l != 1 {
|
||||
t.Fatalf("Shard size should %d, got %d", 1, l)
|
||||
}
|
||||
|
||||
s.Add(1, 1)
|
||||
if l := s.Len(); l != 1 {
|
||||
t.Fatalf("Shard size should %d, got %d", 1, l)
|
||||
}
|
||||
|
||||
s.Add(2, 2)
|
||||
if l := s.Len(); l != 2 {
|
||||
t.Fatalf("Shard size should %d, got %d", 2, l)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShardEvict(t *testing.T) {
|
||||
s := newShard(1)
|
||||
s.Add(1, 1)
|
||||
s.Add(2, 2)
|
||||
// 1 should be gone
|
||||
|
||||
if _, found := s.Get(1); found {
|
||||
t.Fatal("Found item that should have been evicted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShardLenEvict(t *testing.T) {
|
||||
s := newShard(4)
|
||||
s.Add(1, 1)
|
||||
s.Add(2, 1)
|
||||
s.Add(3, 1)
|
||||
s.Add(4, 1)
|
||||
|
||||
if l := s.Len(); l != 4 {
|
||||
t.Fatalf("Shard size should %d, got %d", 4, l)
|
||||
}
|
||||
|
||||
// This should evict one element
|
||||
s.Add(5, 1)
|
||||
if l := s.Len(); l != 4 {
|
||||
t.Fatalf("Shard size should %d, got %d", 4, l)
|
||||
}
|
||||
}
|
||||
58
plugin/pkg/dnsrecorder/recorder.go
Normal file
58
plugin/pkg/dnsrecorder/recorder.go
Normal file
@@ -0,0 +1,58 @@
|
||||
// Package dnsrecorder allows you to record a DNS response when it is send to the client.
|
||||
package dnsrecorder
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Recorder is a type of ResponseWriter that captures
|
||||
// the rcode code written to it and also the size of the message
|
||||
// written in the response. A rcode code does not have
|
||||
// to be written, however, in which case 0 must be assumed.
|
||||
// It is best to have the constructor initialize this type
|
||||
// with that default status code.
|
||||
type Recorder struct {
|
||||
dns.ResponseWriter
|
||||
Rcode int
|
||||
Len int
|
||||
Msg *dns.Msg
|
||||
Start time.Time
|
||||
}
|
||||
|
||||
// New makes and returns a new Recorder,
|
||||
// which captures the DNS rcode from the ResponseWriter
|
||||
// and also the length of the response message written through it.
|
||||
func New(w dns.ResponseWriter) *Recorder {
|
||||
return &Recorder{
|
||||
ResponseWriter: w,
|
||||
Rcode: 0,
|
||||
Msg: nil,
|
||||
Start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMsg records the status code and calls the
|
||||
// underlying ResponseWriter's WriteMsg method.
|
||||
func (r *Recorder) WriteMsg(res *dns.Msg) error {
|
||||
r.Rcode = res.Rcode
|
||||
// We may get called multiple times (axfr for instance).
|
||||
// Save the last message, but add the sizes.
|
||||
r.Len += res.Len()
|
||||
r.Msg = res
|
||||
return r.ResponseWriter.WriteMsg(res)
|
||||
}
|
||||
|
||||
// Write is a wrapper that records the length of the message that gets written.
|
||||
func (r *Recorder) Write(buf []byte) (int, error) {
|
||||
n, err := r.ResponseWriter.Write(buf)
|
||||
if err == nil {
|
||||
r.Len += n
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Hijack implements dns.Hijacker. It simply wraps the underlying
|
||||
// ResponseWriter's Hijack method if there is one, or returns an error.
|
||||
func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return }
|
||||
28
plugin/pkg/dnsrecorder/recorder_test.go
Normal file
28
plugin/pkg/dnsrecorder/recorder_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package dnsrecorder
|
||||
|
||||
/*
|
||||
func TestNewResponseRecorder(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
if !(recordRequest.ResponseWriter == w) {
|
||||
t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
|
||||
}
|
||||
if recordRequest.status != http.StatusOK {
|
||||
t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
responseTestString := "test"
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
buf := []byte(responseTestString)
|
||||
recordRequest.Write(buf)
|
||||
if recordRequest.size != len(buf) {
|
||||
t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
|
||||
}
|
||||
if w.Body.String() != responseTestString {
|
||||
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
|
||||
}
|
||||
}
|
||||
*/
|
||||
15
plugin/pkg/dnsutil/cname.go
Normal file
15
plugin/pkg/dnsutil/cname.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package dnsutil
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// DuplicateCNAME returns true if r already exists in records.
|
||||
func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
|
||||
for _, rec := range records {
|
||||
if v, ok := rec.(*dns.CNAME); ok {
|
||||
if v.Target == r.Target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
55
plugin/pkg/dnsutil/cname_test.go
Normal file
55
plugin/pkg/dnsutil/cname_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package dnsutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestDuplicateCNAME(t *testing.T) {
|
||||
tests := []struct {
|
||||
cname string
|
||||
records []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
|
||||
[]string{
|
||||
"US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534",
|
||||
"1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
|
||||
[]string{
|
||||
"US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
|
||||
[]string{},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for i, test := range tests {
|
||||
cnameRR, err := dns.NewRR(test.cname)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d, cname ('%s') error (%s)!", i, test.cname, err)
|
||||
}
|
||||
cname := cnameRR.(*dns.CNAME)
|
||||
records := []dns.RR{}
|
||||
for j, r := range test.records {
|
||||
rr, err := dns.NewRR(r)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d, record %d ('%s') error (%s)!", i, j, r, err)
|
||||
}
|
||||
records = append(records, rr)
|
||||
}
|
||||
got := DuplicateCNAME(cname, records)
|
||||
if got != test.expected {
|
||||
t.Errorf("Test %d, expected '%v', got '%v' for CNAME ('%s') and RECORDS (%v)", i, test.expected, got, test.cname, test.records)
|
||||
}
|
||||
}
|
||||
}
|
||||
12
plugin/pkg/dnsutil/dedup.go
Normal file
12
plugin/pkg/dnsutil/dedup.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package dnsutil
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// Dedup de-duplicates a message.
|
||||
func Dedup(m *dns.Msg) *dns.Msg {
|
||||
// TODO(miek): expensive!
|
||||
m.Answer = dns.Dedup(m.Answer, nil)
|
||||
m.Ns = dns.Dedup(m.Ns, nil)
|
||||
m.Extra = dns.Dedup(m.Extra, nil)
|
||||
return m
|
||||
}
|
||||
2
plugin/pkg/dnsutil/doc.go
Normal file
2
plugin/pkg/dnsutil/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package dnsutil contains DNS related helper functions.
|
||||
package dnsutil
|
||||
82
plugin/pkg/dnsutil/host.go
Normal file
82
plugin/pkg/dnsutil/host.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package dnsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ParseHostPortOrFile parses the strings in s, each string can either be a address,
|
||||
// address:port or a filename. The address part is checked and the filename case a
|
||||
// resolv.conf like file is parsed and the nameserver found are returned.
|
||||
func ParseHostPortOrFile(s ...string) ([]string, error) {
|
||||
var servers []string
|
||||
for _, host := range s {
|
||||
addr, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
// Parse didn't work, it is not a addr:port combo
|
||||
if net.ParseIP(host) == nil {
|
||||
// Not an IP address.
|
||||
ss, err := tryFile(host)
|
||||
if err == nil {
|
||||
servers = append(servers, ss...)
|
||||
continue
|
||||
}
|
||||
return servers, fmt.Errorf("not an IP address or file: %q", host)
|
||||
}
|
||||
ss := net.JoinHostPort(host, "53")
|
||||
servers = append(servers, ss)
|
||||
continue
|
||||
}
|
||||
|
||||
if net.ParseIP(addr) == nil {
|
||||
// No an IP address.
|
||||
ss, err := tryFile(host)
|
||||
if err == nil {
|
||||
servers = append(servers, ss...)
|
||||
continue
|
||||
}
|
||||
return servers, fmt.Errorf("not an IP address or file: %q", host)
|
||||
}
|
||||
servers = append(servers, host)
|
||||
}
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// Try to open this is a file first.
|
||||
func tryFile(s string) ([]string, error) {
|
||||
c, err := dns.ClientConfigFromFile(s)
|
||||
if err == os.ErrNotExist {
|
||||
return nil, fmt.Errorf("failed to open file %q: %q", s, err)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
servers := []string{}
|
||||
for _, s := range c.Servers {
|
||||
servers = append(servers, net.JoinHostPort(s, c.Port))
|
||||
}
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// ParseHostPort will check if the host part is a valid IP address, if the
|
||||
// IP address is valid, but no port is found, defaultPort is added.
|
||||
func ParseHostPort(s, defaultPort string) (string, error) {
|
||||
addr, port, err := net.SplitHostPort(s)
|
||||
if port == "" {
|
||||
port = defaultPort
|
||||
}
|
||||
if err != nil {
|
||||
if net.ParseIP(s) == nil {
|
||||
return "", fmt.Errorf("must specify an IP address: `%s'", s)
|
||||
}
|
||||
return net.JoinHostPort(s, port), nil
|
||||
}
|
||||
|
||||
if net.ParseIP(addr) == nil {
|
||||
return "", fmt.Errorf("must specify an IP address: `%s'", addr)
|
||||
}
|
||||
return net.JoinHostPort(addr, port), nil
|
||||
}
|
||||
85
plugin/pkg/dnsutil/host_test.go
Normal file
85
plugin/pkg/dnsutil/host_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package dnsutil
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseHostPortOrFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
expected string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
"8.8.8.8",
|
||||
"8.8.8.8:53",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"8.8.8.8:153",
|
||||
"8.8.8.8:153",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"/etc/resolv.conf:53",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"resolv.conf",
|
||||
"127.0.0.1:53",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
err := ioutil.WriteFile("resolv.conf", []byte("nameserver 127.0.0.1\n"), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test resolv.conf")
|
||||
}
|
||||
defer os.Remove("resolv.conf")
|
||||
|
||||
for i, tc := range tests {
|
||||
got, err := ParseHostPortOrFile(tc.in)
|
||||
if err == nil && tc.shouldErr {
|
||||
t.Errorf("Test %d, expected error, got nil", i)
|
||||
continue
|
||||
}
|
||||
if err != nil && tc.shouldErr {
|
||||
continue
|
||||
}
|
||||
if got[0] != tc.expected {
|
||||
t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHostPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
expected string
|
||||
shouldErr bool
|
||||
}{
|
||||
{"8.8.8.8:53", "8.8.8.8:53", false},
|
||||
{"a.a.a.a:153", "", true},
|
||||
{"8.8.8.8", "8.8.8.8:53", false},
|
||||
{"8.8.8.8:", "8.8.8.8:53", false},
|
||||
{"8.8.8.8::53", "", true},
|
||||
{"resolv.conf", "", true},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
got, err := ParseHostPort(tc.in, "53")
|
||||
if err == nil && tc.shouldErr {
|
||||
t.Errorf("Test %d, expected error, got nil", i)
|
||||
continue
|
||||
}
|
||||
if err != nil && !tc.shouldErr {
|
||||
t.Errorf("Test %d, expected no error, got %q", i, err)
|
||||
}
|
||||
if got != tc.expected {
|
||||
t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
19
plugin/pkg/dnsutil/join.go
Normal file
19
plugin/pkg/dnsutil/join.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package dnsutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Join joins labels to form a fully qualified domain name. If the last label is
|
||||
// the root label it is ignored. Not other syntax checks are performed.
|
||||
func Join(labels []string) string {
|
||||
ll := len(labels)
|
||||
if labels[ll-1] == "." {
|
||||
s := strings.Join(labels[:ll-1], ".")
|
||||
return dns.Fqdn(s)
|
||||
}
|
||||
s := strings.Join(labels, ".")
|
||||
return dns.Fqdn(s)
|
||||
}
|
||||
20
plugin/pkg/dnsutil/join_test.go
Normal file
20
plugin/pkg/dnsutil/join_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package dnsutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
tests := []struct {
|
||||
in []string
|
||||
out string
|
||||
}{
|
||||
{[]string{"bla", "bliep", "example", "org"}, "bla.bliep.example.org."},
|
||||
{[]string{"example", "."}, "example."},
|
||||
{[]string{"."}, "."},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
if x := Join(tc.in); x != tc.out {
|
||||
t.Errorf("Test %d, expected %s, got %s", i, tc.out, x)
|
||||
}
|
||||
}
|
||||
}
|
||||
68
plugin/pkg/dnsutil/reverse.go
Normal file
68
plugin/pkg/dnsutil/reverse.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package dnsutil
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExtractAddressFromReverse turns a standard PTR reverse record name
|
||||
// into an IP address. This works for ipv4 or ipv6.
|
||||
//
|
||||
// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion
|
||||
// failes the empty string is returned.
|
||||
func ExtractAddressFromReverse(reverseName string) string {
|
||||
search := ""
|
||||
|
||||
f := reverse
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(reverseName, v4arpaSuffix):
|
||||
search = strings.TrimSuffix(reverseName, v4arpaSuffix)
|
||||
case strings.HasSuffix(reverseName, v6arpaSuffix):
|
||||
search = strings.TrimSuffix(reverseName, v6arpaSuffix)
|
||||
f = reverse6
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
// Reverse the segments and then combine them.
|
||||
return f(strings.Split(search, "."))
|
||||
}
|
||||
|
||||
func reverse(slice []string) string {
|
||||
for i := 0; i < len(slice)/2; i++ {
|
||||
j := len(slice) - i - 1
|
||||
slice[i], slice[j] = slice[j], slice[i]
|
||||
}
|
||||
ip := net.ParseIP(strings.Join(slice, ".")).To4()
|
||||
if ip == nil {
|
||||
return ""
|
||||
}
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
// reverse6 reverse the segments and combine them according to RFC3596:
|
||||
// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2
|
||||
// is reversed to 2001:db8::567:89ab
|
||||
func reverse6(slice []string) string {
|
||||
for i := 0; i < len(slice)/2; i++ {
|
||||
j := len(slice) - i - 1
|
||||
slice[i], slice[j] = slice[j], slice[i]
|
||||
}
|
||||
slice6 := []string{}
|
||||
for i := 0; i < len(slice)/4; i++ {
|
||||
slice6 = append(slice6, strings.Join(slice[i*4:i*4+4], ""))
|
||||
}
|
||||
ip := net.ParseIP(strings.Join(slice6, ":")).To16()
|
||||
if ip == nil {
|
||||
return ""
|
||||
}
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
const (
|
||||
// v4arpaSuffix is the reverse tree suffix for v4 IP addresses.
|
||||
v4arpaSuffix = ".in-addr.arpa."
|
||||
// v6arpaSuffix is the reverse tree suffix for v6 IP addresses.
|
||||
v6arpaSuffix = ".ip6.arpa."
|
||||
)
|
||||
51
plugin/pkg/dnsutil/reverse_test.go
Normal file
51
plugin/pkg/dnsutil/reverse_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package dnsutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractAddressFromReverse(t *testing.T) {
|
||||
tests := []struct {
|
||||
reverseName string
|
||||
expectedAddress string
|
||||
}{
|
||||
{
|
||||
"54.119.58.176.in-addr.arpa.",
|
||||
"176.58.119.54",
|
||||
},
|
||||
{
|
||||
".58.176.in-addr.arpa.",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.in-addr.arpa.",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
"2001:db8::567:89ab",
|
||||
},
|
||||
{
|
||||
"d.0.1.0.0.2.ip6.arpa.",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"54.119.58.176.ip6.arpa.",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"NONAME",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"",
|
||||
"",
|
||||
},
|
||||
}
|
||||
for i, test := range tests {
|
||||
got := ExtractAddressFromReverse(test.reverseName)
|
||||
if got != test.expectedAddress {
|
||||
t.Errorf("Test %d, expected '%s', got '%s'", i, test.expectedAddress, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
20
plugin/pkg/dnsutil/zone.go
Normal file
20
plugin/pkg/dnsutil/zone.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package dnsutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// TrimZone removes the zone component from q. It returns the trimmed
|
||||
// name or an error is zone is longer then qname. The trimmed name will be returned
|
||||
// without a trailing dot.
|
||||
func TrimZone(q string, z string) (string, error) {
|
||||
zl := dns.CountLabel(z)
|
||||
i, ok := dns.PrevLabel(q, zl)
|
||||
if ok || i-1 < 0 {
|
||||
return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z)
|
||||
}
|
||||
// This includes the '.', remove on return
|
||||
return q[:i-1], nil
|
||||
}
|
||||
39
plugin/pkg/dnsutil/zone_test.go
Normal file
39
plugin/pkg/dnsutil/zone_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package dnsutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestTrimZone(t *testing.T) {
|
||||
tests := []struct {
|
||||
qname string
|
||||
zone string
|
||||
expected string
|
||||
err error
|
||||
}{
|
||||
{"a.example.org", "example.org", "a", nil},
|
||||
{"a.b.example.org", "example.org", "a.b", nil},
|
||||
{"b.", ".", "b", nil},
|
||||
{"example.org", "example.org", "", errors.New("should err")},
|
||||
{"org", "example.org", "", errors.New("should err")},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
got, err := TrimZone(dns.Fqdn(tc.qname), dns.Fqdn(tc.zone))
|
||||
if tc.err != nil && err == nil {
|
||||
t.Errorf("Test %d, expected error got nil", i)
|
||||
continue
|
||||
}
|
||||
if tc.err == nil && err != nil {
|
||||
t.Errorf("Test %d, expected no error got %v", i, err)
|
||||
continue
|
||||
}
|
||||
if got != tc.expected {
|
||||
t.Errorf("Test %d, expected %s, got %s", i, tc.expected, got)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
46
plugin/pkg/edns/edns.go
Normal file
46
plugin/pkg/edns/edns.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Package edns provides function useful for adding/inspecting OPT records to/in messages.
|
||||
package edns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Version checks the EDNS version in the request. If error
|
||||
// is nil everything is OK and we can invoke the plugin. If non-nil, the
|
||||
// returned Msg is valid to be returned to the client (and should). For some
|
||||
// reason this response should not contain a question RR in the question section.
|
||||
func Version(req *dns.Msg) (*dns.Msg, error) {
|
||||
opt := req.IsEdns0()
|
||||
if opt == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if opt.Version() == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
// zero out question section, wtf.
|
||||
m.Question = nil
|
||||
|
||||
o := new(dns.OPT)
|
||||
o.Hdr.Name = "."
|
||||
o.Hdr.Rrtype = dns.TypeOPT
|
||||
o.SetVersion(0)
|
||||
o.SetExtendedRcode(dns.RcodeBadVers)
|
||||
m.Extra = []dns.RR{o}
|
||||
|
||||
return m, errors.New("EDNS0 BADVERS")
|
||||
}
|
||||
|
||||
// Size returns a normalized size based on proto.
|
||||
func Size(proto string, size int) int {
|
||||
if proto == "tcp" {
|
||||
return dns.MaxMsgSize
|
||||
}
|
||||
if size < dns.MinMsgSize {
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
37
plugin/pkg/edns/edns_test.go
Normal file
37
plugin/pkg/edns/edns_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package edns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
m := ednsMsg()
|
||||
m.Extra[0].(*dns.OPT).SetVersion(2)
|
||||
|
||||
_, err := Version(m)
|
||||
if err == nil {
|
||||
t.Errorf("expected wrong version, but got OK")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionNoEdns(t *testing.T) {
|
||||
m := ednsMsg()
|
||||
m.Extra = nil
|
||||
|
||||
_, err := Version(m)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but got one: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ednsMsg() *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := new(dns.OPT)
|
||||
o.Hdr.Name = "."
|
||||
o.Hdr.Rrtype = dns.TypeOPT
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
||||
243
plugin/pkg/healthcheck/healthcheck.go
Normal file
243
plugin/pkg/healthcheck/healthcheck.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UpstreamHostDownFunc can be used to customize how Down behaves.
|
||||
type UpstreamHostDownFunc func(*UpstreamHost) bool
|
||||
|
||||
// UpstreamHost represents a single proxy upstream
|
||||
type UpstreamHost struct {
|
||||
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
|
||||
Name string // IP address (and port) of this upstream host
|
||||
Network string // Network (tcp, unix, etc) of the host, default "" is "tcp"
|
||||
Fails int32
|
||||
FailTimeout time.Duration
|
||||
OkUntil time.Time
|
||||
CheckDown UpstreamHostDownFunc
|
||||
CheckURL string
|
||||
WithoutPathPrefix string
|
||||
Checking bool
|
||||
CheckMu sync.Mutex
|
||||
}
|
||||
|
||||
// Down checks whether the upstream host is down or not.
|
||||
// Down will try to use uh.CheckDown first, and will fall
|
||||
// back to some default criteria if necessary.
|
||||
func (uh *UpstreamHost) Down() bool {
|
||||
if uh.CheckDown == nil {
|
||||
// Default settings
|
||||
fails := atomic.LoadInt32(&uh.Fails)
|
||||
after := false
|
||||
|
||||
uh.CheckMu.Lock()
|
||||
until := uh.OkUntil
|
||||
uh.CheckMu.Unlock()
|
||||
|
||||
if !until.IsZero() && time.Now().After(until) {
|
||||
after = true
|
||||
}
|
||||
|
||||
return after || fails > 0
|
||||
}
|
||||
return uh.CheckDown(uh)
|
||||
}
|
||||
|
||||
// HostPool is a collection of UpstreamHosts.
|
||||
type HostPool []*UpstreamHost
|
||||
|
||||
// HealthCheck is used for performing healthcheck
|
||||
// on a collection of upstream hosts and select
|
||||
// one based on the policy.
|
||||
type HealthCheck struct {
|
||||
wg sync.WaitGroup // Used to wait for running goroutines to stop.
|
||||
stop chan struct{} // Signals running goroutines to stop.
|
||||
Hosts HostPool
|
||||
Policy Policy
|
||||
Spray Policy
|
||||
FailTimeout time.Duration
|
||||
MaxFails int32
|
||||
Future time.Duration
|
||||
Path string
|
||||
Port string
|
||||
Interval time.Duration
|
||||
}
|
||||
|
||||
// Start starts the healthcheck
|
||||
func (u *HealthCheck) Start() {
|
||||
u.stop = make(chan struct{})
|
||||
if u.Path != "" {
|
||||
u.wg.Add(1)
|
||||
go func() {
|
||||
defer u.wg.Done()
|
||||
u.healthCheckWorker(u.stop)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop sends a signal to all goroutines started by this staticUpstream to exit
|
||||
// and waits for them to finish before returning.
|
||||
func (u *HealthCheck) Stop() error {
|
||||
close(u.stop)
|
||||
u.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// This was moved into a thread so that each host could throw a health
|
||||
// check at the same time. The reason for this is that if we are checking
|
||||
// 3 hosts, and the first one is gone, and we spend minutes timing out to
|
||||
// fail it, we would not have been doing any other health checks in that
|
||||
// time. So we now have a per-host lock and a threaded health check.
|
||||
//
|
||||
// We use the Checking bool to avoid concurrent checks against the same
|
||||
// host; if one is taking a long time, the next one will find a check in
|
||||
// progress and simply return before trying.
|
||||
//
|
||||
// We are carefully avoiding having the mutex locked while we check,
|
||||
// otherwise checks will back up, potentially a lot of them if a host is
|
||||
// absent for a long time. This arrangement makes checks quickly see if
|
||||
// they are the only one running and abort otherwise.
|
||||
func healthCheckURL(nextTs time.Time, host *UpstreamHost) {
|
||||
|
||||
// lock for our bool check. We don't just defer the unlock because
|
||||
// we don't want the lock held while http.Get runs
|
||||
host.CheckMu.Lock()
|
||||
|
||||
// are we mid check? Don't run another one
|
||||
if host.Checking {
|
||||
host.CheckMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
host.Checking = true
|
||||
host.CheckMu.Unlock()
|
||||
|
||||
//log.Printf("[DEBUG] Healthchecking %s, nextTs is %s\n", url, nextTs.Local())
|
||||
|
||||
// fetch that url. This has been moved into a go func because
|
||||
// when the remote host is not merely not serving, but actually
|
||||
// absent, then tcp syn timeouts can be very long, and so one
|
||||
// fetch could last several check intervals
|
||||
if r, err := http.Get(host.CheckURL); err == nil {
|
||||
io.Copy(ioutil.Discard, r.Body)
|
||||
r.Body.Close()
|
||||
|
||||
if r.StatusCode < 200 || r.StatusCode >= 400 {
|
||||
log.Printf("[WARNING] Host %s health check returned HTTP code %d\n",
|
||||
host.Name, r.StatusCode)
|
||||
nextTs = time.Unix(0, 0)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[WARNING] Host %s health check probe failed: %v\n", host.Name, err)
|
||||
nextTs = time.Unix(0, 0)
|
||||
}
|
||||
|
||||
host.CheckMu.Lock()
|
||||
host.Checking = false
|
||||
host.OkUntil = nextTs
|
||||
host.CheckMu.Unlock()
|
||||
}
|
||||
|
||||
func (u *HealthCheck) healthCheck() {
|
||||
for _, host := range u.Hosts {
|
||||
|
||||
if host.CheckURL == "" {
|
||||
var hostName, checkPort string
|
||||
|
||||
// The DNS server might be an HTTP server. If so, extract its name.
|
||||
ret, err := url.Parse(host.Name)
|
||||
if err == nil && len(ret.Host) > 0 {
|
||||
hostName = ret.Host
|
||||
} else {
|
||||
hostName = host.Name
|
||||
}
|
||||
|
||||
// Extract the port number from the parsed server name.
|
||||
checkHostName, checkPort, err := net.SplitHostPort(hostName)
|
||||
if err != nil {
|
||||
checkHostName = hostName
|
||||
}
|
||||
|
||||
if u.Port != "" {
|
||||
checkPort = u.Port
|
||||
}
|
||||
|
||||
host.CheckURL = "http://" + net.JoinHostPort(checkHostName, checkPort) + u.Path
|
||||
}
|
||||
|
||||
// calculate this before the get
|
||||
nextTs := time.Now().Add(u.Future)
|
||||
|
||||
// locks/bools should prevent requests backing up
|
||||
go healthCheckURL(nextTs, host)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *HealthCheck) healthCheckWorker(stop chan struct{}) {
|
||||
ticker := time.NewTicker(u.Interval)
|
||||
u.healthCheck()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
u.healthCheck()
|
||||
case <-stop:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Select selects an upstream host based on the policy
|
||||
// and the healthcheck result.
|
||||
func (u *HealthCheck) Select() *UpstreamHost {
|
||||
pool := u.Hosts
|
||||
if len(pool) == 1 {
|
||||
if pool[0].Down() && u.Spray == nil {
|
||||
return nil
|
||||
}
|
||||
return pool[0]
|
||||
}
|
||||
allDown := true
|
||||
for _, host := range pool {
|
||||
if !host.Down() {
|
||||
allDown = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allDown {
|
||||
if u.Spray == nil {
|
||||
return nil
|
||||
}
|
||||
return u.Spray.Select(pool)
|
||||
}
|
||||
|
||||
if u.Policy == nil {
|
||||
h := (&Random{}).Select(pool)
|
||||
if h != nil {
|
||||
return h
|
||||
}
|
||||
if h == nil && u.Spray == nil {
|
||||
return nil
|
||||
}
|
||||
return u.Spray.Select(pool)
|
||||
}
|
||||
|
||||
h := u.Policy.Select(pool)
|
||||
if h != nil {
|
||||
return h
|
||||
}
|
||||
|
||||
if u.Spray == nil {
|
||||
return nil
|
||||
}
|
||||
return u.Spray.Select(pool)
|
||||
}
|
||||
120
plugin/pkg/healthcheck/policy.go
Normal file
120
plugin/pkg/healthcheck/policy.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
// SupportedPolicies is the collection of policies registered
|
||||
SupportedPolicies = make(map[string]func() Policy)
|
||||
)
|
||||
|
||||
// RegisterPolicy adds a custom policy to the proxy.
|
||||
func RegisterPolicy(name string, policy func() Policy) {
|
||||
SupportedPolicies[name] = policy
|
||||
}
|
||||
|
||||
// Policy decides how a host will be selected from a pool. When all hosts are unhealthy, it is assumed the
|
||||
// healthchecking failed. In this case each policy will *randomly* return a host from the pool to prevent
|
||||
// no traffic to go through at all.
|
||||
type Policy interface {
|
||||
Select(pool HostPool) *UpstreamHost
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterPolicy("random", func() Policy { return &Random{} })
|
||||
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
|
||||
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
|
||||
}
|
||||
|
||||
// Random is a policy that selects up hosts from a pool at random.
|
||||
type Random struct{}
|
||||
|
||||
// Select selects an up host at random from the specified pool.
|
||||
func (r *Random) Select(pool HostPool) *UpstreamHost {
|
||||
// instead of just generating a random index
|
||||
// this is done to prevent selecting a down host
|
||||
var randHost *UpstreamHost
|
||||
count := 0
|
||||
for _, host := range pool {
|
||||
if host.Down() {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
if count == 1 {
|
||||
randHost = host
|
||||
} else {
|
||||
r := rand.Int() % count
|
||||
if r == (count - 1) {
|
||||
randHost = host
|
||||
}
|
||||
}
|
||||
}
|
||||
return randHost
|
||||
}
|
||||
|
||||
// Spray is a policy that selects a host from a pool at random. This should be used as a last ditch
|
||||
// attempt to get a host when all hosts are reporting unhealthy.
|
||||
type Spray struct{}
|
||||
|
||||
// Select selects an up host at random from the specified pool.
|
||||
func (r *Spray) Select(pool HostPool) *UpstreamHost {
|
||||
rnd := rand.Int() % len(pool)
|
||||
randHost := pool[rnd]
|
||||
log.Printf("[WARNING] All hosts reported as down, spraying to target: %s", randHost.Name)
|
||||
return randHost
|
||||
}
|
||||
|
||||
// LeastConn is a policy that selects the host with the least connections.
|
||||
type LeastConn struct{}
|
||||
|
||||
// Select selects the up host with the least number of connections in the
|
||||
// pool. If more than one host has the same least number of connections,
|
||||
// one of the hosts is chosen at random.
|
||||
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
|
||||
var bestHost *UpstreamHost
|
||||
count := 0
|
||||
leastConn := int64(1<<63 - 1)
|
||||
for _, host := range pool {
|
||||
if host.Down() {
|
||||
continue
|
||||
}
|
||||
hostConns := host.Conns
|
||||
if hostConns < leastConn {
|
||||
bestHost = host
|
||||
leastConn = hostConns
|
||||
count = 1
|
||||
} else if hostConns == leastConn {
|
||||
// randomly select host among hosts with least connections
|
||||
count++
|
||||
if count == 1 {
|
||||
bestHost = host
|
||||
} else {
|
||||
r := rand.Int() % count
|
||||
if r == (count - 1) {
|
||||
bestHost = host
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return bestHost
|
||||
}
|
||||
|
||||
// RoundRobin is a policy that selects hosts based on round robin ordering.
|
||||
type RoundRobin struct {
|
||||
Robin uint32
|
||||
}
|
||||
|
||||
// Select selects an up host from the pool using a round robin ordering scheme.
|
||||
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
|
||||
poolLen := uint32(len(pool))
|
||||
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
|
||||
host := pool[selection]
|
||||
// if the currently selected host is down, just ffwd to up host
|
||||
for i := uint32(1); host.Down() && i < poolLen; i++ {
|
||||
host = pool[(selection+i)%poolLen]
|
||||
}
|
||||
return host
|
||||
}
|
||||
143
plugin/pkg/healthcheck/policy_test.go
Normal file
143
plugin/pkg/healthcheck/policy_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var workableServer *httptest.Server
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
workableServer = httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// do nothing
|
||||
}))
|
||||
r := m.Run()
|
||||
workableServer.Close()
|
||||
os.Exit(r)
|
||||
}
|
||||
|
||||
type customPolicy struct{}
|
||||
|
||||
func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
|
||||
return pool[0]
|
||||
}
|
||||
|
||||
func testPool() HostPool {
|
||||
pool := []*UpstreamHost{
|
||||
{
|
||||
Name: workableServer.URL, // this should resolve (healthcheck test)
|
||||
},
|
||||
{
|
||||
Name: "http://shouldnot.resolve", // this shouldn't
|
||||
},
|
||||
{
|
||||
Name: "http://C",
|
||||
},
|
||||
}
|
||||
return HostPool(pool)
|
||||
}
|
||||
|
||||
func TestRegisterPolicy(t *testing.T) {
|
||||
name := "custom"
|
||||
customPolicy := &customPolicy{}
|
||||
RegisterPolicy(name, func() Policy { return customPolicy })
|
||||
if _, ok := SupportedPolicies[name]; !ok {
|
||||
t.Error("Expected supportedPolicies to have a custom policy.")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO(miek): Disabled for now, we should get out of the habit of using
|
||||
// realtime in these tests .
|
||||
func testHealthCheck(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
|
||||
u := &HealthCheck{
|
||||
Hosts: testPool(),
|
||||
FailTimeout: 10 * time.Second,
|
||||
Future: 60 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
|
||||
u.healthCheck()
|
||||
// sleep a bit, it's async now
|
||||
time.Sleep(time.Duration(2 * time.Second))
|
||||
|
||||
if u.Hosts[0].Down() {
|
||||
t.Error("Expected first host in testpool to not fail healthcheck.")
|
||||
}
|
||||
if !u.Hosts[1].Down() {
|
||||
t.Error("Expected second host in testpool to fail healthcheck.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelect(t *testing.T) {
|
||||
u := &HealthCheck{
|
||||
Hosts: testPool()[:3],
|
||||
FailTimeout: 10 * time.Second,
|
||||
Future: 60 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
u.Hosts[0].OkUntil = time.Unix(0, 0)
|
||||
u.Hosts[1].OkUntil = time.Unix(0, 0)
|
||||
u.Hosts[2].OkUntil = time.Unix(0, 0)
|
||||
if h := u.Select(); h != nil {
|
||||
t.Error("Expected select to return nil as all host are down")
|
||||
}
|
||||
u.Hosts[2].OkUntil = time.Time{}
|
||||
if h := u.Select(); h == nil {
|
||||
t.Error("Expected select to not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
rrPolicy := &RoundRobin{}
|
||||
h := rrPolicy.Select(pool)
|
||||
// First selected host is 1, because counter starts at 0
|
||||
// and increments before host is selected
|
||||
if h != pool[1] {
|
||||
t.Error("Expected first round robin host to be second host in the pool.")
|
||||
}
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected second round robin host to be third host in the pool.")
|
||||
}
|
||||
// mark host as down
|
||||
pool[0].OkUntil = time.Unix(0, 0)
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected third round robin host to be first host in the pool.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastConnPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
lcPolicy := &LeastConn{}
|
||||
pool[0].Conns = 10
|
||||
pool[1].Conns = 10
|
||||
h := lcPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected least connection host to be third host.")
|
||||
}
|
||||
pool[2].Conns = 100
|
||||
h = lcPolicy.Select(pool)
|
||||
if h != pool[0] && h != pool[1] {
|
||||
t.Error("Expected least connection host to be first or second host.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
customPolicy := &customPolicy{}
|
||||
h := customPolicy.Select(pool)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected custom policy host to be the first host.")
|
||||
}
|
||||
}
|
||||
23
plugin/pkg/nonwriter/nonwriter.go
Normal file
23
plugin/pkg/nonwriter/nonwriter.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Package nonwriter implements a dns.ResponseWriter that never writes, but captures the dns.Msg being written.
|
||||
package nonwriter
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Writer is a type of ResponseWriter that captures the message, but never writes to the client.
|
||||
type Writer struct {
|
||||
dns.ResponseWriter
|
||||
Msg *dns.Msg
|
||||
}
|
||||
|
||||
// New makes and returns a new NonWriter.
|
||||
func New(w dns.ResponseWriter) *Writer { return &Writer{ResponseWriter: w} }
|
||||
|
||||
// WriteMsg records the message, but doesn't write it itself.
|
||||
func (w *Writer) WriteMsg(res *dns.Msg) error {
|
||||
w.Msg = res
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Writer) Write(buf []byte) (int, error) { return len(buf), nil }
|
||||
19
plugin/pkg/nonwriter/nonwriter_test.go
Normal file
19
plugin/pkg/nonwriter/nonwriter_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package nonwriter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestNonWriter(t *testing.T) {
|
||||
nw := New(nil)
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.org.", dns.TypeA)
|
||||
if err := nw.WriteMsg(m); err != nil {
|
||||
t.Errorf("Got error when writing to nonwriter: %s", err)
|
||||
}
|
||||
if x := nw.Msg.Question[0].Name; x != "example.org." {
|
||||
t.Errorf("Expacted 'example.org.' got %q:", x)
|
||||
}
|
||||
}
|
||||
16
plugin/pkg/rcode/rcode.go
Normal file
16
plugin/pkg/rcode/rcode.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package rcode
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ToString convert the rcode to the official DNS string, or to "RCODE"+value if the RCODE
|
||||
// value is unknown.
|
||||
func ToString(rcode int) string {
|
||||
if str, ok := dns.RcodeToString[rcode]; ok {
|
||||
return str
|
||||
}
|
||||
return "RCODE" + strconv.Itoa(rcode)
|
||||
}
|
||||
29
plugin/pkg/rcode/rcode_test.go
Normal file
29
plugin/pkg/rcode/rcode_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package rcode
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
in int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
dns.RcodeSuccess,
|
||||
"NOERROR",
|
||||
},
|
||||
{
|
||||
28,
|
||||
"RCODE28",
|
||||
},
|
||||
}
|
||||
for i, test := range tests {
|
||||
got := ToString(test.in)
|
||||
if got != test.expected {
|
||||
t.Errorf("Test %d, expected %s, got %s", i, test.expected, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
161
plugin/pkg/replacer/replacer.go
Normal file
161
plugin/pkg/replacer/replacer.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package replacer
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsrecorder"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Replacer is a type which can replace placeholder
|
||||
// substrings in a string with actual values from a
|
||||
// dns.Msg and responseRecorder. Always use
|
||||
// NewReplacer to get one of these.
|
||||
type Replacer interface {
|
||||
Replace(string) string
|
||||
Set(key, value string)
|
||||
}
|
||||
|
||||
type replacer struct {
|
||||
replacements map[string]string
|
||||
emptyValue string
|
||||
}
|
||||
|
||||
// New makes a new replacer based on r and rr.
|
||||
// Do not create a new replacer until r and rr have all
|
||||
// the needed values, because this function copies those
|
||||
// values into the replacer. rr may be nil if it is not
|
||||
// available. emptyValue should be the string that is used
|
||||
// in place of empty string (can still be empty string).
|
||||
func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer {
|
||||
req := request.Request{W: rr, Req: r}
|
||||
rep := replacer{
|
||||
replacements: map[string]string{
|
||||
"{type}": req.Type(),
|
||||
"{name}": req.Name(),
|
||||
"{class}": req.Class(),
|
||||
"{proto}": req.Proto(),
|
||||
"{when}": func() string {
|
||||
return time.Now().Format(timeFormat)
|
||||
}(),
|
||||
"{size}": strconv.Itoa(req.Len()),
|
||||
"{remote}": req.IP(),
|
||||
"{port}": req.Port(),
|
||||
},
|
||||
emptyValue: emptyValue,
|
||||
}
|
||||
if rr != nil {
|
||||
rcode := dns.RcodeToString[rr.Rcode]
|
||||
if rcode == "" {
|
||||
rcode = strconv.Itoa(rr.Rcode)
|
||||
}
|
||||
rep.replacements["{rcode}"] = rcode
|
||||
rep.replacements["{rsize}"] = strconv.Itoa(rr.Len)
|
||||
rep.replacements["{duration}"] = time.Since(rr.Start).String()
|
||||
if rr.Msg != nil {
|
||||
rep.replacements[headerReplacer+"rflags}"] = flagsToString(rr.Msg.MsgHdr)
|
||||
}
|
||||
}
|
||||
|
||||
// Header placeholders (case-insensitive)
|
||||
rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id))
|
||||
rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(r.Opcode)
|
||||
rep.replacements[headerReplacer+"do}"] = boolToString(req.Do())
|
||||
rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size())
|
||||
|
||||
return rep
|
||||
}
|
||||
|
||||
// Replace performs a replacement of values on s and returns
|
||||
// the string with the replaced values.
|
||||
func (r replacer) Replace(s string) string {
|
||||
// Header replacements - these are case-insensitive, so we can't just use strings.Replace()
|
||||
for strings.Contains(s, headerReplacer) {
|
||||
idxStart := strings.Index(s, headerReplacer)
|
||||
endOffset := idxStart + len(headerReplacer)
|
||||
idxEnd := strings.Index(s[endOffset:], "}")
|
||||
if idxEnd > -1 {
|
||||
placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1])
|
||||
replacement := r.replacements[placeholder]
|
||||
if replacement == "" {
|
||||
replacement = r.emptyValue
|
||||
}
|
||||
s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Regular replacements - these are easier because they're case-sensitive
|
||||
for placeholder, replacement := range r.replacements {
|
||||
if replacement == "" {
|
||||
replacement = r.emptyValue
|
||||
}
|
||||
s = strings.Replace(s, placeholder, replacement, -1)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Set sets key to value in the replacements map.
|
||||
func (r replacer) Set(key, value string) {
|
||||
r.replacements["{"+key+"}"] = value
|
||||
}
|
||||
|
||||
func boolToString(b bool) string {
|
||||
if b {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
// flagsToString checks all header flags and returns those
|
||||
// that are set as a string separated with commas
|
||||
func flagsToString(h dns.MsgHdr) string {
|
||||
flags := make([]string, 7)
|
||||
i := 0
|
||||
|
||||
if h.Response {
|
||||
flags[i] = "qr"
|
||||
i++
|
||||
}
|
||||
|
||||
if h.Authoritative {
|
||||
flags[i] = "aa"
|
||||
i++
|
||||
}
|
||||
if h.Truncated {
|
||||
flags[i] = "tc"
|
||||
i++
|
||||
}
|
||||
if h.RecursionDesired {
|
||||
flags[i] = "rd"
|
||||
i++
|
||||
}
|
||||
if h.RecursionAvailable {
|
||||
flags[i] = "ra"
|
||||
i++
|
||||
}
|
||||
if h.Zero {
|
||||
flags[i] = "z"
|
||||
i++
|
||||
}
|
||||
if h.AuthenticatedData {
|
||||
flags[i] = "ad"
|
||||
i++
|
||||
}
|
||||
if h.CheckingDisabled {
|
||||
flags[i] = "cd"
|
||||
i++
|
||||
}
|
||||
return strings.Join(flags[:i], ",")
|
||||
}
|
||||
|
||||
const (
|
||||
timeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
headerReplacer = "{>"
|
||||
)
|
||||
61
plugin/pkg/replacer/replacer_test.go
Normal file
61
plugin/pkg/replacer/replacer_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package replacer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsrecorder"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestNewReplacer(t *testing.T) {
|
||||
w := dnsrecorder.New(&test.ResponseWriter{})
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.org.", dns.TypeHINFO)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
|
||||
replaceValues := New(r, w, "")
|
||||
|
||||
switch v := replaceValues.(type) {
|
||||
case replacer:
|
||||
|
||||
if v.replacements["{type}"] != "HINFO" {
|
||||
t.Errorf("Expected type to be HINFO, got %q", v.replacements["{type}"])
|
||||
}
|
||||
if v.replacements["{name}"] != "example.org." {
|
||||
t.Errorf("Expected request name to be example.org., got %q", v.replacements["{name}"])
|
||||
}
|
||||
if v.replacements["{size}"] != "29" { // size of request
|
||||
t.Errorf("Expected size to be 29, got %q", v.replacements["{size}"])
|
||||
}
|
||||
|
||||
default:
|
||||
t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
w := dnsrecorder.New(&test.ResponseWriter{})
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.org.", dns.TypeHINFO)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
|
||||
repl := New(r, w, "")
|
||||
|
||||
repl.Set("name", "coredns.io.")
|
||||
repl.Set("type", "A")
|
||||
repl.Set("size", "20")
|
||||
|
||||
if repl.Replace("This name is {name}") != "This name is coredns.io." {
|
||||
t.Error("Expected name replacement failed")
|
||||
}
|
||||
if repl.Replace("This type is {type}") != "This type is A" {
|
||||
t.Error("Expected type replacement failed")
|
||||
}
|
||||
if repl.Replace("The request size is {size}") != "The request size is 20" {
|
||||
t.Error("Expected size replacement failed")
|
||||
}
|
||||
}
|
||||
61
plugin/pkg/response/classify.go
Normal file
61
plugin/pkg/response/classify.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package response
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Class holds sets of Types
|
||||
type Class int
|
||||
|
||||
const (
|
||||
// All is a meta class encompassing all the classes.
|
||||
All Class = iota
|
||||
// Success is a class for a successful response.
|
||||
Success
|
||||
// Denial is a class for denying existence (NXDOMAIN, or a nodata: type does not exist)
|
||||
Denial
|
||||
// Error is a class for errors, right now defined as not Success and not Denial
|
||||
Error
|
||||
)
|
||||
|
||||
func (c Class) String() string {
|
||||
switch c {
|
||||
case All:
|
||||
return "all"
|
||||
case Success:
|
||||
return "success"
|
||||
case Denial:
|
||||
return "denial"
|
||||
case Error:
|
||||
return "error"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ClassFromString returns the class from the string s. If not class matches
|
||||
// the All class and an error are returned
|
||||
func ClassFromString(s string) (Class, error) {
|
||||
switch s {
|
||||
case "all":
|
||||
return All, nil
|
||||
case "success":
|
||||
return Success, nil
|
||||
case "denial":
|
||||
return Denial, nil
|
||||
case "error":
|
||||
return Error, nil
|
||||
}
|
||||
return All, fmt.Errorf("invalid Class: %s", s)
|
||||
}
|
||||
|
||||
// Classify classifies the Type t, it returns its Class.
|
||||
func Classify(t Type) Class {
|
||||
switch t {
|
||||
case NoError, Delegation:
|
||||
return Success
|
||||
case NameError, NoData:
|
||||
return Denial
|
||||
case OtherError:
|
||||
fallthrough
|
||||
default:
|
||||
return Error
|
||||
}
|
||||
}
|
||||
146
plugin/pkg/response/typify.go
Normal file
146
plugin/pkg/response/typify.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Type is the type of the message.
|
||||
type Type int
|
||||
|
||||
const (
|
||||
// NoError indicates a positive reply
|
||||
NoError Type = iota
|
||||
// NameError is a NXDOMAIN in header, SOA in auth.
|
||||
NameError
|
||||
// NoData indicates name found, but not the type: NOERROR in header, SOA in auth.
|
||||
NoData
|
||||
// Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked).
|
||||
Delegation
|
||||
// Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR.
|
||||
Meta
|
||||
// Update is an dynamic update message.
|
||||
Update
|
||||
// OtherError indicates any other error: don't cache these.
|
||||
OtherError
|
||||
)
|
||||
|
||||
var toString = map[Type]string{
|
||||
NoError: "NOERROR",
|
||||
NameError: "NXDOMAIN",
|
||||
NoData: "NODATA",
|
||||
Delegation: "DELEGATION",
|
||||
Meta: "META",
|
||||
Update: "UPDATE",
|
||||
OtherError: "OTHERERROR",
|
||||
}
|
||||
|
||||
func (t Type) String() string { return toString[t] }
|
||||
|
||||
// TypeFromString returns the type from the string s. If not type matches
|
||||
// the OtherError type and an error are returned.
|
||||
func TypeFromString(s string) (Type, error) {
|
||||
for t, str := range toString {
|
||||
if s == str {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return NoError, fmt.Errorf("invalid Type: %s", s)
|
||||
}
|
||||
|
||||
// Typify classifies a message, it returns the Type.
|
||||
func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) {
|
||||
if m == nil {
|
||||
return OtherError, nil
|
||||
}
|
||||
opt := m.IsEdns0()
|
||||
do := false
|
||||
if opt != nil {
|
||||
do = opt.Do()
|
||||
}
|
||||
|
||||
if m.Opcode == dns.OpcodeUpdate {
|
||||
return Update, opt
|
||||
}
|
||||
|
||||
// Check transfer and update first
|
||||
if m.Opcode == dns.OpcodeNotify {
|
||||
return Meta, opt
|
||||
}
|
||||
|
||||
if len(m.Question) > 0 {
|
||||
if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR {
|
||||
return Meta, opt
|
||||
}
|
||||
}
|
||||
|
||||
// If our message contains any expired sigs and we care about that, we should return expired
|
||||
if do {
|
||||
if expired := typifyExpired(m, t); expired {
|
||||
return OtherError, opt
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {
|
||||
return NoError, opt
|
||||
}
|
||||
|
||||
soa := false
|
||||
ns := 0
|
||||
for _, r := range m.Ns {
|
||||
if r.Header().Rrtype == dns.TypeSOA {
|
||||
soa = true
|
||||
continue
|
||||
}
|
||||
if r.Header().Rrtype == dns.TypeNS {
|
||||
ns++
|
||||
}
|
||||
}
|
||||
|
||||
// Check length of different sections, and drop stuff that is just to large? TODO(miek).
|
||||
|
||||
if soa && m.Rcode == dns.RcodeSuccess {
|
||||
return NoData, opt
|
||||
}
|
||||
if soa && m.Rcode == dns.RcodeNameError {
|
||||
return NameError, opt
|
||||
}
|
||||
|
||||
if ns > 0 && m.Rcode == dns.RcodeSuccess {
|
||||
return Delegation, opt
|
||||
}
|
||||
|
||||
if m.Rcode == dns.RcodeSuccess {
|
||||
return NoError, opt
|
||||
}
|
||||
|
||||
return OtherError, opt
|
||||
}
|
||||
|
||||
func typifyExpired(m *dns.Msg, t time.Time) bool {
|
||||
if expired := typifyExpiredRRSIG(m.Answer, t); expired {
|
||||
return true
|
||||
}
|
||||
if expired := typifyExpiredRRSIG(m.Ns, t); expired {
|
||||
return true
|
||||
}
|
||||
if expired := typifyExpiredRRSIG(m.Extra, t); expired {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool {
|
||||
for _, r := range rrs {
|
||||
if r.Header().Rrtype != dns.TypeRRSIG {
|
||||
continue
|
||||
}
|
||||
ok := r.(*dns.RRSIG).ValidityPeriod(t)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
84
plugin/pkg/response/typify_test.go
Normal file
84
plugin/pkg/response/typify_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestTypifyNilMsg(t *testing.T) {
|
||||
var m *dns.Msg
|
||||
|
||||
ty, _ := Typify(m, time.Now().UTC())
|
||||
if ty != OtherError {
|
||||
t.Errorf("message wrongly typified, expected OtherError, got %s", ty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypifyDelegation(t *testing.T) {
|
||||
m := delegationMsg()
|
||||
mt, _ := Typify(m, time.Now().UTC())
|
||||
if mt != Delegation {
|
||||
t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypifyRRSIG(t *testing.T) {
|
||||
now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017")
|
||||
utc := now.UTC()
|
||||
|
||||
m := delegationMsgRRSIGOK()
|
||||
if mt, _ := Typify(m, utc); mt != Delegation {
|
||||
t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
|
||||
}
|
||||
|
||||
// Still a Delegation because EDNS0 OPT DO bool is not set, so we won't check the sigs.
|
||||
m = delegationMsgRRSIGFail()
|
||||
if mt, _ := Typify(m, utc); mt != Delegation {
|
||||
t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
|
||||
}
|
||||
|
||||
m = delegationMsgRRSIGFail()
|
||||
m = addOpt(m)
|
||||
if mt, _ := Typify(m, utc); mt != OtherError {
|
||||
t.Errorf("message is wrongly typified, expected OtherError, got %s", mt)
|
||||
}
|
||||
}
|
||||
|
||||
func delegationMsg() *dns.Msg {
|
||||
return &dns.Msg{
|
||||
Ns: []dns.RR{
|
||||
test.NS("miek.nl. 3600 IN NS linode.atoom.net."),
|
||||
test.NS("miek.nl. 3600 IN NS ns-ext.nlnetlabs.nl."),
|
||||
test.NS("miek.nl. 3600 IN NS omval.tednet.nl."),
|
||||
},
|
||||
Extra: []dns.RR{
|
||||
test.A("omval.tednet.nl. 3600 IN A 185.49.141.42"),
|
||||
test.AAAA("omval.tednet.nl. 3600 IN AAAA 2a04:b900:0:100::42"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func delegationMsgRRSIGOK() *dns.Msg {
|
||||
del := delegationMsg()
|
||||
del.Ns = append(del.Ns,
|
||||
test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20170521031301 20170421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
|
||||
)
|
||||
return del
|
||||
}
|
||||
|
||||
func delegationMsgRRSIGFail() *dns.Msg {
|
||||
del := delegationMsg()
|
||||
del.Ns = append(del.Ns,
|
||||
test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
|
||||
)
|
||||
return del
|
||||
}
|
||||
|
||||
func addOpt(m *dns.Msg) *dns.Msg {
|
||||
m.Extra = append(m.Extra, test.OPT(4096, true))
|
||||
return m
|
||||
}
|
||||
64
plugin/pkg/singleflight/singleflight.go
Normal file
64
plugin/pkg/singleflight/singleflight.go
Normal file
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
Copyright 2012 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package singleflight provides a duplicate function call suppression
|
||||
// mechanism.
|
||||
package singleflight
|
||||
|
||||
import "sync"
|
||||
|
||||
// call is an in-flight or completed Do call
|
||||
type call struct {
|
||||
wg sync.WaitGroup
|
||||
val interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
// Group represents a class of work and forms a namespace in which
|
||||
// units of work can be executed with duplicate suppression.
|
||||
type Group struct {
|
||||
mu sync.Mutex // protects m
|
||||
m map[uint32]*call // lazily initialized
|
||||
}
|
||||
|
||||
// Do executes and returns the results of the given function, making
|
||||
// sure that only one execution is in-flight for a given key at a
|
||||
// time. If a duplicate comes in, the duplicate caller waits for the
|
||||
// original to complete and receives the same results.
|
||||
func (g *Group) Do(key uint32, fn func() (interface{}, error)) (interface{}, error) {
|
||||
g.mu.Lock()
|
||||
if g.m == nil {
|
||||
g.m = make(map[uint32]*call)
|
||||
}
|
||||
if c, ok := g.m[key]; ok {
|
||||
g.mu.Unlock()
|
||||
c.wg.Wait()
|
||||
return c.val, c.err
|
||||
}
|
||||
c := new(call)
|
||||
c.wg.Add(1)
|
||||
g.m[key] = c
|
||||
g.mu.Unlock()
|
||||
|
||||
c.val, c.err = fn()
|
||||
c.wg.Done()
|
||||
|
||||
g.mu.Lock()
|
||||
delete(g.m, key)
|
||||
g.mu.Unlock()
|
||||
|
||||
return c.val, c.err
|
||||
}
|
||||
85
plugin/pkg/singleflight/singleflight_test.go
Normal file
85
plugin/pkg/singleflight/singleflight_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
Copyright 2012 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package singleflight
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDo(t *testing.T) {
|
||||
var g Group
|
||||
v, err := g.Do(1, func() (interface{}, error) {
|
||||
return "bar", nil
|
||||
})
|
||||
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
|
||||
t.Errorf("Do = %v; want %v", got, want)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Do error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoErr(t *testing.T) {
|
||||
var g Group
|
||||
someErr := errors.New("Some error")
|
||||
v, err := g.Do(1, func() (interface{}, error) {
|
||||
return nil, someErr
|
||||
})
|
||||
if err != someErr {
|
||||
t.Errorf("Do error = %v; want someErr", err)
|
||||
}
|
||||
if v != nil {
|
||||
t.Errorf("unexpected non-nil value %#v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoDupSuppress(t *testing.T) {
|
||||
var g Group
|
||||
c := make(chan string)
|
||||
var calls int32
|
||||
fn := func() (interface{}, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return <-c, nil
|
||||
}
|
||||
|
||||
const n = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < n; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
v, err := g.Do(1, fn)
|
||||
if err != nil {
|
||||
t.Errorf("Do error: %v", err)
|
||||
}
|
||||
if v.(string) != "bar" {
|
||||
t.Errorf("got %q; want %q", v, "bar")
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond) // let goroutines above block
|
||||
c <- "bar"
|
||||
wg.Wait()
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Errorf("number of calls = %d; want 1", got)
|
||||
}
|
||||
}
|
||||
128
plugin/pkg/tls/tls.go
Normal file
128
plugin/pkg/tls/tls.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewTLSConfigFromArgs returns a TLS config based upon the passed
|
||||
// in list of arguments. Typically these come straight from the
|
||||
// Corefile.
|
||||
// no args
|
||||
// - creates a Config with no cert and using system CAs
|
||||
// - use for a client that talks to a server with a public signed cert (CA installed in system)
|
||||
// - the client will not be authenticated by the server since there is no cert
|
||||
// one arg: the path to CA PEM file
|
||||
// - creates a Config with no cert using a specific CA
|
||||
// - use for a client that talks to a server with a private signed cert (CA not installed in system)
|
||||
// - the client will not be authenticated by the server since there is no cert
|
||||
// two args: path to cert PEM file, the path to private key PEM file
|
||||
// - creates a Config with a cert, using system CAs to validate the other end
|
||||
// - use for:
|
||||
// - a server; or,
|
||||
// - a client that talks to a server with a public cert and needs certificate-based authentication
|
||||
// - the other end will authenticate this end via the provided cert
|
||||
// - the cert of the other end will be verified via system CAs
|
||||
// three args: path to cert PEM file, path to client private key PEM file, path to CA PEM file
|
||||
// - creates a Config with the cert, using specified CA to validate the other end
|
||||
// - use for:
|
||||
// - a server; or,
|
||||
// - a client that talks to a server with a privately signed cert and needs certificate-based
|
||||
// authentication
|
||||
// - the other end will authenticate this end via the provided cert
|
||||
// - this end will verify the other end's cert using the specified CA
|
||||
func NewTLSConfigFromArgs(args ...string) (*tls.Config, error) {
|
||||
var err error
|
||||
var c *tls.Config
|
||||
switch len(args) {
|
||||
case 0:
|
||||
// No client cert, use system CA
|
||||
c, err = NewTLSClientConfig("")
|
||||
case 1:
|
||||
// No client cert, use specified CA
|
||||
c, err = NewTLSClientConfig(args[0])
|
||||
case 2:
|
||||
// Client cert, use system CA
|
||||
c, err = NewTLSConfig(args[0], args[1], "")
|
||||
case 3:
|
||||
// Client cert, use specified CA
|
||||
c, err = NewTLSConfig(args[0], args[1], args[2])
|
||||
default:
|
||||
err = fmt.Errorf("maximum of three arguments allowed for TLS config, found %d", len(args))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// NewTLSConfig returns a TLS config that includes a certificate
|
||||
// Use for server TLS config or when using a client certificate
|
||||
// If caPath is empty, system CAs will be used
|
||||
func NewTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not load TLS cert: %s", err)
|
||||
}
|
||||
|
||||
roots, err := loadRoots(caPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: roots}, nil
|
||||
}
|
||||
|
||||
// NewTLSClientConfig returns a TLS config for a client connection
|
||||
// If caPath is empty, system CAs will be used
|
||||
func NewTLSClientConfig(caPath string) (*tls.Config, error) {
|
||||
roots, err := loadRoots(caPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tls.Config{RootCAs: roots}, nil
|
||||
}
|
||||
|
||||
func loadRoots(caPath string) (*x509.CertPool, error) {
|
||||
if caPath == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
roots := x509.NewCertPool()
|
||||
pem, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading %s: %s", caPath, err)
|
||||
}
|
||||
ok := roots.AppendCertsFromPEM(pem)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not read root certs: %s", err)
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
// NewHTTPSTransport returns an HTTP transport configured using tls.Config
|
||||
func NewHTTPSTransport(cc *tls.Config) *http.Transport {
|
||||
// this seems like a bad idea but was here in the previous version
|
||||
if cc != nil {
|
||||
cc.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: cc,
|
||||
MaxIdleConnsPerHost: 25,
|
||||
}
|
||||
|
||||
return tr
|
||||
}
|
||||
101
plugin/pkg/tls/tls_test.go
Normal file
101
plugin/pkg/tls/tls_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
)
|
||||
|
||||
func getPEMFiles(t *testing.T) (rmFunc func(), cert, key, ca string) {
|
||||
tempDir, rmFunc, err := test.WritePEMFiles("")
|
||||
if err != nil {
|
||||
t.Fatalf("Could not write PEM files: %s", err)
|
||||
}
|
||||
|
||||
cert = filepath.Join(tempDir, "cert.pem")
|
||||
key = filepath.Join(tempDir, "key.pem")
|
||||
ca = filepath.Join(tempDir, "ca.pem")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestNewTLSConfig(t *testing.T) {
|
||||
rmFunc, cert, key, ca := getPEMFiles(t)
|
||||
defer rmFunc()
|
||||
|
||||
_, err := NewTLSConfig(cert, key, ca)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create TLSConfig: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTLSClientConfig(t *testing.T) {
|
||||
rmFunc, _, _, ca := getPEMFiles(t)
|
||||
defer rmFunc()
|
||||
|
||||
_, err := NewTLSClientConfig(ca)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create TLSConfig: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTLSConfigFromArgs(t *testing.T) {
|
||||
rmFunc, cert, key, ca := getPEMFiles(t)
|
||||
defer rmFunc()
|
||||
|
||||
_, err := NewTLSConfigFromArgs()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create TLSConfig: %s", err)
|
||||
}
|
||||
|
||||
c, err := NewTLSConfigFromArgs(ca)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create TLSConfig: %s", err)
|
||||
}
|
||||
if c.RootCAs == nil {
|
||||
t.Error("RootCAs should not be nil when one arg passed")
|
||||
}
|
||||
|
||||
c, err = NewTLSConfigFromArgs(cert, key)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create TLSConfig: %s", err)
|
||||
}
|
||||
if c.RootCAs != nil {
|
||||
t.Error("RootCAs should be nil when two args passed")
|
||||
}
|
||||
if len(c.Certificates) != 1 {
|
||||
t.Error("Certificates should have a single entry when two args passed")
|
||||
}
|
||||
args := []string{cert, key, ca}
|
||||
c, err = NewTLSConfigFromArgs(args...)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create TLSConfig: %s", err)
|
||||
}
|
||||
if c.RootCAs == nil {
|
||||
t.Error("RootCAs should not be nil when three args passed")
|
||||
}
|
||||
if len(c.Certificates) != 1 {
|
||||
t.Error("Certificateis should have a single entry when three args passed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSTransport(t *testing.T) {
|
||||
rmFunc, _, _, ca := getPEMFiles(t)
|
||||
defer rmFunc()
|
||||
|
||||
cc, err := NewTLSClientConfig(ca)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create TLSConfig: %s", err)
|
||||
}
|
||||
|
||||
tr := NewHTTPSTransport(cc)
|
||||
if tr == nil {
|
||||
t.Errorf("Failed to create https transport with cc")
|
||||
}
|
||||
|
||||
tr = NewHTTPSTransport(nil)
|
||||
if tr == nil {
|
||||
t.Errorf("Failed to create https transport without cc")
|
||||
}
|
||||
}
|
||||
12
plugin/pkg/trace/trace.go
Normal file
12
plugin/pkg/trace/trace.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package trace
|
||||
|
||||
import (
|
||||
"github.com/coredns/coredns/plugin"
|
||||
ot "github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
// Trace holds the tracer and endpoint info
|
||||
type Trace interface {
|
||||
plugin.Handler
|
||||
Tracer() ot.Tracer
|
||||
}
|
||||
Reference in New Issue
Block a user