Remove the word middleware (#1067)

* Rename middleware to plugin

first pass; mostly used 'sed', few spots where I manually changed
text.

This still builds a coredns binary.

* fmt error

* Rename AddMiddleware to AddPlugin

* Readd AddMiddleware to remain backwards compat
This commit is contained in:
Miek Gieben
2017-09-14 09:36:06 +01:00
committed by GitHub
parent b984aa4559
commit d8714e64e4
354 changed files with 974 additions and 969 deletions

129
plugin/pkg/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,129 @@
// Package cache implements a cache. The cache hold 256 shards, each shard
// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it
// just randomly evicts elements when it gets full.
package cache
import (
"hash/fnv"
"sync"
)
// Hash returns the FNV hash of what.
func Hash(what []byte) uint32 {
h := fnv.New32()
h.Write(what)
return h.Sum32()
}
// Cache is cache.
type Cache struct {
shards [shardSize]*shard
}
// shard is a cache with random eviction.
type shard struct {
items map[uint32]interface{}
size int
sync.RWMutex
}
// New returns a new cache.
func New(size int) *Cache {
ssize := size / shardSize
if ssize < 512 {
ssize = 512
}
c := &Cache{}
// Initialize all the shards
for i := 0; i < shardSize; i++ {
c.shards[i] = newShard(ssize)
}
return c
}
// Add adds a new element to the cache. If the element already exists it is overwritten.
func (c *Cache) Add(key uint32, el interface{}) {
shard := key & (shardSize - 1)
c.shards[shard].Add(key, el)
}
// Get looks up element index under key.
func (c *Cache) Get(key uint32) (interface{}, bool) {
shard := key & (shardSize - 1)
return c.shards[shard].Get(key)
}
// Remove removes the element indexed with key.
func (c *Cache) Remove(key uint32) {
shard := key & (shardSize - 1)
c.shards[shard].Remove(key)
}
// Len returns the number of elements in the cache.
func (c *Cache) Len() int {
l := 0
for _, s := range c.shards {
l += s.Len()
}
return l
}
// newShard returns a new shard with size.
func newShard(size int) *shard { return &shard{items: make(map[uint32]interface{}), size: size} }
// Add adds element indexed by key into the cache. Any existing element is overwritten
func (s *shard) Add(key uint32, el interface{}) {
l := s.Len()
if l+1 > s.size {
s.Evict()
}
s.Lock()
s.items[key] = el
s.Unlock()
}
// Remove removes the element indexed by key from the cache.
func (s *shard) Remove(key uint32) {
s.Lock()
delete(s.items, key)
s.Unlock()
}
// Evict removes a random element from the cache.
func (s *shard) Evict() {
s.Lock()
defer s.Unlock()
key := -1
for k := range s.items {
key = int(k)
break
}
if key == -1 {
// empty cache
return
}
delete(s.items, uint32(key))
}
// Get looks up the element indexed under key.
func (s *shard) Get(key uint32) (interface{}, bool) {
s.RLock()
el, found := s.items[key]
s.RUnlock()
return el, found
}
// Len returns the current length of the cache.
func (s *shard) Len() int {
s.RLock()
l := len(s.items)
s.RUnlock()
return l
}
const shardSize = 256

31
plugin/pkg/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,31 @@
package cache
import "testing"
func TestCacheAddAndGet(t *testing.T) {
c := New(4)
c.Add(1, 1)
if _, found := c.Get(1); !found {
t.Fatal("Failed to find inserted record")
}
}
func TestCacheLen(t *testing.T) {
c := New(4)
c.Add(1, 1)
if l := c.Len(); l != 1 {
t.Fatalf("Cache size should %d, got %d", 1, l)
}
c.Add(1, 1)
if l := c.Len(); l != 1 {
t.Fatalf("Cache size should %d, got %d", 1, l)
}
c.Add(2, 2)
if l := c.Len(); l != 2 {
t.Fatalf("Cache size should %d, got %d", 2, l)
}
}

60
plugin/pkg/cache/shard_test.go vendored Normal file
View File

@@ -0,0 +1,60 @@
package cache
import "testing"
func TestShardAddAndGet(t *testing.T) {
s := newShard(4)
s.Add(1, 1)
if _, found := s.Get(1); !found {
t.Fatal("Failed to find inserted record")
}
}
func TestShardLen(t *testing.T) {
s := newShard(4)
s.Add(1, 1)
if l := s.Len(); l != 1 {
t.Fatalf("Shard size should %d, got %d", 1, l)
}
s.Add(1, 1)
if l := s.Len(); l != 1 {
t.Fatalf("Shard size should %d, got %d", 1, l)
}
s.Add(2, 2)
if l := s.Len(); l != 2 {
t.Fatalf("Shard size should %d, got %d", 2, l)
}
}
func TestShardEvict(t *testing.T) {
s := newShard(1)
s.Add(1, 1)
s.Add(2, 2)
// 1 should be gone
if _, found := s.Get(1); found {
t.Fatal("Found item that should have been evicted")
}
}
func TestShardLenEvict(t *testing.T) {
s := newShard(4)
s.Add(1, 1)
s.Add(2, 1)
s.Add(3, 1)
s.Add(4, 1)
if l := s.Len(); l != 4 {
t.Fatalf("Shard size should %d, got %d", 4, l)
}
// This should evict one element
s.Add(5, 1)
if l := s.Len(); l != 4 {
t.Fatalf("Shard size should %d, got %d", 4, l)
}
}

View File

@@ -0,0 +1,58 @@
// Package dnsrecorder allows you to record a DNS response when it is send to the client.
package dnsrecorder
import (
"time"
"github.com/miekg/dns"
)
// Recorder is a type of ResponseWriter that captures
// the rcode code written to it and also the size of the message
// written in the response. A rcode code does not have
// to be written, however, in which case 0 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
type Recorder struct {
dns.ResponseWriter
Rcode int
Len int
Msg *dns.Msg
Start time.Time
}
// New makes and returns a new Recorder,
// which captures the DNS rcode from the ResponseWriter
// and also the length of the response message written through it.
func New(w dns.ResponseWriter) *Recorder {
return &Recorder{
ResponseWriter: w,
Rcode: 0,
Msg: nil,
Start: time.Now(),
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *Recorder) WriteMsg(res *dns.Msg) error {
r.Rcode = res.Rcode
// We may get called multiple times (axfr for instance).
// Save the last message, but add the sizes.
r.Len += res.Len()
r.Msg = res
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the length of the message that gets written.
func (r *Recorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.Len += n
}
return n, err
}
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return }

View File

@@ -0,0 +1,28 @@
package dnsrecorder
/*
func TestNewResponseRecorder(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
if !(recordRequest.ResponseWriter == w) {
t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
}
if recordRequest.status != http.StatusOK {
t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
}
}
func TestWrite(t *testing.T) {
w := httptest.NewRecorder()
responseTestString := "test"
recordRequest := NewResponseRecorder(w)
buf := []byte(responseTestString)
recordRequest.Write(buf)
if recordRequest.size != len(buf) {
t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
}
if w.Body.String() != responseTestString {
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
}
}
*/

View File

@@ -0,0 +1,15 @@
package dnsutil
import "github.com/miekg/dns"
// DuplicateCNAME returns true if r already exists in records.
func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,55 @@
package dnsutil
import (
"testing"
"github.com/miekg/dns"
)
func TestDuplicateCNAME(t *testing.T) {
tests := []struct {
cname string
records []string
expected bool
}{
{
"1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
[]string{
"US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534",
"1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
},
true,
},
{
"1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
[]string{
"US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534",
},
false,
},
{
"1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.",
[]string{},
false,
},
}
for i, test := range tests {
cnameRR, err := dns.NewRR(test.cname)
if err != nil {
t.Fatalf("Test %d, cname ('%s') error (%s)!", i, test.cname, err)
}
cname := cnameRR.(*dns.CNAME)
records := []dns.RR{}
for j, r := range test.records {
rr, err := dns.NewRR(r)
if err != nil {
t.Fatalf("Test %d, record %d ('%s') error (%s)!", i, j, r, err)
}
records = append(records, rr)
}
got := DuplicateCNAME(cname, records)
if got != test.expected {
t.Errorf("Test %d, expected '%v', got '%v' for CNAME ('%s') and RECORDS (%v)", i, test.expected, got, test.cname, test.records)
}
}
}

View File

@@ -0,0 +1,12 @@
package dnsutil
import "github.com/miekg/dns"
// Dedup de-duplicates a message.
func Dedup(m *dns.Msg) *dns.Msg {
// TODO(miek): expensive!
m.Answer = dns.Dedup(m.Answer, nil)
m.Ns = dns.Dedup(m.Ns, nil)
m.Extra = dns.Dedup(m.Extra, nil)
return m
}

View File

@@ -0,0 +1,2 @@
// Package dnsutil contains DNS related helper functions.
package dnsutil

View File

@@ -0,0 +1,82 @@
package dnsutil
import (
"fmt"
"net"
"os"
"github.com/miekg/dns"
)
// ParseHostPortOrFile parses the strings in s, each string can either be a address,
// address:port or a filename. The address part is checked and the filename case a
// resolv.conf like file is parsed and the nameserver found are returned.
func ParseHostPortOrFile(s ...string) ([]string, error) {
var servers []string
for _, host := range s {
addr, _, err := net.SplitHostPort(host)
if err != nil {
// Parse didn't work, it is not a addr:port combo
if net.ParseIP(host) == nil {
// Not an IP address.
ss, err := tryFile(host)
if err == nil {
servers = append(servers, ss...)
continue
}
return servers, fmt.Errorf("not an IP address or file: %q", host)
}
ss := net.JoinHostPort(host, "53")
servers = append(servers, ss)
continue
}
if net.ParseIP(addr) == nil {
// No an IP address.
ss, err := tryFile(host)
if err == nil {
servers = append(servers, ss...)
continue
}
return servers, fmt.Errorf("not an IP address or file: %q", host)
}
servers = append(servers, host)
}
return servers, nil
}
// Try to open this is a file first.
func tryFile(s string) ([]string, error) {
c, err := dns.ClientConfigFromFile(s)
if err == os.ErrNotExist {
return nil, fmt.Errorf("failed to open file %q: %q", s, err)
} else if err != nil {
return nil, err
}
servers := []string{}
for _, s := range c.Servers {
servers = append(servers, net.JoinHostPort(s, c.Port))
}
return servers, nil
}
// ParseHostPort will check if the host part is a valid IP address, if the
// IP address is valid, but no port is found, defaultPort is added.
func ParseHostPort(s, defaultPort string) (string, error) {
addr, port, err := net.SplitHostPort(s)
if port == "" {
port = defaultPort
}
if err != nil {
if net.ParseIP(s) == nil {
return "", fmt.Errorf("must specify an IP address: `%s'", s)
}
return net.JoinHostPort(s, port), nil
}
if net.ParseIP(addr) == nil {
return "", fmt.Errorf("must specify an IP address: `%s'", addr)
}
return net.JoinHostPort(addr, port), nil
}

View File

@@ -0,0 +1,85 @@
package dnsutil
import (
"io/ioutil"
"os"
"testing"
)
func TestParseHostPortOrFile(t *testing.T) {
tests := []struct {
in string
expected string
shouldErr bool
}{
{
"8.8.8.8",
"8.8.8.8:53",
false,
},
{
"8.8.8.8:153",
"8.8.8.8:153",
false,
},
{
"/etc/resolv.conf:53",
"",
true,
},
{
"resolv.conf",
"127.0.0.1:53",
false,
},
}
err := ioutil.WriteFile("resolv.conf", []byte("nameserver 127.0.0.1\n"), 0600)
if err != nil {
t.Fatalf("Failed to write test resolv.conf")
}
defer os.Remove("resolv.conf")
for i, tc := range tests {
got, err := ParseHostPortOrFile(tc.in)
if err == nil && tc.shouldErr {
t.Errorf("Test %d, expected error, got nil", i)
continue
}
if err != nil && tc.shouldErr {
continue
}
if got[0] != tc.expected {
t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got[0])
}
}
}
func TestParseHostPort(t *testing.T) {
tests := []struct {
in string
expected string
shouldErr bool
}{
{"8.8.8.8:53", "8.8.8.8:53", false},
{"a.a.a.a:153", "", true},
{"8.8.8.8", "8.8.8.8:53", false},
{"8.8.8.8:", "8.8.8.8:53", false},
{"8.8.8.8::53", "", true},
{"resolv.conf", "", true},
}
for i, tc := range tests {
got, err := ParseHostPort(tc.in, "53")
if err == nil && tc.shouldErr {
t.Errorf("Test %d, expected error, got nil", i)
continue
}
if err != nil && !tc.shouldErr {
t.Errorf("Test %d, expected no error, got %q", i, err)
}
if got != tc.expected {
t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got)
}
}
}

