First commit

This commit is contained in:
Miek Gieben
2016-03-18 20:57:35 +00:00
commit 3ec0d9fe6b
131 changed files with 15193 additions and 0 deletions

120
middleware/commands.go Normal file
View File

@@ -0,0 +1,120 @@
package middleware
import (
"errors"
"runtime"
"unicode"
"github.com/flynn/go-shlex"
)
var runtimeGoos = runtime.GOOS
// SplitCommandAndArgs takes a command string and parses it
// shell-style into the command and its separate arguments.
func SplitCommandAndArgs(command string) (cmd string, args []string, err error) {
var parts []string
if runtimeGoos == "windows" {
parts = parseWindowsCommand(command) // parse it Windows-style
} else {
parts, err = parseUnixCommand(command) // parse it Unix-style
if err != nil {
err = errors.New("error parsing command: " + err.Error())
return
}
}
if len(parts) == 0 {
err = errors.New("no command contained in '" + command + "'")
return
}
cmd = parts[0]
if len(parts) > 1 {
args = parts[1:]
}
return
}
// parseUnixCommand parses a unix style command line and returns the
// command and its arguments or an error
func parseUnixCommand(cmd string) ([]string, error) {
return shlex.Split(cmd)
}
// parseWindowsCommand parses windows command lines and
// returns the command and the arguments as an array. It
// should be able to parse commonly used command lines.
// Only basic syntax is supported:
// - spaces in double quotes are not token delimiters
// - double quotes are escaped by either backspace or another double quote
// - except for the above case backspaces are path separators (not special)
//
// Many sources point out that escaping quotes using backslash can be unsafe.
// Use two double quotes when possible. (Source: http://stackoverflow.com/a/31413730/2616179 )
//
// This function has to be used on Windows instead
// of the shlex package because this function treats backslash
// characters properly.
func parseWindowsCommand(cmd string) []string {
const backslash = '\\'
const quote = '"'
var parts []string
var part string
var inQuotes bool
var lastRune rune
for i, ch := range cmd {
if i != 0 {
lastRune = rune(cmd[i-1])
}
if ch == backslash {
// put it in the part - for now we don't know if it's an
// escaping char or path separator
part += string(ch)
continue
}
if ch == quote {
if lastRune == backslash {
// remove the backslash from the part and add the escaped quote instead
part = part[:len(part)-1]
part += string(ch)
continue
}
if lastRune == quote {
// revert the last change of the inQuotes state
// it was an escaping quote
inQuotes = !inQuotes
part += string(ch)
continue
}
// normal escaping quotes
inQuotes = !inQuotes
continue
}
if unicode.IsSpace(ch) && !inQuotes && len(part) > 0 {
parts = append(parts, part)
part = ""
continue
}
part += string(ch)
}
if len(part) > 0 {
parts = append(parts, part)
part = ""
}
return parts
}

291
middleware/commands_test.go Normal file
View File

