mirror of
https://github.com/coredns/coredns.git
synced 2025-11-01 10:43:17 -04:00
Cleanup: put middleware helper functions in pkgs (#245)
Move all (almost all) Go files in middleware into their own packages. This makes for better naming and discoverability. Lot of changes elsewhere to make this change. The middleware.State was renamed to request.Request which is better, but still does not cover all use-cases. It was also moved out middleware because it is used by `dnsserver` as well. A pkg/dnsutil packages was added for shared, handy, dns util functions. All normalize functions are now put in normalize.go
This commit is contained in:
57
middleware/pkg/dnsrecorder/recorder.go
Normal file
57
middleware/pkg/dnsrecorder/recorder.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package dnsrecorder
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Recorder is a type of ResponseWriter that captures
|
||||
// the rcode code written to it and also the size of the message
|
||||
// written in the response. A rcode code does not have
|
||||
// to be written, however, in which case 0 must be assumed.
|
||||
// It is best to have the constructor initialize this type
|
||||
// with that default status code.
|
||||
type Recorder struct {
|
||||
dns.ResponseWriter
|
||||
Rcode int
|
||||
Size int
|
||||
Msg *dns.Msg
|
||||
Start time.Time
|
||||
}
|
||||
|
||||
// New makes and returns a new Recorder,
|
||||
// which captures the DNS rcode from the ResponseWriter
|
||||
// and also the length of the response message written through it.
|
||||
func New(w dns.ResponseWriter) *Recorder {
|
||||
return &Recorder{
|
||||
ResponseWriter: w,
|
||||
Rcode: 0,
|
||||
Msg: nil,
|
||||
Start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMsg records the status code and calls the
|
||||
// underlying ResponseWriter's WriteMsg method.
|
||||
func (r *Recorder) WriteMsg(res *dns.Msg) error {
|
||||
r.Rcode = res.Rcode
|
||||
// We may get called multiple times (axfr for instance).
|
||||
// Save the last message, but add the sizes.
|
||||
r.Size += res.Len()
|
||||
r.Msg = res
|
||||
return r.ResponseWriter.WriteMsg(res)
|
||||
}
|
||||
|
||||
// Write is a wrapper that records the size of the message that gets written.
|
||||
func (r *Recorder) Write(buf []byte) (int, error) {
|
||||
n, err := r.ResponseWriter.Write(buf)
|
||||
if err == nil {
|
||||
r.Size += n
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Hijack implements dns.Hijacker. It simply wraps the underlying
|
||||
// ResponseWriter's Hijack method if there is one, or returns an error.
|
||||
func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return }
|
||||
28
middleware/pkg/dnsrecorder/recorder_test.go
Normal file
28
middleware/pkg/dnsrecorder/recorder_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package dnsrecorder
|
||||
|
||||
/*
|
||||
func TestNewResponseRecorder(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
if !(recordRequest.ResponseWriter == w) {
|
||||
t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
|
||||
}
|
||||
if recordRequest.status != http.StatusOK {
|
||||
t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
responseTestString := "test"
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
buf := []byte(responseTestString)
|
||||
recordRequest.Write(buf)
|
||||
if recordRequest.size != len(buf) {
|
||||
t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
|
||||
}
|
||||
if w.Body.String() != responseTestString {
|
||||
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
|
||||
}
|
||||
}
|
||||
*/
|
||||
15
middleware/pkg/dnsutil/cname.go
Normal file
15
middleware/pkg/dnsutil/cname.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package dnsutil
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// DuplicateCNAME returns true if r already exists in records.
|
||||
func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
|
||||
for _, rec := range records {
|
||||
if v, ok := rec.(*dns.CNAME); ok {
|
||||
if v.Target == r.Target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
40
middleware/pkg/dnsutil/reverse.go
Normal file
40
middleware/pkg/dnsutil/reverse.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package dnsutil
|
||||
|
||||
import "strings"
|
||||
|
||||
// ExtractAddressFromReverse turns a standard PTR reverse record name
|
||||
// into an IP address. This works for ipv4 or ipv6.
|
||||
//
|
||||
// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion
|
||||
// failes the empty string is returned.
|
||||
func ExtractAddressFromReverse(reverseName string) string {
|
||||
search := ""
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(reverseName, v4arpaSuffix):
|
||||
search = strings.TrimSuffix(reverseName, v4arpaSuffix)
|
||||
case strings.HasSuffix(reverseName, v6arpaSuffix):
|
||||
search = strings.TrimSuffix(reverseName, v6arpaSuffix)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
// Reverse the segments and then combine them.
|
||||
segments := reverse(strings.Split(search, "."))
|
||||
return strings.Join(segments, ".")
|
||||
}
|
||||
|
||||
func reverse(slice []string) []string {
|
||||
for i := 0; i < len(slice)/2; i++ {
|
||||
j := len(slice) - i - 1
|
||||
slice[i], slice[j] = slice[j], slice[i]
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
const (
|
||||
// v4arpaSuffix is the reverse tree suffix for v4 IP addresses.
|
||||
v4arpaSuffix = ".in-addr.arpa."
|
||||
// v6arpaSuffix is the reverse tree suffix for v6 IP addresses.
|
||||
v6arpaSuffix = ".ip6.arpa."
|
||||
)
|
||||
45
middleware/pkg/edns/edns.go
Normal file
45
middleware/pkg/edns/edns.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package edns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Version checks the EDNS version in the request. If error
|
||||
// is nil everything is OK and we can invoke the middleware. If non-nil, the
|
||||
// returned Msg is valid to be returned to the client (and should). For some
|
||||
// reason this response should not contain a question RR in the question section.
|
||||
func Version(req *dns.Msg) (*dns.Msg, error) {
|
||||
opt := req.IsEdns0()
|
||||
if opt == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if opt.Version() == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
// zero out question section, wtf.
|
||||
m.Question = nil
|
||||
|
||||
o := new(dns.OPT)
|
||||
o.Hdr.Name = "."
|
||||
o.Hdr.Rrtype = dns.TypeOPT
|
||||
o.SetVersion(0)
|
||||
o.SetExtendedRcode(dns.RcodeBadVers)
|
||||
m.Extra = []dns.RR{o}
|
||||
|
||||
return m, errors.New("EDNS0 BADVERS")
|
||||
}
|
||||
|
||||
// Size returns a normalized size based on proto.
|
||||
func Size(proto string, size int) int {
|
||||
if proto == "tcp" {
|
||||
return dns.MaxMsgSize
|
||||
}
|
||||
if size < dns.MinMsgSize {
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
return size
|
||||
}
|
||||
37
middleware/pkg/edns/edns_test.go
Normal file
37
middleware/pkg/edns/edns_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package edns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
m := ednsMsg()
|
||||
m.Extra[0].(*dns.OPT).SetVersion(2)
|
||||
|
||||
_, err := Version(m)
|
||||
if err == nil {
|
||||
t.Errorf("expected wrong version, but got OK")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionNoEdns(t *testing.T) {
|
||||
m := ednsMsg()
|
||||
m.Extra = nil
|
||||
|
||||
_, err := Version(m)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but got one: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ednsMsg() *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := new(dns.OPT)
|
||||
o.Hdr.Name = "."
|
||||
o.Hdr.Rrtype = dns.TypeOPT
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
||||
14
middleware/pkg/rcode/rcode.go
Normal file
14
middleware/pkg/rcode/rcode.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package rcode
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func ToString(rcode int) string {
|
||||
if str, ok := dns.RcodeToString[rcode]; ok {
|
||||
return str
|
||||
}
|
||||
return "RCODE" + strconv.Itoa(rcode)
|
||||
}
|
||||
115
middleware/pkg/replacer/replacer.go
Normal file
115
middleware/pkg/replacer/replacer.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package replacer
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
|
||||
"github.com/miekg/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Replacer is a type which can replace placeholder
|
||||
// substrings in a string with actual values from a
|
||||
// dns.Msg and responseRecorder. Always use
|
||||
// NewReplacer to get one of these.
|
||||
type Replacer interface {
|
||||
Replace(string) string
|
||||
Set(key, value string)
|
||||
}
|
||||
|
||||
type replacer struct {
|
||||
replacements map[string]string
|
||||
emptyValue string
|
||||
}
|
||||
|
||||
// New makes a new replacer based on r and rr.
|
||||
// Do not create a new replacer until r and rr have all
|
||||
// the needed values, because this function copies those
|
||||
// values into the replacer. rr may be nil if it is not
|
||||
// available. emptyValue should be the string that is used
|
||||
// in place of empty string (can still be empty string).
|
||||
func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer {
|
||||
req := request.Request{W: rr, Req: r}
|
||||
rep := replacer{
|
||||
replacements: map[string]string{
|
||||
"{type}": req.Type(),
|
||||
"{name}": req.Name(),
|
||||
"{class}": req.Class(),
|
||||
"{proto}": req.Proto(),
|
||||
"{when}": func() string {
|
||||
return time.Now().Format(timeFormat)
|
||||
}(),
|
||||
"{remote}": req.IP(),
|
||||
"{port}": req.Port(),
|
||||
},
|
||||
emptyValue: emptyValue,
|
||||
}
|
||||
if rr != nil {
|
||||
rcode := dns.RcodeToString[rr.Rcode]
|
||||
if rcode == "" {
|
||||
rcode = strconv.Itoa(rr.Rcode)
|
||||
}
|
||||
rep.replacements["{rcode}"] = rcode
|
||||
rep.replacements["{size}"] = strconv.Itoa(rr.Size)
|
||||
rep.replacements["{duration}"] = time.Since(rr.Start).String()
|
||||
}
|
||||
|
||||
// Header placeholders (case-insensitive)
|
||||
rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id))
|
||||
rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(int(r.Opcode))
|
||||
rep.replacements[headerReplacer+"do}"] = boolToString(req.Do())
|
||||
rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size())
|
||||
|
||||
return rep
|
||||
}
|
||||
|
||||
// Replace performs a replacement of values on s and returns
|
||||
// the string with the replaced values.
|
||||
func (r replacer) Replace(s string) string {
|
||||
// Header replacements - these are case-insensitive, so we can't just use strings.Replace()
|
||||
for strings.Contains(s, headerReplacer) {
|
||||
idxStart := strings.Index(s, headerReplacer)
|
||||
endOffset := idxStart + len(headerReplacer)
|
||||
idxEnd := strings.Index(s[endOffset:], "}")
|
||||
if idxEnd > -1 {
|
||||
placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1])
|
||||
replacement := r.replacements[placeholder]
|
||||
if replacement == "" {
|
||||
replacement = r.emptyValue
|
||||
}
|
||||
s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Regular replacements - these are easier because they're case-sensitive
|
||||
for placeholder, replacement := range r.replacements {
|
||||
if replacement == "" {
|
||||
replacement = r.emptyValue
|
||||
}
|
||||
s = strings.Replace(s, placeholder, replacement, -1)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Set sets key to value in the replacements map.
|
||||
func (r replacer) Set(key, value string) {
|
||||
r.replacements["{"+key+"}"] = value
|
||||
}
|
||||
|
||||
func boolToString(b bool) string {
|
||||
if b {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
const (
|
||||
timeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
headerReplacer = "{>"
|
||||
)
|
||||
119
middleware/pkg/replacer/replacer_test.go
Normal file
119
middleware/pkg/replacer/replacer_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package replacer
|
||||
|
||||
/*
|
||||
func TestNewReplacer(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost", reader)
|
||||
if err != nil {
|
||||
t.Fatal("Request Formation Failed\n")
|
||||
}
|
||||
replaceValues := NewReplacer(request, recordRequest, "")
|
||||
|
||||
switch v := replaceValues.(type) {
|
||||
case replacer:
|
||||
|
||||
if v.replacements["{host}"] != "localhost" {
|
||||
t.Error("Expected host to be localhost")
|
||||
}
|
||||
if v.replacements["{method}"] != "POST" {
|
||||
t.Error("Expected request method to be POST")
|
||||
}
|
||||
if v.replacements["{status}"] != "200" {
|
||||
t.Error("Expected status to be 200")
|
||||
}
|
||||
|
||||
default:
|
||||
t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplace(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost", reader)
|
||||
if err != nil {
|
||||
t.Fatal("Request Formation Failed\n")
|
||||
}
|
||||
request.Header.Set("Custom", "foobarbaz")
|
||||
request.Header.Set("ShorterVal", "1")
|
||||
repl := NewReplacer(request, recordRequest, "-")
|
||||
|
||||
if expected, actual := "This host is localhost.", repl.Replace("This host is {host}."); expected != actual {
|
||||
t.Errorf("{host} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := "This request method is POST.", repl.Replace("This request method is {method}."); expected != actual {
|
||||
t.Errorf("{method} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := "The response status is 200.", repl.Replace("The response status is {status}."); expected != actual {
|
||||
t.Errorf("{status} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := "The Custom header is foobarbaz.", repl.Replace("The Custom header is {>Custom}."); expected != actual {
|
||||
t.Errorf("{>Custom} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test header case-insensitivity
|
||||
if expected, actual := "The cUsToM header is foobarbaz...", repl.Replace("The cUsToM header is {>cUsToM}..."); expected != actual {
|
||||
t.Errorf("{>cUsToM} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test non-existent header/value
|
||||
if expected, actual := "The Non-Existent header is -.", repl.Replace("The Non-Existent header is {>Non-Existent}."); expected != actual {
|
||||
t.Errorf("{>Non-Existent} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test bad placeholder
|
||||
if expected, actual := "Bad {host placeholder...", repl.Replace("Bad {host placeholder..."); expected != actual {
|
||||
t.Errorf("bad placeholder: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test bad header placeholder
|
||||
if expected, actual := "Bad {>Custom placeholder", repl.Replace("Bad {>Custom placeholder"); expected != actual {
|
||||
t.Errorf("bad header placeholder: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test bad header placeholder with valid one later
|
||||
if expected, actual := "Bad -", repl.Replace("Bad {>Custom placeholder {>ShorterVal}"); expected != actual {
|
||||
t.Errorf("bad header placeholders: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test shorter header value with multiple placeholders
|
||||
if expected, actual := "Short value 1 then foobarbaz.", repl.Replace("Short value {>ShorterVal} then {>Custom}."); expected != actual {
|
||||
t.Errorf("short value: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost", reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Request Formation Failed \n")
|
||||
}
|
||||
repl := NewReplacer(request, recordRequest, "")
|
||||
|
||||
repl.Set("host", "getcaddy.com")
|
||||
repl.Set("method", "GET")
|
||||
repl.Set("status", "201")
|
||||
repl.Set("variable", "value")
|
||||
|
||||
if repl.Replace("This host is {host}") != "This host is getcaddy.com" {
|
||||
t.Error("Expected host replacement failed")
|
||||
}
|
||||
if repl.Replace("This request method is {method}") != "This request method is GET" {
|
||||
t.Error("Expected method replacement failed")
|
||||
}
|
||||
if repl.Replace("The response status is {status}") != "The response status is 201" {
|
||||
t.Error("Expected status replacement failed")
|
||||
}
|
||||
if repl.Replace("The value of variable is {variable}") != "The value of variable is value" {
|
||||
t.Error("Expected variable replacement failed")
|
||||
}
|
||||
}
|
||||
*/
|
||||
52
middleware/pkg/response/classify.go
Normal file
52
middleware/pkg/response/classify.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package response
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
type Type int
|
||||
|
||||
const (
|
||||
Success Type = iota
|
||||
NameError // NXDOMAIN in header, SOA in auth.
|
||||
NoData // NOERROR in header, SOA in auth.
|
||||
Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked).
|
||||
OtherError // Don't cache these.
|
||||
)
|
||||
|
||||
// Classify classifies a message, it returns the Type.
|
||||
func Classify(m *dns.Msg) (Type, *dns.OPT) {
|
||||
opt := m.IsEdns0()
|
||||
|
||||
if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {
|
||||
return Success, opt
|
||||
}
|
||||
|
||||
soa := false
|
||||
ns := 0
|
||||
for _, r := range m.Ns {
|
||||
if r.Header().Rrtype == dns.TypeSOA {
|
||||
soa = true
|
||||
continue
|
||||
}
|
||||
if r.Header().Rrtype == dns.TypeNS {
|
||||
ns++
|
||||
}
|
||||
}
|
||||
|
||||
// Check length of different sections, and drop stuff that is just to large? TODO(miek).
|
||||
if soa && m.Rcode == dns.RcodeSuccess {
|
||||
return NoData, opt
|
||||
}
|
||||
if soa && m.Rcode == dns.RcodeNameError {
|
||||
return NameError, opt
|
||||
}
|
||||
|
||||
if ns > 0 && ns == len(m.Ns) && m.Rcode == dns.RcodeSuccess {
|
||||
return Delegation, opt
|
||||
}
|
||||
|
||||
if m.Rcode == dns.RcodeSuccess {
|
||||
return Success, opt
|
||||
}
|
||||
|
||||
return OtherError, opt
|
||||
}
|
||||
31
middleware/pkg/response/classify_test.go
Normal file
31
middleware/pkg/response/classify_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/coredns/middleware/test"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestClassifyDelegation(t *testing.T) {
|
||||
m := delegationMsg()
|
||||
mt, _ := Classify(m)
|
||||
if mt != Delegation {
|
||||
t.Errorf("message is wrongly classified, expected delegation, got %d", mt)
|
||||
}
|
||||
}
|
||||
|
||||
func delegationMsg() *dns.Msg {
|
||||
return &dns.Msg{
|
||||
Ns: []dns.RR{
|
||||
test.NS("miek.nl. 3600 IN NS linode.atoom.net."),
|
||||
test.NS("miek.nl. 3600 IN NS ns-ext.nlnetlabs.nl."),
|
||||
test.NS("miek.nl. 3600 IN NS omval.tednet.nl."),
|
||||
},
|
||||
Extra: []dns.RR{
|
||||
test.A("omval.tednet.nl. 3600 IN A 185.49.141.42"),
|
||||
test.AAAA("omval.tednet.nl. 3600 IN AAAA 2a04:b900:0:100::42"),
|
||||
},
|
||||
}
|
||||
}
|
||||
62
middleware/pkg/roller/roller.go
Normal file
62
middleware/pkg/roller/roller.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package roller
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
func Parse(c *caddy.Controller) (*LogRoller, error) {
|
||||
var size, age, keep int
|
||||
// This is kind of a hack to support nested blocks:
|
||||
// As we are already in a block: either log or errors,
|
||||
// c.nesting > 0 but, as soon as c meets a }, it thinks
|
||||
// the block is over and return false for c.NextBlock.
|
||||
for c.NextBlock() {
|
||||
what := c.Val()
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
value := c.Val()
|
||||
var err error
|
||||
switch what {
|
||||
case "size":
|
||||
size, err = strconv.Atoi(value)
|
||||
case "age":
|
||||
age, err = strconv.Atoi(value)
|
||||
case "keep":
|
||||
keep, err = strconv.Atoi(value)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &LogRoller{
|
||||
MaxSize: size,
|
||||
MaxAge: age,
|
||||
MaxBackups: keep,
|
||||
LocalTime: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LogRoller implements a middleware that provides a rolling logger.
|
||||
type LogRoller struct {
|
||||
Filename string
|
||||
MaxSize int
|
||||
MaxAge int
|
||||
MaxBackups int
|
||||
LocalTime bool
|
||||
}
|
||||
|
||||
// GetLogWriter returns an io.Writer that writes to a rolling logger.
|
||||
func (l LogRoller) GetLogWriter() io.Writer {
|
||||
return &lumberjack.Logger{
|
||||
Filename: l.Filename,
|
||||
MaxSize: l.MaxSize,
|
||||
MaxAge: l.MaxAge,
|
||||
MaxBackups: l.MaxBackups,
|
||||
LocalTime: l.LocalTime,
|
||||
}
|
||||
}
|
||||
64
middleware/pkg/singleflight/singleflight.go
Normal file
64
middleware/pkg/singleflight/singleflight.go
Normal file
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
Copyright 2012 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package singleflight provides a duplicate function call suppression
|
||||
// mechanism.
|
||||
package singleflight
|
||||
|
||||
import "sync"
|
||||
|
||||
// call is an in-flight or completed Do call
|
||||
type call struct {
|
||||
wg sync.WaitGroup
|
||||
val interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
// Group represents a class of work and forms a namespace in which
|
||||
// units of work can be executed with duplicate suppression.
|
||||
type Group struct {
|
||||
mu sync.Mutex // protects m
|
||||
m map[string]*call // lazily initialized
|
||||
}
|
||||
|
||||
// Do executes and returns the results of the given function, making
|
||||
// sure that only one execution is in-flight for a given key at a
|
||||
// time. If a duplicate comes in, the duplicate caller waits for the
|
||||
// original to complete and receives the same results.
|
||||
func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
|
||||
g.mu.Lock()
|
||||
if g.m == nil {
|
||||
g.m = make(map[string]*call)
|
||||
}
|
||||
if c, ok := g.m[key]; ok {
|
||||
g.mu.Unlock()
|
||||
c.wg.Wait()
|
||||
return c.val, c.err
|
||||
}
|
||||
c := new(call)
|
||||
c.wg.Add(1)
|
||||
g.m[key] = c
|
||||
g.mu.Unlock()
|
||||
|
||||
c.val, c.err = fn()
|
||||
c.wg.Done()
|
||||
|
||||
g.mu.Lock()
|
||||
delete(g.m, key)
|
||||
g.mu.Unlock()
|
||||
|
||||
return c.val, c.err
|
||||
}
|
||||
85
middleware/pkg/singleflight/singleflight_test.go
Normal file
85
middleware/pkg/singleflight/singleflight_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
Copyright 2012 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package singleflight
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDo(t *testing.T) {
|
||||
var g Group
|
||||
v, err := g.Do("key", func() (interface{}, error) {
|
||||
return "bar", nil
|
||||
})
|
||||
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
|
||||
t.Errorf("Do = %v; want %v", got, want)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Do error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoErr(t *testing.T) {
|
||||
var g Group
|
||||
someErr := errors.New("Some error")
|
||||
v, err := g.Do("key", func() (interface{}, error) {
|
||||
return nil, someErr
|
||||
})
|
||||
if err != someErr {
|
||||
t.Errorf("Do error = %v; want someErr", err)
|
||||
}
|
||||
if v != nil {
|
||||
t.Errorf("unexpected non-nil value %#v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoDupSuppress(t *testing.T) {
|
||||
var g Group
|
||||
c := make(chan string)
|
||||
var calls int32
|
||||
fn := func() (interface{}, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return <-c, nil
|
||||
}
|
||||
|
||||
const n = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < n; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
v, err := g.Do("key", fn)
|
||||
if err != nil {
|
||||
t.Errorf("Do error: %v", err)
|
||||
}
|
||||
if v.(string) != "bar" {
|
||||
t.Errorf("got %q; want %q", v, "bar")
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond) // let goroutines above block
|
||||
c <- "bar"
|
||||
wg.Wait()
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Errorf("number of calls = %d; want 1", got)
|
||||
}
|
||||
}
|
||||
59
middleware/pkg/storage/fs.go
Normal file
59
middleware/pkg/storage/fs.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// dir wraps an http.Dir that restrict file access to a specific directory tree, see http.Dir's documentation
|
||||
// for methods for accessing files.
|
||||
type dir http.Dir
|
||||
|
||||
// CoreDir is the directory where middleware can store assets, like zone files after a zone transfer
|
||||
// or public and private keys or anything else a middleware might need. The convention is to place
|
||||
// assets in a subdirectory named after the zone prefixed with "D", to prevent the root zone become a hidden directory.
|
||||
//
|
||||
// Dexample.org/Kexample.org<something>.key
|
||||
//
|
||||
// Note that subzone(s) under example.org are places in the own directory under CoreDir:
|
||||
//
|
||||
// Dexample.org/...
|
||||
// Db.example.org/...
|
||||
//
|
||||
// CoreDir will default to "$HOME/.coredns" on Unix, but it's location can be overriden with the COREDNSPATH
|
||||
// environment variable.
|
||||
var CoreDir dir = dir(fsPath())
|
||||
|
||||
func (d dir) Zone(z string) dir {
|
||||
if z != "." && z[len(z)-2] == '.' {
|
||||
return dir(path.Join(string(d), "D"+z[:len(z)-1]))
|
||||
}
|
||||
return dir(path.Join(string(d), "D"+z))
|
||||
}
|
||||
|
||||
// fsPath returns the path to the directory where the application may store data.
|
||||
// If COREDNSPATH env variable. is set, that value is used. Otherwise, the path is
|
||||
// the result of evaluating "$HOME/.coredns".
|
||||
func fsPath() string {
|
||||
if corePath := os.Getenv("COREDNSPATH"); corePath != "" {
|
||||
return corePath
|
||||
}
|
||||
return filepath.Join(userHomeDir(), ".coredns")
|
||||
}
|
||||
|
||||
// userHomeDir returns the user's home directory according to environment variables.
|
||||
//
|
||||
// Credit: http://stackoverflow.com/a/7922977/1048862
|
||||
func userHomeDir() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
|
||||
if home == "" {
|
||||
home = os.Getenv("USERPROFILE")
|
||||
}
|
||||
return home
|
||||
}
|
||||
return os.Getenv("HOME")
|
||||
}
|
||||
42
middleware/pkg/storage/fs_test.go
Normal file
42
middleware/pkg/storage/fs_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestfsPath(t *testing.T) {
|
||||
if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") {
|
||||
t.Errorf("Expected path to be a .coredns folder, got: %v", actual)
|
||||
}
|
||||
|
||||
os.Setenv("COREDNSPATH", "testpath")
|
||||
defer os.Setenv("COREDNSPATH", "")
|
||||
if actual, expected := fsPath(), "testpath"; actual != expected {
|
||||
t.Errorf("Expected path to be %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZone(t *testing.T) {
|
||||
for _, ts := range []string{"example.org.", "example.org"} {
|
||||
d := CoreDir.Zone(ts)
|
||||
actual := path.Base(string(d))
|
||||
expected := "D" + ts
|
||||
if actual != expected {
|
||||
t.Errorf("Expected path to be %v, got %v", actual, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestZoneRoot(t *testing.T) {
|
||||
for _, ts := range []string{"."} {
|
||||
d := CoreDir.Zone(ts)
|
||||
actual := path.Base(string(d))
|
||||
expected := "D" + ts
|
||||
if actual != expected {
|
||||
t.Errorf("Expected path to be %v, got %v", actual, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user