View File

@@ -0,0 +1,19 @@
package dnsutil
import (
"strings"
"github.com/miekg/dns"
)
// Join joins labels to form a fully qualified domain name. If the last label is
// the root label it is ignored. Not other syntax checks are performed.
func Join(labels []string) string {
ll := len(labels)
if labels[ll-1] == "." {
s := strings.Join(labels[:ll-1], ".")
return dns.Fqdn(s)
}
s := strings.Join(labels, ".")
return dns.Fqdn(s)
}

View File

@@ -0,0 +1,20 @@
package dnsutil
import "testing"
func TestJoin(t *testing.T) {
tests := []struct {
in []string
out string
}{
{[]string{"bla", "bliep", "example", "org"}, "bla.bliep.example.org."},
{[]string{"example", "."}, "example."},
{[]string{"."}, "."},
}
for i, tc := range tests {
if x := Join(tc.in); x != tc.out {
t.Errorf("Test %d, expected %s, got %s", i, tc.out, x)
}
}
}

View File

@@ -0,0 +1,68 @@
package dnsutil
import (
"net"
"strings"
)
// ExtractAddressFromReverse turns a standard PTR reverse record name
// into an IP address. This works for ipv4 or ipv6.
//
// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion
// failes the empty string is returned.
func ExtractAddressFromReverse(reverseName string) string {
search := ""
f := reverse
switch {
case strings.HasSuffix(reverseName, v4arpaSuffix):
search = strings.TrimSuffix(reverseName, v4arpaSuffix)
case strings.HasSuffix(reverseName, v6arpaSuffix):
search = strings.TrimSuffix(reverseName, v6arpaSuffix)
f = reverse6
default:
return ""
}
// Reverse the segments and then combine them.
return f(strings.Split(search, "."))
}
func reverse(slice []string) string {
for i := 0; i < len(slice)/2; i++ {
j := len(slice) - i - 1
slice[i], slice[j] = slice[j], slice[i]
}
ip := net.ParseIP(strings.Join(slice, ".")).To4()
if ip == nil {
return ""
}
return ip.String()
}
// reverse6 reverse the segments and combine them according to RFC3596:
// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2
// is reversed to 2001:db8::567:89ab
func reverse6(slice []string) string {
for i := 0; i < len(slice)/2; i++ {
j := len(slice) - i - 1
slice[i], slice[j] = slice[j], slice[i]
}
slice6 := []string{}
for i := 0; i < len(slice)/4; i++ {
slice6 = append(slice6, strings.Join(slice[i*4:i*4+4], ""))
}
ip := net.ParseIP(strings.Join(slice6, ":")).To16()
if ip == nil {
return ""
}
return ip.String()
}
const (
// v4arpaSuffix is the reverse tree suffix for v4 IP addresses.
v4arpaSuffix = ".in-addr.arpa."
// v6arpaSuffix is the reverse tree suffix for v6 IP addresses.
v6arpaSuffix = ".ip6.arpa."
)