@@ -0,0 +1,291 @@
package middleware
import (
"fmt"
"runtime"
"strings"
"testing"
)
func TestParseUnixCommand(t *testing.T) {
tests := []struct {
input string
expected []string
}{
// 0 - emtpy command
{
input: ``,
expected: []string{},
},
// 1 - command without arguments
{
input: `command`,
expected: []string{`command`},
},
// 2 - command with single argument
{
input: `command arg1`,
expected: []string{`command`, `arg1`},
},
// 3 - command with multiple arguments
{
input: `command arg1 arg2`,
expected: []string{`command`, `arg1`, `arg2`},
},
// 4 - command with single argument with space character - in quotes
{
input: `command "arg1 arg1"`,
expected: []string{`command`, `arg1 arg1`},
},
// 5 - command with multiple spaces and tab character
{
input: "command arg1 arg2\targ3",
expected: []string{`command`, `arg1`, `arg2`, `arg3`},
},
// 6 - command with single argument with space character - escaped with backspace
{
input: `command arg1\ arg2`,
expected: []string{`command`, `arg1 arg2`},
},
// 7 - single quotes should escape special chars
{
input: `command 'arg1\ arg2'`,
expected: []string{`command`, `arg1\ arg2`},
},
}
for i, test := range tests {
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
actual, _ := parseUnixCommand(test.input)
if len(actual) != len(test.expected) {
t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
continue
}
for j := 0; j < len(actual); j++ {
if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
}
}
}
}
func TestParseWindowsCommand(t *testing.T) {
tests := []struct {
input string
expected []string
}{
{ // 0 - empty command - do not fail
input: ``,
expected: []string{},
},
{ // 1 - cmd without args
input: `cmd`,
expected: []string{`cmd`},
},
{ // 2 - multiple args
input: `cmd arg1 arg2`,
expected: []string{`cmd`, `arg1`, `arg2`},
},
{ // 3 - multiple args with space
input: `cmd "combined arg" arg2`,
expected: []string{`cmd`, `combined arg`, `arg2`},
},
{ // 4 - path without spaces
input: `mkdir C:\Windows\foo\bar`,
expected: []string{`mkdir`, `C:\Windows\foo\bar`},
},
{ // 5 - command with space in quotes
input: `"command here"`,
expected: []string{`command here`},
},
{ // 6 - argument with escaped quotes (two quotes)
input: `cmd ""arg""`,
expected: []string{`cmd`, `"arg"`},
},
{ // 7 - argument with escaped quotes (backslash)
input: `cmd \"arg\"`,
expected: []string{`cmd`, `"arg"`},
},
{ // 8 - two quotes (escaped) inside an inQuote element
input: `cmd "a ""quoted value"`,
expected: []string{`cmd`, `a "quoted value`},
},
// TODO - see how many quotes are dislayed if we use "", """, """""""
{ // 9 - two quotes outside an inQuote element
input: `cmd a ""quoted value`,
expected: []string{`cmd`, `a`, `"quoted`, `value`},
},
{ // 10 - path with space in quotes
input: `mkdir "C:\directory name\foobar"`,
expected: []string{`mkdir`, `C:\directory name\foobar`},
},
{ // 11 - space without quotes
input: `mkdir C:\ space`,
expected: []string{`mkdir`, `C:\`, `space`},
},
{ // 12 - space in quotes
input: `mkdir "C:\ space"`,
expected: []string{`mkdir`, `C:\ space`},
},
{ // 13 - UNC
input: `mkdir \\?\C:\Users`,
expected: []string{`mkdir`, `\\?\C:\Users`},
},
{ // 14 - UNC with space
input: `mkdir "\\?\C:\Program Files"`,
expected: []string{`mkdir`, `\\?\C:\Program Files`},
},
{ // 15 - unclosed quotes - treat as if the path ends with quote
input: `mkdir "c:\Program files`,
expected: []string{`mkdir`, `c:\Program files`},
},
{ // 16 - quotes used inside the argument
input: `mkdir "c:\P"rogra"m f"iles`,
expected: []string{`mkdir`, `c:\Program files`},
},
}
for i, test := range tests {
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
actual := parseWindowsCommand(test.input)
if len(actual) != len(test.expected) {
t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
continue
}
for j := 0; j < len(actual); j++ {
if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
}
}
}
}
func TestSplitCommandAndArgs(t *testing.T) {
// force linux parsing. It's more robust and covers error cases
runtimeGoos = "linux"
defer func() {
runtimeGoos = runtime.GOOS
}()
var parseErrorContent = "error parsing command:"
var noCommandErrContent = "no command contained in"
tests := []struct {
input string
expectedCommand string
expectedArgs []string
expectedErrContent string
}{
// 0 - emtpy command
{
input: ``,
expectedCommand: ``,
expectedArgs: nil,
expectedErrContent: noCommandErrContent,
},
// 1 - command without arguments
{
input: `command`,
expectedCommand: `command`,
expectedArgs: nil,
expectedErrContent: ``,
},
// 2 - command with single argument
{
input: `command arg1`,
expectedCommand: `command`,
expectedArgs: []string{`arg1`},
expectedErrContent: ``,
},
// 3 - command with multiple arguments
{
input: `command arg1 arg2`,
expectedCommand: `command`,
expectedArgs: []string{`arg1`, `arg2`},
expectedErrContent: ``,
},
// 4 - command with unclosed quotes
{
input: `command "arg1 arg2`,
expectedCommand: "",
expectedArgs: nil,
expectedErrContent: parseErrorContent,
},
// 5 - command with unclosed quotes
{
input: `command 'arg1 arg2"`,
expectedCommand: "",
expectedArgs: nil,
expectedErrContent: parseErrorContent,
},
}
for i, test := range tests {
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
actualCommand, actualArgs, actualErr := SplitCommandAndArgs(test.input)
// test if error matches expectation
if test.expectedErrContent != "" {
if actualErr == nil {
t.Errorf(errorPrefix+"Expected error with content [%s], found no error."+errorSuffix, test.expectedErrContent)
} else if !strings.Contains(actualErr.Error(), test.expectedErrContent) {
t.Errorf(errorPrefix+"Expected error with content [%s], found [%v]."+errorSuffix, test.expectedErrContent, actualErr)
}
} else if actualErr != nil {
t.Errorf(errorPrefix+"Expected no error, found [%v]."+errorSuffix, actualErr)
}
// test if command matches
if test.expectedCommand != actualCommand {
t.Errorf(errorPrefix+"Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand)
}
// test if arguments match
if len(test.expectedArgs) != len(actualArgs) {
t.Errorf(errorPrefix+"Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs)
} else {
// test args only if the count matches.
for j, actualArg := range actualArgs {
expectedArg := test.expectedArgs[j]
if actualArg != expectedArg {
t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg)
}
}
}
}
}
func ExampleSplitCommandAndArgs() {
var commandLine string
var command string
var args []string
// just for the test - change GOOS and reset it at the end of the test
runtimeGoos = "windows"
defer func() {
runtimeGoos = runtime.GOOS
}()
commandLine = `mkdir /P "C:\Program Files"`
command, args, _ = SplitCommandAndArgs(commandLine)
fmt.Printf("Windows: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
// set GOOS to linux
runtimeGoos = "linux"
commandLine = `mkdir -p /path/with\ space`
command, args, _ = SplitCommandAndArgs(commandLine)
fmt.Printf("Linux: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
// Output:
// Windows: mkdir /P "C:\Program Files": mkdir [/P,C:\Program Files]
// Linux: mkdir -p /path/with\ space: mkdir [-p,/path/with space]
}

135
middleware/context.go Normal file
View File

@@ -0,0 +1,135 @@
package middleware
import (
"net"
"net/http"
"strings"
"time"
"github.com/miekg/dns"
)
// This file contains the context and functions available for
// use in the templates.
// Context is the context with which Caddy templates are executed.
type Context struct {
Root http.FileSystem // TODO(miek): needed
Req *dns.Msg
W dns.ResponseWriter
}
// Now returns the current timestamp in the specified format.
func (c Context) Now(format string) string {
return time.Now().Format(format)
}
// NowDate returns the current date/time that can be used
// in other time functions.
func (c Context) NowDate() time.Time {
return time.Now()
}
// Header gets the value of a header.
func (c Context) Header() *dns.RR_Header {
// TODO(miek)
return nil
}
// IP gets the (remote) IP address of the client making the request.
func (c Context) IP() string {
ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String())
if err != nil {
return c.W.RemoteAddr().String()
}
return ip
}
// Post gets the (remote) Port of the client making the request.
func (c Context) Port() (string, error) {
_, port, err := net.SplitHostPort(c.W.RemoteAddr().String())
if err != nil {
return "0", err
}
return port, nil
}
// Proto gets the protocol used as the transport. This
// will be udp or tcp.
func (c Context) Proto() string {
if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok {
return "udp"
}
if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok {
return "tcp"
}
return "udp"
}
// Family returns the family of the transport.
// 1 for IPv4 and 2 for IPv6.
func (c Context) Family() int {
var a net.IP
ip := c.W.RemoteAddr()
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
if i, ok := ip.(*net.TCPAddr); ok {
a = i.IP
}
if a.To4() != nil {
return 1
}
return 2
}
// Type returns the type of the question as a string.
func (c Context) Type() string {
return dns.Type(c.Req.Question[0].Qtype).String()
}
// QType returns the type of the question as a uint16.
func (c Context) QType() uint16 {
return c.Req.Question[0].Qtype
}
// Name returns the name of the question in the request. Note
// this name will always have a closing dot and will be lower cased.
func (c Context) Name() string {
return strings.ToLower(dns.Name(c.Req.Question[0].Name).String())
}
// QName returns the name of the question in the request.
func (c Context) QName() string {
return dns.Name(c.Req.Question[0].Name).String()
}
// Class returns the class of the question in the request.
func (c Context) Class() string {
return dns.Class(c.Req.Question[0].Qclass).String()
}
// QClass returns the class of the question in the request.
func (c Context) QClass() uint16 {
return c.Req.Question[0].Qclass
}
// More convience types for extracting stuff from a message?
// Header?
// ErrorMessage returns an error message suitable for sending
// back to the client.
func (c Context) ErrorMessage(rcode int) *dns.Msg {
m := new(dns.Msg)
m.SetRcode(c.Req, rcode)
return m
}
// AnswerMessage returns an error message suitable for sending
// back to the client.
func (c Context) AnswerMessage() *dns.Msg {
m := new(dns.Msg)
m.SetReply(c.Req)
return m
}

613
middleware/context_test.go Normal file
View File

@@ -0,0 +1,613 @@
package middleware
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestInclude(t *testing.T) {
context := getContextOrFail(t)
inputFilename := "test_file"
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
defer func() {
err := os.Remove(absInFilePath)
if err != nil && !os.IsNotExist(err) {
t.Fatalf("Failed to clean test file!")
}
}()
tests := []struct {
fileContent string
expectedContent string
shouldErr bool
expectedErrorContent string
}{
// Test 0 - all good
{
fileContent: `str1 {{ .Root }} str2`,
expectedContent: fmt.Sprintf("str1 %s str2", context.Root),
shouldErr: false,
expectedErrorContent: "",
},
// Test 1 - failure on template.Parse
{
fileContent: `str1 {{ .Root } str2`,
expectedContent: "",
shouldErr: true,
expectedErrorContent: `unexpected "}" in operand`,
},
// Test 3 - failure on template.Execute
{
fileContent: `str1 {{ .InvalidField }} str2`,
expectedContent: "",
shouldErr: true,
expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`,
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
// WriteFile truncates the contentt
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
if err != nil {
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
}
content, err := context.Include(inputFilename)
if err != nil {
if !test.shouldErr {
t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error())
}
if !strings.Contains(err.Error(), test.expectedErrorContent) {
t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error())
}
}
if err == nil && test.shouldErr {
t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename)
}
if content != test.expectedContent {
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
}
}
}
func TestIncludeNotExisting(t *testing.T) {
context := getContextOrFail(t)
_, err := context.Include("not_existing")
if err == nil {
t.Errorf("Expected error but found nil!")
}
}
func TestMarkdown(t *testing.T) {
context := getContextOrFail(t)
inputFilename := "test_file"
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
defer func() {
err := os.Remove(absInFilePath)
if err != nil && !os.IsNotExist(err) {
t.Fatalf("Failed to clean test file!")
}
}()
tests := []struct {
fileContent string
expectedContent string
}{
// Test 0 - test parsing of markdown
{
fileContent: "* str1\n* str2\n",
expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
// WriteFile truncates the contentt
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
if err != nil {
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
}
content, _ := context.Markdown(inputFilename)
if content != test.expectedContent {
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
}
}
}
func TestCookie(t *testing.T) {
tests := []struct {
cookie *http.Cookie
cookieName string
expectedValue string
}{
// Test 0 - happy path
{
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
cookieName: "cookieName",
expectedValue: "cookieValue",
},
// Test 1 - try to get a non-existing cookie
{
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
cookieName: "notExisting",
expectedValue: "",
},
// Test 2 - partial name match
{
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
cookieName: "cook",
expectedValue: "",
},
// Test 3 - cookie with optional fields
{
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
cookieName: "cookie",
expectedValue: "cookieValue",
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
// reinitialize the context for each test
context := getContextOrFail(t)
context.Req.AddCookie(test.cookie)
actualCookieVal := context.Cookie(test.cookieName)
if actualCookieVal != test.expectedValue {
t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName)
}
}
}
func TestCookieMultipleCookies(t *testing.T) {
context := getContextOrFail(t)
cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
// make sure that there's no state and multiple requests for different cookies return the correct result
for i := 0; i < 10; i++ {
context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)})
}
for i := 0; i < 10; i++ {
expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
if actualCookieVal != expectedCookieVal {
t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
}
}
}
func TestHeader(t *testing.T) {
context := getContextOrFail(t)
headerKey, headerVal := "Header1", "HeaderVal1"
context.Req.Header.Add(headerKey, headerVal)
actualHeaderVal := context.Header(headerKey)
if actualHeaderVal != headerVal {
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
}
missingHeaderVal := context.Header("not-existing")
if missingHeaderVal != "" {
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
}
}
func TestIP(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
inputRemoteAddr string
expectedIP string
}{
// Test 0 - ipv4 with port
{"1.1.1.1:1111", "1.1.1.1"},
// Test 1 - ipv4 without port
{"1.1.1.1", "1.1.1.1"},
// Test 2 - ipv6 with port
{"[::1]:11", "::1"},
// Test 3 - ipv6 without port and brackets
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
// Test 4 - ipv6 with zone and port
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
context.Req.RemoteAddr = test.inputRemoteAddr
actualIP := context.IP()
if actualIP != test.expectedIP {
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
}
}
}
func TestURL(t *testing.T) {
context := getContextOrFail(t)
inputURL := "http://localhost"
context.Req.RequestURI = inputURL
if inputURL != context.URI() {
t.Errorf("Expected url %s, found %s", inputURL, context.URI())
}
}
func TestHost(t *testing.T) {
tests := []struct {
input string
expectedHost string
shouldErr bool
}{
{
input: "localhost:123",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "localhost",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "[::]",
expectedHost: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
}
}
func TestPort(t *testing.T) {
tests := []struct {
input string
expectedPort string
shouldErr bool
}{
{
input: "localhost:123",
expectedPort: "123",
shouldErr: false,
},
{
input: "localhost",
expectedPort: "80", // assuming 80 is the default port
shouldErr: false,
},
{
input: ":8080",
expectedPort: "8080",
shouldErr: false,
},
{
input: "[::]",
expectedPort: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
}
}
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
context := getContextOrFail(t)
context.Req.Host = input
var actualResult, testedObject string
var err error
if isTestingHost {
actualResult, err = context.Host()
testedObject = "host"
} else {
actualResult, err = context.Port()
testedObject = "port"
}
if shouldErr && err == nil {
t.Errorf("Expected error, found nil!")
return
}
if !shouldErr && err != nil {
t.Errorf("Expected no error, found %s", err)
return
}
if actualResult != expectedResult {
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
}
}
func TestMethod(t *testing.T) {
context := getContextOrFail(t)
method := "POST"
context.Req.Method = method
if method != context.Method() {
t.Errorf("Expected method %s, found %s", method, context.Method())
}
}
func TestPathMatches(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
urlStr string
pattern string
shouldMatch bool
}{
// Test 0
{
urlStr: "http://localhost/",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost/",
pattern: "/",
shouldMatch: true,
},
// Test 3
{
urlStr: "http://localhost/?param=val",
pattern: "/",
shouldMatch: true,
},
// Test 4
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir2",
shouldMatch: false,
},
// Test 5
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
// Test 6
{
urlStr: "http://localhost:444/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
// Test 7
{
urlStr: "http://localhost/dir1/dir2",
pattern: "*/dir2",
shouldMatch: false,
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
var err error
context.Req.URL, err = url.Parse(test.urlStr)
if err != nil {
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
}
matches := context.PathMatches(test.pattern)
if matches != test.shouldMatch {
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
}
}
}
func TestTruncate(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
inputString string
inputLength int
expected string
}{
// Test 0 - small length
{
inputString: "string",
inputLength: 1,
expected: "s",
},
// Test 1 - exact length
{
inputString: "string",
inputLength: 6,
expected: "string",
},
// Test 2 - bigger length
{
inputString: "string",
inputLength: 10,
expected: "string",
},
// Test 3 - zero length
{
inputString: "string",
inputLength: 0,
expected: "",
},
// Test 4 - negative, smaller length
{
inputString: "string",
inputLength: -5,
expected: "tring",
},
// Test 5 - negative, exact length
{
inputString: "string",
inputLength: -6,
expected: "string",
},
// Test 6 - negative, bigger length
{
inputString: "string",
inputLength: -7,
expected: "string",
},
}
for i, test := range tests {
actual := context.Truncate(test.inputString, test.inputLength)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
}
}
}
func TestStripHTML(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
input string
expected string
}{
// Test 0 - no tags
{
input: `h1`,
expected: `h1`,
},
// Test 1 - happy path
{
input: `<h1>h1</h1>`,
expected: `h1`,
},
// Test 2 - tag in quotes
{
input: `<h1">">h1</h1>`,
expected: `h1`,
},
// Test 3 - multiple tags
{
input: `<h1><b>h1</b></h1>`,
expected: `h1`,
},
// Test 4 - tags not closed
{
input: `<h1`,
expected: `<h1`,
},
// Test 5 - false start
{
input: `<h1<b>hi`,
expected: `<h1hi`,
},
}
for i, test := range tests {
actual := context.StripHTML(test.input)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input)
}
}
}
func TestStripExt(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
input string
expected string
}{
// Test 0 - empty input
{
input: "",
expected: "",
},
// Test 1 - relative file with ext
{
input: "file.ext",
expected: "file",
},
// Test 2 - relative file without ext
{
input: "file",
expected: "file",
},
// Test 3 - absolute file without ext
{
input: "/file",
expected: "/file",
},
// Test 4 - absolute file with ext
{
input: "/file.ext",
expected: "/file",
},
// Test 5 - with ext but ends with /
{
input: "/dir.ext/",
expected: "/dir.ext/",
},
// Test 6 - file with ext under dir with ext
{
input: "/dir.ext/file.ext",
expected: "/dir.ext/file",
},
}
for i, test := range tests {
actual := context.StripExt(test.input)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input)
}
}
}
func initTestContext() (Context, error) {
body := bytes.NewBufferString("request body")
request, err := http.NewRequest("GET", "https://localhost", body)
if err != nil {
return Context{}, err
}
return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
}
func getContextOrFail(t *testing.T) Context {
context, err := initTestContext()
if err != nil {
t.Fatalf("Failed to prepare test context")
}
return context
}
func getTestPrefix(testN int) string {
return fmt.Sprintf("Test [%d]: ", testN)
}

100
middleware/errors/errors.go Normal file
View File

@@ -0,0 +1,100 @@
// Package errors implements an HTTP error handling middleware.
package errors
import (
"fmt"
"log"
"runtime"
"strings"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
// ErrorHandler handles DNS errors (and errors from other middleware).
type ErrorHandler struct {
Next middleware.Handler
LogFile string
Log *log.Logger
LogRoller *middleware.LogRoller
Debug bool // if true, errors are written out to client rather than to a log
}
func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
defer h.recovery(w, r)
rcode, err := h.Next.ServeDNS(w, r)
if err != nil {
errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
if h.Debug {
// Write error to response as a txt message instead of to log
answer := debugMsg(rcode, r)
txt, _ := dns.NewRR(". IN 0 TXT " + errMsg)
answer.Answer = append(answer.Answer, txt)
w.WriteMsg(answer)
return 0, err
}
h.Log.Println(errMsg)
}
return rcode, err
}
func (h ErrorHandler) recovery(w dns.ResponseWriter, r *dns.Msg) {
rec := recover()
if rec == nil {
return
}
// Obtain source of panic
// From: https://gist.github.com/swdunlop/9629168
var name, file string // function name, file name
var line int
var pc [16]uintptr
n := runtime.Callers(3, pc[:])
for _, pc := range pc[:n] {
fn := runtime.FuncForPC(pc)
if fn == nil {
continue
}
file, line = fn.FileLine(pc)
name = fn.Name()
if !strings.HasPrefix(name, "runtime.") {
break
}
}
// Trim file path
delim := "/coredns/"
pkgPathPos := strings.Index(file, delim)
if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) {
file = file[pkgPathPos+len(delim):]
}
panicMsg := fmt.Sprintf("%s [PANIC %s %s] %s:%d - %v", time.Now().Format(timeFormat), r.Question[0].Name, dns.Type(r.Question[0].Qtype), file, line, rec)
if h.Debug {
// Write error and stack trace to the response rather than to a log
var stackBuf [4096]byte
stack := stackBuf[:runtime.Stack(stackBuf[:], false)]
answer := debugMsg(dns.RcodeServerFailure, r)
// add stack buf in TXT, limited to 255 chars for now.
txt, _ := dns.NewRR(". IN 0 TXT " + string(stack[:255]))
answer.Answer = append(answer.Answer, txt)
w.WriteMsg(answer)
} else {
// Currently we don't use the function name, since file:line is more conventional
h.Log.Printf(panicMsg)
}
}
// debugMsg creates a debug message that gets send back to the client.
func debugMsg(rcode int, r *dns.Msg) *dns.Msg {
answer := new(dns.Msg)
answer.SetRcode(r, rcode)
return answer
}
const timeFormat = "02/Jan/2006:15:04:05 -0700"

View File

@@ -0,0 +1,168 @@
package errors
import (
"bytes"
"errors"
"fmt"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/miekg/coredns/middleware"
)
func TestErrors(t *testing.T) {
// create a temporary page
path := filepath.Join(os.TempDir(), "errors_test.html")
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
defer os.Remove(path)
const content = "This is a error page"
_, err = f.WriteString(content)
if err != nil {
t.Fatal(err)
}
f.Close()
buf := bytes.Buffer{}
em := ErrorHandler{
ErrorPages: map[int]string{
http.StatusNotFound: path,
http.StatusForbidden: "not_exist_file",
},
Log: log.New(&buf, "", 0),
}
_, notExistErr := os.Open("not_exist_file")
testErr := errors.New("test error")
tests := []struct {
next middleware.Handler
expectedCode int
expectedBody string
expectedLog string
expectedErr error
}{
{
next: genErrorHandler(http.StatusOK, nil, "normal"),
expectedCode: http.StatusOK,
expectedBody: "normal",
expectedLog: "",
expectedErr: nil,
},
{
next: genErrorHandler(http.StatusMovedPermanently, testErr, ""),
expectedCode: http.StatusMovedPermanently,
expectedBody: "",
expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr),
expectedErr: testErr,
},
{
next: genErrorHandler(http.StatusBadRequest, nil, ""),
expectedCode: 0,
expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest,
http.StatusText(http.StatusBadRequest)),
expectedLog: "",
expectedErr: nil,
},
{
next: genErrorHandler(http.StatusNotFound, nil, ""),
expectedCode: 0,
expectedBody: content,
expectedLog: "",
expectedErr: nil,
},
{
next: genErrorHandler(http.StatusForbidden, nil, ""),
expectedCode: 0,
expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden,
http.StatusText(http.StatusForbidden)),
expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n",
http.StatusForbidden, notExistErr),
expectedErr: nil,
},
}
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
for i, test := range tests {
em.Next = test.next
buf.Reset()
rec := httptest.NewRecorder()
code, err := em.ServeHTTP(rec, req)
if err != test.expectedErr {
t.Errorf("Test %d: Expected error %v, but got %v",
i, test.expectedErr, err)
}
if code != test.expectedCode {
t.Errorf("Test %d: Expected status code %d, but got %d",
i, test.expectedCode, code)
}
if body := rec.Body.String(); body != test.expectedBody {
t.Errorf("Test %d: Expected body %q, but got %q",
i, test.expectedBody, body)
}
if log := buf.String(); !strings.Contains(log, test.expectedLog) {
t.Errorf("Test %d: Expected log %q, but got %q",
i, test.expectedLog, log)
}
}
}
func TestVisibleErrorWithPanic(t *testing.T) {
const panicMsg = "I'm a panic"
eh := ErrorHandler{
ErrorPages: make(map[int]string),
Debug: true,
Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
panic(panicMsg)
}),
}
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
code, err := eh.ServeHTTP(rec, req)
if code != 0 {
t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code)
}
if err != nil {
t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err)
}
body := rec.Body.String()
if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") {
t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body)
}
if !strings.Contains(body, panicMsg) {
t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body)
}
if len(body) < 500 {
t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body))
}
}
func genErrorHandler(status int, err error, body string) middleware.Handler {
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
if len(body) > 0 {
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
fmt.Fprint(w, body)
}
return status, err
})
}

0
middleware/etcd/TODO Normal file
View File

10
middleware/exchange.go Normal file
View File

@@ -0,0 +1,10 @@
package middleware
import "github.com/miekg/dns"
// Exchang sends message m to the server.
// TODO(miek): optionally it can do retries of other silly stuff.
func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) {
r, _, err := c.Exchange(m, server)
return r, err
}

89
middleware/file/file.go Normal file
View File

@@ -0,0 +1,89 @@
package file
// TODO(miek): the zone's implementation is basically non-existent
// we return a list and when searching for an answer we iterate
// over the list. This must be moved to a tree-like structure and
// have some fluff for DNSSEC (and be memory efficient).
import (
"strings"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
type (
File struct {
Next middleware.Handler
Zones Zones
// Maybe a list of all zones as well, as a []string?
}
Zone []dns.RR
Zones struct {
Z map[string]Zone // utterly braindead impl. TODO(miek): fix
Names []string
}
)
func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{W: w, Req: r}
qname := context.Name()
zone := middleware.Zones(f.Zones.Names).Matches(qname)
if zone == "" {
return f.Next.ServeDNS(w, r)
}
names, nodata := f.Zones.Z[zone].lookup(qname, context.QType())
var answer *dns.Msg
switch {
case nodata:
answer = context.AnswerMessage()
answer.Ns = names
case len(names) == 0:
answer = context.AnswerMessage()
answer.Ns = names
answer.Rcode = dns.RcodeNameError
case len(names) > 0:
answer = context.AnswerMessage()
answer.Answer = names
default:
answer = context.ErrorMessage(dns.RcodeServerFailure)
}
// Check return size, etc. TODO(miek)
w.WriteMsg(answer)
return 0, nil
}
// Lookup will try to find qname and qtype in z. It returns the
// records found *or* a boolean saying NODATA. If the answer
// is NODATA then the RR returned is the SOA record.
//
// TODO(miek): EXTREMELY STUPID IMPLEMENTATION.
// Doesn't do much, no delegation, no cname, nothing really, etc.
// TODO(miek): even NODATA looks broken
func (z Zone) lookup(qname string, qtype uint16) ([]dns.RR, bool) {
var (
nodata bool
rep []dns.RR
soa dns.RR
)
for _, rr := range z {
if rr.Header().Rrtype == dns.TypeSOA {
soa = rr
}
// Match function in Go DNS?
if strings.ToLower(rr.Header().Name) == qname {
if rr.Header().Rrtype == qtype {
rep = append(rep, rr)
nodata = false
}
}
}
if nodata {
return []dns.RR{soa}, true
}
return rep, false
}

View File

@@ -0,0 +1,325 @@
package file
import (
"errors"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
var ErrCustom = errors.New("Custom Error")
// testFiles is a map with relative paths to test files as keys and file content as values.
// The map represents the following structure:
// - $TEMP/caddy_testdir/
// '-- file1.html
// '-- dirwithindex/
// '---- index.html
// '-- dir/
// '---- file2.html
// '---- hidden.html
var testFiles = map[string]string{
"file1.html": "<h1>file1.html</h1>",
filepath.Join("dirwithindex", "index.html"): "<h1>dirwithindex/index.html</h1>",
filepath.Join("dir", "file2.html"): "<h1>dir/file2.html</h1>",
filepath.Join("dir", "hidden.html"): "<h1>dir/hidden.html</h1>",
}
// TestServeHTTP covers positive scenarios when serving files.
func TestServeHTTP(t *testing.T) {
beforeServeHTTPTest(t)
defer afterServeHTTPTest(t)
fileserver := FileServer(http.Dir(testDir), []string{"hidden.html"})
movedPermanently := "Moved Permanently"
tests := []struct {
url string
expectedStatus int
expectedBodyContent string
}{
// Test 0 - access without any path
{
url: "https://foo",
expectedStatus: http.StatusNotFound,
},
// Test 1 - access root (without index.html)
{
url: "https://foo/",
expectedStatus: http.StatusNotFound,
},
// Test 2 - access existing file
{
url: "https://foo/file1.html",
expectedStatus: http.StatusOK,
expectedBodyContent: testFiles["file1.html"],
},
// Test 3 - access folder with index file with trailing slash
{
url: "https://foo/dirwithindex/",
expectedStatus: http.StatusOK,
expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
},
// Test 4 - access folder with index file without trailing slash
{
url: "https://foo/dirwithindex",
expectedStatus: http.StatusMovedPermanently,
expectedBodyContent: movedPermanently,
},
// Test 5 - access folder without index file
{
url: "https://foo/dir/",
expectedStatus: http.StatusNotFound,
},
// Test 6 - access folder without trailing slash
{
url: "https://foo/dir",
expectedStatus: http.StatusMovedPermanently,
expectedBodyContent: movedPermanently,
},
// Test 6 - access file with trailing slash
{
url: "https://foo/file1.html/",
expectedStatus: http.StatusMovedPermanently,
expectedBodyContent: movedPermanently,
},
// Test 7 - access not existing path
{
url: "https://foo/not_existing",
expectedStatus: http.StatusNotFound,
},
// Test 8 - access a file, marked as hidden
{
url: "https://foo/dir/hidden.html",
expectedStatus: http.StatusNotFound,
},
// Test 9 - access a index file directly
{
url: "https://foo/dirwithindex/index.html",
expectedStatus: http.StatusOK,
expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
},
// Test 10 - send a request with query params
{
url: "https://foo/dir?param1=val",
expectedStatus: http.StatusMovedPermanently,
expectedBodyContent: movedPermanently,
},
}
for i, test := range tests {
responseRecorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.url, strings.NewReader(""))
status, err := fileserver.ServeHTTP(responseRecorder, request)
// check if error matches expectations
if err != nil {
t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err)
}
// check status code
if test.expectedStatus != status {
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
}
// check body content
if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) {
t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String())
}
}
}
// beforeServeHTTPTest creates a test directory with the structure, defined in the variable testFiles
func beforeServeHTTPTest(t *testing.T) {
// make the root test dir
err := os.Mkdir(testDir, os.ModePerm)
if err != nil {
if !os.IsExist(err) {
t.Fatalf("Failed to create test dir. Error was: %v", err)
return
}
}
for relFile, fileContent := range testFiles {
absFile := filepath.Join(testDir, relFile)
// make sure the parent directories exist
parentDir := filepath.Dir(absFile)
_, err = os.Stat(parentDir)
if err != nil {
os.MkdirAll(parentDir, os.ModePerm)
}
// now create the test files
f, err := os.Create(absFile)
if err != nil {
t.Fatalf("Failed to create test file %s. Error was: %v", absFile, err)
return
}
// and fill them with content
_, err = f.WriteString(fileContent)
if err != nil {
t.Fatalf("Failed to write to %s. Error was: %v", absFile, err)
return
}
f.Close()
}
}
// afterServeHTTPTest removes the test dir and all its content
func afterServeHTTPTest(t *testing.T) {
// cleans up everything under the test dir. No need to clean the individual files.
err := os.RemoveAll(testDir)
if err != nil {
t.Fatalf("Failed to clean up test dir %s. Error was: %v", testDir, err)
}
}
// failingFS implements the http.FileSystem interface. The Open method always returns the error, assigned to err
type failingFS struct {
err error // the error to return when Open is called
fileImpl http.File // inject the file implementation
}
// Open returns the assigned failingFile and error
func (f failingFS) Open(path string) (http.File, error) {
return f.fileImpl, f.err
}
// failingFile implements http.File but returns a predefined error on every Stat() method call.
type failingFile struct {
http.File
err error
}
// Stat returns nil FileInfo and the provided error on every call
func (ff failingFile) Stat() (os.FileInfo, error) {
return nil, ff.err
}
// Close is noop and returns no error
func (ff failingFile) Close() error {
return nil
}
// TestServeHTTPFailingFS tests error cases where the Open function fails with various errors.
func TestServeHTTPFailingFS(t *testing.T) {
tests := []struct {
fsErr error
expectedStatus int
expectedErr error
expectedHeaders map[string]string
}{
{
fsErr: os.ErrNotExist,
expectedStatus: http.StatusNotFound,
expectedErr: nil,
},
{
fsErr: os.ErrPermission,
expectedStatus: http.StatusForbidden,
expectedErr: os.ErrPermission,
},
{
fsErr: ErrCustom,
expectedStatus: http.StatusServiceUnavailable,
expectedErr: ErrCustom,
expectedHeaders: map[string]string{"Retry-After": "5"},
},
}
for i, test := range tests {
// initialize a file server with the failing FileSystem
fileserver := FileServer(failingFS{err: test.fsErr}, nil)
// prepare the request and response
request, err := http.NewRequest("GET", "https://foo/", nil)
if err != nil {
t.Fatalf("Failed to build request. Error was: %v", err)
}
responseRecorder := httptest.NewRecorder()
status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
// check the status
if status != test.expectedStatus {
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
}
// check the error
if actualErr != test.expectedErr {
t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
}
// check the headers - a special case for server under load
if test.expectedHeaders != nil && len(test.expectedHeaders) > 0 {
for expectedKey, expectedVal := range test.expectedHeaders {
actualVal := responseRecorder.Header().Get(expectedKey)
if expectedVal != actualVal {
t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal)
}
}
}
}
}
// TestServeHTTPFailingStat tests error cases where the initial Open function succeeds, but the Stat method on the opened file fails.
func TestServeHTTPFailingStat(t *testing.T) {
tests := []struct {
statErr error
expectedStatus int
expectedErr error
}{
{
statErr: os.ErrNotExist,
expectedStatus: http.StatusNotFound,
expectedErr: nil,
},
{
statErr: os.ErrPermission,
expectedStatus: http.StatusForbidden,
expectedErr: os.ErrPermission,
},
{
statErr: ErrCustom,
expectedStatus: http.StatusInternalServerError,
expectedErr: ErrCustom,
},
}
for i, test := range tests {
// initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will
fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil)
// prepare the request and response
request, err := http.NewRequest("GET", "https://foo/", nil)
if err != nil {
t.Fatalf("Failed to build request. Error was: %v", err)
}
responseRecorder := httptest.NewRecorder()
status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
// check the status
if status != test.expectedStatus {
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
}
// check the error
if actualErr != test.expectedErr {
t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
}
}
}

22
middleware/host.go Normal file
View File

@@ -0,0 +1,22 @@
package middleware
import (
"net"
"strings"
"github.com/miekg/dns"
)
// Host represents a host from the Caddyfile, may contain port.
type Host string
// Standard host will return the host portion of host, stripping
// of any port. The host will also be fully qualified and lowercased.
func (h Host) StandardHost() string {
// separate host and port
host, _, err := net.SplitHostPort(string(h))
if err != nil {
host, _, _ = net.SplitHostPort(string(h) + ":")
}
return strings.ToLower(dns.Fqdn(host))
}

66
middleware/log/log.go Normal file
View File

@@ -0,0 +1,66 @@
// Package log implements basic but useful request (access) logging middleware.
package log
import (
"log"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
// Logger is a basic request logging middleware.
type Logger struct {
Next middleware.Handler
Rules []Rule
ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
}
func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
for _, rule := range l.Rules {
/*
if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
responseRecorder := middleware.NewResponseRecorder(w)
status, err := l.Next.ServeHTTP(responseRecorder, r)
if status >= 400 {
// There was an error up the chain, but no response has been written yet.
// The error must be handled here so the log entry will record the response size.
if l.ErrorFunc != nil {
l.ErrorFunc(responseRecorder, r, status)
} else {
// Default failover error handler
responseRecorder.WriteHeader(status)
fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status))
}
status = 0
}
rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue)
rule.Log.Println(rep.Replace(rule.Format))
return status, err
}
*/
rule = rule
}
return l.Next.ServeDNS(w, r)
}
// Rule configures the logging middleware.
type Rule struct {
PathScope string
OutputFile string
Format string
Log *log.Logger
Roller *middleware.LogRoller
}
const (
// DefaultLogFilename is the default log filename.
DefaultLogFilename = "access.log"
// CommonLogFormat is the common log format.
CommonLogFormat = `{remote} ` + CommonLogEmptyValue + ` [{when}] "{type} {name} {proto}" {rcode} {size}`
// CommonLogEmptyValue is the common empty log value.
CommonLogEmptyValue = "-"
// CombinedLogFormat is the combined log format.
CombinedLogFormat = CommonLogFormat + ` "{>Referer}" "{>User-Agent}"` // Something here as well
// DefaultLogFormat is the default log format.
DefaultLogFormat = CommonLogFormat
)

View File

@@ -0,0 +1,48 @@
package log
import (
"bytes"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type erroringMiddleware struct{}
func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
return http.StatusNotFound, nil
}
func TestLoggedStatus(t *testing.T) {
var f bytes.Buffer
var next erroringMiddleware
rule := Rule{
PathScope: "/",
Format: DefaultLogFormat,
Log: log.New(&f, "", 0),
}
logger := Logger{
Rules: []Rule{rule},
Next: next,
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
status, err := logger.ServeHTTP(rec, r)
if status != 0 {
t.Error("Expected status to be 0 - was", status)
}
logged := f.String()
if !strings.Contains(logged, "404 13") {
t.Error("Expected 404 to be logged. Logged string -", logged)
}
}

105
middleware/middleware.go Normal file
View File

@@ -0,0 +1,105 @@
// Package middleware provides some types and functions common among middleware.
package middleware
import (
"time"
"github.com/miekg/dns"
)
type (
// Middleware is the middle layer which represents the traditional
// idea of middleware: it chains one Handler to the next by being
// passed the next Handler in the chain.
Middleware func(Handler) Handler
// Handler is like dns.Handler except ServeDNS may return an rcode
// and/or error.
//
// If ServeDNS writes to the response body, it should return a status
// code of 0. This signals to other handlers above it that the response
// body is already written, and that they should not write to it also.
//
// If ServeDNS encounters an error, it should return the error value
// so it can be logged by designated error-handling middleware.
//
// If writing a response after calling another ServeDNS method, the
// returned rcode SHOULD be used when writing the response.
//
// If handling errors after calling another ServeDNS method, the
// returned error value SHOULD be logged or handled accordingly.
//
// Otherwise, return values should be propagated down the middleware
// chain by returning them unchanged.
Handler interface {
ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error)
}
// HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error. See Handler
// documentation for more information.
HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error)
)
// ServeDNS implements the Handler interface.
func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(w, r)
}
// IndexFile looks for a file in /root/fpath/indexFile for each string
// in indexFiles. If an index file is found, it returns the root-relative
// path to the file and true. If no index file is found, empty string
// and false is returned. fpath must end in a forward slash '/'
// otherwise no index files will be tried (directory paths must end
// in a forward slash according to HTTP).
//
// All paths passed into and returned from this function use '/' as the
// path separator, just like URLs. IndexFle handles path manipulation
// internally for systems that use different path separators.
/*
func IndexFile(root http.FileSystem, fpath string, indexFiles []string) (string, bool) {
if fpath[len(fpath)-1] != '/' || root == nil {
return "", false
}
for _, indexFile := range indexFiles {
// func (http.FileSystem).Open wants all paths separated by "/",
// regardless of operating system convention, so use
// path.Join instead of filepath.Join
fp := path.Join(fpath, indexFile)
f, err := root.Open(fp)
if err == nil {
f.Close()
return fp, true
}
}
return "", false
}
// SetLastModifiedHeader checks if the provided modTime is valid and if it is sets it
// as a Last-Modified header to the ResponseWriter. If the modTime is in the future
// the current time is used instead.
func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) {
if modTime.IsZero() || modTime.Equal(time.Unix(0, 0)) {
// the time does not appear to be valid. Don't put it in the response
return
}
// RFC 2616 - Section 14.29 - Last-Modified:
// An origin server MUST NOT send a Last-Modified date which is later than the
// server's time of message origination. In such cases, where the resource's last
// modification would indicate some time in the future, the server MUST replace
// that date with the message origination date.
now := currentTime()
if modTime.After(now) {
modTime = now
}
w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat))
}
*/
// currentTime, as it is defined here, returns time.Now().
// It's defined as a variable for mocking time in tests.
var currentTime = func() time.Time {
return time.Now()
}

View File

@@ -0,0 +1,108 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestIndexfile(t *testing.T) {
tests := []struct {
rootDir http.FileSystem
fpath string
indexFiles []string
shouldErr bool
expectedFilePath string //retun value
expectedBoolValue bool //return value
}{
{
http.Dir("./templates/testdata"),
"/images/",
[]string{"img.htm"},
false,
"/images/img.htm",
true,
},
}
for i, test := range tests {
actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles)
if actualBoolValue == true && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if actualBoolValue != true && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist")
}
if actualFilePath != test.expectedFilePath {
t.Fatalf("Test %d expected returned filepath to be %s, but got %s ",
i, test.expectedFilePath, actualFilePath)
}
if actualBoolValue != test.expectedBoolValue {
t.Fatalf("Test %d expected returned bool value to be %v, but got %v ",
i, test.expectedBoolValue, actualBoolValue)
}
}
}
func TestSetLastModified(t *testing.T) {
nowTime := time.Now()
// ovewrite the function to return reliable time
originalGetCurrentTimeFunc := currentTime
currentTime = func() time.Time {
return nowTime
}
defer func() {
currentTime = originalGetCurrentTimeFunc
}()
pastTime := nowTime.Truncate(1 * time.Hour)
futureTime := nowTime.Add(1 * time.Hour)
tests := []struct {
inputModTime time.Time
expectedIsHeaderSet bool
expectedLastModified string
}{
{
inputModTime: pastTime,
expectedIsHeaderSet: true,
expectedLastModified: pastTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: nowTime,
expectedIsHeaderSet: true,
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: futureTime,
expectedIsHeaderSet: true,
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: time.Time{},
expectedIsHeaderSet: false,
},
}
for i, test := range tests {
responseRecorder := httptest.NewRecorder()
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
SetLastModifiedHeader(responseRecorder, test.inputModTime)
actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified")
if test.expectedIsHeaderSet && actualLastModifiedHeader == "" {
t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing")
}
if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" {
t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader)
}
if test.expectedLastModified != actualLastModifiedHeader {
t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader)
}
}
}

18
middleware/path.go Normal file
View File

@@ -0,0 +1,18 @@
package middleware
import "strings"
// TODO(miek): matches for names.
// Path represents a URI path, maybe with pattern characters.
type Path string
// Matches checks to see if other matches p.
//
// Path matching will probably not always be a direct
// comparison; this method assures that paths can be
// easily and consistently matched.
func (p Path) Matches(other string) bool {
return strings.HasPrefix(string(p), other)
}

View File

@@ -0,0 +1,31 @@
package metrics
import (
"strconv"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{W: w, Req: r}
qname := context.Name()
qtype := context.Type()
zone := middleware.Zones(m.ZoneNames).Matches(qname)
if zone == "" {
zone = "."
}
// Record response to get status code and size of the reply.
rw := middleware.NewResponseRecorder(w)
status, err := m.Next.ServeDNS(rw, r)
requestCount.WithLabelValues(zone, qtype).Inc()
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
responseSize.WithLabelValues(zone).Observe(float64(rw.Size()))
responseRcode.WithLabelValues(zone, strconv.Itoa(rw.Rcode())).Inc()
return status, err
}

View File

@@ -0,0 +1,80 @@
package metrics
import (
"fmt"
"net/http"
"sync"
"github.com/miekg/coredns/middleware"
"github.com/prometheus/client_golang/prometheus"
)
const namespace = "daddy"
var (
requestCount *prometheus.CounterVec
requestDuration *prometheus.HistogramVec
responseSize *prometheus.HistogramVec
responseRcode *prometheus.CounterVec
)
const path = "/metrics"
// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics
type Metrics struct {
Next middleware.Handler
Addr string // where to we listen
Once sync.Once
ZoneNames []string
}
func (m *Metrics) Start() error {
m.Once.Do(func() {
define("")
prometheus.MustRegister(requestCount)
prometheus.MustRegister(requestDuration)
prometheus.MustRegister(responseSize)
prometheus.MustRegister(responseRcode)
http.Handle(path, prometheus.Handler())
go func() {
fmt.Errorf("%s", http.ListenAndServe(m.Addr, nil))
}()
})
return nil
}
func define(subsystem string) {
if subsystem == "" {
subsystem = "dns"
}
requestCount = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "request_count_total",
Help: "Counter of DNS requests made per zone and type.",
}, []string{"zone", "qtype"})
requestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "request_duration_seconds",
Help: "Histogram of the time (in seconds) each request took.",
}, []string{"zone"})
responseSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "response_size_bytes",
Help: "Size of the returns response in bytes.",
Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3},
}, []string{"zone"})
responseRcode = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "rcode_code_count_total",
Help: "Counter of response status codes.",
}, []string{"zone", "rcode"})
}

101
middleware/proxy/policy.go Normal file
View File

@@ -0,0 +1,101 @@
package proxy
import (
"math/rand"
"sync/atomic"
)
// HostPool is a collection of UpstreamHosts.
type HostPool []*UpstreamHost
// Policy decides how a host will be selected from a pool.
type Policy interface {
Select(pool HostPool) *UpstreamHost
}
func init() {
RegisterPolicy("random", func() Policy { return &Random{} })
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
}
// Random is a policy that selects up hosts from a pool at random.
type Random struct{}
// Select selects an up host at random from the specified pool.
func (r *Random) Select(pool HostPool) *UpstreamHost {
// instead of just generating a random index
// this is done to prevent selecting a down host
var randHost *UpstreamHost
count := 0
for _, host := range pool {
if host.Down() {
continue
}
count++
if count == 1 {
randHost = host
} else {
r := rand.Int() % count
if r == (count - 1) {
randHost = host
}
}
}
return randHost
}
// LeastConn is a policy that selects the host with the least connections.
type LeastConn struct{}
// Select selects the up host with the least number of connections in the
// pool. If more than one host has the same least number of connections,
// one of the hosts is chosen at random.
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
var bestHost *UpstreamHost
count := 0
leastConn := int64(1<<63 - 1)
for _, host := range pool {
if host.Down() {
continue
}
hostConns := host.Conns
if hostConns < leastConn {
bestHost = host
leastConn = hostConns
count = 1
} else if hostConns == leastConn {
// randomly select host among hosts with least connections
count++
if count == 1 {
bestHost = host
} else {
r := rand.Int() % count
if r == (count - 1) {
bestHost = host
}
}
}
}
return bestHost
}
// RoundRobin is a policy that selects hosts based on round robin ordering.
type RoundRobin struct {
Robin uint32
}
// Select selects an up host from the pool using a round robin ordering scheme.
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
poolLen := uint32(len(pool))
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
host := pool[selection]
// if the currently selected host is down, just ffwd to up host
for i := uint32(1); host.Down() && i < poolLen; i++ {
host = pool[(selection+i)%poolLen]
}
if host.Down() {
return nil
}
return host
}

View File

@@ -0,0 +1,87 @@
package proxy
import (
"net/http"
"net/http/httptest"
"os"
"testing"
)
var workableServer *httptest.Server
func TestMain(m *testing.M) {
workableServer = httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// do nothing
}))
r := m.Run()
workableServer.Close()
os.Exit(r)
}
type customPolicy struct{}
func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
return pool[0]
}
func testPool() HostPool {
pool := []*UpstreamHost{
{
Name: workableServer.URL, // this should resolve (healthcheck test)
},
{
Name: "http://shouldnot.resolve", // this shouldn't
},
{
Name: "http://C",
},
}
return HostPool(pool)
}
func TestRoundRobinPolicy(t *testing.T) {
pool := testPool()
rrPolicy := &RoundRobin{}
h := rrPolicy.Select(pool)
// First selected host is 1, because counter starts at 0
// and increments before host is selected
if h != pool[1] {
t.Error("Expected first round robin host to be second host in the pool.")
}
h = rrPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected second round robin host to be third host in the pool.")
}
// mark host as down
pool[0].Unhealthy = true
h = rrPolicy.Select(pool)
if h != pool[1] {
t.Error("Expected third round robin host to be first host in the pool.")
}
}
func TestLeastConnPolicy(t *testing.T) {
pool := testPool()
lcPolicy := &LeastConn{}
pool[0].Conns = 10
pool[1].Conns = 10
h := lcPolicy.Select(pool)
if h != pool[2] {
t.Error("Expected least connection host to be third host.")
}
pool[2].Conns = 100
h = lcPolicy.Select(pool)
if h != pool[0] && h != pool[1] {
t.Error("Expected least connection host to be first or second host.")
}
}
func TestCustomPolicy(t *testing.T) {
pool := testPool()
customPolicy := &customPolicy{}
h := customPolicy.Select(pool)
if h != pool[0] {
t.Error("Expected custom policy host to be the first host.")
}
}

120
middleware/proxy/proxy.go Normal file
View File

@@ -0,0 +1,120 @@
// Package proxy is middleware that proxies requests.
package proxy
import (
"errors"
"net/http"
"sync/atomic"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
var errUnreachable = errors.New("unreachable backend")
// Proxy represents a middleware instance that can proxy requests.
type Proxy struct {
Next middleware.Handler
Client Client
Upstreams []Upstream
}
type Client struct {
UDP *dns.Client
TCP *dns.Client
}
// Upstream manages a pool of proxy upstream hosts. Select should return a
// suitable upstream host, or nil if no such hosts are available.
type Upstream interface {
// The domain name this upstream host should be routed on.
From() string
// Selects an upstream host to be routed to.
Select() *UpstreamHost
// Checks if subpdomain is not an ignored.
IsAllowedPath(string) bool
}
// UpstreamHostDownFunc can be used to customize how Down behaves.
type UpstreamHostDownFunc func(*UpstreamHost) bool
// UpstreamHost represents a single proxy upstream
type UpstreamHost struct {
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
Name string // IP address (and port) of this upstream host
Fails int32
FailTimeout time.Duration
Unhealthy bool
ExtraHeaders http.Header
CheckDown UpstreamHostDownFunc
WithoutPathPrefix string
}
// Down checks whether the upstream host is down or not.
// Down will try to use uh.CheckDown first, and will fall
// back to some default criteria if necessary.
func (uh *UpstreamHost) Down() bool {
if uh.CheckDown == nil {
// Default settings
return uh.Unhealthy || uh.Fails > 0
}
return uh.CheckDown(uh)
}
// tryDuration is how long to try upstream hosts; failures result in
// immediate retries until this duration ends or we get a nil host.
var tryDuration = 60 * time.Second
// ServeDNS satisfies the middleware.Handler interface.
func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
for _, upstream := range p.Upstreams {
// allowed bla bla bla TODO(miek): fix full proxy spec from caddy
start := time.Now()
// Since Select() should give us "up" hosts, keep retrying
// hosts until timeout (or until we get a nil host).
for time.Now().Sub(start) < tryDuration {
host := upstream.Select()
if host == nil {
return dns.RcodeServerFailure, errUnreachable
}
// TODO(miek): PORT!
reverseproxy := ReverseProxy{Host: host.Name, Client: p.Client}
atomic.AddInt64(&host.Conns, 1)
backendErr := reverseproxy.ServeDNS(w, r, nil)
atomic.AddInt64(&host.Conns, -1)
if backendErr == nil {
return 0, nil
}
timeout := host.FailTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
atomic.AddInt32(&host.Fails, 1)
go func(host *UpstreamHost, timeout time.Duration) {
time.Sleep(timeout)
atomic.AddInt32(&host.Fails, -1)
}(host, timeout)
}
return dns.RcodeServerFailure, errUnreachable
}
return p.Next.ServeDNS(w, r)
}
func Clients() Client {
udp := newClient("udp", defaultTimeout)
tcp := newClient("tcp", defaultTimeout)
return Client{UDP: udp, TCP: tcp}
}
// newClient returns a new client for proxy requests.
func newClient(net string, timeout time.Duration) *dns.Client {
if timeout == 0 {
timeout = defaultTimeout
}
return &dns.Client{Net: net, ReadTimeout: timeout, WriteTimeout: timeout, SingleInflight: true}
}
const defaultTimeout = 5 * time.Second

View File

@@ -0,0 +1,317 @@
package proxy
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"golang.org/x/net/websocket"
)
func init() {
tryDuration = 50 * time.Millisecond // prevent tests from hanging
}
func TestReverseProxy(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var requestReceived bool
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestReceived = true
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
// set up proxy
p := &Proxy{
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
if !requestReceived {
t.Error("Expected backend to receive request, but it didn't")
}
}
func TestReverseProxyInsecureSkipVerify(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var requestReceived bool
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestReceived = true
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
// set up proxy
p := &Proxy{
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
if !requestReceived {
t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't")
}
}
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
// No-op websocket backend simply allows the WS connection to be
// accepted then it will be immediately closed. Perfect for testing.
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {}))
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL)
// Create client request
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
r.Header = http.Header{
"Connection": {"Upgrade"},
"Upgrade": {"websocket"},
"Origin": {wsNop.URL},
"Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
"Sec-WebSocket-Version": {"13"},
}
// Capture the request
w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)}
// Booya! Do the test.
p.ServeHTTP(w, r)
// Make sure the backend accepted the WS connection.
// Mostly interested in the Upgrade and Connection response headers
// and the 101 status code.
expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n")
actual := w.fakeConn.writeBuf.Bytes()
if !bytes.Equal(actual, expected) {
t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
}
}
func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
// Echo server allows us to test that socket bytes are properly
// being proxied.
wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
io.Copy(ws, ws)
}))
defer wsEcho.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsEcho.URL)
// This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our
// WS client will connect to this test server, not
// the echo client directly.
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
// Set up WebSocket client
url := strings.Replace(echoProxy.URL, "http://", "ws://", 1)
ws, err := websocket.Dial(url, "", echoProxy.URL)
if err != nil {
t.Fatal(err)
}
defer ws.Close()
// Send test message
trialMsg := "Is it working?"
websocket.Message.Send(ws, trialMsg)
// It should be echoed back to us
var actualMsg string
websocket.Message.Receive(ws, &actualMsg)
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func TestUnixSocketProxy(t *testing.T) {
if runtime.GOOS == "windows" {
return
}
trialMsg := "Is it working?"
var proxySuccess bool
// This is our fake "application" we want to proxy to
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Request was proxied when this is called
proxySuccess = true
fmt.Fprint(w, trialMsg)
}))
// Get absolute path for unix: socket
socketPath, err := filepath.Abs("./test_socket")
if err != nil {
t.Fatalf("Unable to get absolute path: %v", err)
}
// Change httptest.Server listener to listen to unix: socket
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL)
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
actualMsg := fmt.Sprintf("%s", greeting)
if !proxySuccess {
t.Errorf("Expected request to be proxied, but it wasn't")
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name)
u := &fakeUpstream{
name: name,
host: &UpstreamHost{
Name: name,
ReverseProxy: NewSingleHostReverseProxy(uri, ""),
},
}
if insecure {
u.host.ReverseProxy.Transport = InsecureTransport
}
return u
}
type fakeUpstream struct {
name string
host *UpstreamHost
}
func (u *fakeUpstream) From() string {
return "/"
}
func (u *fakeUpstream) Select() *UpstreamHost {
return u.host
}
func (u *fakeUpstream) IsAllowedPath(requestPath string) bool {
return true
}
// newWebSocketTestProxy returns a test proxy that will
// redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket
// proxy.
func newWebSocketTestProxy(backendAddr string) *Proxy {
return &Proxy{
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}},
}
}
type fakeWsUpstream struct {
name string
}
func (u *fakeWsUpstream) From() string {
return "/"
}
func (u *fakeWsUpstream) Select() *UpstreamHost {
uri, _ := url.Parse(u.name)
return &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, ""),
ExtraHeaders: http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}},
}
}
func (u *fakeWsUpstream) IsAllowedPath(requestPath string) bool {
return true
}
// recorderHijacker is a ResponseRecorder that can
// be hijacked.
type recorderHijacker struct {
*httptest.ResponseRecorder
fakeConn *fakeConn
}
func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return rh.fakeConn, nil, nil
}
type fakeConn struct {
readBuf bytes.Buffer
writeBuf bytes.Buffer
}
func (c *fakeConn) LocalAddr() net.Addr { return nil }
func (c *fakeConn) RemoteAddr() net.Addr { return nil }
func (c *fakeConn) SetDeadline(t time.Time) error { return nil }
func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil }
func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }

View File

@@ -0,0 +1,36 @@
// Package proxy is middleware that proxies requests.
package proxy
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
type ReverseProxy struct {
Host string
Client Client
}
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
// TODO(miek): use extra!
var (
reply *dns.Msg
err error
)
context := middleware.Context{W: w, Req: r}
// tls+tcp ?
if context.Proto() == "tcp" {
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
} else {
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
}
if err != nil {
return err
}
reply.Compress = true
reply.Id = r.Id
w.WriteMsg(reply)
return nil
}

View File

@@ -0,0 +1,235 @@
package proxy
import (
"io"
"io/ioutil"
"net/http"
"path"
"strconv"
"time"
"github.com/miekg/coredns/core/parse"
"github.com/miekg/coredns/middleware"
)
var (
supportedPolicies = make(map[string]func() Policy)
)
type staticUpstream struct {
from string
// TODO(miek): allows use to added headers
proxyHeaders http.Header // TODO(miek): kill
Hosts HostPool
Policy Policy
FailTimeout time.Duration
MaxFails int32
HealthCheck struct {
Path string
Interval time.Duration
}
WithoutPathPrefix string
IgnoredSubPaths []string
}
// NewStaticUpstreams parses the configuration input and sets up
// static upstreams for the proxy middleware.
func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
var upstreams []Upstream
for c.Next() {
upstream := &staticUpstream{
from: "",
proxyHeaders: make(http.Header),
Hosts: nil,
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
if !c.Args(&upstream.from) {
return upstreams, c.ArgErr()
}
to := c.RemainingArgs()
if len(to) == 0 {
return upstreams, c.ArgErr()
}
for c.NextBlock() {
if err := parseBlock(&c, upstream); err != nil {
return upstreams, err
}
}
upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to {
uh := &UpstreamHost{
Name: host,
Conns: 0,
Fails: 0,
FailTimeout: upstream.FailTimeout,
Unhealthy: false,
ExtraHeaders: upstream.proxyHeaders,
CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool {
if uh.Unhealthy {
return true
}
if uh.Fails >= upstream.MaxFails &&
upstream.MaxFails != 0 {
return true
}
return false
}
}(upstream),
WithoutPathPrefix: upstream.WithoutPathPrefix,
}
upstream.Hosts[i] = uh
}
if upstream.HealthCheck.Path != "" {
go upstream.HealthCheckWorker(nil)
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
// RegisterPolicy adds a custom policy to the proxy.
func RegisterPolicy(name string, policy func() Policy) {
supportedPolicies[name] = policy
}
func (u *staticUpstream) From() string {
return u.from
}
func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
switch c.Val() {
case "policy":
if !c.NextArg() {
return c.ArgErr()
}
policyCreateFunc, ok := supportedPolicies[c.Val()]
if !ok {
return c.ArgErr()
}
u.Policy = policyCreateFunc()
case "fail_timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.FailTimeout = dur
case "max_fails":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
u.MaxFails = int32(n)
case "health_check":
if !c.NextArg() {
return c.ArgErr()
}
u.HealthCheck.Path = c.Val()
u.HealthCheck.Interval = 30 * time.Second
if c.NextArg() {
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.HealthCheck.Interval = dur
}
case "proxy_header":
var header, value string
if !c.Args(&header, &value) {
return c.ArgErr()
}
u.proxyHeaders.Add(header, value)
case "websocket":
u.proxyHeaders.Add("Connection", "{>Connection}")
u.proxyHeaders.Add("Upgrade", "{>Upgrade}")
case "without":
if !c.NextArg() {
return c.ArgErr()
}
u.WithoutPathPrefix = c.Val()
case "except":
ignoredPaths := c.RemainingArgs()
if len(ignoredPaths) == 0 {
return c.ArgErr()
}
u.IgnoredSubPaths = ignoredPaths
default:
return c.Errf("unknown property '%s'", c.Val())
}
return nil
}
func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts {
hostURL := host.Name + u.HealthCheck.Path
if r, err := http.Get(hostURL); err == nil {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
} else {
host.Unhealthy = true
}
}
}
func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
ticker := time.NewTicker(u.HealthCheck.Interval)
u.healthCheck()
for {
select {
case <-ticker.C:
u.healthCheck()
case <-stop:
// TODO: the library should provide a stop channel and global
// waitgroup to allow goroutines started by plugins a chance
// to clean themselves up.
}
}
}
func (u *staticUpstream) Select() *UpstreamHost {
pool := u.Hosts
if len(pool) == 1 {
if pool[0].Down() {
return nil
}
return pool[0]
}
allDown := true
for _, host := range pool {
if !host.Down() {
allDown = false
break
}
}
if allDown {
return nil
}
if u.Policy == nil {
return (&Random{}).Select(pool)
}
return u.Policy.Select(pool)
}
func (u *staticUpstream) IsAllowedPath(requestPath string) bool {
for _, ignoredSubPath := range u.IgnoredSubPaths {
if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
return false
}
}
return true
}

View File

@@ -0,0 +1,83 @@
package proxy
import (
"testing"
"time"
)
func TestHealthCheck(t *testing.T) {
upstream := &staticUpstream{
from: "",
Hosts: testPool(),
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
upstream.healthCheck()
if upstream.Hosts[0].Down() {
t.Error("Expected first host in testpool to not fail healthcheck.")
}
if !upstream.Hosts[1].Down() {
t.Error("Expected second host in testpool to fail healthcheck.")
}
}
func TestSelect(t *testing.T) {
upstream := &staticUpstream{
from: "",
Hosts: testPool()[:3],
Policy: &Random{},
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
upstream.Hosts[0].Unhealthy = true
upstream.Hosts[1].Unhealthy = true
upstream.Hosts[2].Unhealthy = true
if h := upstream.Select(); h != nil {
t.Error("Expected select to return nil as all host are down")
}
upstream.Hosts[2].Unhealthy = false
if h := upstream.Select(); h == nil {
t.Error("Expected select to not return nil")
}
}
func TestRegisterPolicy(t *testing.T) {
name := "custom"
customPolicy := &customPolicy{}
RegisterPolicy(name, func() Policy { return customPolicy })
if _, ok := supportedPolicies[name]; !ok {
t.Error("Expected supportedPolicies to have a custom policy.")
}
}
func TestAllowedPaths(t *testing.T) {
upstream := &staticUpstream{
from: "/proxy",
IgnoredSubPaths: []string{"/download", "/static"},
}
tests := []struct {
url string
expected bool
}{
{"/proxy", true},
{"/proxy/dl", true},
{"/proxy/download", false},
{"/proxy/download/static", false},
{"/proxy/static", false},
{"/proxy/static/download", false},
{"/proxy/something/download", true},
{"/proxy/something/static", true},
{"/proxy//static", false},
{"/proxy//static//download", false},
{"/proxy//download", false},
}
for i, test := range tests {
isAllowed := upstream.IsAllowedPath(test.url)
if test.expected != isAllowed {
t.Errorf("Test %d: expected %v found %v", i+1, test.expected, isAllowed)
}
}
}

70
middleware/recorder.go Normal file
View File

@@ -0,0 +1,70 @@
package middleware
import (
"time"
"github.com/miekg/dns"
)
// ResponseRecorder is a type of ResponseWriter that captures
// the rcode code written to it and also the size of the message
// written in the response. A rcode code does not have
// to be written, however, in which case 0 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
type ResponseRecorder struct {
dns.ResponseWriter
rcode int
size int
start time.Time
}
// NewResponseRecorder makes and returns a new responseRecorder,
// which captures the DNS rcode from the ResponseWriter
// and also the length of the response message written through it.
func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder {
return &ResponseRecorder{
ResponseWriter: w,
rcode: 0,
start: time.Now(),
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error {
r.rcode = res.Rcode
r.size = res.Len()
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.size += n
}
return n, err
}
// Size returns the size.
func (r *ResponseRecorder) Size() int {
return r.size
}
// Rcode returns the rcode.
func (r *ResponseRecorder) Rcode() int {
return r.rcode
}
// Start returns the start time of the ResponseRecorder.
func (r *ResponseRecorder) Start() time.Time {
return r.start
}
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *ResponseRecorder) Hijack() {
r.ResponseWriter.Hijack()
return
}

View File

@@ -0,0 +1,32 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestNewResponseRecorder(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
if !(recordRequest.ResponseWriter == w) {
t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
}
if recordRequest.status != http.StatusOK {
t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
}
}
func TestWrite(t *testing.T) {
w := httptest.NewRecorder()
responseTestString := "test"
recordRequest := NewResponseRecorder(w)
buf := []byte(responseTestString)
recordRequest.Write(buf)
if recordRequest.size != len(buf) {
t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
}
if w.Body.String() != responseTestString {
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
}
}

View File

@@ -0,0 +1,84 @@
// Reflect provides middleware that reflects back some client properties.
// This is the default middleware when Caddy is run without configuration.
//
// The left-most label must be `who`.
// When queried for type A (resp. AAAA), it sends back the IPv4 (resp. v6) address.
// In the additional section the port number and transport are shown.
// Basic use pattern:
//
// dig @localhost -p 1053 who.miek.nl A
//
// ;; ANSWER SECTION:
// who.miek.nl. 0 IN A 127.0.0.1
//
// ;; ADDITIONAL SECTION:
// who.miek.nl. 0 IN TXT "Port: 56195 (udp)"
package reflect
import (
"errors"
"net"
"strings"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
type Reflect struct {
Next middleware.Handler
}
func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{Req: r, W: w}
class := r.Question[0].Qclass
qname := r.Question[0].Name
i, ok := dns.NextLabel(qname, 0)
if strings.ToLower(qname[:i]) != who || ok {
err := context.ErrorMessage(dns.RcodeFormatError)
w.WriteMsg(err)
return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
}
answer := new(dns.Msg)
answer.SetReply(r)
answer.Compress = true
answer.Authoritative = true
ip := context.IP()
proto := context.Proto()
port, _ := context.Port()
family := context.Family()
var rr dns.RR
switch family {
case 1:
rr = new(dns.A)
rr.(*dns.A).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeA, Class: class, Ttl: 0}
rr.(*dns.A).A = net.ParseIP(ip).To4()
case 2:
rr = new(dns.AAAA)
rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeAAAA, Class: class, Ttl: 0}
rr.(*dns.AAAA).AAAA = net.ParseIP(ip)
}
t := new(dns.TXT)
t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
t.Txt = []string{"Port: " + port + " (" + proto + ")"}
switch context.Type() {
case "TXT":
answer.Answer = append(answer.Answer, t)
answer.Extra = append(answer.Extra, rr)
default:
fallthrough
case "AAAA", "A":
answer.Answer = append(answer.Answer, rr)
answer.Extra = append(answer.Extra, t)
}
w.WriteMsg(answer)
return 0, nil
}
const who = "who."

View File

@@ -0,0 +1 @@
package reflect

98
middleware/replacer.go Normal file
View File

@@ -0,0 +1,98 @@
package middleware
import (
"strconv"
"strings"
"time"
"github.com/miekg/dns"
)
// Replacer is a type which can replace placeholder
// substrings in a string with actual values from a
// http.Request and responseRecorder. Always use
// NewReplacer to get one of these.
type Replacer interface {
Replace(string) string
Set(key, value string)
}
type replacer struct {
replacements map[string]string
emptyValue string
}
// NewReplacer makes a new replacer based on r and rr.
// Do not create a new replacer until r and rr have all
// the needed values, because this function copies those
// values into the replacer. rr may be nil if it is not
// available. emptyValue should be the string that is used
// in place of empty string (can still be empty string).
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
context := Context{W: rr, Req: r}
rep := replacer{
replacements: map[string]string{
"{type}": context.Type(),
"{name}": context.Name(),
"{class}": context.Class(),
"{proto}": context.Proto(),
"{when}": func() string {
return time.Now().Format(timeFormat)
}(),
"{remote}": context.IP(),
"{port}": func() string {
p, _ := context.Port()
return p
}(),
},
emptyValue: emptyValue,
}
if rr != nil {
rep.replacements["{rcode}"] = strconv.Itoa(rr.rcode)
rep.replacements["{size}"] = strconv.Itoa(rr.size)
rep.replacements["{latency}"] = time.Since(rr.start).String()
}
return rep
}
// Replace performs a replacement of values on s and returns
// the string with the replaced values.
func (r replacer) Replace(s string) string {
// Header replacements - these are case-insensitive, so we can't just use strings.Replace()
for strings.Contains(s, headerReplacer) {
idxStart := strings.Index(s, headerReplacer)
endOffset := idxStart + len(headerReplacer)
idxEnd := strings.Index(s[endOffset:], "}")
if idxEnd > -1 {
placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1])
replacement := r.replacements[placeholder]
if replacement == "" {
replacement = r.emptyValue
}
s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:]
} else {
break
}
}
// Regular replacements - these are easier because they're case-sensitive
for placeholder, replacement := range r.replacements {
if replacement == "" {
replacement = r.emptyValue
}
s = strings.Replace(s, placeholder, replacement, -1)
}
return s
}
// Set sets key to value in the replacements map.
func (r replacer) Set(key, value string) {
r.replacements["{"+key+"}"] = value
}
const (
timeFormat = "02/Jan/2006:15:04:05 -0700"
headerReplacer = "{>"
)

124
middleware/replacer_test.go Normal file
View File

@@ -0,0 +1,124 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestNewReplacer(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost", reader)
if err != nil {
t.Fatal("Request Formation Failed\n")
}
replaceValues := NewReplacer(request, recordRequest, "")
switch v := replaceValues.(type) {
case replacer:
if v.replacements["{host}"] != "localhost" {
t.Error("Expected host to be localhost")
}
if v.replacements["{method}"] != "POST" {
t.Error("Expected request method to be POST")
}
if v.replacements["{status}"] != "200" {
t.Error("Expected status to be 200")
}
default:
t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n")
}
}
func TestReplace(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost", reader)
if err != nil {
t.Fatal("Request Formation Failed\n")
}
request.Header.Set("Custom", "foobarbaz")
request.Header.Set("ShorterVal", "1")
repl := NewReplacer(request, recordRequest, "-")
if expected, actual := "This host is localhost.", repl.Replace("This host is {host}."); expected != actual {
t.Errorf("{host} replacement: expected '%s', got '%s'", expected, actual)
}
if expected, actual := "This request method is POST.", repl.Replace("This request method is {method}."); expected != actual {
t.Errorf("{method} replacement: expected '%s', got '%s'", expected, actual)
}
if expected, actual := "The response status is 200.", repl.Replace("The response status is {status}."); expected != actual {
t.Errorf("{status} replacement: expected '%s', got '%s'", expected, actual)
}
if expected, actual := "The Custom header is foobarbaz.", repl.Replace("The Custom header is {>Custom}."); expected != actual {
t.Errorf("{>Custom} replacement: expected '%s', got '%s'", expected, actual)
}
// Test header case-insensitivity
if expected, actual := "The cUsToM header is foobarbaz...", repl.Replace("The cUsToM header is {>cUsToM}..."); expected != actual {
t.Errorf("{>cUsToM} replacement: expected '%s', got '%s'", expected, actual)
}
// Test non-existent header/value
if expected, actual := "The Non-Existent header is -.", repl.Replace("The Non-Existent header is {>Non-Existent}."); expected != actual {
t.Errorf("{>Non-Existent} replacement: expected '%s', got '%s'", expected, actual)
}
// Test bad placeholder
if expected, actual := "Bad {host placeholder...", repl.Replace("Bad {host placeholder..."); expected != actual {
t.Errorf("bad placeholder: expected '%s', got '%s'", expected, actual)
}
// Test bad header placeholder
if expected, actual := "Bad {>Custom placeholder", repl.Replace("Bad {>Custom placeholder"); expected != actual {
t.Errorf("bad header placeholder: expected '%s', got '%s'", expected, actual)
}
// Test bad header placeholder with valid one later
if expected, actual := "Bad -", repl.Replace("Bad {>Custom placeholder {>ShorterVal}"); expected != actual {
t.Errorf("bad header placeholders: expected '%s', got '%s'", expected, actual)
}
// Test shorter header value with multiple placeholders
if expected, actual := "Short value 1 then foobarbaz.", repl.Replace("Short value {>ShorterVal} then {>Custom}."); expected != actual {
t.Errorf("short value: expected '%s', got '%s'", expected, actual)
}
}
func TestSet(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost", reader)
if err != nil {
t.Fatalf("Request Formation Failed \n")
}
repl := NewReplacer(request, recordRequest, "")
repl.Set("host", "getcaddy.com")
repl.Set("method", "GET")
repl.Set("status", "201")
repl.Set("variable", "value")
if repl.Replace("This host is {host}") != "This host is getcaddy.com" {
t.Error("Expected host replacement failed")
}
if repl.Replace("This request method is {method}") != "This request method is GET" {
t.Error("Expected method replacement failed")
}
if repl.Replace("The response status is {status}") != "The response status is 201" {
t.Error("Expected status replacement failed")
}
if repl.Replace("The value of variable is {variable}") != "The value of variable is value" {
t.Error("Expected variable replacement failed")
}
}

View File

@@ -0,0 +1,130 @@
package rewrite
import (
"fmt"
"regexp"
"strings"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
// Operators
const (
Is = "is"
Not = "not"
Has = "has"
NotHas = "not_has"
StartsWith = "starts_with"
EndsWith = "ends_with"
Match = "match"
NotMatch = "not_match"
)
func operatorError(operator string) error {
return fmt.Errorf("Invalid operator %v", operator)
}
func newReplacer(r *dns.Msg) middleware.Replacer {
return middleware.NewReplacer(r, nil, "")
}
// condition is a rewrite condition.
type condition func(string, string) bool
var conditions = map[string]condition{
Is: isFunc,
Not: notFunc,
Has: hasFunc,
NotHas: notHasFunc,
StartsWith: startsWithFunc,
EndsWith: endsWithFunc,
Match: matchFunc,
NotMatch: notMatchFunc,
}
// isFunc is condition for Is operator.
// It checks for equality.
func isFunc(a, b string) bool {
return a == b
}
// notFunc is condition for Not operator.
// It checks for inequality.
func notFunc(a, b string) bool {
return a != b
}
// hasFunc is condition for Has operator.
// It checks if b is a substring of a.
func hasFunc(a, b string) bool {
return strings.Contains(a, b)
}
// notHasFunc is condition for NotHas operator.
// It checks if b is not a substring of a.
func notHasFunc(a, b string) bool {
return !strings.Contains(a, b)
}
// startsWithFunc is condition for StartsWith operator.
// It checks if b is a prefix of a.
func startsWithFunc(a, b string) bool {
return strings.HasPrefix(a, b)
}
// endsWithFunc is condition for EndsWith operator.
// It checks if b is a suffix of a.
func endsWithFunc(a, b string) bool {
return strings.HasSuffix(a, b)
}
// matchFunc is condition for Match operator.
// It does regexp matching of a against pattern in b
// and returns if they match.
func matchFunc(a, b string) bool {
matched, _ := regexp.MatchString(b, a)
return matched
}
// notMatchFunc is condition for NotMatch operator.
// It does regexp matching of a against pattern in b
// and returns if they do not match.
func notMatchFunc(a, b string) bool {
matched, _ := regexp.MatchString(b, a)
return !matched
}
// If is statement for a rewrite condition.
type If struct {
A string
Operator string
B string
}
// True returns true if the condition is true and false otherwise.
// If r is not nil, it replaces placeholders before comparison.
func (i If) True(r *dns.Msg) bool {
if c, ok := conditions[i.Operator]; ok {
a, b := i.A, i.B
if r != nil {
replacer := newReplacer(r)
a = replacer.Replace(i.A)
b = replacer.Replace(i.B)
}
return c(a, b)
}
return false
}
// NewIf creates a new If condition.
func NewIf(a, operator, b string) (If, error) {
if _, ok := conditions[operator]; !ok {
return If{}, operatorError(operator)
}
return If{
A: a,
Operator: operator,
B: b,
}, nil
}

View File

@@ -0,0 +1,106 @@
package rewrite
import (
"net/http"
"strings"
"testing"
)
func TestConditions(t *testing.T) {
tests := []struct {
condition string
isTrue bool
}{
{"a is b", false},
{"a is a", true},
{"a not b", true},
{"a not a", false},
{"a has a", true},
{"a has b", false},
{"ba has b", true},
{"bab has b", true},
{"bab has bb", false},
{"a not_has a", false},
{"a not_has b", true},
{"ba not_has b", false},
{"bab not_has b", false},
{"bab not_has bb", true},
{"bab starts_with bb", false},
{"bab starts_with ba", true},
{"bab starts_with bab", true},
{"bab ends_with bb", false},
{"bab ends_with bab", true},
{"bab ends_with ab", true},
{"a match *", false},
{"a match a", true},
{"a match .*", true},
{"a match a.*", true},
{"a match b.*", false},
{"ba match b.*", true},
{"ba match b[a-z]", true},
{"b0 match b[a-z]", false},
{"b0a match b[a-z]", false},
{"b0a match b[a-z]+", false},
{"b0a match b[a-z0-9]+", true},
{"a not_match *", true},
{"a not_match a", false},
{"a not_match .*", false},
{"a not_match a.*", false},
{"a not_match b.*", true},
{"ba not_match b.*", false},
{"ba not_match b[a-z]", false},
{"b0 not_match b[a-z]", true},
{"b0a not_match b[a-z]", true},
{"b0a not_match b[a-z]+", true},
{"b0a not_match b[a-z0-9]+", false},
}
for i, test := range tests {
str := strings.Fields(test.condition)
ifCond, err := NewIf(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
isTrue := ifCond.True(nil)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
}
}
invalidOperators := []string{"ss", "and", "if"}
for _, op := range invalidOperators {
_, err := NewIf("a", op, "b")
if err == nil {
t.Errorf("Invalid operator %v used, expected error.", op)
}
}
replaceTests := []struct {
url string
condition string
isTrue bool
}{
{"/home", "{uri} match /home", true},
{"/hom", "{uri} match /home", false},
{"/hom", "{uri} starts_with /home", false},
{"/hom", "{uri} starts_with /h", true},
{"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
{"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
}
for i, test := range replaceTests {
r, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Error(err)
}
str := strings.Fields(test.condition)
ifCond, err := NewIf(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
isTrue := ifCond.True(r)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
}
}
}

View File

@@ -0,0 +1,38 @@
package rewrite
import "github.com/miekg/dns"
// ResponseRevert reverses the operations done on the question section of a packet.
// This is need because the client will otherwise disregards the response, i.e.
// dig will complain with ';; Question section mismatch: got miek.nl/HINFO/IN'
type ResponseReverter struct {
dns.ResponseWriter
original dns.Question
}
func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter {
return &ResponseReverter{
ResponseWriter: w,
original: r.Question[0],
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *ResponseReverter) WriteMsg(res *dns.Msg) error {
res.Question[0] = r.original
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseReverter) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
return n, err
}
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *ResponseReverter) Hijack() {
r.ResponseWriter.Hijack()
return
}

View File

@@ -0,0 +1,223 @@
// Package rewrite is middleware for rewriting requests internally to
// something different.
package rewrite
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
// Result is the result of a rewrite
type Result int
const (
// RewriteIgnored is returned when rewrite is not done on request.
RewriteIgnored Result = iota
// RewriteDone is returned when rewrite is done on request.
RewriteDone
// RewriteStatus is returned when rewrite is not needed and status code should be set
// for the request.
RewriteStatus
)
// Rewrite is middleware to rewrite requests internally before being handled.
type Rewrite struct {
Next middleware.Handler
Rules []Rule
}
// ServeHTTP implements the middleware.Handler interface.
func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
wr := NewResponseReverter(w, r)
for _, rule := range rw.Rules {
switch result := rule.Rewrite(r); result {
case RewriteDone:
return rw.Next.ServeDNS(wr, r)
case RewriteIgnored:
break
case RewriteStatus:
// only valid for complex rules.
// if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 {
// return cRule.Status, nil
// }
}
}
return rw.Next.ServeDNS(w, r)
}
// Rule describes an internal location rewrite rule.
type Rule interface {
// Rewrite rewrites the internal location of the current request.
Rewrite(*dns.Msg) Result
}
// SimpleRule is a simple rewrite rule. If the From and To look like a type
// the type of the request is rewritten, otherwise the name is.
// Note: TSIG signed requests will be invalid.
type SimpleRule struct {
From, To string
fromType, toType uint16
}
// NewSimpleRule creates a new Simple Rule
func NewSimpleRule(from, to string) SimpleRule {
tpf := dns.StringToType[from]
tpt := dns.StringToType[to]
return SimpleRule{From: from, To: to, fromType: tpf, toType: tpt}
}
// Rewrite rewrites the the current request.
func (s SimpleRule) Rewrite(r *dns.Msg) Result {
if s.fromType > 0 && s.toType > 0 {
if r.Question[0].Qtype == s.fromType {
r.Question[0].Qtype = s.toType
return RewriteDone
}
}
// if the question name matches the full name, or subset rewrite that
// s.Question[0].Name
return RewriteIgnored
}
/*
// ComplexRule is a rewrite rule based on a regular expression
type ComplexRule struct {
// Path base. Request to this path and subpaths will be rewritten
Base string
// Path to rewrite to
To string
// If set, neither performs rewrite nor proceeds
// with request. Only returns code.
Status int
// Extensions to filter by
Exts []string
// Rewrite conditions
Ifs []If
*regexp.Regexp
}
// NewComplexRule creates a new RegexpRule. It returns an error if regexp
// pattern (pattern) or extensions (ext) are invalid.
func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
// validate regexp if present
var r *regexp.Regexp
if pattern != "" {
var err error
r, err = regexp.Compile(pattern)
if err != nil {
return nil, err
}
}
// validate extensions if present
for _, v := range ext {
if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
// check if no extension is specified
if v != "/" && v != "!/" {
return nil, fmt.Errorf("invalid extension %v", v)
}
}
}
return &ComplexRule{
Base: base,
To: to,
Status: status,
Exts: ext,
Ifs: ifs,
Regexp: r,
}, nil
}
// Rewrite rewrites the internal location of the current request.
func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) {
rPath := req.URL.Path
replacer := newReplacer(req)
// validate base
if !middleware.Path(rPath).Matches(r.Base) {
return
}
// validate extensions
if !r.matchExt(rPath) {
return
}
// validate regexp if present
if r.Regexp != nil {
// include trailing slash in regexp if present
start := len(r.Base)
if strings.HasSuffix(r.Base, "/") {
start--
}
matches := r.FindStringSubmatch(rPath[start:])
switch len(matches) {
case 0:
// no match
return
default:
// set regexp match variables {1}, {2} ...
for i := 1; i < len(matches); i++ {
replacer.Set(fmt.Sprint(i), matches[i])
}
}
}
// validate rewrite conditions
for _, i := range r.Ifs {
if !i.True(req) {
return
}
}
// if status is present, stop rewrite and return it.
if r.Status != 0 {
return RewriteStatus
}
// attempt rewrite
return To(fs, req, r.To, replacer)
}
// matchExt matches rPath against registered file extensions.
// Returns true if a match is found and false otherwise.
func (r *ComplexRule) matchExt(rPath string) bool {
f := filepath.Base(rPath)
ext := path.Ext(f)
if ext == "" {
ext = "/"
}
mustUse := false
for _, v := range r.Exts {
use := true
if v[0] == '!' {
use = false
v = v[1:]
}
if use {
mustUse = true
}
if ext == v {
return use
}
}
if mustUse {
return false
}
return true
}
*/

View File

@@ -0,0 +1,159 @@
package rewrite
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/miekg/coredns/middleware"
)
func TestRewrite(t *testing.T) {
rw := Rewrite{
Next: middleware.HandlerFunc(urlPrinter),
Rules: []Rule{
NewSimpleRule("/from", "/to"),
NewSimpleRule("/a", "/b"),
NewSimpleRule("/b", "/b{uri}"),
},
FileSys: http.Dir("."),
}
regexps := [][]string{
{"/reg/", ".*", "/to", ""},
{"/r/", "[a-z]+", "/toaz", "!.html|"},
{"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""},
{"/ab/", "ab", "/ab?{query}", ".txt|"},
{"/ab/", "ab", "/ab?type=html&{query}", ".html|"},
{"/abc/", "ab", "/abc/{file}", ".html|"},
{"/abcd/", "ab", "/a/{dir}/{file}", ".html|"},
{"/abcde/", "ab", "/a#{fragment}", ".html|"},
{"/ab/", `.*\.jpg`, "/ajpg", ""},
{"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""},
{"/reg2grp", `(.*)`, "/{1}", ""},
{"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""},
}
for _, regexpRule := range regexps {
var ext []string
if s := strings.Split(regexpRule[3], "|"); len(s) > 1 {
ext = s[:len(s)-1]
}
rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil)
if err != nil {
t.Fatal(err)
}
rw.Rules = append(rw.Rules, rule)
}
tests := []struct {
from string
expectedTo string
}{
{"/from", "/to"},
{"/a", "/b"},
{"/b", "/b/b"},
{"/aa", "/aa"},
{"/", "/"},
{"/a?foo=bar", "/b?foo=bar"},
{"/asdf?foo=bar", "/asdf?foo=bar"},
{"/foo#bar", "/foo#bar"},
{"/a#foo", "/b#foo"},
{"/reg/foo", "/to"},
{"/re", "/re"},
{"/r/", "/r/"},
{"/r/123", "/r/123"},
{"/r/a123", "/toaz"},
{"/r/abcz", "/toaz"},
{"/r/z", "/toaz"},
{"/r/z.html", "/r/z.html"},
{"/r/z.js", "/toaz"},
{"/url/asAB", "/to/url/asAB"},
{"/url/aBsAB", "/url/aBsAB"},
{"/url/a00sAB", "/to/url/a00sAB"},
{"/url/a0z0sAB", "/to/url/a0z0sAB"},
{"/ab/aa", "/ab/aa"},
{"/ab/ab", "/ab/ab"},
{"/ab/ab.txt", "/ab"},
{"/ab/ab.txt?name=name", "/ab?name=name"},
{"/ab/ab.html?name=name", "/ab?type=html&name=name"},
{"/abc/ab.html", "/abc/ab.html"},
{"/abcd/abcd.html", "/a/abcd/abcd.html"},
{"/abcde/abcde.html", "/a"},
{"/abcde/abcde.html#1234", "/a#1234"},
{"/ab/ab.jpg", "/ajpg"},
{"/reggrp/ad/12", "/a12"},
{"/reggrp/ad/124a", "/a124/a"},
{"/reggrp/ad/124abc", "/a124/abc"},
{"/reg2grp/ad/124abc", "/ad/124abc"},
{"/reg3grp/ad/aa/66", "/adaa66"},
{"/reg3grp/ad612/n1n/ab", "/ad612n1nab"},
}
for i, test := range tests {
req, err := http.NewRequest("GET", test.from, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
rec := httptest.NewRecorder()
rw.ServeHTTP(rec, req)
if rec.Body.String() != test.expectedTo {
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'",
i, test.expectedTo, rec.Body.String())
}
}
statusTests := []struct {
status int
base string
to string
regexp string
statusExpected bool
}{
{400, "/status", "", "", true},
{400, "/ignore", "", "", false},
{400, "/", "", "^/ignore", false},
{400, "/", "", "(.*)", true},
{400, "/status", "", "", true},
}
for i, s := range statusTests {
urlPath := fmt.Sprintf("/status%d", i)
rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil)
if err != nil {
t.Fatalf("Test %d: No error expected for rule but found %v", i, err)
}
rw.Rules = []Rule{rule}
req, err := http.NewRequest("GET", urlPath, nil)
if err != nil {
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
}
rec := httptest.NewRecorder()
code, err := rw.ServeHTTP(rec, req)
if err != nil {
t.Fatalf("Test %d: No error expected for handler but found %v", i, err)
}
if s.statusExpected {
if rec.Body.String() != "" {
t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String())
}
if code != s.status {
t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code)
}
} else {
if code != 0 {
t.Errorf("Test %d: Expected no status code found %d", i, code)
}
}
}
}
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprintf(w, r.URL.String())
return 0, nil
}

View File

1
middleware/rewrite/testdata/testfile vendored Normal file
View File

@@ -0,0 +1 @@
empty

27
middleware/roller.go Normal file
View File

@@ -0,0 +1,27 @@
package middleware
import (
"io"
"gopkg.in/natefinch/lumberjack.v2"
)
// LogRoller implements a middleware that provides a rolling logger.
type LogRoller struct {
Filename string
MaxSize int
MaxAge int
MaxBackups int
LocalTime bool
}
// GetLogWriter returns an io.Writer that writes to a rolling logger.
func (l LogRoller) GetLogWriter() io.Writer {
return &lumberjack.Logger{
Filename: l.Filename,
MaxSize: l.MaxSize,
MaxAge: l.MaxAge,
MaxBackups: l.MaxBackups,
LocalTime: l.LocalTime,
}
}

21
middleware/zone.go Normal file
View File

@@ -0,0 +1,21 @@
package middleware
import "strings"
type Zones []string
// Matches checks to see if other matches p.
// The match will return the most specific zones
// that matches other. The empty string signals a not found
// condition.
func (z Zones) Matches(qname string) string {
zone := ""
for _, zname := range z {
if strings.HasSuffix(qname, zname) {
if len(zname) > len(zone) {
zone = zname
}
}
}
return zone
}