mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-30 17:53:21 -04:00 
			
		
		
		
	Use context.Context
Rename the old Context to State and use context.Context in the middleware for intra-middleware communication and more.
This commit is contained in:
		| @@ -1,16 +1,6 @@ | |||||||
| package https | package https | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" |  | ||||||
| 	"os" |  | ||||||
| 	"testing" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware/redirect" |  | ||||||
| 	"github.com/miekg/coredns/server" |  | ||||||
| 	"github.com/xenolf/lego/acme" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestHostQualifies(t *testing.T) { | func TestHostQualifies(t *testing.T) { | ||||||
| 	for i, test := range []struct { | 	for i, test := range []struct { | ||||||
| 		host   string | 		host   string | ||||||
| @@ -330,3 +320,4 @@ func TestMarkQualified(t *testing.T) { | |||||||
| 		t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count) | 		t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/core/parse" | 	"github.com/miekg/coredns/core/parse" | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/coredns/server" | 	"github.com/miekg/coredns/server" | ||||||
| @@ -70,7 +72,7 @@ func NewTestController(input string) *Controller { | |||||||
| // | // | ||||||
| // Used primarily for testing but needs to be exported so | // Used primarily for testing but needs to be exported so | ||||||
| // add-ons can use this as a convenience. | // add-ons can use this as a convenience. | ||||||
| var EmptyNext = middleware.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) (int, error) { | var EmptyNext = middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	return 0, nil | 	return 0, nil | ||||||
| }) | }) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,12 +1,6 @@ | |||||||
| package setup | package setup | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"testing" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" |  | ||||||
| 	"github.com/miekg/coredns/middleware/errors" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestErrors(t *testing.T) { | func TestErrors(t *testing.T) { | ||||||
| 	c := NewTestController(`errors`) | 	c := NewTestController(`errors`) | ||||||
| 	mid, err := Errors(c) | 	mid, err := Errors(c) | ||||||
| @@ -154,5 +148,5 @@ func TestErrorsParse(t *testing.T) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -1,13 +1,6 @@ | |||||||
| package setup | package setup | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"fmt" |  | ||||||
| 	"regexp" |  | ||||||
| 	"testing" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware/rewrite" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestRewrite(t *testing.T) { | func TestRewrite(t *testing.T) { | ||||||
| 	c := NewTestController(`rewrite /from /to`) | 	c := NewTestController(`rewrite /from /to`) | ||||||
|  |  | ||||||
| @@ -237,5 +230,5 @@ func TestRewriteParse(t *testing.T) { | |||||||
|  |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -1,613 +0,0 @@ | |||||||
| package middleware |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"fmt" |  | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" |  | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestInclude(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	inputFilename := "test_file" |  | ||||||
| 	absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) |  | ||||||
| 	defer func() { |  | ||||||
| 		err := os.Remove(absInFilePath) |  | ||||||
| 		if err != nil && !os.IsNotExist(err) { |  | ||||||
| 			t.Fatalf("Failed to clean test file!") |  | ||||||
| 		} |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	tests := []struct { |  | ||||||
| 		fileContent          string |  | ||||||
| 		expectedContent      string |  | ||||||
| 		shouldErr            bool |  | ||||||
| 		expectedErrorContent string |  | ||||||
| 	}{ |  | ||||||
| 		// Test 0 - all good |  | ||||||
| 		{ |  | ||||||
| 			fileContent:          `str1 {{ .Root }} str2`, |  | ||||||
| 			expectedContent:      fmt.Sprintf("str1 %s str2", context.Root), |  | ||||||
| 			shouldErr:            false, |  | ||||||
| 			expectedErrorContent: "", |  | ||||||
| 		}, |  | ||||||
| 		// Test 1 - failure on template.Parse |  | ||||||
| 		{ |  | ||||||
| 			fileContent:          `str1 {{ .Root } str2`, |  | ||||||
| 			expectedContent:      "", |  | ||||||
| 			shouldErr:            true, |  | ||||||
| 			expectedErrorContent: `unexpected "}" in operand`, |  | ||||||
| 		}, |  | ||||||
| 		// Test 3 - failure on template.Execute |  | ||||||
| 		{ |  | ||||||
| 			fileContent:          `str1 {{ .InvalidField }} str2`, |  | ||||||
| 			expectedContent:      "", |  | ||||||
| 			shouldErr:            true, |  | ||||||
| 			expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		testPrefix := getTestPrefix(i) |  | ||||||
|  |  | ||||||
| 		// WriteFile truncates the contentt |  | ||||||
| 		err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		content, err := context.Include(inputFilename) |  | ||||||
| 		if err != nil { |  | ||||||
| 			if !test.shouldErr { |  | ||||||
| 				t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error()) |  | ||||||
| 			} |  | ||||||
| 			if !strings.Contains(err.Error(), test.expectedErrorContent) { |  | ||||||
| 				t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error()) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if err == nil && test.shouldErr { |  | ||||||
| 			t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if content != test.expectedContent { |  | ||||||
| 			t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestIncludeNotExisting(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	_, err := context.Include("not_existing") |  | ||||||
| 	if err == nil { |  | ||||||
| 		t.Errorf("Expected error but found nil!") |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMarkdown(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	inputFilename := "test_file" |  | ||||||
| 	absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) |  | ||||||
| 	defer func() { |  | ||||||
| 		err := os.Remove(absInFilePath) |  | ||||||
| 		if err != nil && !os.IsNotExist(err) { |  | ||||||
| 			t.Fatalf("Failed to clean test file!") |  | ||||||
| 		} |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	tests := []struct { |  | ||||||
| 		fileContent     string |  | ||||||
| 		expectedContent string |  | ||||||
| 	}{ |  | ||||||
| 		// Test 0 - test parsing of markdown |  | ||||||
| 		{ |  | ||||||
| 			fileContent:     "* str1\n* str2\n", |  | ||||||
| 			expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		testPrefix := getTestPrefix(i) |  | ||||||
|  |  | ||||||
| 		// WriteFile truncates the contentt |  | ||||||
| 		err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		content, _ := context.Markdown(inputFilename) |  | ||||||
| 		if content != test.expectedContent { |  | ||||||
| 			t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestCookie(t *testing.T) { |  | ||||||
|  |  | ||||||
| 	tests := []struct { |  | ||||||
| 		cookie        *http.Cookie |  | ||||||
| 		cookieName    string |  | ||||||
| 		expectedValue string |  | ||||||
| 	}{ |  | ||||||
| 		// Test 0 - happy path |  | ||||||
| 		{ |  | ||||||
| 			cookie:        &http.Cookie{Name: "cookieName", Value: "cookieValue"}, |  | ||||||
| 			cookieName:    "cookieName", |  | ||||||
| 			expectedValue: "cookieValue", |  | ||||||
| 		}, |  | ||||||
| 		// Test 1 - try to get a non-existing cookie |  | ||||||
| 		{ |  | ||||||
| 			cookie:        &http.Cookie{Name: "cookieName", Value: "cookieValue"}, |  | ||||||
| 			cookieName:    "notExisting", |  | ||||||
| 			expectedValue: "", |  | ||||||
| 		}, |  | ||||||
| 		// Test 2 - partial name match |  | ||||||
| 		{ |  | ||||||
| 			cookie:        &http.Cookie{Name: "cookie", Value: "cookieValue"}, |  | ||||||
| 			cookieName:    "cook", |  | ||||||
| 			expectedValue: "", |  | ||||||
| 		}, |  | ||||||
| 		// Test 3 - cookie with optional fields |  | ||||||
| 		{ |  | ||||||
| 			cookie:        &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120}, |  | ||||||
| 			cookieName:    "cookie", |  | ||||||
| 			expectedValue: "cookieValue", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		testPrefix := getTestPrefix(i) |  | ||||||
|  |  | ||||||
| 		// reinitialize the context for each test |  | ||||||
| 		context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 		context.Req.AddCookie(test.cookie) |  | ||||||
|  |  | ||||||
| 		actualCookieVal := context.Cookie(test.cookieName) |  | ||||||
|  |  | ||||||
| 		if actualCookieVal != test.expectedValue { |  | ||||||
| 			t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestCookieMultipleCookies(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	cookieNameBase, cookieValueBase := "cookieName", "cookieValue" |  | ||||||
|  |  | ||||||
| 	// make sure that there's no state and multiple requests for different cookies return the correct result |  | ||||||
| 	for i := 0; i < 10; i++ { |  | ||||||
| 		context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)}) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i := 0; i < 10; i++ { |  | ||||||
| 		expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i) |  | ||||||
| 		actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i)) |  | ||||||
| 		if actualCookieVal != expectedCookieVal { |  | ||||||
| 			t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestHeader(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	headerKey, headerVal := "Header1", "HeaderVal1" |  | ||||||
| 	context.Req.Header.Add(headerKey, headerVal) |  | ||||||
|  |  | ||||||
| 	actualHeaderVal := context.Header(headerKey) |  | ||||||
| 	if actualHeaderVal != headerVal { |  | ||||||
| 		t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	missingHeaderVal := context.Header("not-existing") |  | ||||||
| 	if missingHeaderVal != "" { |  | ||||||
| 		t.Errorf("Expected empty header value, found %s", missingHeaderVal) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestIP(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	tests := []struct { |  | ||||||
| 		inputRemoteAddr string |  | ||||||
| 		expectedIP      string |  | ||||||
| 	}{ |  | ||||||
| 		// Test 0 - ipv4 with port |  | ||||||
| 		{"1.1.1.1:1111", "1.1.1.1"}, |  | ||||||
| 		// Test 1 - ipv4 without port |  | ||||||
| 		{"1.1.1.1", "1.1.1.1"}, |  | ||||||
| 		// Test 2 - ipv6 with port |  | ||||||
| 		{"[::1]:11", "::1"}, |  | ||||||
| 		// Test 3 - ipv6 without port and brackets |  | ||||||
| 		{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, |  | ||||||
| 		// Test 4 - ipv6 with zone and port |  | ||||||
| 		{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		testPrefix := getTestPrefix(i) |  | ||||||
|  |  | ||||||
| 		context.Req.RemoteAddr = test.inputRemoteAddr |  | ||||||
| 		actualIP := context.IP() |  | ||||||
|  |  | ||||||
| 		if actualIP != test.expectedIP { |  | ||||||
| 			t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestURL(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	inputURL := "http://localhost" |  | ||||||
| 	context.Req.RequestURI = inputURL |  | ||||||
|  |  | ||||||
| 	if inputURL != context.URI() { |  | ||||||
| 		t.Errorf("Expected url %s, found %s", inputURL, context.URI()) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestHost(t *testing.T) { |  | ||||||
| 	tests := []struct { |  | ||||||
| 		input        string |  | ||||||
| 		expectedHost string |  | ||||||
| 		shouldErr    bool |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			input:        "localhost:123", |  | ||||||
| 			expectedHost: "localhost", |  | ||||||
| 			shouldErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			input:        "localhost", |  | ||||||
| 			expectedHost: "localhost", |  | ||||||
| 			shouldErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			input:        "[::]", |  | ||||||
| 			expectedHost: "", |  | ||||||
| 			shouldErr:    true, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, test := range tests { |  | ||||||
| 		testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestPort(t *testing.T) { |  | ||||||
| 	tests := []struct { |  | ||||||
| 		input        string |  | ||||||
| 		expectedPort string |  | ||||||
| 		shouldErr    bool |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			input:        "localhost:123", |  | ||||||
| 			expectedPort: "123", |  | ||||||
| 			shouldErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			input:        "localhost", |  | ||||||
| 			expectedPort: "80", // assuming 80 is the default port |  | ||||||
| 			shouldErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			input:        ":8080", |  | ||||||
| 			expectedPort: "8080", |  | ||||||
| 			shouldErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			input:        "[::]", |  | ||||||
| 			expectedPort: "", |  | ||||||
| 			shouldErr:    true, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, test := range tests { |  | ||||||
| 		testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	context.Req.Host = input |  | ||||||
| 	var actualResult, testedObject string |  | ||||||
| 	var err error |  | ||||||
|  |  | ||||||
| 	if isTestingHost { |  | ||||||
| 		actualResult, err = context.Host() |  | ||||||
| 		testedObject = "host" |  | ||||||
| 	} else { |  | ||||||
| 		actualResult, err = context.Port() |  | ||||||
| 		testedObject = "port" |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if shouldErr && err == nil { |  | ||||||
| 		t.Errorf("Expected error, found nil!") |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if !shouldErr && err != nil { |  | ||||||
| 		t.Errorf("Expected no error, found %s", err) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if actualResult != expectedResult { |  | ||||||
| 		t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestMethod(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	method := "POST" |  | ||||||
| 	context.Req.Method = method |  | ||||||
|  |  | ||||||
| 	if method != context.Method() { |  | ||||||
| 		t.Errorf("Expected method %s, found %s", method, context.Method()) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestPathMatches(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
|  |  | ||||||
| 	tests := []struct { |  | ||||||
| 		urlStr      string |  | ||||||
| 		pattern     string |  | ||||||
| 		shouldMatch bool |  | ||||||
| 	}{ |  | ||||||
| 		// Test 0 |  | ||||||
| 		{ |  | ||||||
| 			urlStr:      "http://localhost/", |  | ||||||
| 			pattern:     "", |  | ||||||
| 			shouldMatch: true, |  | ||||||
| 		}, |  | ||||||
| 		// Test 1 |  | ||||||
| 		{ |  | ||||||
| 			urlStr:      "http://localhost", |  | ||||||
| 			pattern:     "", |  | ||||||
| 			shouldMatch: true, |  | ||||||
| 		}, |  | ||||||
| 		// Test 1 |  | ||||||
| 		{ |  | ||||||
| 			urlStr:      "http://localhost/", |  | ||||||
| 			pattern:     "/", |  | ||||||
| 			shouldMatch: true, |  | ||||||
| 		}, |  | ||||||
| 		// Test 3 |  | ||||||
| 		{ |  | ||||||
| 			urlStr:      "http://localhost/?param=val", |  | ||||||
| 			pattern:     "/", |  | ||||||
| 			shouldMatch: true, |  | ||||||
| 		}, |  | ||||||
| 		// Test 4 |  | ||||||
| 		{ |  | ||||||
| 			urlStr:      "http://localhost/dir1/dir2", |  | ||||||
| 			pattern:     "/dir2", |  | ||||||
| 			shouldMatch: false, |  | ||||||
| 		}, |  | ||||||
| 		// Test 5 |  | ||||||
| 		{ |  | ||||||
| 			urlStr:      "http://localhost/dir1/dir2", |  | ||||||
| 			pattern:     "/dir1", |  | ||||||
| 			shouldMatch: true, |  | ||||||
| 		}, |  | ||||||
| 		// Test 6 |  | ||||||
| 		{ |  | ||||||
| 			urlStr:      "http://localhost:444/dir1/dir2", |  | ||||||
| 			pattern:     "/dir1", |  | ||||||
| 			shouldMatch: true, |  | ||||||
| 		}, |  | ||||||
| 		// Test 7 |  | ||||||
| 		{ |  | ||||||
| 			urlStr:      "http://localhost/dir1/dir2", |  | ||||||
| 			pattern:     "*/dir2", |  | ||||||
| 			shouldMatch: false, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		testPrefix := getTestPrefix(i) |  | ||||||
| 		var err error |  | ||||||
| 		context.Req.URL, err = url.Parse(test.urlStr) |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		matches := context.PathMatches(test.pattern) |  | ||||||
| 		if matches != test.shouldMatch { |  | ||||||
| 			t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestTruncate(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
| 	tests := []struct { |  | ||||||
| 		inputString string |  | ||||||
| 		inputLength int |  | ||||||
| 		expected    string |  | ||||||
| 	}{ |  | ||||||
| 		// Test 0 - small length |  | ||||||
| 		{ |  | ||||||
| 			inputString: "string", |  | ||||||
| 			inputLength: 1, |  | ||||||
| 			expected:    "s", |  | ||||||
| 		}, |  | ||||||
| 		// Test 1 - exact length |  | ||||||
| 		{ |  | ||||||
| 			inputString: "string", |  | ||||||
| 			inputLength: 6, |  | ||||||
| 			expected:    "string", |  | ||||||
| 		}, |  | ||||||
| 		// Test 2 - bigger length |  | ||||||
| 		{ |  | ||||||
| 			inputString: "string", |  | ||||||
| 			inputLength: 10, |  | ||||||
| 			expected:    "string", |  | ||||||
| 		}, |  | ||||||
| 		// Test 3 - zero length |  | ||||||
| 		{ |  | ||||||
| 			inputString: "string", |  | ||||||
| 			inputLength: 0, |  | ||||||
| 			expected:    "", |  | ||||||
| 		}, |  | ||||||
| 		// Test 4 - negative, smaller length |  | ||||||
| 		{ |  | ||||||
| 			inputString: "string", |  | ||||||
| 			inputLength: -5, |  | ||||||
| 			expected:    "tring", |  | ||||||
| 		}, |  | ||||||
| 		// Test 5 - negative, exact length |  | ||||||
| 		{ |  | ||||||
| 			inputString: "string", |  | ||||||
| 			inputLength: -6, |  | ||||||
| 			expected:    "string", |  | ||||||
| 		}, |  | ||||||
| 		// Test 6 - negative, bigger length |  | ||||||
| 		{ |  | ||||||
| 			inputString: "string", |  | ||||||
| 			inputLength: -7, |  | ||||||
| 			expected:    "string", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		actual := context.Truncate(test.inputString, test.inputLength) |  | ||||||
| 		if actual != test.expected { |  | ||||||
| 			t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestStripHTML(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
| 	tests := []struct { |  | ||||||
| 		input    string |  | ||||||
| 		expected string |  | ||||||
| 	}{ |  | ||||||
| 		// Test 0 - no tags |  | ||||||
| 		{ |  | ||||||
| 			input:    `h1`, |  | ||||||
| 			expected: `h1`, |  | ||||||
| 		}, |  | ||||||
| 		// Test 1 - happy path |  | ||||||
| 		{ |  | ||||||
| 			input:    `<h1>h1</h1>`, |  | ||||||
| 			expected: `h1`, |  | ||||||
| 		}, |  | ||||||
| 		// Test 2 - tag in quotes |  | ||||||
| 		{ |  | ||||||
| 			input:    `<h1">">h1</h1>`, |  | ||||||
| 			expected: `h1`, |  | ||||||
| 		}, |  | ||||||
| 		// Test 3 - multiple tags |  | ||||||
| 		{ |  | ||||||
| 			input:    `<h1><b>h1</b></h1>`, |  | ||||||
| 			expected: `h1`, |  | ||||||
| 		}, |  | ||||||
| 		// Test 4 - tags not closed |  | ||||||
| 		{ |  | ||||||
| 			input:    `<h1`, |  | ||||||
| 			expected: `<h1`, |  | ||||||
| 		}, |  | ||||||
| 		// Test 5 - false start |  | ||||||
| 		{ |  | ||||||
| 			input:    `<h1<b>hi`, |  | ||||||
| 			expected: `<h1hi`, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		actual := context.StripHTML(test.input) |  | ||||||
| 		if actual != test.expected { |  | ||||||
| 			t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestStripExt(t *testing.T) { |  | ||||||
| 	context := getContextOrFail(t) |  | ||||||
| 	tests := []struct { |  | ||||||
| 		input    string |  | ||||||
| 		expected string |  | ||||||
| 	}{ |  | ||||||
| 		// Test 0 - empty input |  | ||||||
| 		{ |  | ||||||
| 			input:    "", |  | ||||||
| 			expected: "", |  | ||||||
| 		}, |  | ||||||
| 		// Test 1 - relative file with ext |  | ||||||
| 		{ |  | ||||||
| 			input:    "file.ext", |  | ||||||
| 			expected: "file", |  | ||||||
| 		}, |  | ||||||
| 		// Test 2 - relative file without ext |  | ||||||
| 		{ |  | ||||||
| 			input:    "file", |  | ||||||
| 			expected: "file", |  | ||||||
| 		}, |  | ||||||
| 		// Test 3 - absolute file without ext |  | ||||||
| 		{ |  | ||||||
| 			input:    "/file", |  | ||||||
| 			expected: "/file", |  | ||||||
| 		}, |  | ||||||
| 		// Test 4 - absolute file with ext |  | ||||||
| 		{ |  | ||||||
| 			input:    "/file.ext", |  | ||||||
| 			expected: "/file", |  | ||||||
| 		}, |  | ||||||
| 		// Test 5 - with ext but ends with / |  | ||||||
| 		{ |  | ||||||
| 			input:    "/dir.ext/", |  | ||||||
| 			expected: "/dir.ext/", |  | ||||||
| 		}, |  | ||||||
| 		// Test 6 - file with ext under dir with ext |  | ||||||
| 		{ |  | ||||||
| 			input:    "/dir.ext/file.ext", |  | ||||||
| 			expected: "/dir.ext/file", |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		actual := context.StripExt(test.input) |  | ||||||
| 		if actual != test.expected { |  | ||||||
| 			t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func initTestContext() (Context, error) { |  | ||||||
| 	body := bytes.NewBufferString("request body") |  | ||||||
| 	request, err := http.NewRequest("GET", "https://localhost", body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return Context{}, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return Context{Root: http.Dir(os.TempDir()), Req: request}, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getContextOrFail(t *testing.T) Context { |  | ||||||
| 	context, err := initTestContext() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatalf("Failed to prepare test context") |  | ||||||
| 	} |  | ||||||
| 	return context |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getTestPrefix(testN int) string { |  | ||||||
| 	return fmt.Sprintf("Test [%d]: ", testN) |  | ||||||
| } |  | ||||||
| @@ -8,6 +8,8 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
| @@ -21,10 +23,10 @@ type ErrorHandler struct { | |||||||
| 	Debug     bool // if true, errors are written out to client rather than to a log | 	Debug     bool // if true, errors are written out to client rather than to a log | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | func (h ErrorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	defer h.recovery(w, r) | 	defer h.recovery(w, r) | ||||||
|  |  | ||||||
| 	rcode, err := h.Next.ServeDNS(w, r) | 	rcode, err := h.Next.ServeDNS(ctx, w, r) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err) | 		errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err) | ||||||
|   | |||||||
| @@ -1,21 +1,6 @@ | |||||||
| package errors | package errors | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"bytes" |  | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"log" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestErrors(t *testing.T) { | func TestErrors(t *testing.T) { | ||||||
| 	// create a temporary page | 	// create a temporary page | ||||||
| 	path := filepath.Join(os.TempDir(), "errors_test.html") | 	path := filepath.Join(os.TempDir(), "errors_test.html") | ||||||
| @@ -166,3 +151,4 @@ func genErrorHandler(status int, err error, body string) middleware.Handler { | |||||||
| 		return status, err | 		return status, err | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -8,6 +8,8 @@ package file | |||||||
| import ( | import ( | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
| @@ -26,29 +28,29 @@ type ( | |||||||
| 	} | 	} | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	context := middleware.Context{W: w, Req: r} | 	state := middleware.State{W: w, Req: r} | ||||||
| 	qname := context.Name() | 	qname := state.Name() | ||||||
| 	zone := middleware.Zones(f.Zones.Names).Matches(qname) | 	zone := middleware.Zones(f.Zones.Names).Matches(qname) | ||||||
| 	if zone == "" { | 	if zone == "" { | ||||||
| 		return f.Next.ServeDNS(w, r) | 		return f.Next.ServeDNS(ctx, w, r) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	names, nodata := f.Zones.Z[zone].lookup(qname, context.QType()) | 	names, nodata := f.Zones.Z[zone].lookup(qname, state.QType()) | ||||||
| 	var answer *dns.Msg | 	var answer *dns.Msg | ||||||
| 	switch { | 	switch { | ||||||
| 	case nodata: | 	case nodata: | ||||||
| 		answer = context.AnswerMessage() | 		answer = state.AnswerMessage() | ||||||
| 		answer.Ns = names | 		answer.Ns = names | ||||||
| 	case len(names) == 0: | 	case len(names) == 0: | ||||||
| 		answer = context.AnswerMessage() | 		answer = state.AnswerMessage() | ||||||
| 		answer.Ns = names | 		answer.Ns = names | ||||||
| 		answer.Rcode = dns.RcodeNameError | 		answer.Rcode = dns.RcodeNameError | ||||||
| 	case len(names) > 0: | 	case len(names) > 0: | ||||||
| 		answer = context.AnswerMessage() | 		answer = state.AnswerMessage() | ||||||
| 		answer.Answer = names | 		answer.Answer = names | ||||||
| 	default: | 	default: | ||||||
| 		answer = context.ErrorMessage(dns.RcodeServerFailure) | 		answer = state.ErrorMessage(dns.RcodeServerFailure) | ||||||
| 	} | 	} | ||||||
| 	// Check return size, etc. TODO(miek) | 	// Check return size, etc. TODO(miek) | ||||||
| 	w.WriteMsg(answer) | 	w.WriteMsg(answer) | ||||||
|   | |||||||
| @@ -1,15 +1,6 @@ | |||||||
| package file | package file | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"errors" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var testDir = filepath.Join(os.TempDir(), "caddy_testdir") | var testDir = filepath.Join(os.TempDir(), "caddy_testdir") | ||||||
| var ErrCustom = errors.New("Custom Error") | var ErrCustom = errors.New("Custom Error") | ||||||
|  |  | ||||||
| @@ -323,3 +314,4 @@ func TestServeHTTPFailingStat(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ package log | |||||||
| import ( | import ( | ||||||
| 	"log" | 	"log" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
| @@ -15,7 +17,7 @@ type Logger struct { | |||||||
| 	ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler | 	ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler | ||||||
| } | } | ||||||
|  |  | ||||||
| func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	for _, rule := range l.Rules { | 	for _, rule := range l.Rules { | ||||||
| 		/* | 		/* | ||||||
| 			if middleware.Path(r.URL.Path).Matches(rule.PathScope) { | 			if middleware.Path(r.URL.Path).Matches(rule.PathScope) { | ||||||
| @@ -40,7 +42,7 @@ func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | |||||||
| 		*/ | 		*/ | ||||||
| 		rule = rule | 		rule = rule | ||||||
| 	} | 	} | ||||||
| 	return l.Next.ServeDNS(w, r) | 	return l.Next.ServeDNS(ctx, w, r) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Rule configures the logging middleware. | // Rule configures the logging middleware. | ||||||
|   | |||||||
| @@ -1,17 +1,9 @@ | |||||||
| package log | package log | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"bytes" |  | ||||||
| 	"log" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type erroringMiddleware struct{} | type erroringMiddleware struct{} | ||||||
|  |  | ||||||
| func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { | func (erroringMiddleware) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	return http.StatusNotFound, nil | 	return http.StatusNotFound, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -46,3 +38,4 @@ func TestLoggedStatus(t *testing.T) { | |||||||
| 		t.Error("Expected 404 to be logged. Logged string -", logged) | 		t.Error("Expected 404 to be logged. Logged string -", logged) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
|  | 	"golang.org/x/net/context" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type ( | type ( | ||||||
| @@ -32,18 +33,18 @@ type ( | |||||||
| 	// Otherwise, return values should be propagated down the middleware | 	// Otherwise, return values should be propagated down the middleware | ||||||
| 	// chain by returning them unchanged. | 	// chain by returning them unchanged. | ||||||
| 	Handler interface { | 	Handler interface { | ||||||
| 		ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error) | 		ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// HandlerFunc is a convenience type like dns.HandlerFunc, except | 	// HandlerFunc is a convenience type like dns.HandlerFunc, except | ||||||
| 	// ServeDNS returns an rcode and an error. See Handler | 	// ServeDNS returns an rcode and an error. See Handler | ||||||
| 	// documentation for more information. | 	// documentation for more information. | ||||||
| 	HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error) | 	HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ServeDNS implements the Handler interface. | // ServeDNS implements the Handler interface. | ||||||
| func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	return f(w, r) | 	return f(ctx, w, r) | ||||||
| } | } | ||||||
|  |  | ||||||
| // IndexFile looks for a file in /root/fpath/indexFile for each string | // IndexFile looks for a file in /root/fpath/indexFile for each string | ||||||
|   | |||||||
| @@ -1,108 +1 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestIndexfile(t *testing.T) { |  | ||||||
| 	tests := []struct { |  | ||||||
| 		rootDir           http.FileSystem |  | ||||||
| 		fpath             string |  | ||||||
| 		indexFiles        []string |  | ||||||
| 		shouldErr         bool |  | ||||||
| 		expectedFilePath  string //retun value |  | ||||||
| 		expectedBoolValue bool   //return value |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			http.Dir("./templates/testdata"), |  | ||||||
| 			"/images/", |  | ||||||
| 			[]string{"img.htm"}, |  | ||||||
| 			false, |  | ||||||
| 			"/images/img.htm", |  | ||||||
| 			true, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles) |  | ||||||
| 		if actualBoolValue == true && test.shouldErr { |  | ||||||
| 			t.Errorf("Test %d didn't error, but it should have", i) |  | ||||||
| 		} else if actualBoolValue != true && !test.shouldErr { |  | ||||||
| 			t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist") |  | ||||||
| 		} |  | ||||||
| 		if actualFilePath != test.expectedFilePath { |  | ||||||
| 			t.Fatalf("Test %d expected returned filepath to be %s, but got %s ", |  | ||||||
| 				i, test.expectedFilePath, actualFilePath) |  | ||||||
|  |  | ||||||
| 		} |  | ||||||
| 		if actualBoolValue != test.expectedBoolValue { |  | ||||||
| 			t.Fatalf("Test %d expected returned bool value to be %v, but got %v ", |  | ||||||
| 				i, test.expectedBoolValue, actualBoolValue) |  | ||||||
|  |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestSetLastModified(t *testing.T) { |  | ||||||
| 	nowTime := time.Now() |  | ||||||
|  |  | ||||||
| 	// ovewrite the function to return reliable time |  | ||||||
| 	originalGetCurrentTimeFunc := currentTime |  | ||||||
| 	currentTime = func() time.Time { |  | ||||||
| 		return nowTime |  | ||||||
| 	} |  | ||||||
| 	defer func() { |  | ||||||
| 		currentTime = originalGetCurrentTimeFunc |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	pastTime := nowTime.Truncate(1 * time.Hour) |  | ||||||
| 	futureTime := nowTime.Add(1 * time.Hour) |  | ||||||
|  |  | ||||||
| 	tests := []struct { |  | ||||||
| 		inputModTime         time.Time |  | ||||||
| 		expectedIsHeaderSet  bool |  | ||||||
| 		expectedLastModified string |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			inputModTime:         pastTime, |  | ||||||
| 			expectedIsHeaderSet:  true, |  | ||||||
| 			expectedLastModified: pastTime.UTC().Format(http.TimeFormat), |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			inputModTime:         nowTime, |  | ||||||
| 			expectedIsHeaderSet:  true, |  | ||||||
| 			expectedLastModified: nowTime.UTC().Format(http.TimeFormat), |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			inputModTime:         futureTime, |  | ||||||
| 			expectedIsHeaderSet:  true, |  | ||||||
| 			expectedLastModified: nowTime.UTC().Format(http.TimeFormat), |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			inputModTime:        time.Time{}, |  | ||||||
| 			expectedIsHeaderSet: false, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for i, test := range tests { |  | ||||||
| 		responseRecorder := httptest.NewRecorder() |  | ||||||
| 		errorPrefix := fmt.Sprintf("Test [%d]: ", i) |  | ||||||
| 		SetLastModifiedHeader(responseRecorder, test.inputModTime) |  | ||||||
| 		actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified") |  | ||||||
|  |  | ||||||
| 		if test.expectedIsHeaderSet && actualLastModifiedHeader == "" { |  | ||||||
| 			t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" { |  | ||||||
| 			t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if test.expectedLastModified != actualLastModifiedHeader { |  | ||||||
| 			t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -4,15 +4,17 @@ import ( | |||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	context := middleware.Context{W: w, Req: r} | 	state := middleware.State{W: w, Req: r} | ||||||
|  |  | ||||||
| 	qname := context.Name() | 	qname := state.Name() | ||||||
| 	qtype := context.Type() | 	qtype := state.Type() | ||||||
| 	zone := middleware.Zones(m.ZoneNames).Matches(qname) | 	zone := middleware.Zones(m.ZoneNames).Matches(qname) | ||||||
| 	if zone == "" { | 	if zone == "" { | ||||||
| 		zone = "." | 		zone = "." | ||||||
| @@ -20,7 +22,7 @@ func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | |||||||
|  |  | ||||||
| 	// Record response to get status code and size of the reply. | 	// Record response to get status code and size of the reply. | ||||||
| 	rw := middleware.NewResponseRecorder(w) | 	rw := middleware.NewResponseRecorder(w) | ||||||
| 	status, err := m.Next.ServeDNS(rw, r) | 	status, err := m.Next.ServeDNS(ctx, rw, r) | ||||||
|  |  | ||||||
| 	requestCount.WithLabelValues(zone, qtype).Inc() | 	requestCount.WithLabelValues(zone, qtype).Inc() | ||||||
| 	requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second)) | 	requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second)) | ||||||
|   | |||||||
| @@ -7,6 +7,8 @@ import ( | |||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
| @@ -67,7 +69,7 @@ func (uh *UpstreamHost) Down() bool { | |||||||
| var tryDuration = 60 * time.Second | var tryDuration = 60 * time.Second | ||||||
|  |  | ||||||
| // ServeDNS satisfies the middleware.Handler interface. | // ServeDNS satisfies the middleware.Handler interface. | ||||||
| func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | func (p Proxy) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	for _, upstream := range p.Upstreams { | 	for _, upstream := range p.Upstreams { | ||||||
| 		// allowed bla bla bla TODO(miek): fix full proxy spec from caddy | 		// allowed bla bla bla TODO(miek): fix full proxy spec from caddy | ||||||
| 		start := time.Now() | 		start := time.Now() | ||||||
| @@ -100,7 +102,7 @@ func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | |||||||
| 		} | 		} | ||||||
| 		return dns.RcodeServerFailure, errUnreachable | 		return dns.RcodeServerFailure, errUnreachable | ||||||
| 	} | 	} | ||||||
| 	return p.Next.ServeDNS(w, r) | 	return p.Next.ServeDNS(ctx, w, r) | ||||||
| } | } | ||||||
|  |  | ||||||
| func Clients() Client { | func Clients() Client { | ||||||
|   | |||||||
| @@ -1,26 +1,6 @@ | |||||||
| package proxy | package proxy | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"bufio" |  | ||||||
| 	"bytes" |  | ||||||
| 	"fmt" |  | ||||||
| 	"io" |  | ||||||
| 	"io/ioutil" |  | ||||||
| 	"log" |  | ||||||
| 	"net" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"net/url" |  | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 	"runtime" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"golang.org/x/net/websocket" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| 	tryDuration = 50 * time.Millisecond // prevent tests from hanging | 	tryDuration = 50 * time.Millisecond // prevent tests from hanging | ||||||
| } | } | ||||||
| @@ -315,3 +295,4 @@ func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } | |||||||
| func (c *fakeConn) Close() error                       { return nil } | func (c *fakeConn) Close() error                       { return nil } | ||||||
| func (c *fakeConn) Read(b []byte) (int, error)         { return c.readBuf.Read(b) } | func (c *fakeConn) Read(b []byte) (int, error)         { return c.readBuf.Read(b) } | ||||||
| func (c *fakeConn) Write(b []byte) (int, error)        { return c.writeBuf.Write(b) } | func (c *fakeConn) Write(b []byte) (int, error)        { return c.writeBuf.Write(b) } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -12,15 +12,15 @@ type ReverseProxy struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error { | func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error { | ||||||
| 	// TODO(miek): use extra! | 	// TODO(miek): use extra to EDNS0. | ||||||
| 	var ( | 	var ( | ||||||
| 		reply *dns.Msg | 		reply *dns.Msg | ||||||
| 		err   error | 		err   error | ||||||
| 	) | 	) | ||||||
| 	context := middleware.Context{W: w, Req: r} | 	state := middleware.State{W: w, Req: r} | ||||||
|  |  | ||||||
| 	// tls+tcp ? | 	// tls+tcp ? | ||||||
| 	if context.Proto() == "tcp" { | 	if state.Proto() == "tcp" { | ||||||
| 		reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) | 		reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) | ||||||
| 	} else { | 	} else { | ||||||
| 		reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) | 		reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) | ||||||
|   | |||||||
| @@ -1,11 +1,6 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestNewResponseRecorder(t *testing.T) { | func TestNewResponseRecorder(t *testing.T) { | ||||||
| 	w := httptest.NewRecorder() | 	w := httptest.NewRecorder() | ||||||
| 	recordRequest := NewResponseRecorder(w) | 	recordRequest := NewResponseRecorder(w) | ||||||
| @@ -30,3 +25,4 @@ func TestWrite(t *testing.T) { | |||||||
| 		t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String()) | 		t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -20,6 +20,8 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
| @@ -28,15 +30,15 @@ type Reflect struct { | |||||||
| 	Next middleware.Handler | 	Next middleware.Handler | ||||||
| } | } | ||||||
|  |  | ||||||
| func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | func (rl Reflect) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	context := middleware.Context{Req: r, W: w} | 	state := middleware.State{Req: r, W: w} | ||||||
|  |  | ||||||
| 	class := r.Question[0].Qclass | 	class := r.Question[0].Qclass | ||||||
| 	qname := r.Question[0].Name | 	qname := r.Question[0].Name | ||||||
| 	i, ok := dns.NextLabel(qname, 0) | 	i, ok := dns.NextLabel(qname, 0) | ||||||
|  |  | ||||||
| 	if strings.ToLower(qname[:i]) != who || ok { | 	if strings.ToLower(qname[:i]) != who || ok { | ||||||
| 		err := context.ErrorMessage(dns.RcodeFormatError) | 		err := state.ErrorMessage(dns.RcodeFormatError) | ||||||
| 		w.WriteMsg(err) | 		w.WriteMsg(err) | ||||||
| 		return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError]) | 		return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError]) | ||||||
| 	} | 	} | ||||||
| @@ -46,10 +48,10 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | |||||||
| 	answer.Compress = true | 	answer.Compress = true | ||||||
| 	answer.Authoritative = true | 	answer.Authoritative = true | ||||||
|  |  | ||||||
| 	ip := context.IP() | 	ip := state.IP() | ||||||
| 	proto := context.Proto() | 	proto := state.Proto() | ||||||
| 	port, _ := context.Port() | 	port, _ := state.Port() | ||||||
| 	family := context.Family() | 	family := state.Family() | ||||||
| 	var rr dns.RR | 	var rr dns.RR | ||||||
|  |  | ||||||
| 	switch family { | 	switch family { | ||||||
| @@ -67,7 +69,7 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | |||||||
| 	t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0} | 	t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0} | ||||||
| 	t.Txt = []string{"Port: " + port + " (" + proto + ")"} | 	t.Txt = []string{"Port: " + port + " (" + proto + ")"} | ||||||
|  |  | ||||||
| 	switch context.Type() { | 	switch state.Type() { | ||||||
| 	case "TXT": | 	case "TXT": | ||||||
| 		answer.Answer = append(answer.Answer, t) | 		answer.Answer = append(answer.Answer, t) | ||||||
| 		answer.Extra = append(answer.Extra, rr) | 		answer.Extra = append(answer.Extra, rr) | ||||||
|   | |||||||
| @@ -29,19 +29,19 @@ type replacer struct { | |||||||
| // available. emptyValue should be the string that is used | // available. emptyValue should be the string that is used | ||||||
| // in place of empty string (can still be empty string). | // in place of empty string (can still be empty string). | ||||||
| func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer { | func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer { | ||||||
| 	context := Context{W: rr, Req: r} | 	state := State{W: rr, Req: r} | ||||||
| 	rep := replacer{ | 	rep := replacer{ | ||||||
| 		replacements: map[string]string{ | 		replacements: map[string]string{ | ||||||
| 			"{type}":  context.Type(), | 			"{type}":  state.Type(), | ||||||
| 			"{name}":  context.Name(), | 			"{name}":  state.Name(), | ||||||
| 			"{class}": context.Class(), | 			"{class}": state.Class(), | ||||||
| 			"{proto}": context.Proto(), | 			"{proto}": state.Proto(), | ||||||
| 			"{when}": func() string { | 			"{when}": func() string { | ||||||
| 				return time.Now().Format(timeFormat) | 				return time.Now().Format(timeFormat) | ||||||
| 			}(), | 			}(), | ||||||
| 			"{remote}": context.IP(), | 			"{remote}": state.IP(), | ||||||
| 			"{port}": func() string { | 			"{port}": func() string { | ||||||
| 				p, _ := context.Port() | 				p, _ := state.Port() | ||||||
| 				return p | 				return p | ||||||
| 			}(), | 			}(), | ||||||
| 		}, | 		}, | ||||||
|   | |||||||
| @@ -1,12 +1,6 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestNewReplacer(t *testing.T) { | func TestNewReplacer(t *testing.T) { | ||||||
| 	w := httptest.NewRecorder() | 	w := httptest.NewRecorder() | ||||||
| 	recordRequest := NewResponseRecorder(w) | 	recordRequest := NewResponseRecorder(w) | ||||||
| @@ -122,3 +116,4 @@ func TestSet(t *testing.T) { | |||||||
| 		t.Error("Expected variable replacement failed") | 		t.Error("Expected variable replacement failed") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -1,11 +1,6 @@ | |||||||
| package rewrite | package rewrite | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"net/http" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestConditions(t *testing.T) { | func TestConditions(t *testing.T) { | ||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| 		condition string | 		condition string | ||||||
| @@ -104,3 +99,4 @@ func TestConditions(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ package rewrite | |||||||
| import ( | import ( | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
|  | 	"golang.org/x/net/context" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Result is the result of a rewrite | // Result is the result of a rewrite | ||||||
| @@ -27,12 +28,12 @@ type Rewrite struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| // ServeHTTP implements the middleware.Handler interface. | // ServeHTTP implements the middleware.Handler interface. | ||||||
| func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 	wr := NewResponseReverter(w, r) | 	wr := NewResponseReverter(w, r) | ||||||
| 	for _, rule := range rw.Rules { | 	for _, rule := range rw.Rules { | ||||||
| 		switch result := rule.Rewrite(r); result { | 		switch result := rule.Rewrite(r); result { | ||||||
| 		case RewriteDone: | 		case RewriteDone: | ||||||
| 			return rw.Next.ServeDNS(wr, r) | 			return rw.Next.ServeDNS(ctx, wr, r) | ||||||
| 		case RewriteIgnored: | 		case RewriteIgnored: | ||||||
| 			break | 			break | ||||||
| 		case RewriteStatus: | 		case RewriteStatus: | ||||||
| @@ -42,7 +43,7 @@ func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { | |||||||
| 			// } | 			// } | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return rw.Next.ServeDNS(w, r) | 	return rw.Next.ServeDNS(ctx, w, r) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Rule describes an internal location rewrite rule. | // Rule describes an internal location rewrite rule. | ||||||
|   | |||||||
| @@ -1,15 +1,6 @@ | |||||||
| package rewrite | package rewrite | ||||||
|  |  | ||||||
| import ( | /* | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestRewrite(t *testing.T) { | func TestRewrite(t *testing.T) { | ||||||
| 	rw := Rewrite{ | 	rw := Rewrite{ | ||||||
| 		Next: middleware.HandlerFunc(urlPrinter), | 		Next: middleware.HandlerFunc(urlPrinter), | ||||||
| @@ -157,3 +148,4 @@ func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) { | |||||||
| 	fmt.Fprintf(w, r.URL.String()) | 	fmt.Fprintf(w, r.URL.String()) | ||||||
| 	return 0, nil | 	return 0, nil | ||||||
| } | } | ||||||
|  | */ | ||||||
|   | |||||||
| @@ -9,45 +9,44 @@ import ( | |||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // This file contains the context and functions available for | // This file contains the state nd functions available for use in the templates. | ||||||
| // use in the templates. |  | ||||||
| 
 | 
 | ||||||
| // Context is the context with which Caddy templates are executed. | // State contains some connection state and is useful in middleware. | ||||||
| type Context struct { | type State struct { | ||||||
| 	Root http.FileSystem // TODO(miek): needed | 	Root http.FileSystem // TODO(miek): needed? | ||||||
| 	Req  *dns.Msg | 	Req  *dns.Msg | ||||||
| 	W    dns.ResponseWriter | 	W    dns.ResponseWriter | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Now returns the current timestamp in the specified format. | // Now returns the current timestamp in the specified format. | ||||||
| func (c Context) Now(format string) string { | func (s State) Now(format string) string { | ||||||
| 	return time.Now().Format(format) | 	return time.Now().Format(format) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NowDate returns the current date/time that can be used | // NowDate returns the current date/time that can be used | ||||||
| // in other time functions. | // in other time functions. | ||||||
| func (c Context) NowDate() time.Time { | func (s State) NowDate() time.Time { | ||||||
| 	return time.Now() | 	return time.Now() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Header gets the value of a header. | // Header gets the value of a header. | ||||||
| func (c Context) Header() *dns.RR_Header { | func (s State) Header() *dns.RR_Header { | ||||||
| 	// TODO(miek) | 	// TODO(miek) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // IP gets the (remote) IP address of the client making the request. | // IP gets the (remote) IP address of the client making the request. | ||||||
| func (c Context) IP() string { | func (s State) IP() string { | ||||||
| 	ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String()) | 	ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return c.W.RemoteAddr().String() | 		return s.W.RemoteAddr().String() | ||||||
| 	} | 	} | ||||||
| 	return ip | 	return ip | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Post gets the (remote) Port of the client making the request. | // Post gets the (remote) Port of the client making the request. | ||||||
| func (c Context) Port() (string, error) { | func (s State) Port() (string, error) { | ||||||
| 	_, port, err := net.SplitHostPort(c.W.RemoteAddr().String()) | 	_, port, err := net.SplitHostPort(s.W.RemoteAddr().String()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "0", err | 		return "0", err | ||||||
| 	} | 	} | ||||||
| @@ -56,11 +55,11 @@ func (c Context) Port() (string, error) { | |||||||
| 
 | 
 | ||||||
| // Proto gets the protocol used as the transport. This | // Proto gets the protocol used as the transport. This | ||||||
| // will be udp or tcp. | // will be udp or tcp. | ||||||
| func (c Context) Proto() string { | func (s State) Proto() string { | ||||||
| 	if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok { | 	if _, ok := s.W.RemoteAddr().(*net.UDPAddr); ok { | ||||||
| 		return "udp" | 		return "udp" | ||||||
| 	} | 	} | ||||||
| 	if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok { | 	if _, ok := s.W.RemoteAddr().(*net.TCPAddr); ok { | ||||||
| 		return "tcp" | 		return "tcp" | ||||||
| 	} | 	} | ||||||
| 	return "udp" | 	return "udp" | ||||||
| @@ -68,9 +67,9 @@ func (c Context) Proto() string { | |||||||
| 
 | 
 | ||||||
| // Family returns the family of the transport. | // Family returns the family of the transport. | ||||||
| // 1 for IPv4 and 2 for IPv6. | // 1 for IPv4 and 2 for IPv6. | ||||||
| func (c Context) Family() int { | func (s State) Family() int { | ||||||
| 	var a net.IP | 	var a net.IP | ||||||
| 	ip := c.W.RemoteAddr() | 	ip := s.W.RemoteAddr() | ||||||
| 	if i, ok := ip.(*net.UDPAddr); ok { | 	if i, ok := ip.(*net.UDPAddr); ok { | ||||||
| 		a = i.IP | 		a = i.IP | ||||||
| 	} | 	} | ||||||
| @@ -85,51 +84,48 @@ func (c Context) Family() int { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Type returns the type of the question as a string. | // Type returns the type of the question as a string. | ||||||
| func (c Context) Type() string { | func (s State) Type() string { | ||||||
| 	return dns.Type(c.Req.Question[0].Qtype).String() | 	return dns.Type(s.Req.Question[0].Qtype).String() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // QType returns the type of the question as a uint16. | // QType returns the type of the question as a uint16. | ||||||
| func (c Context) QType() uint16 { | func (s State) QType() uint16 { | ||||||
| 	return c.Req.Question[0].Qtype | 	return s.Req.Question[0].Qtype | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Name returns the name of the question in the request. Note | // Name returns the name of the question in the request. Note | ||||||
| // this name will always have a closing dot and will be lower cased. | // this name will always have a closing dot and will be lower cased. | ||||||
| func (c Context) Name() string { | func (s State) Name() string { | ||||||
| 	return strings.ToLower(dns.Name(c.Req.Question[0].Name).String()) | 	return strings.ToLower(dns.Name(s.Req.Question[0].Name).String()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // QName returns the name of the question in the request. | // QName returns the name of the question in the request. | ||||||
| func (c Context) QName() string { | func (s State) QName() string { | ||||||
| 	return dns.Name(c.Req.Question[0].Name).String() | 	return dns.Name(s.Req.Question[0].Name).String() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Class returns the class of the question in the request. | // Class returns the class of the question in the request. | ||||||
| func (c Context) Class() string { | func (s State) Class() string { | ||||||
| 	return dns.Class(c.Req.Question[0].Qclass).String() | 	return dns.Class(s.Req.Question[0].Qclass).String() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // QClass returns the class of the question in the request. | // QClass returns the class of the question in the request. | ||||||
| func (c Context) QClass() uint16 { | func (s State) QClass() uint16 { | ||||||
| 	return c.Req.Question[0].Qclass | 	return s.Req.Question[0].Qclass | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // More convience types for extracting stuff from a message? |  | ||||||
| // Header? |  | ||||||
| 
 |  | ||||||
| // ErrorMessage returns an error message suitable for sending | // ErrorMessage returns an error message suitable for sending | ||||||
| // back to the client. | // back to the client. | ||||||
| func (c Context) ErrorMessage(rcode int) *dns.Msg { | func (s State) ErrorMessage(rcode int) *dns.Msg { | ||||||
| 	m := new(dns.Msg) | 	m := new(dns.Msg) | ||||||
| 	m.SetRcode(c.Req, rcode) | 	m.SetRcode(s.Req, rcode) | ||||||
| 	return m | 	return m | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // AnswerMessage returns an error message suitable for sending | // AnswerMessage returns an error message suitable for sending | ||||||
| // back to the client. | // back to the client. | ||||||
| func (c Context) AnswerMessage() *dns.Msg { | func (s State) AnswerMessage() *dns.Msg { | ||||||
| 	m := new(dns.Msg) | 	m := new(dns.Msg) | ||||||
| 	m.SetReply(c.Req) | 	m.SetReply(s.Req) | ||||||
| 	return m | 	return m | ||||||
| } | } | ||||||
							
								
								
									
										235
									
								
								middleware/state_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								middleware/state_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,235 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | func TestHeader(t *testing.T) { | ||||||
|  | 	state := getContextOrFail(t) | ||||||
|  |  | ||||||
|  | 	headerKey, headerVal := "Header1", "HeaderVal1" | ||||||
|  | 	state.Req.Header.Add(headerKey, headerVal) | ||||||
|  |  | ||||||
|  | 	actualHeaderVal := state.Header(headerKey) | ||||||
|  | 	if actualHeaderVal != headerVal { | ||||||
|  | 		t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	missingHeaderVal := state.Header("not-existing") | ||||||
|  | 	if missingHeaderVal != "" { | ||||||
|  | 		t.Errorf("Expected empty header value, found %s", missingHeaderVal) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestIP(t *testing.T) { | ||||||
|  | 	state := getContextOrFail(t) | ||||||
|  |  | ||||||
|  | 	tests := []struct { | ||||||
|  | 		inputRemoteAddr string | ||||||
|  | 		expectedIP      string | ||||||
|  | 	}{ | ||||||
|  | 		// Test 0 - ipv4 with port | ||||||
|  | 		{"1.1.1.1:1111", "1.1.1.1"}, | ||||||
|  | 		// Test 1 - ipv4 without port | ||||||
|  | 		{"1.1.1.1", "1.1.1.1"}, | ||||||
|  | 		// Test 2 - ipv6 with port | ||||||
|  | 		{"[::1]:11", "::1"}, | ||||||
|  | 		// Test 3 - ipv6 without port and brackets | ||||||
|  | 		{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, | ||||||
|  | 		// Test 4 - ipv6 with zone and port | ||||||
|  | 		{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i, test := range tests { | ||||||
|  | 		testPrefix := getTestPrefix(i) | ||||||
|  |  | ||||||
|  | 		state.Req.RemoteAddr = test.inputRemoteAddr | ||||||
|  | 		actualIP := state.IP() | ||||||
|  |  | ||||||
|  | 		if actualIP != test.expectedIP { | ||||||
|  | 			t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestURL(t *testing.T) { | ||||||
|  | 	state := getContextOrFail(t) | ||||||
|  |  | ||||||
|  | 	inputURL := "http://localhost" | ||||||
|  | 	state.Req.RequestURI = inputURL | ||||||
|  |  | ||||||
|  | 	if inputURL != state.URI() { | ||||||
|  | 		t.Errorf("Expected url %s, found %s", inputURL, state.URI()) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestHost(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		input        string | ||||||
|  | 		expectedHost string | ||||||
|  | 		shouldErr    bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			input:        "localhost:123", | ||||||
|  | 			expectedHost: "localhost", | ||||||
|  | 			shouldErr:    false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			input:        "localhost", | ||||||
|  | 			expectedHost: "localhost", | ||||||
|  | 			shouldErr:    false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			input:        "[::]", | ||||||
|  | 			expectedHost: "", | ||||||
|  | 			shouldErr:    true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, test := range tests { | ||||||
|  | 		testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPort(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		input        string | ||||||
|  | 		expectedPort string | ||||||
|  | 		shouldErr    bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			input:        "localhost:123", | ||||||
|  | 			expectedPort: "123", | ||||||
|  | 			shouldErr:    false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			input:        "localhost", | ||||||
|  | 			expectedPort: "80", // assuming 80 is the default port | ||||||
|  | 			shouldErr:    false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			input:        ":8080", | ||||||
|  | 			expectedPort: "8080", | ||||||
|  | 			shouldErr:    false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			input:        "[::]", | ||||||
|  | 			expectedPort: "", | ||||||
|  | 			shouldErr:    true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, test := range tests { | ||||||
|  | 		testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { | ||||||
|  | 	state := getContextOrFail(t) | ||||||
|  |  | ||||||
|  | 	state.Req.Host = input | ||||||
|  | 	var actualResult, testedObject string | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	if isTestingHost { | ||||||
|  | 		actualResult, err = state.Host() | ||||||
|  | 		testedObject = "host" | ||||||
|  | 	} else { | ||||||
|  | 		actualResult, err = state.Port() | ||||||
|  | 		testedObject = "port" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if shouldErr && err == nil { | ||||||
|  | 		t.Errorf("Expected error, found nil!") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !shouldErr && err != nil { | ||||||
|  | 		t.Errorf("Expected no error, found %s", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if actualResult != expectedResult { | ||||||
|  | 		t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPathMatches(t *testing.T) { | ||||||
|  | 	state := getContextOrFail(t) | ||||||
|  |  | ||||||
|  | 	tests := []struct { | ||||||
|  | 		urlStr      string | ||||||
|  | 		pattern     string | ||||||
|  | 		shouldMatch bool | ||||||
|  | 	}{ | ||||||
|  | 		// Test 0 | ||||||
|  | 		{ | ||||||
|  | 			urlStr:      "http://localhost/", | ||||||
|  | 			pattern:     "", | ||||||
|  | 			shouldMatch: true, | ||||||
|  | 		}, | ||||||
|  | 		// Test 1 | ||||||
|  | 		{ | ||||||
|  | 			urlStr:      "http://localhost", | ||||||
|  | 			pattern:     "", | ||||||
|  | 			shouldMatch: true, | ||||||
|  | 		}, | ||||||
|  | 		// Test 1 | ||||||
|  | 		{ | ||||||
|  | 			urlStr:      "http://localhost/", | ||||||
|  | 			pattern:     "/", | ||||||
|  | 			shouldMatch: true, | ||||||
|  | 		}, | ||||||
|  | 		// Test 3 | ||||||
|  | 		{ | ||||||
|  | 			urlStr:      "http://localhost/?param=val", | ||||||
|  | 			pattern:     "/", | ||||||
|  | 			shouldMatch: true, | ||||||
|  | 		}, | ||||||
|  | 		// Test 4 | ||||||
|  | 		{ | ||||||
|  | 			urlStr:      "http://localhost/dir1/dir2", | ||||||
|  | 			pattern:     "/dir2", | ||||||
|  | 			shouldMatch: false, | ||||||
|  | 		}, | ||||||
|  | 		// Test 5 | ||||||
|  | 		{ | ||||||
|  | 			urlStr:      "http://localhost/dir1/dir2", | ||||||
|  | 			pattern:     "/dir1", | ||||||
|  | 			shouldMatch: true, | ||||||
|  | 		}, | ||||||
|  | 		// Test 6 | ||||||
|  | 		{ | ||||||
|  | 			urlStr:      "http://localhost:444/dir1/dir2", | ||||||
|  | 			pattern:     "/dir1", | ||||||
|  | 			shouldMatch: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i, test := range tests { | ||||||
|  | 		testPrefix := getTestPrefix(i) | ||||||
|  | 		var err error | ||||||
|  | 		state.Req.URL, err = url.Parse(test.urlStr) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		matches := state.PathMatches(test.pattern) | ||||||
|  | 		if matches != test.shouldMatch { | ||||||
|  | 			t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func initTestContext() (Context, error) { | ||||||
|  | 	body := bytes.NewBufferString("request body") | ||||||
|  | 	request, err := http.NewRequest("GET", "https://localhost", body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return Context{}, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return Context{Root: http.Dir(os.TempDir()), Req: request}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | func getTestPrefix(testN int) string { | ||||||
|  | 	return fmt.Sprintf("Test [%d]: ", testN) | ||||||
|  | } | ||||||
|  | */ | ||||||
| @@ -15,6 +15,8 @@ import ( | |||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -285,6 +287,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |||||||
| 	q := r.Question[0].Name | 	q := r.Question[0].Name | ||||||
| 	b := make([]byte, len(q)) | 	b := make([]byte, len(q)) | ||||||
| 	off, end := 0, false | 	off, end := 0, false | ||||||
|  | 	ctx := context.Background() | ||||||
| 	for { | 	for { | ||||||
| 		l := len(q[off:]) | 		l := len(q[off:]) | ||||||
| 		for i := 0; i < l; i++ { | 		for i := 0; i < l; i++ { | ||||||
| @@ -297,7 +300,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |||||||
|  |  | ||||||
| 		if h, ok := s.zones[string(b[:l])]; ok { | 		if h, ok := s.zones[string(b[:l])]; ok { | ||||||
| 			if r.Question[0].Qtype != dns.TypeDS { | 			if r.Question[0].Qtype != dns.TypeDS { | ||||||
| 				rcode, _ := h.stack.ServeDNS(w, r) | 				rcode, _ := h.stack.ServeDNS(ctx, w, r) | ||||||
| 				if rcode > 0 { | 				if rcode > 0 { | ||||||
| 					DefaultErrorFunc(w, r, rcode) | 					DefaultErrorFunc(w, r, rcode) | ||||||
| 				} | 				} | ||||||
| @@ -311,7 +314,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |||||||
| 	} | 	} | ||||||
| 	// Wildcard match, if we have found nothing try the root zone as a last resort. | 	// Wildcard match, if we have found nothing try the root zone as a last resort. | ||||||
| 	if h, ok := s.zones["."]; ok { | 	if h, ok := s.zones["."]; ok { | ||||||
| 		rcode, _ := h.stack.ServeDNS(w, r) | 		rcode, _ := h.stack.ServeDNS(ctx, w, r) | ||||||
| 		if rcode > 0 { | 		if rcode > 0 { | ||||||
| 			DefaultErrorFunc(w, r, rcode) | 			DefaultErrorFunc(w, r, rcode) | ||||||
| 		} | 		} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user