View File

@@ -0,0 +1,51 @@
package dnsutil
import (
"testing"
)
func TestExtractAddressFromReverse(t *testing.T) {
tests := []struct {
reverseName string
expectedAddress string
}{
{
"54.119.58.176.in-addr.arpa.",
"176.58.119.54",
},
{
".58.176.in-addr.arpa.",
"",
},
{
"b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.in-addr.arpa.",
"",
},
{
"b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
"2001:db8::567:89ab",
},
{
"d.0.1.0.0.2.ip6.arpa.",
"",
},
{
"54.119.58.176.ip6.arpa.",
"",
},
{
"NONAME",
"",
},
{
"",
"",
},
}
for i, test := range tests {
got := ExtractAddressFromReverse(test.reverseName)
if got != test.expectedAddress {
t.Errorf("Test %d, expected '%s', got '%s'", i, test.expectedAddress, got)
}
}
}

View File

@@ -0,0 +1,20 @@
package dnsutil
import (
"errors"
"github.com/miekg/dns"
)
// TrimZone removes the zone component from q. It returns the trimmed
// name or an error is zone is longer then qname. The trimmed name will be returned
// without a trailing dot.
func TrimZone(q string, z string) (string, error) {
zl := dns.CountLabel(z)
i, ok := dns.PrevLabel(q, zl)
if ok || i-1 < 0 {
return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z)
}
// This includes the '.', remove on return
return q[:i-1], nil
}

View File

@@ -0,0 +1,39 @@
package dnsutil
import (
"errors"
"testing"
"github.com/miekg/dns"
)
func TestTrimZone(t *testing.T) {
tests := []struct {
qname string
zone string
expected string
err error
}{
{"a.example.org", "example.org", "a", nil},
{"a.b.example.org", "example.org", "a.b", nil},
{"b.", ".", "b", nil},
{"example.org", "example.org", "", errors.New("should err")},
{"org", "example.org", "", errors.New("should err")},
}
for i, tc := range tests {
got, err := TrimZone(dns.Fqdn(tc.qname), dns.Fqdn(tc.zone))
if tc.err != nil && err == nil {
t.Errorf("Test %d, expected error got nil", i)
continue
}
if tc.err == nil && err != nil {
t.Errorf("Test %d, expected no error got %v", i, err)
continue
}
if got != tc.expected {
t.Errorf("Test %d, expected %s, got %s", i, tc.expected, got)
continue
}
}
}

46
plugin/pkg/edns/edns.go Normal file
View File

@@ -0,0 +1,46 @@
// Package edns provides function useful for adding/inspecting OPT records to/in messages.
package edns
import (
"errors"
"github.com/miekg/dns"
)
// Version checks the EDNS version in the request. If error
// is nil everything is OK and we can invoke the plugin. If non-nil, the
// returned Msg is valid to be returned to the client (and should). For some
// reason this response should not contain a question RR in the question section.
func Version(req *dns.Msg) (*dns.Msg, error) {
opt := req.IsEdns0()
if opt == nil {
return nil, nil
}
if opt.Version() == 0 {
return nil, nil
}
m := new(dns.Msg)
m.SetReply(req)
// zero out question section, wtf.
m.Question = nil
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetVersion(0)
o.SetExtendedRcode(dns.RcodeBadVers)
m.Extra = []dns.RR{o}
return m, errors.New("EDNS0 BADVERS")
}
// Size returns a normalized size based on proto.
func Size(proto string, size int) int {
if proto == "tcp" {
return dns.MaxMsgSize
}
if size < dns.MinMsgSize {
return dns.MinMsgSize
}
return size
}

View File

@@ -0,0 +1,37 @@
package edns
import (
"testing"
"github.com/miekg/dns"
)
func TestVersion(t *testing.T) {
m := ednsMsg()
m.Extra[0].(*dns.OPT).SetVersion(2)
_, err := Version(m)
if err == nil {
t.Errorf("expected wrong version, but got OK")
}
}
func TestVersionNoEdns(t *testing.T) {
m := ednsMsg()
m.Extra = nil
_, err := Version(m)
if err != nil {
t.Errorf("expected no error, but got one: %s", err)
}
}
func ednsMsg() *dns.Msg {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
m.Extra = append(m.Extra, o)
return m
}

View File

@@ -0,0 +1,243 @@
package healthcheck
import (
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
)
// UpstreamHostDownFunc can be used to customize how Down behaves.
type UpstreamHostDownFunc func(*UpstreamHost) bool
// UpstreamHost represents a single proxy upstream
type UpstreamHost struct {
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
Name string // IP address (and port) of this upstream host
Network string // Network (tcp, unix, etc) of the host, default "" is "tcp"
Fails int32
FailTimeout time.Duration
OkUntil time.Time
CheckDown UpstreamHostDownFunc
CheckURL string
WithoutPathPrefix string
Checking bool
CheckMu sync.Mutex
}
// Down checks whether the upstream host is down or not.
// Down will try to use uh.CheckDown first, and will fall
// back to some default criteria if necessary.
func (uh *UpstreamHost) Down() bool {
if uh.CheckDown == nil {
// Default settings
fails := atomic.LoadInt32(&uh.Fails)
after := false
uh.CheckMu.Lock()
until := uh.OkUntil
uh.CheckMu.Unlock()
if !until.IsZero() && time.Now().After(until) {
after = true
}
return after || fails > 0
}
return uh.CheckDown(uh)
}
// HostPool is a collection of UpstreamHosts.
type HostPool []*UpstreamHost
// HealthCheck is used for performing healthcheck
// on a collection of upstream hosts and select
// one based on the policy.
type HealthCheck struct {
wg sync.WaitGroup // Used to wait for running goroutines to stop.
stop chan struct{} // Signals running goroutines to stop.
Hosts HostPool
Policy Policy
Spray Policy
FailTimeout time.Duration
MaxFails int32
Future time.Duration
Path string
Port string
Interval time.Duration
}
// Start starts the healthcheck
func (u *HealthCheck) Start() {
u.stop = make(chan struct{})
if u.Path != "" {
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.healthCheckWorker(u.stop)
}()
}
}
// Stop sends a signal to all goroutines started by this staticUpstream to exit
// and waits for them to finish before returning.
func (u *HealthCheck) Stop() error {
close(u.stop)
u.wg.Wait()
return nil
}
// This was moved into a thread so that each host could throw a health
// check at the same time. The reason for this is that if we are checking
// 3 hosts, and the first one is gone, and we spend minutes timing out to
// fail it, we would not have been doing any other health checks in that
// time. So we now have a per-host lock and a threaded health check.
//
// We use the Checking bool to avoid concurrent checks against the same
// host; if one is taking a long time, the next one will find a check in
// progress and simply return before trying.
//
// We are carefully avoiding having the mutex locked while we check,
// otherwise checks will back up, potentially a lot of them if a host is
// absent for a long time. This arrangement makes checks quickly see if
// they are the only one running and abort otherwise.
func healthCheckURL(nextTs time.Time, host *UpstreamHost) {
// lock for our bool check. We don't just defer the unlock because
// we don't want the lock held while http.Get runs
host.CheckMu.Lock()
// are we mid check? Don't run another one
if host.Checking {
host.CheckMu.Unlock()
return
}
host.Checking = true
host.CheckMu.Unlock()
//log.Printf("[DEBUG] Healthchecking %s, nextTs is %s\n", url, nextTs.Local())
// fetch that url. This has been moved into a go func because
// when the remote host is not merely not serving, but actually
// absent, then tcp syn timeouts can be very long, and so one
// fetch could last several check intervals
if r, err := http.Get(host.CheckURL); err == nil {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
if r.StatusCode < 200 || r.StatusCode >= 400 {
log.Printf("[WARNING] Host %s health check returned HTTP code %d\n",
host.Name, r.StatusCode)
nextTs = time.Unix(0, 0)
}
} else {
log.Printf("[WARNING] Host %s health check probe failed: %v\n", host.Name, err)
nextTs = time.Unix(0, 0)
}
host.CheckMu.Lock()
host.Checking = false
host.OkUntil = nextTs
host.CheckMu.Unlock()
}
func (u *HealthCheck) healthCheck() {
for _, host := range u.Hosts {
if host.CheckURL == "" {
var hostName, checkPort string
// The DNS server might be an HTTP server. If so, extract its name.
ret, err := url.Parse(host.Name)
if err == nil && len(ret.Host) > 0 {
hostName = ret.Host
} else {
hostName = host.Name
}
// Extract the port number from the parsed server name.
checkHostName, checkPort, err := net.SplitHostPort(hostName)
if err != nil {
checkHostName = hostName
}
if u.Port != "" {
checkPort = u.Port
}
host.CheckURL = "http://" + net.JoinHostPort(checkHostName, checkPort) + u.Path
}
// calculate this before the get
nextTs := time.Now().Add(u.Future)
// locks/bools should prevent requests backing up
go healthCheckURL(nextTs, host)
}
}
func (u *HealthCheck) healthCheckWorker(stop chan struct{}) {
ticker := time.NewTicker(u.Interval)
u.healthCheck()
for {
select {
case <-ticker.C:
u.healthCheck()
case <-stop:
ticker.Stop()
return
}
}
}
// Select selects an upstream host based on the policy
// and the healthcheck result.
func (u *HealthCheck) Select() *UpstreamHost {
pool := u.Hosts
if len(pool) == 1 {
if pool[0].Down() && u.Spray == nil {
return nil
}
return pool[0]
}
allDown := true
for _, host := range pool {
if !host.Down() {
allDown = false
break
}
}
if allDown {
if u.Spray == nil {
return nil
}
return u.Spray.Select(pool)
}
if u.Policy == nil {
h := (&Random{}).Select(pool)
if h != nil {
return h
}
if h == nil && u.Spray == nil {
return nil
}
return u.Spray.Select(pool)
}
h := u.Policy.Select(pool)
if h != nil {
return h
}
if u.Spray == nil {
return nil
}
return u.Spray.Select(pool)
}

View File

@@ -0,0 +1,120 @@
package healthcheck
import (
"log"
"math/rand"
"sync/atomic"
)
var (
// SupportedPolicies is the collection of policies registered
SupportedPolicies = make(map[string]func() Policy)
)
// RegisterPolicy adds a custom policy to the proxy.
func RegisterPolicy(name string, policy func() Policy) {
SupportedPolicies[name] = policy
}
// Policy decides how a host will be selected from a pool. When all hosts are unhealthy, it is assumed the
// healthchecking failed. In this case each policy will *randomly* return a host from the pool to prevent
// no traffic to go through at all.
type Policy interface {
Select(pool HostPool) *UpstreamHost
}
func init() {
RegisterPolicy("random", func() Policy { return &Random{} })
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
}
// Random is a policy that selects up hosts from a pool at random.
type Random struct{}
// Select selects an up host at random from the specified pool.
func (r *Random) Select(pool HostPool) *UpstreamHost {
// instead of just generating a random index
// this is done to prevent selecting a down host
var randHost *UpstreamHost
count := 0
for _, host := range pool {
if host.Down() {
continue
}
count++
if count == 1 {
randHost = host
} else {
r := rand.Int() % count
if r == (count - 1) {
randHost = host
}
}
}
return randHost
}
// Spray is a policy that selects a host from a pool at random. This should be used as a last ditch
// attempt to get a host when all hosts are reporting unhealthy.
type Spray struct{}
// Select selects an up host at random from the specified pool.
func (r *Spray) Select(pool HostPool) *UpstreamHost {
rnd := rand.Int() % len(pool)
randHost := pool[rnd]
log.Printf("[WARNING] All hosts reported as down, spraying to target: %s", randHost.Name)
return randHost
}
// LeastConn is a policy that selects the host with the least connections.
type LeastConn struct{}
// Select selects the up host with the least number of connections in the
// pool. If more than one host has the same least number of connections,
// one of the hosts is chosen at random.
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
var bestHost *UpstreamHost
count := 0
leastConn := int64(1<<63 - 1)
for _, host := range pool {
if host.Down() {
continue
}
hostConns := host.Conns
if hostConns < leastConn {
bestHost = host
leastConn = hostConns
count = 1
} else if hostConns == leastConn {
// randomly select host among hosts with least connections
count++
if count == 1 {
bestHost = host
} else {
r := rand.Int() % count
if r == (count - 1) {
bestHost = host
}
}
}
}
return bestHost
}
// RoundRobin is a policy that selects hosts based on round robin ordering.
type RoundRobin struct {
Robin uint32
}
// Select selects an up host from the pool using a round robin ordering scheme.
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
poolLen := uint32(len(pool))
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
host := pool[selection]
// if the currently selected host is down, just ffwd to up host
for i := uint32(1); host.Down() && i < poolLen; i++ {
host = pool[(selection+i)%poolLen]
}
return host
}

View File

@@ -0,0 +1,143 @@
package healthcheck
import (
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
)
var workableServer *httptest.Server
func TestMain(m *testing.M) {
workableServer = httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// do nothing
}))
r := m.Run()
workableServer.Close()
os.Exit(r)
}
type customPolicy struct{}
func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
return pool[0]
}
func testPool() HostPool {
pool := []*UpstreamHost{
{
Name: workableServer.URL, // this should resolve (healthcheck test)
},
{
Name: "http://shouldnot.resolve", // this shouldn't
},
{
Name: "http://C",
},
}
return HostPool(pool)
}
func TestRegisterPolicy(t *testing.T) {
name := "custom"
customPolicy := &customPolicy{}
RegisterPolicy(name, func() Policy { return customPolicy })
if _, ok := SupportedPolicies[name]; !ok {
t.Error("Expected supportedPolicies to have a custom policy.")
}
}
// TODO(miek): Disabled for now, we should get out of the habit of using
// realtime in these tests .
func testHealthCheck(t *testing.T) {
log.SetOutput(ioutil.Discard)
u := &HealthCheck{
Hosts: testPool(),
FailTimeout: 10 * time.Second,
Future: 60 * time.Second,
MaxFails: 1,
}
u.healthCheck()
// sleep a bit, it's async now
time.Sleep(time.Duration(2 * time.Second))
if u.Hosts[0].Down() {
t.Error("Expected first host in testpool to not fail healthcheck.")
}
if !u.Hosts[1].Down() {
t.Error("Expected second host in testpool to fail healthcheck.")
}
}
func TestSelect(t *testing.T) {
u := &HealthCheck{
Hosts: testPool()[:3],
FailTimeout: 10 * time.Second,
Future: 60 * time.Second,
MaxFails: 1,
}
u.Hosts[0].OkUntil = time.Unix(0, 0)
u.Hosts[1].OkUntil = time.Unix(0, 0)
u.Hosts[2].OkUntil = time.Unix(0, 0)
if h := u.Select(); h != nil {
t.Error("Expected select to return nil as all host are down")
}
u.Hosts[2].OkUntil = time.Time{}
if h := u.Select(); h == nil {
t.Error("Expected select to not return nil")
}
}
func TestRoundRobinPolicy(t *testing.T) {
pool := testPool()
rrPolicy := &RoundRobin{}
h := rrPolicy.Select(pool)
// First selected host is 1, because counter starts at 0
// and increments before host is selected
if h != pool[1] {
t.Error("Expected first round robin host to be second host in the pool.")
}
h = rrPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected second round robin host to be third host in the pool.")
}
// mark host as down
pool[0].OkUntil = time.Unix(0, 0)
h = rrPolicy.Select(pool)
if h != pool[1] {
t.Error("Expected third round robin host to be first host in the pool.")
}
}
func TestLeastConnPolicy(t *testing.T) {
pool := testPool()
lcPolicy := &LeastConn{}
pool[0].Conns = 10
pool[1].Conns = 10
h := lcPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected least connection host to be third host.")
}
pool[2].Conns = 100
h = lcPolicy.Select(pool)
if h != pool[0] && h != pool[1] {
t.Error("Expected least connection host to be first or second host.")
}
}
func TestCustomPolicy(t *testing.T) {
pool := testPool()
customPolicy := &customPolicy{}
h := customPolicy.Select(pool)
if h != pool[0] {
t.Error("Expected custom policy host to be the first host.")
}
}

View File

@@ -0,0 +1,23 @@
// Package nonwriter implements a dns.ResponseWriter that never writes, but captures the dns.Msg being written.
package nonwriter
import (
"github.com/miekg/dns"
)
// Writer is a type of ResponseWriter that captures the message, but never writes to the client.
type Writer struct {
dns.ResponseWriter
Msg *dns.Msg
}
// New makes and returns a new NonWriter.
func New(w dns.ResponseWriter) *Writer { return &Writer{ResponseWriter: w} }
// WriteMsg records the message, but doesn't write it itself.
func (w *Writer) WriteMsg(res *dns.Msg) error {
w.Msg = res
return nil
}
func (w *Writer) Write(buf []byte) (int, error) { return len(buf), nil }

View File

@@ -0,0 +1,19 @@
package nonwriter
import (
"testing"
"github.com/miekg/dns"
)
func TestNonWriter(t *testing.T) {
nw := New(nil)
m := new(dns.Msg)
m.SetQuestion("example.org.", dns.TypeA)
if err := nw.WriteMsg(m); err != nil {
t.Errorf("Got error when writing to nonwriter: %s", err)
}
if x := nw.Msg.Question[0].Name; x != "example.org." {
t.Errorf("Expacted 'example.org.' got %q:", x)
}
}

16
plugin/pkg/rcode/rcode.go Normal file
View File

@@ -0,0 +1,16 @@
package rcode
import (
"strconv"
"github.com/miekg/dns"
)
// ToString convert the rcode to the official DNS string, or to "RCODE"+value if the RCODE
// value is unknown.
func ToString(rcode int) string {
if str, ok := dns.RcodeToString[rcode]; ok {
return str
}
return "RCODE" + strconv.Itoa(rcode)
}

View File

@@ -0,0 +1,29 @@
package rcode
import (
"testing"
"github.com/miekg/dns"
)
func TestToString(t *testing.T) {
tests := []struct {
in int
expected string
}{
{
dns.RcodeSuccess,
"NOERROR",
},
{
28,
"RCODE28",
},
}
for i, test := range tests {
got := ToString(test.in)
if got != test.expected {
t.Errorf("Test %d, expected %s, got %s", i, test.expected, got)
}
}
}

View File

@@ -0,0 +1,161 @@
package replacer
import (
"strconv"
"strings"
"time"
"github.com/coredns/coredns/plugin/pkg/dnsrecorder"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Replacer is a type which can replace placeholder
// substrings in a string with actual values from a
// dns.Msg and responseRecorder. Always use
// NewReplacer to get one of these.
type Replacer interface {
Replace(string) string
Set(key, value string)
}
type replacer struct {
replacements map[string]string
emptyValue string
}
// New makes a new replacer based on r and rr.
// Do not create a new replacer until r and rr have all
// the needed values, because this function copies those
// values into the replacer. rr may be nil if it is not
// available. emptyValue should be the string that is used
// in place of empty string (can still be empty string).
func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer {
req := request.Request{W: rr, Req: r}
rep := replacer{
replacements: map[string]string{
"{type}": req.Type(),
"{name}": req.Name(),
"{class}": req.Class(),
"{proto}": req.Proto(),
"{when}": func() string {
return time.Now().Format(timeFormat)
}(),
"{size}": strconv.Itoa(req.Len()),
"{remote}": req.IP(),
"{port}": req.Port(),
},
emptyValue: emptyValue,
}
if rr != nil {
rcode := dns.RcodeToString[rr.Rcode]
if rcode == "" {
rcode = strconv.Itoa(rr.Rcode)
}
rep.replacements["{rcode}"] = rcode
rep.replacements["{rsize}"] = strconv.Itoa(rr.Len)
rep.replacements["{duration}"] = time.Since(rr.Start).String()
if rr.Msg != nil {
rep.replacements[headerReplacer+"rflags}"] = flagsToString(rr.Msg.MsgHdr)
}
}
// Header placeholders (case-insensitive)
rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id))
rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(r.Opcode)
rep.replacements[headerReplacer+"do}"] = boolToString(req.Do())
rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size())
return rep
}
// Replace performs a replacement of values on s and returns
// the string with the replaced values.
func (r replacer) Replace(s string) string {
// Header replacements - these are case-insensitive, so we can't just use strings.Replace()
for strings.Contains(s, headerReplacer) {
idxStart := strings.Index(s, headerReplacer)
endOffset := idxStart + len(headerReplacer)
idxEnd := strings.Index(s[endOffset:], "}")
if idxEnd > -1 {
placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1])
replacement := r.replacements[placeholder]
if replacement == "" {
replacement = r.emptyValue
}
s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:]
} else {
break
}
}
// Regular replacements - these are easier because they're case-sensitive
for placeholder, replacement := range r.replacements {
if replacement == "" {
replacement = r.emptyValue
}
s = strings.Replace(s, placeholder, replacement, -1)
}
return s
}
// Set sets key to value in the replacements map.
func (r replacer) Set(key, value string) {
r.replacements["{"+key+"}"] = value
}
func boolToString(b bool) string {
if b {
return "true"
}
return "false"
}
// flagsToString checks all header flags and returns those
// that are set as a string separated with commas
func flagsToString(h dns.MsgHdr) string {
flags := make([]string, 7)
i := 0
if h.Response {
flags[i] = "qr"
i++
}
if h.Authoritative {
flags[i] = "aa"
i++
}
if h.Truncated {
flags[i] = "tc"
i++
}
if h.RecursionDesired {
flags[i] = "rd"
i++
}
if h.RecursionAvailable {
flags[i] = "ra"
i++
}
if h.Zero {
flags[i] = "z"
i++
}
if h.AuthenticatedData {
flags[i] = "ad"
i++
}
if h.CheckingDisabled {
flags[i] = "cd"
i++
}
return strings.Join(flags[:i], ",")
}
const (
timeFormat = "02/Jan/2006:15:04:05 -0700"
headerReplacer = "{>"
)

View File

@@ -0,0 +1,61 @@
package replacer
import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnsrecorder"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
func TestNewReplacer(t *testing.T) {
w := dnsrecorder.New(&test.ResponseWriter{})
r := new(dns.Msg)
r.SetQuestion("example.org.", dns.TypeHINFO)
r.MsgHdr.AuthenticatedData = true
replaceValues := New(r, w, "")
switch v := replaceValues.(type) {
case replacer:
if v.replacements["{type}"] != "HINFO" {
t.Errorf("Expected type to be HINFO, got %q", v.replacements["{type}"])
}
if v.replacements["{name}"] != "example.org." {
t.Errorf("Expected request name to be example.org., got %q", v.replacements["{name}"])
}
if v.replacements["{size}"] != "29" { // size of request
t.Errorf("Expected size to be 29, got %q", v.replacements["{size}"])
}
default:
t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n")
}
}
func TestSet(t *testing.T) {
w := dnsrecorder.New(&test.ResponseWriter{})
r := new(dns.Msg)
r.SetQuestion("example.org.", dns.TypeHINFO)
r.MsgHdr.AuthenticatedData = true
repl := New(r, w, "")
repl.Set("name", "coredns.io.")
repl.Set("type", "A")
repl.Set("size", "20")
if repl.Replace("This name is {name}") != "This name is coredns.io." {
t.Error("Expected name replacement failed")
}
if repl.Replace("This type is {type}") != "This type is A" {
t.Error("Expected type replacement failed")
}
if repl.Replace("The request size is {size}") != "The request size is 20" {
t.Error("Expected size replacement failed")
}
}

View File

@@ -0,0 +1,61 @@
package response
import "fmt"
// Class holds sets of Types
type Class int
const (
// All is a meta class encompassing all the classes.
All Class = iota
// Success is a class for a successful response.
Success
// Denial is a class for denying existence (NXDOMAIN, or a nodata: type does not exist)
Denial
// Error is a class for errors, right now defined as not Success and not Denial
Error
)
func (c Class) String() string {
switch c {
case All:
return "all"
case Success:
return "success"
case Denial:
return "denial"
case Error:
return "error"
}
return ""
}
// ClassFromString returns the class from the string s. If not class matches
// the All class and an error are returned
func ClassFromString(s string) (Class, error) {
switch s {
case "all":
return All, nil
case "success":
return Success, nil
case "denial":
return Denial, nil
case "error":
return Error, nil
}
return All, fmt.Errorf("invalid Class: %s", s)
}
// Classify classifies the Type t, it returns its Class.
func Classify(t Type) Class {
switch t {
case NoError, Delegation:
return Success
case NameError, NoData:
return Denial
case OtherError:
fallthrough
default:
return Error
}
}

View File

@@ -0,0 +1,146 @@
package response
import (
"fmt"
"time"
"github.com/miekg/dns"
)
// Type is the type of the message.
type Type int
const (
// NoError indicates a positive reply
NoError Type = iota
// NameError is a NXDOMAIN in header, SOA in auth.
NameError
// NoData indicates name found, but not the type: NOERROR in header, SOA in auth.
NoData
// Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked).
Delegation
// Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR.
Meta
// Update is an dynamic update message.
Update
// OtherError indicates any other error: don't cache these.
OtherError
)
var toString = map[Type]string{
NoError: "NOERROR",
NameError: "NXDOMAIN",
NoData: "NODATA",
Delegation: "DELEGATION",
Meta: "META",
Update: "UPDATE",
OtherError: "OTHERERROR",
}
func (t Type) String() string { return toString[t] }
// TypeFromString returns the type from the string s. If not type matches
// the OtherError type and an error are returned.
func TypeFromString(s string) (Type, error) {
for t, str := range toString {
if s == str {
return t, nil
}
}
return NoError, fmt.Errorf("invalid Type: %s", s)
}
// Typify classifies a message, it returns the Type.
func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) {
if m == nil {
return OtherError, nil
}
opt := m.IsEdns0()
do := false
if opt != nil {
do = opt.Do()
}
if m.Opcode == dns.OpcodeUpdate {
return Update, opt
}
// Check transfer and update first
if m.Opcode == dns.OpcodeNotify {
return Meta, opt
}
if len(m.Question) > 0 {
if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR {
return Meta, opt
}
}
// If our message contains any expired sigs and we care about that, we should return expired
if do {
if expired := typifyExpired(m, t); expired {
return OtherError, opt
}
}
if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {
return NoError, opt
}
soa := false
ns := 0
for _, r := range m.Ns {
if r.Header().Rrtype == dns.TypeSOA {
soa = true
continue
}
if r.Header().Rrtype == dns.TypeNS {
ns++
}
}
// Check length of different sections, and drop stuff that is just to large? TODO(miek).
if soa && m.Rcode == dns.RcodeSuccess {
return NoData, opt
}
if soa && m.Rcode == dns.RcodeNameError {
return NameError, opt
}
if ns > 0 && m.Rcode == dns.RcodeSuccess {
return Delegation, opt
}
if m.Rcode == dns.RcodeSuccess {
return NoError, opt
}
return OtherError, opt
}
func typifyExpired(m *dns.Msg, t time.Time) bool {
if expired := typifyExpiredRRSIG(m.Answer, t); expired {
return true
}
if expired := typifyExpiredRRSIG(m.Ns, t); expired {
return true
}
if expired := typifyExpiredRRSIG(m.Extra, t); expired {
return true
}
return false
}
func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool {
for _, r := range rrs {
if r.Header().Rrtype != dns.TypeRRSIG {
continue
}
ok := r.(*dns.RRSIG).ValidityPeriod(t)
if !ok {
return true
}
}
return false
}

View File

@@ -0,0 +1,84 @@
package response
import (
"testing"
"time"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
func TestTypifyNilMsg(t *testing.T) {
var m *dns.Msg
ty, _ := Typify(m, time.Now().UTC())
if ty != OtherError {
t.Errorf("message wrongly typified, expected OtherError, got %s", ty)
}
}
func TestTypifyDelegation(t *testing.T) {
m := delegationMsg()
mt, _ := Typify(m, time.Now().UTC())
if mt != Delegation {
t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
}
}
func TestTypifyRRSIG(t *testing.T) {
now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017")
utc := now.UTC()
m := delegationMsgRRSIGOK()
if mt, _ := Typify(m, utc); mt != Delegation {
t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
}
// Still a Delegation because EDNS0 OPT DO bool is not set, so we won't check the sigs.
m = delegationMsgRRSIGFail()
if mt, _ := Typify(m, utc); mt != Delegation {
t.Errorf("message is wrongly typified, expected Delegation, got %s", mt)
}
m = delegationMsgRRSIGFail()
m = addOpt(m)
if mt, _ := Typify(m, utc); mt != OtherError {
t.Errorf("message is wrongly typified, expected OtherError, got %s", mt)
}
}
func delegationMsg() *dns.Msg {
return &dns.Msg{
Ns: []dns.RR{
test.NS("miek.nl. 3600 IN NS linode.atoom.net."),
test.NS("miek.nl. 3600 IN NS ns-ext.nlnetlabs.nl."),
test.NS("miek.nl. 3600 IN NS omval.tednet.nl."),
},
Extra: []dns.RR{
test.A("omval.tednet.nl. 3600 IN A 185.49.141.42"),
test.AAAA("omval.tednet.nl. 3600 IN AAAA 2a04:b900:0:100::42"),
},
}
}
func delegationMsgRRSIGOK() *dns.Msg {
del := delegationMsg()
del.Ns = append(del.Ns,
test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20170521031301 20170421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
)
return del
}
func delegationMsgRRSIGFail() *dns.Msg {
del := delegationMsg()
del.Ns = append(del.Ns,
test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
)
return del
}
func addOpt(m *dns.Msg) *dns.Msg {
m.Extra = append(m.Extra, test.OPT(4096, true))
return m
}

View File

@@ -0,0 +1,64 @@
/*
Copyright 2012 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package singleflight provides a duplicate function call suppression
// mechanism.
package singleflight
import "sync"
// call is an in-flight or completed Do call
type call struct {
wg sync.WaitGroup
val interface{}
err error
}
// Group represents a class of work and forms a namespace in which
// units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[uint32]*call // lazily initialized
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
func (g *Group) Do(key uint32, fn func() (interface{}, error)) (interface{}, error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[uint32]*call)
}
if c, ok := g.m[key]; ok {
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.err
}

View File

@@ -0,0 +1,85 @@
/*
Copyright 2012 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package singleflight
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestDo(t *testing.T) {
var g Group
v, err := g.Do(1, func() (interface{}, error) {
return "bar", nil
})
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
t.Errorf("Do = %v; want %v", got, want)
}
if err != nil {
t.Errorf("Do error = %v", err)
}
}
func TestDoErr(t *testing.T) {
var g Group
someErr := errors.New("Some error")
v, err := g.Do(1, func() (interface{}, error) {
return nil, someErr
})
if err != someErr {
t.Errorf("Do error = %v; want someErr", err)
}
if v != nil {
t.Errorf("unexpected non-nil value %#v", v)
}
}
func TestDoDupSuppress(t *testing.T) {
var g Group
c := make(chan string)
var calls int32
fn := func() (interface{}, error) {
atomic.AddInt32(&calls, 1)
return <-c, nil
}
const n = 10
var wg sync.WaitGroup
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
v, err := g.Do(1, fn)
if err != nil {
t.Errorf("Do error: %v", err)
}
if v.(string) != "bar" {
t.Errorf("got %q; want %q", v, "bar")
}
wg.Done()
}()
}
time.Sleep(100 * time.Millisecond) // let goroutines above block
c <- "bar"
wg.Wait()
if got := atomic.LoadInt32(&calls); got != 1 {
t.Errorf("number of calls = %d; want 1", got)
}
}

128
plugin/pkg/tls/tls.go Normal file
View File

@@ -0,0 +1,128 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"time"
)
// NewTLSConfigFromArgs returns a TLS config based upon the passed
// in list of arguments. Typically these come straight from the
// Corefile.
// no args
// - creates a Config with no cert and using system CAs
// - use for a client that talks to a server with a public signed cert (CA installed in system)
// - the client will not be authenticated by the server since there is no cert
// one arg: the path to CA PEM file
// - creates a Config with no cert using a specific CA
// - use for a client that talks to a server with a private signed cert (CA not installed in system)
// - the client will not be authenticated by the server since there is no cert
// two args: path to cert PEM file, the path to private key PEM file
// - creates a Config with a cert, using system CAs to validate the other end
// - use for:
// - a server; or,
// - a client that talks to a server with a public cert and needs certificate-based authentication
// - the other end will authenticate this end via the provided cert
// - the cert of the other end will be verified via system CAs
// three args: path to cert PEM file, path to client private key PEM file, path to CA PEM file
// - creates a Config with the cert, using specified CA to validate the other end
// - use for:
// - a server; or,
// - a client that talks to a server with a privately signed cert and needs certificate-based
// authentication
// - the other end will authenticate this end via the provided cert
// - this end will verify the other end's cert using the specified CA
func NewTLSConfigFromArgs(args ...string) (*tls.Config, error) {
var err error
var c *tls.Config
switch len(args) {
case 0:
// No client cert, use system CA
c, err = NewTLSClientConfig("")
case 1:
// No client cert, use specified CA
c, err = NewTLSClientConfig(args[0])
case 2:
// Client cert, use system CA
c, err = NewTLSConfig(args[0], args[1], "")
case 3:
// Client cert, use specified CA
c, err = NewTLSConfig(args[0], args[1], args[2])
default:
err = fmt.Errorf("maximum of three arguments allowed for TLS config, found %d", len(args))
}
if err != nil {
return nil, err
}
return c, nil
}
// NewTLSConfig returns a TLS config that includes a certificate
// Use for server TLS config or when using a client certificate
// If caPath is empty, system CAs will be used
func NewTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("could not load TLS cert: %s", err)
}
roots, err := loadRoots(caPath)
if err != nil {
return nil, err
}
return &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: roots}, nil
}
// NewTLSClientConfig returns a TLS config for a client connection
// If caPath is empty, system CAs will be used
func NewTLSClientConfig(caPath string) (*tls.Config, error) {
roots, err := loadRoots(caPath)
if err != nil {
return nil, err
}
return &tls.Config{RootCAs: roots}, nil
}
func loadRoots(caPath string) (*x509.CertPool, error) {
if caPath == "" {
return nil, nil
}
roots := x509.NewCertPool()
pem, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("error reading %s: %s", caPath, err)
}
ok := roots.AppendCertsFromPEM(pem)
if !ok {
return nil, fmt.Errorf("could not read root certs: %s", err)
}
return roots, nil
}
// NewHTTPSTransport returns an HTTP transport configured using tls.Config
func NewHTTPSTransport(cc *tls.Config) *http.Transport {
// this seems like a bad idea but was here in the previous version
if cc != nil {
cc.InsecureSkipVerify = true
}
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: cc,
MaxIdleConnsPerHost: 25,
}
return tr
}

101
plugin/pkg/tls/tls_test.go Normal file
View File

@@ -0,0 +1,101 @@
package tls
import (
"path/filepath"
"testing"
"github.com/coredns/coredns/plugin/test"
)
func getPEMFiles(t *testing.T) (rmFunc func(), cert, key, ca string) {
tempDir, rmFunc, err := test.WritePEMFiles("")
if err != nil {
t.Fatalf("Could not write PEM files: %s", err)
}
cert = filepath.Join(tempDir, "cert.pem")
key = filepath.Join(tempDir, "key.pem")
ca = filepath.Join(tempDir, "ca.pem")
return
}
func TestNewTLSConfig(t *testing.T) {
rmFunc, cert, key, ca := getPEMFiles(t)
defer rmFunc()
_, err := NewTLSConfig(cert, key, ca)
if err != nil {
t.Errorf("Failed to create TLSConfig: %s", err)
}
}
func TestNewTLSClientConfig(t *testing.T) {
rmFunc, _, _, ca := getPEMFiles(t)
defer rmFunc()
_, err := NewTLSClientConfig(ca)
if err != nil {
t.Errorf("Failed to create TLSConfig: %s", err)
}
}
func TestNewTLSConfigFromArgs(t *testing.T) {
rmFunc, cert, key, ca := getPEMFiles(t)
defer rmFunc()
_, err := NewTLSConfigFromArgs()
if err != nil {
t.Errorf("Failed to create TLSConfig: %s", err)
}
c, err := NewTLSConfigFromArgs(ca)
if err != nil {
t.Errorf("Failed to create TLSConfig: %s", err)
}
if c.RootCAs == nil {
t.Error("RootCAs should not be nil when one arg passed")
}
c, err = NewTLSConfigFromArgs(cert, key)
if err != nil {
t.Errorf("Failed to create TLSConfig: %s", err)
}
if c.RootCAs != nil {
t.Error("RootCAs should be nil when two args passed")
}
if len(c.Certificates) != 1 {
t.Error("Certificates should have a single entry when two args passed")
}
args := []string{cert, key, ca}
c, err = NewTLSConfigFromArgs(args...)
if err != nil {
t.Errorf("Failed to create TLSConfig: %s", err)
}
if c.RootCAs == nil {
t.Error("RootCAs should not be nil when three args passed")
}
if len(c.Certificates) != 1 {
t.Error("Certificateis should have a single entry when three args passed")
}
}
func TestNewHTTPSTransport(t *testing.T) {
rmFunc, _, _, ca := getPEMFiles(t)
defer rmFunc()
cc, err := NewTLSClientConfig(ca)
if err != nil {
t.Errorf("Failed to create TLSConfig: %s", err)
}
tr := NewHTTPSTransport(cc)
if tr == nil {
t.Errorf("Failed to create https transport with cc")
}
tr = NewHTTPSTransport(nil)
if tr == nil {
t.Errorf("Failed to create https transport without cc")
}
}

12
plugin/pkg/trace/trace.go Normal file
View File

@@ -0,0 +1,12 @@
package trace
import (
"github.com/coredns/coredns/plugin"
ot "github.com/opentracing/opentracing-go"
)
// Trace holds the tracer and endpoint info
type Trace interface {
plugin.Handler
Tracer() ot.Tracer
}