mirror of
https://github.com/coredns/coredns.git
synced 2026-03-01 02:13:10 -05:00
First commit
This commit is contained in:
120
middleware/commands.go
Normal file
120
middleware/commands.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"unicode"
|
||||
|
||||
"github.com/flynn/go-shlex"
|
||||
)
|
||||
|
||||
var runtimeGoos = runtime.GOOS
|
||||
|
||||
// SplitCommandAndArgs takes a command string and parses it
|
||||
// shell-style into the command and its separate arguments.
|
||||
func SplitCommandAndArgs(command string) (cmd string, args []string, err error) {
|
||||
var parts []string
|
||||
|
||||
if runtimeGoos == "windows" {
|
||||
parts = parseWindowsCommand(command) // parse it Windows-style
|
||||
} else {
|
||||
parts, err = parseUnixCommand(command) // parse it Unix-style
|
||||
if err != nil {
|
||||
err = errors.New("error parsing command: " + err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
err = errors.New("no command contained in '" + command + "'")
|
||||
return
|
||||
}
|
||||
|
||||
cmd = parts[0]
|
||||
if len(parts) > 1 {
|
||||
args = parts[1:]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// parseUnixCommand parses a unix style command line and returns the
|
||||
// command and its arguments or an error
|
||||
func parseUnixCommand(cmd string) ([]string, error) {
|
||||
return shlex.Split(cmd)
|
||||
}
|
||||
|
||||
// parseWindowsCommand parses windows command lines and
|
||||
// returns the command and the arguments as an array. It
|
||||
// should be able to parse commonly used command lines.
|
||||
// Only basic syntax is supported:
|
||||
// - spaces in double quotes are not token delimiters
|
||||
// - double quotes are escaped by either backspace or another double quote
|
||||
// - except for the above case backspaces are path separators (not special)
|
||||
//
|
||||
// Many sources point out that escaping quotes using backslash can be unsafe.
|
||||
// Use two double quotes when possible. (Source: http://stackoverflow.com/a/31413730/2616179 )
|
||||
//
|
||||
// This function has to be used on Windows instead
|
||||
// of the shlex package because this function treats backslash
|
||||
// characters properly.
|
||||
func parseWindowsCommand(cmd string) []string {
|
||||
const backslash = '\\'
|
||||
const quote = '"'
|
||||
|
||||
var parts []string
|
||||
var part string
|
||||
var inQuotes bool
|
||||
var lastRune rune
|
||||
|
||||
for i, ch := range cmd {
|
||||
|
||||
if i != 0 {
|
||||
lastRune = rune(cmd[i-1])
|
||||
}
|
||||
|
||||
if ch == backslash {
|
||||
// put it in the part - for now we don't know if it's an
|
||||
// escaping char or path separator
|
||||
part += string(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == quote {
|
||||
if lastRune == backslash {
|
||||
// remove the backslash from the part and add the escaped quote instead
|
||||
part = part[:len(part)-1]
|
||||
part += string(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
if lastRune == quote {
|
||||
// revert the last change of the inQuotes state
|
||||
// it was an escaping quote
|
||||
inQuotes = !inQuotes
|
||||
part += string(ch)
|
||||
continue
|
||||
}
|
||||
|
||||
// normal escaping quotes
|
||||
inQuotes = !inQuotes
|
||||
continue
|
||||
|
||||
}
|
||||
|
||||
if unicode.IsSpace(ch) && !inQuotes && len(part) > 0 {
|
||||
parts = append(parts, part)
|
||||
part = ""
|
||||
continue
|
||||
}
|
||||
|
||||
part += string(ch)
|
||||
}
|
||||
|
||||
if len(part) > 0 {
|
||||
parts = append(parts, part)
|
||||
part = ""
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
291
middleware/commands_test.go
Normal file
291
middleware/commands_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseUnixCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
// 0 - emtpy command
|
||||
{
|
||||
input: ``,
|
||||
expected: []string{},
|
||||
},
|
||||
// 1 - command without arguments
|
||||
{
|
||||
input: `command`,
|
||||
expected: []string{`command`},
|
||||
},
|
||||
// 2 - command with single argument
|
||||
{
|
||||
input: `command arg1`,
|
||||
expected: []string{`command`, `arg1`},
|
||||
},
|
||||
// 3 - command with multiple arguments
|
||||
{
|
||||
input: `command arg1 arg2`,
|
||||
expected: []string{`command`, `arg1`, `arg2`},
|
||||
},
|
||||
// 4 - command with single argument with space character - in quotes
|
||||
{
|
||||
input: `command "arg1 arg1"`,
|
||||
expected: []string{`command`, `arg1 arg1`},
|
||||
},
|
||||
// 5 - command with multiple spaces and tab character
|
||||
{
|
||||
input: "command arg1 arg2\targ3",
|
||||
expected: []string{`command`, `arg1`, `arg2`, `arg3`},
|
||||
},
|
||||
// 6 - command with single argument with space character - escaped with backspace
|
||||
{
|
||||
input: `command arg1\ arg2`,
|
||||
expected: []string{`command`, `arg1 arg2`},
|
||||
},
|
||||
// 7 - single quotes should escape special chars
|
||||
{
|
||||
input: `command 'arg1\ arg2'`,
|
||||
expected: []string{`command`, `arg1\ arg2`},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
|
||||
errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
|
||||
actual, _ := parseUnixCommand(test.input)
|
||||
if len(actual) != len(test.expected) {
|
||||
t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
|
||||
continue
|
||||
}
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
|
||||
t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWindowsCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{ // 0 - empty command - do not fail
|
||||
input: ``,
|
||||
expected: []string{},
|
||||
},
|
||||
{ // 1 - cmd without args
|
||||
input: `cmd`,
|
||||
expected: []string{`cmd`},
|
||||
},
|
||||
{ // 2 - multiple args
|
||||
input: `cmd arg1 arg2`,
|
||||
expected: []string{`cmd`, `arg1`, `arg2`},
|
||||
},
|
||||
{ // 3 - multiple args with space
|
||||
input: `cmd "combined arg" arg2`,
|
||||
expected: []string{`cmd`, `combined arg`, `arg2`},
|
||||
},
|
||||
{ // 4 - path without spaces
|
||||
input: `mkdir C:\Windows\foo\bar`,
|
||||
expected: []string{`mkdir`, `C:\Windows\foo\bar`},
|
||||
},
|
||||
{ // 5 - command with space in quotes
|
||||
input: `"command here"`,
|
||||
expected: []string{`command here`},
|
||||
},
|
||||
{ // 6 - argument with escaped quotes (two quotes)
|
||||
input: `cmd ""arg""`,
|
||||
expected: []string{`cmd`, `"arg"`},
|
||||
},
|
||||
{ // 7 - argument with escaped quotes (backslash)
|
||||
input: `cmd \"arg\"`,
|
||||
expected: []string{`cmd`, `"arg"`},
|
||||
},
|
||||
{ // 8 - two quotes (escaped) inside an inQuote element
|
||||
input: `cmd "a ""quoted value"`,
|
||||
expected: []string{`cmd`, `a "quoted value`},
|
||||
},
|
||||
// TODO - see how many quotes are dislayed if we use "", """, """""""
|
||||
{ // 9 - two quotes outside an inQuote element
|
||||
input: `cmd a ""quoted value`,
|
||||
expected: []string{`cmd`, `a`, `"quoted`, `value`},
|
||||
},
|
||||
{ // 10 - path with space in quotes
|
||||
input: `mkdir "C:\directory name\foobar"`,
|
||||
expected: []string{`mkdir`, `C:\directory name\foobar`},
|
||||
},
|
||||
{ // 11 - space without quotes
|
||||
input: `mkdir C:\ space`,
|
||||
expected: []string{`mkdir`, `C:\`, `space`},
|
||||
},
|
||||
{ // 12 - space in quotes
|
||||
input: `mkdir "C:\ space"`,
|
||||
expected: []string{`mkdir`, `C:\ space`},
|
||||
},
|
||||
{ // 13 - UNC
|
||||
input: `mkdir \\?\C:\Users`,
|
||||
expected: []string{`mkdir`, `\\?\C:\Users`},
|
||||
},
|
||||
{ // 14 - UNC with space
|
||||
input: `mkdir "\\?\C:\Program Files"`,
|
||||
expected: []string{`mkdir`, `\\?\C:\Program Files`},
|
||||
},
|
||||
|
||||
{ // 15 - unclosed quotes - treat as if the path ends with quote
|
||||
input: `mkdir "c:\Program files`,
|
||||
expected: []string{`mkdir`, `c:\Program files`},
|
||||
},
|
||||
{ // 16 - quotes used inside the argument
|
||||
input: `mkdir "c:\P"rogra"m f"iles`,
|
||||
expected: []string{`mkdir`, `c:\Program files`},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
|
||||
errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
|
||||
|
||||
actual := parseWindowsCommand(test.input)
|
||||
if len(actual) != len(test.expected) {
|
||||
t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
|
||||
continue
|
||||
}
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
|
||||
t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitCommandAndArgs(t *testing.T) {
|
||||
|
||||
// force linux parsing. It's more robust and covers error cases
|
||||
runtimeGoos = "linux"
|
||||
defer func() {
|
||||
runtimeGoos = runtime.GOOS
|
||||
}()
|
||||
|
||||
var parseErrorContent = "error parsing command:"
|
||||
var noCommandErrContent = "no command contained in"
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedCommand string
|
||||
expectedArgs []string
|
||||
expectedErrContent string
|
||||
}{
|
||||
// 0 - emtpy command
|
||||
{
|
||||
input: ``,
|
||||
expectedCommand: ``,
|
||||
expectedArgs: nil,
|
||||
expectedErrContent: noCommandErrContent,
|
||||
},
|
||||
// 1 - command without arguments
|
||||
{
|
||||
input: `command`,
|
||||
expectedCommand: `command`,
|
||||
expectedArgs: nil,
|
||||
expectedErrContent: ``,
|
||||
},
|
||||
// 2 - command with single argument
|
||||
{
|
||||
input: `command arg1`,
|
||||
expectedCommand: `command`,
|
||||
expectedArgs: []string{`arg1`},
|
||||
expectedErrContent: ``,
|
||||
},
|
||||
// 3 - command with multiple arguments
|
||||
{
|
||||
input: `command arg1 arg2`,
|
||||
expectedCommand: `command`,
|
||||
expectedArgs: []string{`arg1`, `arg2`},
|
||||
expectedErrContent: ``,
|
||||
},
|
||||
// 4 - command with unclosed quotes
|
||||
{
|
||||
input: `command "arg1 arg2`,
|
||||
expectedCommand: "",
|
||||
expectedArgs: nil,
|
||||
expectedErrContent: parseErrorContent,
|
||||
},
|
||||
// 5 - command with unclosed quotes
|
||||
{
|
||||
input: `command 'arg1 arg2"`,
|
||||
expectedCommand: "",
|
||||
expectedArgs: nil,
|
||||
expectedErrContent: parseErrorContent,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
|
||||
errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
|
||||
actualCommand, actualArgs, actualErr := SplitCommandAndArgs(test.input)
|
||||
|
||||
// test if error matches expectation
|
||||
if test.expectedErrContent != "" {
|
||||
if actualErr == nil {
|
||||
t.Errorf(errorPrefix+"Expected error with content [%s], found no error."+errorSuffix, test.expectedErrContent)
|
||||
} else if !strings.Contains(actualErr.Error(), test.expectedErrContent) {
|
||||
t.Errorf(errorPrefix+"Expected error with content [%s], found [%v]."+errorSuffix, test.expectedErrContent, actualErr)
|
||||
}
|
||||
} else if actualErr != nil {
|
||||
t.Errorf(errorPrefix+"Expected no error, found [%v]."+errorSuffix, actualErr)
|
||||
}
|
||||
|
||||
// test if command matches
|
||||
if test.expectedCommand != actualCommand {
|
||||
t.Errorf(errorPrefix+"Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand)
|
||||
}
|
||||
|
||||
// test if arguments match
|
||||
if len(test.expectedArgs) != len(actualArgs) {
|
||||
t.Errorf(errorPrefix+"Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs)
|
||||
} else {
|
||||
// test args only if the count matches.
|
||||
for j, actualArg := range actualArgs {
|
||||
expectedArg := test.expectedArgs[j]
|
||||
if actualArg != expectedArg {
|
||||
t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleSplitCommandAndArgs() {
|
||||
var commandLine string
|
||||
var command string
|
||||
var args []string
|
||||
|
||||
// just for the test - change GOOS and reset it at the end of the test
|
||||
runtimeGoos = "windows"
|
||||
defer func() {
|
||||
runtimeGoos = runtime.GOOS
|
||||
}()
|
||||
|
||||
commandLine = `mkdir /P "C:\Program Files"`
|
||||
command, args, _ = SplitCommandAndArgs(commandLine)
|
||||
|
||||
fmt.Printf("Windows: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
|
||||
|
||||
// set GOOS to linux
|
||||
runtimeGoos = "linux"
|
||||
|
||||
commandLine = `mkdir -p /path/with\ space`
|
||||
command, args, _ = SplitCommandAndArgs(commandLine)
|
||||
|
||||
fmt.Printf("Linux: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
|
||||
|
||||
// Output:
|
||||
// Windows: mkdir /P "C:\Program Files": mkdir [/P,C:\Program Files]
|
||||
// Linux: mkdir -p /path/with\ space: mkdir [-p,/path/with space]
|
||||
}
|
||||
135
middleware/context.go
Normal file
135
middleware/context.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// This file contains the context and functions available for
|
||||
// use in the templates.
|
||||
|
||||
// Context is the context with which Caddy templates are executed.
|
||||
type Context struct {
|
||||
Root http.FileSystem // TODO(miek): needed
|
||||
Req *dns.Msg
|
||||
W dns.ResponseWriter
|
||||
}
|
||||
|
||||
// Now returns the current timestamp in the specified format.
|
||||
func (c Context) Now(format string) string {
|
||||
return time.Now().Format(format)
|
||||
}
|
||||
|
||||
// NowDate returns the current date/time that can be used
|
||||
// in other time functions.
|
||||
func (c Context) NowDate() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// Header gets the value of a header.
|
||||
func (c Context) Header() *dns.RR_Header {
|
||||
// TODO(miek)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IP gets the (remote) IP address of the client making the request.
|
||||
func (c Context) IP() string {
|
||||
ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return c.W.RemoteAddr().String()
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// Post gets the (remote) Port of the client making the request.
|
||||
func (c Context) Port() (string, error) {
|
||||
_, port, err := net.SplitHostPort(c.W.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return "0", err
|
||||
}
|
||||
return port, nil
|
||||
}
|
||||
|
||||
// Proto gets the protocol used as the transport. This
|
||||
// will be udp or tcp.
|
||||
func (c Context) Proto() string {
|
||||
if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok {
|
||||
return "udp"
|
||||
}
|
||||
if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok {
|
||||
return "tcp"
|
||||
}
|
||||
return "udp"
|
||||
}
|
||||
|
||||
// Family returns the family of the transport.
|
||||
// 1 for IPv4 and 2 for IPv6.
|
||||
func (c Context) Family() int {
|
||||
var a net.IP
|
||||
ip := c.W.RemoteAddr()
|
||||
if i, ok := ip.(*net.UDPAddr); ok {
|
||||
a = i.IP
|
||||
}
|
||||
if i, ok := ip.(*net.TCPAddr); ok {
|
||||
a = i.IP
|
||||
}
|
||||
|
||||
if a.To4() != nil {
|
||||
return 1
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
// Type returns the type of the question as a string.
|
||||
func (c Context) Type() string {
|
||||
return dns.Type(c.Req.Question[0].Qtype).String()
|
||||
}
|
||||
|
||||
// QType returns the type of the question as a uint16.
|
||||
func (c Context) QType() uint16 {
|
||||
return c.Req.Question[0].Qtype
|
||||
}
|
||||
|
||||
// Name returns the name of the question in the request. Note
|
||||
// this name will always have a closing dot and will be lower cased.
|
||||
func (c Context) Name() string {
|
||||
return strings.ToLower(dns.Name(c.Req.Question[0].Name).String())
|
||||
}
|
||||
|
||||
// QName returns the name of the question in the request.
|
||||
func (c Context) QName() string {
|
||||
return dns.Name(c.Req.Question[0].Name).String()
|
||||
}
|
||||
|
||||
// Class returns the class of the question in the request.
|
||||
func (c Context) Class() string {
|
||||
return dns.Class(c.Req.Question[0].Qclass).String()
|
||||
}
|
||||
|
||||
// QClass returns the class of the question in the request.
|
||||
func (c Context) QClass() uint16 {
|
||||
return c.Req.Question[0].Qclass
|
||||
}
|
||||
|
||||
// More convience types for extracting stuff from a message?
|
||||
// Header?
|
||||
|
||||
// ErrorMessage returns an error message suitable for sending
|
||||
// back to the client.
|
||||
func (c Context) ErrorMessage(rcode int) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(c.Req, rcode)
|
||||
return m
|
||||
}
|
||||
|
||||
// AnswerMessage returns an error message suitable for sending
|
||||
// back to the client.
|
||||
func (c Context) AnswerMessage() *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(c.Req)
|
||||
return m
|
||||
}
|
||||
613
middleware/context_test.go
Normal file
613
middleware/context_test.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInclude(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
inputFilename := "test_file"
|
||||
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
|
||||
defer func() {
|
||||
err := os.Remove(absInFilePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
t.Fatalf("Failed to clean test file!")
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
fileContent string
|
||||
expectedContent string
|
||||
shouldErr bool
|
||||
expectedErrorContent string
|
||||
}{
|
||||
// Test 0 - all good
|
||||
{
|
||||
fileContent: `str1 {{ .Root }} str2`,
|
||||
expectedContent: fmt.Sprintf("str1 %s str2", context.Root),
|
||||
shouldErr: false,
|
||||
expectedErrorContent: "",
|
||||
},
|
||||
// Test 1 - failure on template.Parse
|
||||
{
|
||||
fileContent: `str1 {{ .Root } str2`,
|
||||
expectedContent: "",
|
||||
shouldErr: true,
|
||||
expectedErrorContent: `unexpected "}" in operand`,
|
||||
},
|
||||
// Test 3 - failure on template.Execute
|
||||
{
|
||||
fileContent: `str1 {{ .InvalidField }} str2`,
|
||||
expectedContent: "",
|
||||
shouldErr: true,
|
||||
expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
testPrefix := getTestPrefix(i)
|
||||
|
||||
// WriteFile truncates the contentt
|
||||
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
|
||||
if err != nil {
|
||||
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
|
||||
}
|
||||
|
||||
content, err := context.Include(inputFilename)
|
||||
if err != nil {
|
||||
if !test.shouldErr {
|
||||
t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error())
|
||||
}
|
||||
if !strings.Contains(err.Error(), test.expectedErrorContent) {
|
||||
t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil && test.shouldErr {
|
||||
t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename)
|
||||
}
|
||||
|
||||
if content != test.expectedContent {
|
||||
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncludeNotExisting(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
_, err := context.Include("not_existing")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but found nil!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdown(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
inputFilename := "test_file"
|
||||
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
|
||||
defer func() {
|
||||
err := os.Remove(absInFilePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
t.Fatalf("Failed to clean test file!")
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
fileContent string
|
||||
expectedContent string
|
||||
}{
|
||||
// Test 0 - test parsing of markdown
|
||||
{
|
||||
fileContent: "* str1\n* str2\n",
|
||||
expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
testPrefix := getTestPrefix(i)
|
||||
|
||||
// WriteFile truncates the contentt
|
||||
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
|
||||
if err != nil {
|
||||
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
|
||||
}
|
||||
|
||||
content, _ := context.Markdown(inputFilename)
|
||||
if content != test.expectedContent {
|
||||
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookie(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
cookie *http.Cookie
|
||||
cookieName string
|
||||
expectedValue string
|
||||
}{
|
||||
// Test 0 - happy path
|
||||
{
|
||||
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
|
||||
cookieName: "cookieName",
|
||||
expectedValue: "cookieValue",
|
||||
},
|
||||
// Test 1 - try to get a non-existing cookie
|
||||
{
|
||||
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
|
||||
cookieName: "notExisting",
|
||||
expectedValue: "",
|
||||
},
|
||||
// Test 2 - partial name match
|
||||
{
|
||||
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
|
||||
cookieName: "cook",
|
||||
expectedValue: "",
|
||||
},
|
||||
// Test 3 - cookie with optional fields
|
||||
{
|
||||
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
|
||||
cookieName: "cookie",
|
||||
expectedValue: "cookieValue",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
testPrefix := getTestPrefix(i)
|
||||
|
||||
// reinitialize the context for each test
|
||||
context := getContextOrFail(t)
|
||||
|
||||
context.Req.AddCookie(test.cookie)
|
||||
|
||||
actualCookieVal := context.Cookie(test.cookieName)
|
||||
|
||||
if actualCookieVal != test.expectedValue {
|
||||
t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieMultipleCookies(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
|
||||
|
||||
// make sure that there's no state and multiple requests for different cookies return the correct result
|
||||
for i := 0; i < 10; i++ {
|
||||
context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)})
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
|
||||
actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
|
||||
if actualCookieVal != expectedCookieVal {
|
||||
t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeader(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
headerKey, headerVal := "Header1", "HeaderVal1"
|
||||
context.Req.Header.Add(headerKey, headerVal)
|
||||
|
||||
actualHeaderVal := context.Header(headerKey)
|
||||
if actualHeaderVal != headerVal {
|
||||
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
|
||||
}
|
||||
|
||||
missingHeaderVal := context.Header("not-existing")
|
||||
if missingHeaderVal != "" {
|
||||
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIP(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
tests := []struct {
|
||||
inputRemoteAddr string
|
||||
expectedIP string
|
||||
}{
|
||||
// Test 0 - ipv4 with port
|
||||
{"1.1.1.1:1111", "1.1.1.1"},
|
||||
// Test 1 - ipv4 without port
|
||||
{"1.1.1.1", "1.1.1.1"},
|
||||
// Test 2 - ipv6 with port
|
||||
{"[::1]:11", "::1"},
|
||||
// Test 3 - ipv6 without port and brackets
|
||||
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
|
||||
// Test 4 - ipv6 with zone and port
|
||||
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
testPrefix := getTestPrefix(i)
|
||||
|
||||
context.Req.RemoteAddr = test.inputRemoteAddr
|
||||
actualIP := context.IP()
|
||||
|
||||
if actualIP != test.expectedIP {
|
||||
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURL(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
inputURL := "http://localhost"
|
||||
context.Req.RequestURI = inputURL
|
||||
|
||||
if inputURL != context.URI() {
|
||||
t.Errorf("Expected url %s, found %s", inputURL, context.URI())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedHost string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
input: "localhost:123",
|
||||
expectedHost: "localhost",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: "localhost",
|
||||
expectedHost: "localhost",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: "[::]",
|
||||
expectedHost: "",
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedPort string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
input: "localhost:123",
|
||||
expectedPort: "123",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: "localhost",
|
||||
expectedPort: "80", // assuming 80 is the default port
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: ":8080",
|
||||
expectedPort: "8080",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: "[::]",
|
||||
expectedPort: "",
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
|
||||
}
|
||||
}
|
||||
|
||||
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
context.Req.Host = input
|
||||
var actualResult, testedObject string
|
||||
var err error
|
||||
|
||||
if isTestingHost {
|
||||
actualResult, err = context.Host()
|
||||
testedObject = "host"
|
||||
} else {
|
||||
actualResult, err = context.Port()
|
||||
testedObject = "port"
|
||||
}
|
||||
|
||||
if shouldErr && err == nil {
|
||||
t.Errorf("Expected error, found nil!")
|
||||
return
|
||||
}
|
||||
|
||||
if !shouldErr && err != nil {
|
||||
t.Errorf("Expected no error, found %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if actualResult != expectedResult {
|
||||
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethod(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
method := "POST"
|
||||
context.Req.Method = method
|
||||
|
||||
if method != context.Method() {
|
||||
t.Errorf("Expected method %s, found %s", method, context.Method())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPathMatches(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
|
||||
tests := []struct {
|
||||
urlStr string
|
||||
pattern string
|
||||
shouldMatch bool
|
||||
}{
|
||||
// Test 0
|
||||
{
|
||||
urlStr: "http://localhost/",
|
||||
pattern: "",
|
||||
shouldMatch: true,
|
||||
},
|
||||
// Test 1
|
||||
{
|
||||
urlStr: "http://localhost",
|
||||
pattern: "",
|
||||
shouldMatch: true,
|
||||
},
|
||||
// Test 1
|
||||
{
|
||||
urlStr: "http://localhost/",
|
||||
pattern: "/",
|
||||
shouldMatch: true,
|
||||
},
|
||||
// Test 3
|
||||
{
|
||||
urlStr: "http://localhost/?param=val",
|
||||
pattern: "/",
|
||||
shouldMatch: true,
|
||||
},
|
||||
// Test 4
|
||||
{
|
||||
urlStr: "http://localhost/dir1/dir2",
|
||||
pattern: "/dir2",
|
||||
shouldMatch: false,
|
||||
},
|
||||
// Test 5
|
||||
{
|
||||
urlStr: "http://localhost/dir1/dir2",
|
||||
pattern: "/dir1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
// Test 6
|
||||
{
|
||||
urlStr: "http://localhost:444/dir1/dir2",
|
||||
pattern: "/dir1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
// Test 7
|
||||
{
|
||||
urlStr: "http://localhost/dir1/dir2",
|
||||
pattern: "*/dir2",
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
testPrefix := getTestPrefix(i)
|
||||
var err error
|
||||
context.Req.URL, err = url.Parse(test.urlStr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
|
||||
}
|
||||
|
||||
matches := context.PathMatches(test.pattern)
|
||||
if matches != test.shouldMatch {
|
||||
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
tests := []struct {
|
||||
inputString string
|
||||
inputLength int
|
||||
expected string
|
||||
}{
|
||||
// Test 0 - small length
|
||||
{
|
||||
inputString: "string",
|
||||
inputLength: 1,
|
||||
expected: "s",
|
||||
},
|
||||
// Test 1 - exact length
|
||||
{
|
||||
inputString: "string",
|
||||
inputLength: 6,
|
||||
expected: "string",
|
||||
},
|
||||
// Test 2 - bigger length
|
||||
{
|
||||
inputString: "string",
|
||||
inputLength: 10,
|
||||
expected: "string",
|
||||
},
|
||||
// Test 3 - zero length
|
||||
{
|
||||
inputString: "string",
|
||||
inputLength: 0,
|
||||
expected: "",
|
||||
},
|
||||
// Test 4 - negative, smaller length
|
||||
{
|
||||
inputString: "string",
|
||||
inputLength: -5,
|
||||
expected: "tring",
|
||||
},
|
||||
// Test 5 - negative, exact length
|
||||
{
|
||||
inputString: "string",
|
||||
inputLength: -6,
|
||||
expected: "string",
|
||||
},
|
||||
// Test 6 - negative, bigger length
|
||||
{
|
||||
inputString: "string",
|
||||
inputLength: -7,
|
||||
expected: "string",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
actual := context.Truncate(test.inputString, test.inputLength)
|
||||
if actual != test.expected {
|
||||
t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripHTML(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
// Test 0 - no tags
|
||||
{
|
||||
input: `h1`,
|
||||
expected: `h1`,
|
||||
},
|
||||
// Test 1 - happy path
|
||||
{
|
||||
input: `<h1>h1</h1>`,
|
||||
expected: `h1`,
|
||||
},
|
||||
// Test 2 - tag in quotes
|
||||
{
|
||||
input: `<h1">">h1</h1>`,
|
||||
expected: `h1`,
|
||||
},
|
||||
// Test 3 - multiple tags
|
||||
{
|
||||
input: `<h1><b>h1</b></h1>`,
|
||||
expected: `h1`,
|
||||
},
|
||||
// Test 4 - tags not closed
|
||||
{
|
||||
input: `<h1`,
|
||||
expected: `<h1`,
|
||||
},
|
||||
// Test 5 - false start
|
||||
{
|
||||
input: `<h1<b>hi`,
|
||||
expected: `<h1hi`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
actual := context.StripHTML(test.input)
|
||||
if actual != test.expected {
|
||||
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripExt(t *testing.T) {
|
||||
context := getContextOrFail(t)
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
// Test 0 - empty input
|
||||
{
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
// Test 1 - relative file with ext
|
||||
{
|
||||
input: "file.ext",
|
||||
expected: "file",
|
||||
},
|
||||
// Test 2 - relative file without ext
|
||||
{
|
||||
input: "file",
|
||||
expected: "file",
|
||||
},
|
||||
// Test 3 - absolute file without ext
|
||||
{
|
||||
input: "/file",
|
||||
expected: "/file",
|
||||
},
|
||||
// Test 4 - absolute file with ext
|
||||
{
|
||||
input: "/file.ext",
|
||||
expected: "/file",
|
||||
},
|
||||
// Test 5 - with ext but ends with /
|
||||
{
|
||||
input: "/dir.ext/",
|
||||
expected: "/dir.ext/",
|
||||
},
|
||||
// Test 6 - file with ext under dir with ext
|
||||
{
|
||||
input: "/dir.ext/file.ext",
|
||||
expected: "/dir.ext/file",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
actual := context.StripExt(test.input)
|
||||
if actual != test.expected {
|
||||
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func initTestContext() (Context, error) {
|
||||
body := bytes.NewBufferString("request body")
|
||||
request, err := http.NewRequest("GET", "https://localhost", body)
|
||||
if err != nil {
|
||||
return Context{}, err
|
||||
}
|
||||
|
||||
return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
|
||||
}
|
||||
|
||||
func getContextOrFail(t *testing.T) Context {
|
||||
context, err := initTestContext()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to prepare test context")
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
func getTestPrefix(testN int) string {
|
||||
return fmt.Sprintf("Test [%d]: ", testN)
|
||||
}
|
||||
100
middleware/errors/errors.go
Normal file
100
middleware/errors/errors.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Package errors implements an HTTP error handling middleware.
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ErrorHandler handles DNS errors (and errors from other middleware).
|
||||
type ErrorHandler struct {
|
||||
Next middleware.Handler
|
||||
LogFile string
|
||||
Log *log.Logger
|
||||
LogRoller *middleware.LogRoller
|
||||
Debug bool // if true, errors are written out to client rather than to a log
|
||||
}
|
||||
|
||||
func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
defer h.recovery(w, r)
|
||||
|
||||
rcode, err := h.Next.ServeDNS(w, r)
|
||||
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
|
||||
|
||||
if h.Debug {
|
||||
// Write error to response as a txt message instead of to log
|
||||
answer := debugMsg(rcode, r)
|
||||
txt, _ := dns.NewRR(". IN 0 TXT " + errMsg)
|
||||
answer.Answer = append(answer.Answer, txt)
|
||||
w.WriteMsg(answer)
|
||||
return 0, err
|
||||
}
|
||||
h.Log.Println(errMsg)
|
||||
}
|
||||
|
||||
return rcode, err
|
||||
}
|
||||
|
||||
func (h ErrorHandler) recovery(w dns.ResponseWriter, r *dns.Msg) {
|
||||
rec := recover()
|
||||
if rec == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Obtain source of panic
|
||||
// From: https://gist.github.com/swdunlop/9629168
|
||||
var name, file string // function name, file name
|
||||
var line int
|
||||
var pc [16]uintptr
|
||||
n := runtime.Callers(3, pc[:])
|
||||
for _, pc := range pc[:n] {
|
||||
fn := runtime.FuncForPC(pc)
|
||||
if fn == nil {
|
||||
continue
|
||||
}
|
||||
file, line = fn.FileLine(pc)
|
||||
name = fn.Name()
|
||||
if !strings.HasPrefix(name, "runtime.") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Trim file path
|
||||
delim := "/coredns/"
|
||||
pkgPathPos := strings.Index(file, delim)
|
||||
if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) {
|
||||
file = file[pkgPathPos+len(delim):]
|
||||
}
|
||||
|
||||
panicMsg := fmt.Sprintf("%s [PANIC %s %s] %s:%d - %v", time.Now().Format(timeFormat), r.Question[0].Name, dns.Type(r.Question[0].Qtype), file, line, rec)
|
||||
if h.Debug {
|
||||
// Write error and stack trace to the response rather than to a log
|
||||
var stackBuf [4096]byte
|
||||
stack := stackBuf[:runtime.Stack(stackBuf[:], false)]
|
||||
answer := debugMsg(dns.RcodeServerFailure, r)
|
||||
// add stack buf in TXT, limited to 255 chars for now.
|
||||
txt, _ := dns.NewRR(". IN 0 TXT " + string(stack[:255]))
|
||||
answer.Answer = append(answer.Answer, txt)
|
||||
w.WriteMsg(answer)
|
||||
} else {
|
||||
// Currently we don't use the function name, since file:line is more conventional
|
||||
h.Log.Printf(panicMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// debugMsg creates a debug message that gets send back to the client.
|
||||
func debugMsg(rcode int, r *dns.Msg) *dns.Msg {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(r, rcode)
|
||||
return answer
|
||||
}
|
||||
|
||||
const timeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
168
middleware/errors/errors_test.go
Normal file
168
middleware/errors/errors_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
)
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
// create a temporary page
|
||||
path := filepath.Join(os.TempDir(), "errors_test.html")
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(path)
|
||||
|
||||
const content = "This is a error page"
|
||||
_, err = f.WriteString(content)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
em := ErrorHandler{
|
||||
ErrorPages: map[int]string{
|
||||
http.StatusNotFound: path,
|
||||
http.StatusForbidden: "not_exist_file",
|
||||
},
|
||||
Log: log.New(&buf, "", 0),
|
||||
}
|
||||
_, notExistErr := os.Open("not_exist_file")
|
||||
|
||||
testErr := errors.New("test error")
|
||||
tests := []struct {
|
||||
next middleware.Handler
|
||||
expectedCode int
|
||||
expectedBody string
|
||||
expectedLog string
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
next: genErrorHandler(http.StatusOK, nil, "normal"),
|
||||
expectedCode: http.StatusOK,
|
||||
expectedBody: "normal",
|
||||
expectedLog: "",
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
next: genErrorHandler(http.StatusMovedPermanently, testErr, ""),
|
||||
expectedCode: http.StatusMovedPermanently,
|
||||
expectedBody: "",
|
||||
expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr),
|
||||
expectedErr: testErr,
|
||||
},
|
||||
{
|
||||
next: genErrorHandler(http.StatusBadRequest, nil, ""),
|
||||
expectedCode: 0,
|
||||
expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest,
|
||||
http.StatusText(http.StatusBadRequest)),
|
||||
expectedLog: "",
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
next: genErrorHandler(http.StatusNotFound, nil, ""),
|
||||
expectedCode: 0,
|
||||
expectedBody: content,
|
||||
expectedLog: "",
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
next: genErrorHandler(http.StatusForbidden, nil, ""),
|
||||
expectedCode: 0,
|
||||
expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden,
|
||||
http.StatusText(http.StatusForbidden)),
|
||||
expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n",
|
||||
http.StatusForbidden, notExistErr),
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i, test := range tests {
|
||||
em.Next = test.next
|
||||
buf.Reset()
|
||||
rec := httptest.NewRecorder()
|
||||
code, err := em.ServeHTTP(rec, req)
|
||||
|
||||
if err != test.expectedErr {
|
||||
t.Errorf("Test %d: Expected error %v, but got %v",
|
||||
i, test.expectedErr, err)
|
||||
}
|
||||
if code != test.expectedCode {
|
||||
t.Errorf("Test %d: Expected status code %d, but got %d",
|
||||
i, test.expectedCode, code)
|
||||
}
|
||||
if body := rec.Body.String(); body != test.expectedBody {
|
||||
t.Errorf("Test %d: Expected body %q, but got %q",
|
||||
i, test.expectedBody, body)
|
||||
}
|
||||
if log := buf.String(); !strings.Contains(log, test.expectedLog) {
|
||||
t.Errorf("Test %d: Expected log %q, but got %q",
|
||||
i, test.expectedLog, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisibleErrorWithPanic(t *testing.T) {
|
||||
const panicMsg = "I'm a panic"
|
||||
eh := ErrorHandler{
|
||||
ErrorPages: make(map[int]string),
|
||||
Debug: true,
|
||||
Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
panic(panicMsg)
|
||||
}),
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
code, err := eh.ServeHTTP(rec, req)
|
||||
|
||||
if code != 0 {
|
||||
t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
|
||||
if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") {
|
||||
t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body)
|
||||
}
|
||||
if !strings.Contains(body, panicMsg) {
|
||||
t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body)
|
||||
}
|
||||
if len(body) < 500 {
|
||||
t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body))
|
||||
}
|
||||
}
|
||||
|
||||
func genErrorHandler(status int, err error, body string) middleware.Handler {
|
||||
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
if len(body) > 0 {
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||
fmt.Fprint(w, body)
|
||||
}
|
||||
return status, err
|
||||
})
|
||||
}
|
||||
0
middleware/etcd/TODO
Normal file
0
middleware/etcd/TODO
Normal file
10
middleware/exchange.go
Normal file
10
middleware/exchange.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// Exchang sends message m to the server.
|
||||
// TODO(miek): optionally it can do retries of other silly stuff.
|
||||
func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) {
|
||||
r, _, err := c.Exchange(m, server)
|
||||
return r, err
|
||||
}
|
||||
89
middleware/file/file.go
Normal file
89
middleware/file/file.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package file
|
||||
|
||||
// TODO(miek): the zone's implementation is basically non-existent
|
||||
// we return a list and when searching for an answer we iterate
|
||||
// over the list. This must be moved to a tree-like structure and
|
||||
// have some fluff for DNSSEC (and be memory efficient).
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type (
|
||||
File struct {
|
||||
Next middleware.Handler
|
||||
Zones Zones
|
||||
// Maybe a list of all zones as well, as a []string?
|
||||
}
|
||||
|
||||
Zone []dns.RR
|
||||
Zones struct {
|
||||
Z map[string]Zone // utterly braindead impl. TODO(miek): fix
|
||||
Names []string
|
||||
}
|
||||
)
|
||||
|
||||
func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
context := middleware.Context{W: w, Req: r}
|
||||
qname := context.Name()
|
||||
zone := middleware.Zones(f.Zones.Names).Matches(qname)
|
||||
if zone == "" {
|
||||
return f.Next.ServeDNS(w, r)
|
||||
}
|
||||
|
||||
names, nodata := f.Zones.Z[zone].lookup(qname, context.QType())
|
||||
var answer *dns.Msg
|
||||
switch {
|
||||
case nodata:
|
||||
answer = context.AnswerMessage()
|
||||
answer.Ns = names
|
||||
case len(names) == 0:
|
||||
answer = context.AnswerMessage()
|
||||
answer.Ns = names
|
||||
answer.Rcode = dns.RcodeNameError
|
||||
case len(names) > 0:
|
||||
answer = context.AnswerMessage()
|
||||
answer.Answer = names
|
||||
default:
|
||||
answer = context.ErrorMessage(dns.RcodeServerFailure)
|
||||
}
|
||||
// Check return size, etc. TODO(miek)
|
||||
w.WriteMsg(answer)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Lookup will try to find qname and qtype in z. It returns the
|
||||
// records found *or* a boolean saying NODATA. If the answer
|
||||
// is NODATA then the RR returned is the SOA record.
|
||||
//
|
||||
// TODO(miek): EXTREMELY STUPID IMPLEMENTATION.
|
||||
// Doesn't do much, no delegation, no cname, nothing really, etc.
|
||||
// TODO(miek): even NODATA looks broken
|
||||
func (z Zone) lookup(qname string, qtype uint16) ([]dns.RR, bool) {
|
||||
var (
|
||||
nodata bool
|
||||
rep []dns.RR
|
||||
soa dns.RR
|
||||
)
|
||||
|
||||
for _, rr := range z {
|
||||
if rr.Header().Rrtype == dns.TypeSOA {
|
||||
soa = rr
|
||||
}
|
||||
// Match function in Go DNS?
|
||||
if strings.ToLower(rr.Header().Name) == qname {
|
||||
if rr.Header().Rrtype == qtype {
|
||||
rep = append(rep, rr)
|
||||
nodata = false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
if nodata {
|
||||
return []dns.RR{soa}, true
|
||||
}
|
||||
return rep, false
|
||||
}
|
||||
325
middleware/file/file_test.go
Normal file
325
middleware/file/file_test.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
|
||||
var ErrCustom = errors.New("Custom Error")
|
||||
|
||||
// testFiles is a map with relative paths to test files as keys and file content as values.
|
||||
// The map represents the following structure:
|
||||
// - $TEMP/caddy_testdir/
|
||||
// '-- file1.html
|
||||
// '-- dirwithindex/
|
||||
// '---- index.html
|
||||
// '-- dir/
|
||||
// '---- file2.html
|
||||
// '---- hidden.html
|
||||
var testFiles = map[string]string{
|
||||
"file1.html": "<h1>file1.html</h1>",
|
||||
filepath.Join("dirwithindex", "index.html"): "<h1>dirwithindex/index.html</h1>",
|
||||
filepath.Join("dir", "file2.html"): "<h1>dir/file2.html</h1>",
|
||||
filepath.Join("dir", "hidden.html"): "<h1>dir/hidden.html</h1>",
|
||||
}
|
||||
|
||||
// TestServeHTTP covers positive scenarios when serving files.
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
|
||||
beforeServeHTTPTest(t)
|
||||
defer afterServeHTTPTest(t)
|
||||
|
||||
fileserver := FileServer(http.Dir(testDir), []string{"hidden.html"})
|
||||
|
||||
movedPermanently := "Moved Permanently"
|
||||
|
||||
tests := []struct {
|
||||
url string
|
||||
|
||||
expectedStatus int
|
||||
expectedBodyContent string
|
||||
}{
|
||||
// Test 0 - access without any path
|
||||
{
|
||||
url: "https://foo",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
// Test 1 - access root (without index.html)
|
||||
{
|
||||
url: "https://foo/",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
// Test 2 - access existing file
|
||||
{
|
||||
url: "https://foo/file1.html",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBodyContent: testFiles["file1.html"],
|
||||
},
|
||||
// Test 3 - access folder with index file with trailing slash
|
||||
{
|
||||
url: "https://foo/dirwithindex/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
|
||||
},
|
||||
// Test 4 - access folder with index file without trailing slash
|
||||
{
|
||||
url: "https://foo/dirwithindex",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
expectedBodyContent: movedPermanently,
|
||||
},
|
||||
// Test 5 - access folder without index file
|
||||
{
|
||||
url: "https://foo/dir/",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
// Test 6 - access folder without trailing slash
|
||||
{
|
||||
url: "https://foo/dir",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
expectedBodyContent: movedPermanently,
|
||||
},
|
||||
// Test 6 - access file with trailing slash
|
||||
{
|
||||
url: "https://foo/file1.html/",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
expectedBodyContent: movedPermanently,
|
||||
},
|
||||
// Test 7 - access not existing path
|
||||
{
|
||||
url: "https://foo/not_existing",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
// Test 8 - access a file, marked as hidden
|
||||
{
|
||||
url: "https://foo/dir/hidden.html",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
// Test 9 - access a index file directly
|
||||
{
|
||||
url: "https://foo/dirwithindex/index.html",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
|
||||
},
|
||||
// Test 10 - send a request with query params
|
||||
{
|
||||
url: "https://foo/dir?param1=val",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
expectedBodyContent: movedPermanently,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
request, err := http.NewRequest("GET", test.url, strings.NewReader(""))
|
||||
status, err := fileserver.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// check if error matches expectations
|
||||
if err != nil {
|
||||
t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err)
|
||||
}
|
||||
|
||||
// check status code
|
||||
if test.expectedStatus != status {
|
||||
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
|
||||
}
|
||||
|
||||
// check body content
|
||||
if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) {
|
||||
t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// beforeServeHTTPTest creates a test directory with the structure, defined in the variable testFiles
|
||||
func beforeServeHTTPTest(t *testing.T) {
|
||||
// make the root test dir
|
||||
err := os.Mkdir(testDir, os.ModePerm)
|
||||
if err != nil {
|
||||
if !os.IsExist(err) {
|
||||
t.Fatalf("Failed to create test dir. Error was: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for relFile, fileContent := range testFiles {
|
||||
absFile := filepath.Join(testDir, relFile)
|
||||
|
||||
// make sure the parent directories exist
|
||||
parentDir := filepath.Dir(absFile)
|
||||
_, err = os.Stat(parentDir)
|
||||
if err != nil {
|
||||
os.MkdirAll(parentDir, os.ModePerm)
|
||||
}
|
||||
|
||||
// now create the test files
|
||||
f, err := os.Create(absFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test file %s. Error was: %v", absFile, err)
|
||||
return
|
||||
}
|
||||
|
||||
// and fill them with content
|
||||
_, err = f.WriteString(fileContent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write to %s. Error was: %v", absFile, err)
|
||||
return
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// afterServeHTTPTest removes the test dir and all its content
|
||||
func afterServeHTTPTest(t *testing.T) {
|
||||
// cleans up everything under the test dir. No need to clean the individual files.
|
||||
err := os.RemoveAll(testDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to clean up test dir %s. Error was: %v", testDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
// failingFS implements the http.FileSystem interface. The Open method always returns the error, assigned to err
|
||||
type failingFS struct {
|
||||
err error // the error to return when Open is called
|
||||
fileImpl http.File // inject the file implementation
|
||||
}
|
||||
|
||||
// Open returns the assigned failingFile and error
|
||||
func (f failingFS) Open(path string) (http.File, error) {
|
||||
return f.fileImpl, f.err
|
||||
}
|
||||
|
||||
// failingFile implements http.File but returns a predefined error on every Stat() method call.
|
||||
type failingFile struct {
|
||||
http.File
|
||||
err error
|
||||
}
|
||||
|
||||
// Stat returns nil FileInfo and the provided error on every call
|
||||
func (ff failingFile) Stat() (os.FileInfo, error) {
|
||||
return nil, ff.err
|
||||
}
|
||||
|
||||
// Close is noop and returns no error
|
||||
func (ff failingFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestServeHTTPFailingFS tests error cases where the Open function fails with various errors.
|
||||
func TestServeHTTPFailingFS(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
fsErr error
|
||||
expectedStatus int
|
||||
expectedErr error
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
fsErr: os.ErrNotExist,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
fsErr: os.ErrPermission,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedErr: os.ErrPermission,
|
||||
},
|
||||
{
|
||||
fsErr: ErrCustom,
|
||||
expectedStatus: http.StatusServiceUnavailable,
|
||||
expectedErr: ErrCustom,
|
||||
expectedHeaders: map[string]string{"Retry-After": "5"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
// initialize a file server with the failing FileSystem
|
||||
fileserver := FileServer(failingFS{err: test.fsErr}, nil)
|
||||
|
||||
// prepare the request and response
|
||||
request, err := http.NewRequest("GET", "https://foo/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to build request. Error was: %v", err)
|
||||
}
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// check the status
|
||||
if status != test.expectedStatus {
|
||||
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
|
||||
}
|
||||
|
||||
// check the error
|
||||
if actualErr != test.expectedErr {
|
||||
t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
|
||||
}
|
||||
|
||||
// check the headers - a special case for server under load
|
||||
if test.expectedHeaders != nil && len(test.expectedHeaders) > 0 {
|
||||
for expectedKey, expectedVal := range test.expectedHeaders {
|
||||
actualVal := responseRecorder.Header().Get(expectedKey)
|
||||
if expectedVal != actualVal {
|
||||
t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTPFailingStat tests error cases where the initial Open function succeeds, but the Stat method on the opened file fails.
|
||||
func TestServeHTTPFailingStat(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
statErr error
|
||||
expectedStatus int
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
statErr: os.ErrNotExist,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
statErr: os.ErrPermission,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedErr: os.ErrPermission,
|
||||
},
|
||||
{
|
||||
statErr: ErrCustom,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedErr: ErrCustom,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
// initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will
|
||||
fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil)
|
||||
|
||||
// prepare the request and response
|
||||
request, err := http.NewRequest("GET", "https://foo/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to build request. Error was: %v", err)
|
||||
}
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
|
||||
|
||||
// check the status
|
||||
if status != test.expectedStatus {
|
||||
t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
|
||||
}
|
||||
|
||||
// check the error
|
||||
if actualErr != test.expectedErr {
|
||||
t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
22
middleware/host.go
Normal file
22
middleware/host.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Host represents a host from the Caddyfile, may contain port.
|
||||
type Host string
|
||||
|
||||
// Standard host will return the host portion of host, stripping
|
||||
// of any port. The host will also be fully qualified and lowercased.
|
||||
func (h Host) StandardHost() string {
|
||||
// separate host and port
|
||||
host, _, err := net.SplitHostPort(string(h))
|
||||
if err != nil {
|
||||
host, _, _ = net.SplitHostPort(string(h) + ":")
|
||||
}
|
||||
return strings.ToLower(dns.Fqdn(host))
|
||||
}
|
||||
66
middleware/log/log.go
Normal file
66
middleware/log/log.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Package log implements basic but useful request (access) logging middleware.
|
||||
package log
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Logger is a basic request logging middleware.
|
||||
type Logger struct {
|
||||
Next middleware.Handler
|
||||
Rules []Rule
|
||||
ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
|
||||
}
|
||||
|
||||
func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
for _, rule := range l.Rules {
|
||||
/*
|
||||
if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
|
||||
responseRecorder := middleware.NewResponseRecorder(w)
|
||||
status, err := l.Next.ServeHTTP(responseRecorder, r)
|
||||
if status >= 400 {
|
||||
// There was an error up the chain, but no response has been written yet.
|
||||
// The error must be handled here so the log entry will record the response size.
|
||||
if l.ErrorFunc != nil {
|
||||
l.ErrorFunc(responseRecorder, r, status)
|
||||
} else {
|
||||
// Default failover error handler
|
||||
responseRecorder.WriteHeader(status)
|
||||
fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status))
|
||||
}
|
||||
status = 0
|
||||
}
|
||||
rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue)
|
||||
rule.Log.Println(rep.Replace(rule.Format))
|
||||
return status, err
|
||||
}
|
||||
*/
|
||||
rule = rule
|
||||
}
|
||||
return l.Next.ServeDNS(w, r)
|
||||
}
|
||||
|
||||
// Rule configures the logging middleware.
|
||||
type Rule struct {
|
||||
PathScope string
|
||||
OutputFile string
|
||||
Format string
|
||||
Log *log.Logger
|
||||
Roller *middleware.LogRoller
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultLogFilename is the default log filename.
|
||||
DefaultLogFilename = "access.log"
|
||||
// CommonLogFormat is the common log format.
|
||||
CommonLogFormat = `{remote} ` + CommonLogEmptyValue + ` [{when}] "{type} {name} {proto}" {rcode} {size}`
|
||||
// CommonLogEmptyValue is the common empty log value.
|
||||
CommonLogEmptyValue = "-"
|
||||
// CombinedLogFormat is the combined log format.
|
||||
CombinedLogFormat = CommonLogFormat + ` "{>Referer}" "{>User-Agent}"` // Something here as well
|
||||
// DefaultLogFormat is the default log format.
|
||||
DefaultLogFormat = CommonLogFormat
|
||||
)
|
||||
48
middleware/log/log_test.go
Normal file
48
middleware/log/log_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type erroringMiddleware struct{}
|
||||
|
||||
func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
return http.StatusNotFound, nil
|
||||
}
|
||||
|
||||
func TestLoggedStatus(t *testing.T) {
|
||||
var f bytes.Buffer
|
||||
var next erroringMiddleware
|
||||
rule := Rule{
|
||||
PathScope: "/",
|
||||
Format: DefaultLogFormat,
|
||||
Log: log.New(&f, "", 0),
|
||||
}
|
||||
|
||||
logger := Logger{
|
||||
Rules: []Rule{rule},
|
||||
Next: next,
|
||||
}
|
||||
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
status, err := logger.ServeHTTP(rec, r)
|
||||
if status != 0 {
|
||||
t.Error("Expected status to be 0 - was", status)
|
||||
}
|
||||
|
||||
logged := f.String()
|
||||
if !strings.Contains(logged, "404 13") {
|
||||
t.Error("Expected 404 to be logged. Logged string -", logged)
|
||||
}
|
||||
}
|
||||
105
middleware/middleware.go
Normal file
105
middleware/middleware.go
Normal file
@@ -0,0 +1,105 @@
|
||||
// Package middleware provides some types and functions common among middleware.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type (
|
||||
// Middleware is the middle layer which represents the traditional
|
||||
// idea of middleware: it chains one Handler to the next by being
|
||||
// passed the next Handler in the chain.
|
||||
Middleware func(Handler) Handler
|
||||
|
||||
// Handler is like dns.Handler except ServeDNS may return an rcode
|
||||
// and/or error.
|
||||
//
|
||||
// If ServeDNS writes to the response body, it should return a status
|
||||
// code of 0. This signals to other handlers above it that the response
|
||||
// body is already written, and that they should not write to it also.
|
||||
//
|
||||
// If ServeDNS encounters an error, it should return the error value
|
||||
// so it can be logged by designated error-handling middleware.
|
||||
//
|
||||
// If writing a response after calling another ServeDNS method, the
|
||||
// returned rcode SHOULD be used when writing the response.
|
||||
//
|
||||
// If handling errors after calling another ServeDNS method, the
|
||||
// returned error value SHOULD be logged or handled accordingly.
|
||||
//
|
||||
// Otherwise, return values should be propagated down the middleware
|
||||
// chain by returning them unchanged.
|
||||
Handler interface {
|
||||
ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error)
|
||||
}
|
||||
|
||||
// HandlerFunc is a convenience type like dns.HandlerFunc, except
|
||||
// ServeDNS returns an rcode and an error. See Handler
|
||||
// documentation for more information.
|
||||
HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error)
|
||||
)
|
||||
|
||||
// ServeDNS implements the Handler interface.
|
||||
func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
return f(w, r)
|
||||
}
|
||||
|
||||
// IndexFile looks for a file in /root/fpath/indexFile for each string
|
||||
// in indexFiles. If an index file is found, it returns the root-relative
|
||||
// path to the file and true. If no index file is found, empty string
|
||||
// and false is returned. fpath must end in a forward slash '/'
|
||||
// otherwise no index files will be tried (directory paths must end
|
||||
// in a forward slash according to HTTP).
|
||||
//
|
||||
// All paths passed into and returned from this function use '/' as the
|
||||
// path separator, just like URLs. IndexFle handles path manipulation
|
||||
// internally for systems that use different path separators.
|
||||
/*
|
||||
func IndexFile(root http.FileSystem, fpath string, indexFiles []string) (string, bool) {
|
||||
if fpath[len(fpath)-1] != '/' || root == nil {
|
||||
return "", false
|
||||
}
|
||||
for _, indexFile := range indexFiles {
|
||||
// func (http.FileSystem).Open wants all paths separated by "/",
|
||||
// regardless of operating system convention, so use
|
||||
// path.Join instead of filepath.Join
|
||||
fp := path.Join(fpath, indexFile)
|
||||
f, err := root.Open(fp)
|
||||
if err == nil {
|
||||
f.Close()
|
||||
return fp, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// SetLastModifiedHeader checks if the provided modTime is valid and if it is sets it
|
||||
// as a Last-Modified header to the ResponseWriter. If the modTime is in the future
|
||||
// the current time is used instead.
|
||||
func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) {
|
||||
if modTime.IsZero() || modTime.Equal(time.Unix(0, 0)) {
|
||||
// the time does not appear to be valid. Don't put it in the response
|
||||
return
|
||||
}
|
||||
|
||||
// RFC 2616 - Section 14.29 - Last-Modified:
|
||||
// An origin server MUST NOT send a Last-Modified date which is later than the
|
||||
// server's time of message origination. In such cases, where the resource's last
|
||||
// modification would indicate some time in the future, the server MUST replace
|
||||
// that date with the message origination date.
|
||||
now := currentTime()
|
||||
if modTime.After(now) {
|
||||
modTime = now
|
||||
}
|
||||
|
||||
w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat))
|
||||
}
|
||||
*/
|
||||
|
||||
// currentTime, as it is defined here, returns time.Now().
|
||||
// It's defined as a variable for mocking time in tests.
|
||||
var currentTime = func() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
108
middleware/middleware_test.go
Normal file
108
middleware/middleware_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIndexfile(t *testing.T) {
|
||||
tests := []struct {
|
||||
rootDir http.FileSystem
|
||||
fpath string
|
||||
indexFiles []string
|
||||
shouldErr bool
|
||||
expectedFilePath string //retun value
|
||||
expectedBoolValue bool //return value
|
||||
}{
|
||||
{
|
||||
http.Dir("./templates/testdata"),
|
||||
"/images/",
|
||||
[]string{"img.htm"},
|
||||
false,
|
||||
"/images/img.htm",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for i, test := range tests {
|
||||
actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles)
|
||||
if actualBoolValue == true && test.shouldErr {
|
||||
t.Errorf("Test %d didn't error, but it should have", i)
|
||||
} else if actualBoolValue != true && !test.shouldErr {
|
||||
t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist")
|
||||
}
|
||||
if actualFilePath != test.expectedFilePath {
|
||||
t.Fatalf("Test %d expected returned filepath to be %s, but got %s ",
|
||||
i, test.expectedFilePath, actualFilePath)
|
||||
|
||||
}
|
||||
if actualBoolValue != test.expectedBoolValue {
|
||||
t.Fatalf("Test %d expected returned bool value to be %v, but got %v ",
|
||||
i, test.expectedBoolValue, actualBoolValue)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLastModified(t *testing.T) {
|
||||
nowTime := time.Now()
|
||||
|
||||
// ovewrite the function to return reliable time
|
||||
originalGetCurrentTimeFunc := currentTime
|
||||
currentTime = func() time.Time {
|
||||
return nowTime
|
||||
}
|
||||
defer func() {
|
||||
currentTime = originalGetCurrentTimeFunc
|
||||
}()
|
||||
|
||||
pastTime := nowTime.Truncate(1 * time.Hour)
|
||||
futureTime := nowTime.Add(1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
inputModTime time.Time
|
||||
expectedIsHeaderSet bool
|
||||
expectedLastModified string
|
||||
}{
|
||||
{
|
||||
inputModTime: pastTime,
|
||||
expectedIsHeaderSet: true,
|
||||
expectedLastModified: pastTime.UTC().Format(http.TimeFormat),
|
||||
},
|
||||
{
|
||||
inputModTime: nowTime,
|
||||
expectedIsHeaderSet: true,
|
||||
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
|
||||
},
|
||||
{
|
||||
inputModTime: futureTime,
|
||||
expectedIsHeaderSet: true,
|
||||
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
|
||||
},
|
||||
{
|
||||
inputModTime: time.Time{},
|
||||
expectedIsHeaderSet: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
|
||||
SetLastModifiedHeader(responseRecorder, test.inputModTime)
|
||||
actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified")
|
||||
|
||||
if test.expectedIsHeaderSet && actualLastModifiedHeader == "" {
|
||||
t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing")
|
||||
}
|
||||
|
||||
if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" {
|
||||
t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader)
|
||||
}
|
||||
|
||||
if test.expectedLastModified != actualLastModifiedHeader {
|
||||
t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader)
|
||||
}
|
||||
}
|
||||
}
|
||||
18
middleware/path.go
Normal file
18
middleware/path.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package middleware
|
||||
|
||||
import "strings"
|
||||
|
||||
|
||||
// TODO(miek): matches for names.
|
||||
|
||||
// Path represents a URI path, maybe with pattern characters.
|
||||
type Path string
|
||||
|
||||
// Matches checks to see if other matches p.
|
||||
//
|
||||
// Path matching will probably not always be a direct
|
||||
// comparison; this method assures that paths can be
|
||||
// easily and consistently matched.
|
||||
func (p Path) Matches(other string) bool {
|
||||
return strings.HasPrefix(string(p), other)
|
||||
}
|
||||
31
middleware/prometheus/handler.go
Normal file
31
middleware/prometheus/handler.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
context := middleware.Context{W: w, Req: r}
|
||||
|
||||
qname := context.Name()
|
||||
qtype := context.Type()
|
||||
zone := middleware.Zones(m.ZoneNames).Matches(qname)
|
||||
if zone == "" {
|
||||
zone = "."
|
||||
}
|
||||
|
||||
// Record response to get status code and size of the reply.
|
||||
rw := middleware.NewResponseRecorder(w)
|
||||
status, err := m.Next.ServeDNS(rw, r)
|
||||
|
||||
requestCount.WithLabelValues(zone, qtype).Inc()
|
||||
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
|
||||
responseSize.WithLabelValues(zone).Observe(float64(rw.Size()))
|
||||
responseRcode.WithLabelValues(zone, strconv.Itoa(rw.Rcode())).Inc()
|
||||
|
||||
return status, err
|
||||
}
|
||||
80
middleware/prometheus/metrics.go
Normal file
80
middleware/prometheus/metrics.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
const namespace = "daddy"
|
||||
|
||||
var (
|
||||
requestCount *prometheus.CounterVec
|
||||
requestDuration *prometheus.HistogramVec
|
||||
responseSize *prometheus.HistogramVec
|
||||
responseRcode *prometheus.CounterVec
|
||||
)
|
||||
|
||||
const path = "/metrics"
|
||||
|
||||
// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics
|
||||
type Metrics struct {
|
||||
Next middleware.Handler
|
||||
Addr string // where to we listen
|
||||
Once sync.Once
|
||||
ZoneNames []string
|
||||
}
|
||||
|
||||
func (m *Metrics) Start() error {
|
||||
m.Once.Do(func() {
|
||||
define("")
|
||||
|
||||
prometheus.MustRegister(requestCount)
|
||||
prometheus.MustRegister(requestDuration)
|
||||
prometheus.MustRegister(responseSize)
|
||||
prometheus.MustRegister(responseRcode)
|
||||
|
||||
http.Handle(path, prometheus.Handler())
|
||||
go func() {
|
||||
fmt.Errorf("%s", http.ListenAndServe(m.Addr, nil))
|
||||
}()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func define(subsystem string) {
|
||||
if subsystem == "" {
|
||||
subsystem = "dns"
|
||||
}
|
||||
requestCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "request_count_total",
|
||||
Help: "Counter of DNS requests made per zone and type.",
|
||||
}, []string{"zone", "qtype"})
|
||||
|
||||
requestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "request_duration_seconds",
|
||||
Help: "Histogram of the time (in seconds) each request took.",
|
||||
}, []string{"zone"})
|
||||
|
||||
responseSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "response_size_bytes",
|
||||
Help: "Size of the returns response in bytes.",
|
||||
Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3},
|
||||
}, []string{"zone"})
|
||||
|
||||
responseRcode = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "rcode_code_count_total",
|
||||
Help: "Counter of response status codes.",
|
||||
}, []string{"zone", "rcode"})
|
||||
}
|
||||
101
middleware/proxy/policy.go
Normal file
101
middleware/proxy/policy.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// HostPool is a collection of UpstreamHosts.
|
||||
type HostPool []*UpstreamHost
|
||||
|
||||
// Policy decides how a host will be selected from a pool.
|
||||
type Policy interface {
|
||||
Select(pool HostPool) *UpstreamHost
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterPolicy("random", func() Policy { return &Random{} })
|
||||
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
|
||||
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
|
||||
}
|
||||
|
||||
// Random is a policy that selects up hosts from a pool at random.
|
||||
type Random struct{}
|
||||
|
||||
// Select selects an up host at random from the specified pool.
|
||||
func (r *Random) Select(pool HostPool) *UpstreamHost {
|
||||
// instead of just generating a random index
|
||||
// this is done to prevent selecting a down host
|
||||
var randHost *UpstreamHost
|
||||
count := 0
|
||||
for _, host := range pool {
|
||||
if host.Down() {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
if count == 1 {
|
||||
randHost = host
|
||||
} else {
|
||||
r := rand.Int() % count
|
||||
if r == (count - 1) {
|
||||
randHost = host
|
||||
}
|
||||
}
|
||||
}
|
||||
return randHost
|
||||
}
|
||||
|
||||
// LeastConn is a policy that selects the host with the least connections.
|
||||
type LeastConn struct{}
|
||||
|
||||
// Select selects the up host with the least number of connections in the
|
||||
// pool. If more than one host has the same least number of connections,
|
||||
// one of the hosts is chosen at random.
|
||||
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
|
||||
var bestHost *UpstreamHost
|
||||
count := 0
|
||||
leastConn := int64(1<<63 - 1)
|
||||
for _, host := range pool {
|
||||
if host.Down() {
|
||||
continue
|
||||
}
|
||||
hostConns := host.Conns
|
||||
if hostConns < leastConn {
|
||||
bestHost = host
|
||||
leastConn = hostConns
|
||||
count = 1
|
||||
} else if hostConns == leastConn {
|
||||
// randomly select host among hosts with least connections
|
||||
count++
|
||||
if count == 1 {
|
||||
bestHost = host
|
||||
} else {
|
||||
r := rand.Int() % count
|
||||
if r == (count - 1) {
|
||||
bestHost = host
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return bestHost
|
||||
}
|
||||
|
||||
// RoundRobin is a policy that selects hosts based on round robin ordering.
|
||||
type RoundRobin struct {
|
||||
Robin uint32
|
||||
}
|
||||
|
||||
// Select selects an up host from the pool using a round robin ordering scheme.
|
||||
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
|
||||
poolLen := uint32(len(pool))
|
||||
selection := atomic.AddUint32(&r.Robin, 1) % poolLen
|
||||
host := pool[selection]
|
||||
// if the currently selected host is down, just ffwd to up host
|
||||
for i := uint32(1); host.Down() && i < poolLen; i++ {
|
||||
host = pool[(selection+i)%poolLen]
|
||||
}
|
||||
if host.Down() {
|
||||
return nil
|
||||
}
|
||||
return host
|
||||
}
|
||||
87
middleware/proxy/policy_test.go
Normal file
87
middleware/proxy/policy_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var workableServer *httptest.Server
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
workableServer = httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
// do nothing
|
||||
}))
|
||||
r := m.Run()
|
||||
workableServer.Close()
|
||||
os.Exit(r)
|
||||
}
|
||||
|
||||
type customPolicy struct{}
|
||||
|
||||
func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
|
||||
return pool[0]
|
||||
}
|
||||
|
||||
func testPool() HostPool {
|
||||
pool := []*UpstreamHost{
|
||||
{
|
||||
Name: workableServer.URL, // this should resolve (healthcheck test)
|
||||
},
|
||||
{
|
||||
Name: "http://shouldnot.resolve", // this shouldn't
|
||||
},
|
||||
{
|
||||
Name: "http://C",
|
||||
},
|
||||
}
|
||||
return HostPool(pool)
|
||||
}
|
||||
|
||||
func TestRoundRobinPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
rrPolicy := &RoundRobin{}
|
||||
h := rrPolicy.Select(pool)
|
||||
// First selected host is 1, because counter starts at 0
|
||||
// and increments before host is selected
|
||||
if h != pool[1] {
|
||||
t.Error("Expected first round robin host to be second host in the pool.")
|
||||
}
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected second round robin host to be third host in the pool.")
|
||||
}
|
||||
// mark host as down
|
||||
pool[0].Unhealthy = true
|
||||
h = rrPolicy.Select(pool)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected third round robin host to be first host in the pool.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastConnPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
lcPolicy := &LeastConn{}
|
||||
pool[0].Conns = 10
|
||||
pool[1].Conns = 10
|
||||
h := lcPolicy.Select(pool)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected least connection host to be third host.")
|
||||
}
|
||||
pool[2].Conns = 100
|
||||
h = lcPolicy.Select(pool)
|
||||
if h != pool[0] && h != pool[1] {
|
||||
t.Error("Expected least connection host to be first or second host.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
customPolicy := &customPolicy{}
|
||||
h := customPolicy.Select(pool)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected custom policy host to be the first host.")
|
||||
}
|
||||
}
|
||||
120
middleware/proxy/proxy.go
Normal file
120
middleware/proxy/proxy.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Package proxy is middleware that proxies requests.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var errUnreachable = errors.New("unreachable backend")
|
||||
|
||||
// Proxy represents a middleware instance that can proxy requests.
|
||||
type Proxy struct {
|
||||
Next middleware.Handler
|
||||
Client Client
|
||||
Upstreams []Upstream
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
UDP *dns.Client
|
||||
TCP *dns.Client
|
||||
}
|
||||
|
||||
// Upstream manages a pool of proxy upstream hosts. Select should return a
|
||||
// suitable upstream host, or nil if no such hosts are available.
|
||||
type Upstream interface {
|
||||
// The domain name this upstream host should be routed on.
|
||||
From() string
|
||||
// Selects an upstream host to be routed to.
|
||||
Select() *UpstreamHost
|
||||
// Checks if subpdomain is not an ignored.
|
||||
IsAllowedPath(string) bool
|
||||
}
|
||||
|
||||
// UpstreamHostDownFunc can be used to customize how Down behaves.
|
||||
type UpstreamHostDownFunc func(*UpstreamHost) bool
|
||||
|
||||
// UpstreamHost represents a single proxy upstream
|
||||
type UpstreamHost struct {
|
||||
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
|
||||
Name string // IP address (and port) of this upstream host
|
||||
Fails int32
|
||||
FailTimeout time.Duration
|
||||
Unhealthy bool
|
||||
ExtraHeaders http.Header
|
||||
CheckDown UpstreamHostDownFunc
|
||||
WithoutPathPrefix string
|
||||
}
|
||||
|
||||
// Down checks whether the upstream host is down or not.
|
||||
// Down will try to use uh.CheckDown first, and will fall
|
||||
// back to some default criteria if necessary.
|
||||
func (uh *UpstreamHost) Down() bool {
|
||||
if uh.CheckDown == nil {
|
||||
// Default settings
|
||||
return uh.Unhealthy || uh.Fails > 0
|
||||
}
|
||||
return uh.CheckDown(uh)
|
||||
}
|
||||
|
||||
// tryDuration is how long to try upstream hosts; failures result in
|
||||
// immediate retries until this duration ends or we get a nil host.
|
||||
var tryDuration = 60 * time.Second
|
||||
|
||||
// ServeDNS satisfies the middleware.Handler interface.
|
||||
func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
for _, upstream := range p.Upstreams {
|
||||
// allowed bla bla bla TODO(miek): fix full proxy spec from caddy
|
||||
start := time.Now()
|
||||
|
||||
// Since Select() should give us "up" hosts, keep retrying
|
||||
// hosts until timeout (or until we get a nil host).
|
||||
for time.Now().Sub(start) < tryDuration {
|
||||
host := upstream.Select()
|
||||
if host == nil {
|
||||
return dns.RcodeServerFailure, errUnreachable
|
||||
}
|
||||
// TODO(miek): PORT!
|
||||
reverseproxy := ReverseProxy{Host: host.Name, Client: p.Client}
|
||||
|
||||
atomic.AddInt64(&host.Conns, 1)
|
||||
backendErr := reverseproxy.ServeDNS(w, r, nil)
|
||||
atomic.AddInt64(&host.Conns, -1)
|
||||
if backendErr == nil {
|
||||
return 0, nil
|
||||
}
|
||||
timeout := host.FailTimeout
|
||||
if timeout == 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
atomic.AddInt32(&host.Fails, 1)
|
||||
go func(host *UpstreamHost, timeout time.Duration) {
|
||||
time.Sleep(timeout)
|
||||
atomic.AddInt32(&host.Fails, -1)
|
||||
}(host, timeout)
|
||||
}
|
||||
return dns.RcodeServerFailure, errUnreachable
|
||||
}
|
||||
return p.Next.ServeDNS(w, r)
|
||||
}
|
||||
|
||||
func Clients() Client {
|
||||
udp := newClient("udp", defaultTimeout)
|
||||
tcp := newClient("tcp", defaultTimeout)
|
||||
return Client{UDP: udp, TCP: tcp}
|
||||
}
|
||||
|
||||
// newClient returns a new client for proxy requests.
|
||||
func newClient(net string, timeout time.Duration) *dns.Client {
|
||||
if timeout == 0 {
|
||||
timeout = defaultTimeout
|
||||
}
|
||||
return &dns.Client{Net: net, ReadTimeout: timeout, WriteTimeout: timeout, SingleInflight: true}
|
||||
}
|
||||
|
||||
const defaultTimeout = 5 * time.Second
|
||||
317
middleware/proxy/proxy_test.go
Normal file
317
middleware/proxy/proxy_test.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func init() {
|
||||
tryDuration = 50 * time.Millisecond // prevent tests from hanging
|
||||
}
|
||||
|
||||
func TestReverseProxy(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
var requestReceived bool
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestReceived = true
|
||||
w.Write([]byte("Hello, client"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
if !requestReceived {
|
||||
t.Error("Expected backend to receive request, but it didn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyInsecureSkipVerify(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
var requestReceived bool
|
||||
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestReceived = true
|
||||
w.Write([]byte("Hello, client"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
if !requestReceived {
|
||||
t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
||||
// No-op websocket backend simply allows the WS connection to be
|
||||
// accepted then it will be immediately closed. Perfect for testing.
|
||||
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {}))
|
||||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL)
|
||||
|
||||
// Create client request
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
r.Header = http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Upgrade": {"websocket"},
|
||||
"Origin": {wsNop.URL},
|
||||
"Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
|
||||
"Sec-WebSocket-Version": {"13"},
|
||||
}
|
||||
|
||||
// Capture the request
|
||||
w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)}
|
||||
|
||||
// Booya! Do the test.
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
// Make sure the backend accepted the WS connection.
|
||||
// Mostly interested in the Upgrade and Connection response headers
|
||||
// and the 101 status code.
|
||||
expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n")
|
||||
actual := w.fakeConn.writeBuf.Bytes()
|
||||
if !bytes.Equal(actual, expected) {
|
||||
t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||
// Echo server allows us to test that socket bytes are properly
|
||||
// being proxied.
|
||||
wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
|
||||
io.Copy(ws, ws)
|
||||
}))
|
||||
defer wsEcho.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsEcho.URL)
|
||||
|
||||
// This is a full end-end test, so the proxy handler
|
||||
// has to be part of a server listening on a port. Our
|
||||
// WS client will connect to this test server, not
|
||||
// the echo client directly.
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
defer echoProxy.Close()
|
||||
|
||||
// Set up WebSocket client
|
||||
url := strings.Replace(echoProxy.URL, "http://", "ws://", 1)
|
||||
ws, err := websocket.Dial(url, "", echoProxy.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Send test message
|
||||
trialMsg := "Is it working?"
|
||||
websocket.Message.Send(ws, trialMsg)
|
||||
|
||||
// It should be echoed back to us
|
||||
var actualMsg string
|
||||
websocket.Message.Receive(ws, &actualMsg)
|
||||
if actualMsg != trialMsg {
|
||||
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketProxy(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
trialMsg := "Is it working?"
|
||||
|
||||
var proxySuccess bool
|
||||
|
||||
// This is our fake "application" we want to proxy to
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Request was proxied when this is called
|
||||
proxySuccess = true
|
||||
|
||||
fmt.Fprint(w, trialMsg)
|
||||
}))
|
||||
|
||||
// Get absolute path for unix: socket
|
||||
socketPath, err := filepath.Abs("./test_socket")
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get absolute path: %v", err)
|
||||
}
|
||||
|
||||
// Change httptest.Server listener to listen to unix: socket
|
||||
ln, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to listen: %v", err)
|
||||
}
|
||||
ts.Listener = ln
|
||||
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||
p := newWebSocketTestProxy(url)
|
||||
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
defer echoProxy.Close()
|
||||
|
||||
res, err := http.Get(echoProxy.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to GET: %v", err)
|
||||
}
|
||||
|
||||
greeting, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to GET: %v", err)
|
||||
}
|
||||
|
||||
actualMsg := fmt.Sprintf("%s", greeting)
|
||||
|
||||
if !proxySuccess {
|
||||
t.Errorf("Expected request to be proxied, but it wasn't")
|
||||
}
|
||||
|
||||
if actualMsg != trialMsg {
|
||||
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||
uri, _ := url.Parse(name)
|
||||
u := &fakeUpstream{
|
||||
name: name,
|
||||
host: &UpstreamHost{
|
||||
Name: name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, ""),
|
||||
},
|
||||
}
|
||||
if insecure {
|
||||
u.host.ReverseProxy.Transport = InsecureTransport
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
type fakeUpstream struct {
|
||||
name string
|
||||
host *UpstreamHost
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) From() string {
|
||||
return "/"
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) Select() *UpstreamHost {
|
||||
return u.host
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) IsAllowedPath(requestPath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// newWebSocketTestProxy returns a test proxy that will
|
||||
// redirect to the specified backendAddr. The function
|
||||
// also sets up the rules/environment for testing WebSocket
|
||||
// proxy.
|
||||
func newWebSocketTestProxy(backendAddr string) *Proxy {
|
||||
return &Proxy{
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}},
|
||||
}
|
||||
}
|
||||
|
||||
type fakeWsUpstream struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) From() string {
|
||||
return "/"
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) Select() *UpstreamHost {
|
||||
uri, _ := url.Parse(u.name)
|
||||
return &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, ""),
|
||||
ExtraHeaders: http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"}},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) IsAllowedPath(requestPath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// recorderHijacker is a ResponseRecorder that can
|
||||
// be hijacked.
|
||||
type recorderHijacker struct {
|
||||
*httptest.ResponseRecorder
|
||||
fakeConn *fakeConn
|
||||
}
|
||||
|
||||
func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return rh.fakeConn, nil, nil
|
||||
}
|
||||
|
||||
type fakeConn struct {
|
||||
readBuf bytes.Buffer
|
||||
writeBuf bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *fakeConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
func (c *fakeConn) Close() error { return nil }
|
||||
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
|
||||
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
|
||||
36
middleware/proxy/reverseproxy.go
Normal file
36
middleware/proxy/reverseproxy.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Package proxy is middleware that proxies requests.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type ReverseProxy struct {
|
||||
Host string
|
||||
Client Client
|
||||
}
|
||||
|
||||
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
|
||||
// TODO(miek): use extra!
|
||||
var (
|
||||
reply *dns.Msg
|
||||
err error
|
||||
)
|
||||
context := middleware.Context{W: w, Req: r}
|
||||
|
||||
// tls+tcp ?
|
||||
if context.Proto() == "tcp" {
|
||||
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
|
||||
} else {
|
||||
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reply.Compress = true
|
||||
reply.Id = r.Id
|
||||
w.WriteMsg(reply)
|
||||
return nil
|
||||
}
|
||||
235
middleware/proxy/upstream.go
Normal file
235
middleware/proxy/upstream.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/core/parse"
|
||||
"github.com/miekg/coredns/middleware"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedPolicies = make(map[string]func() Policy)
|
||||
)
|
||||
|
||||
type staticUpstream struct {
|
||||
from string
|
||||
// TODO(miek): allows use to added headers
|
||||
proxyHeaders http.Header // TODO(miek): kill
|
||||
Hosts HostPool
|
||||
Policy Policy
|
||||
|
||||
FailTimeout time.Duration
|
||||
MaxFails int32
|
||||
HealthCheck struct {
|
||||
Path string
|
||||
Interval time.Duration
|
||||
}
|
||||
WithoutPathPrefix string
|
||||
IgnoredSubPaths []string
|
||||
}
|
||||
|
||||
// NewStaticUpstreams parses the configuration input and sets up
|
||||
// static upstreams for the proxy middleware.
|
||||
func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
|
||||
var upstreams []Upstream
|
||||
for c.Next() {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
proxyHeaders: make(http.Header),
|
||||
Hosts: nil,
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
|
||||
if !c.Args(&upstream.from) {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
to := c.RemainingArgs()
|
||||
if len(to) == 0 {
|
||||
return upstreams, c.ArgErr()
|
||||
}
|
||||
|
||||
for c.NextBlock() {
|
||||
if err := parseBlock(&c, upstream); err != nil {
|
||||
return upstreams, err
|
||||
}
|
||||
}
|
||||
|
||||
upstream.Hosts = make([]*UpstreamHost, len(to))
|
||||
for i, host := range to {
|
||||
uh := &UpstreamHost{
|
||||
Name: host,
|
||||
Conns: 0,
|
||||
Fails: 0,
|
||||
FailTimeout: upstream.FailTimeout,
|
||||
Unhealthy: false,
|
||||
ExtraHeaders: upstream.proxyHeaders,
|
||||
CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
|
||||
return func(uh *UpstreamHost) bool {
|
||||
if uh.Unhealthy {
|
||||
return true
|
||||
}
|
||||
if uh.Fails >= upstream.MaxFails &&
|
||||
upstream.MaxFails != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}(upstream),
|
||||
WithoutPathPrefix: upstream.WithoutPathPrefix,
|
||||
}
|
||||
upstream.Hosts[i] = uh
|
||||
}
|
||||
|
||||
if upstream.HealthCheck.Path != "" {
|
||||
go upstream.HealthCheckWorker(nil)
|
||||
}
|
||||
upstreams = append(upstreams, upstream)
|
||||
}
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
// RegisterPolicy adds a custom policy to the proxy.
|
||||
func RegisterPolicy(name string, policy func() Policy) {
|
||||
supportedPolicies[name] = policy
|
||||
}
|
||||
|
||||
func (u *staticUpstream) From() string {
|
||||
return u.from
|
||||
}
|
||||
|
||||
func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
|
||||
switch c.Val() {
|
||||
case "policy":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
policyCreateFunc, ok := supportedPolicies[c.Val()]
|
||||
if !ok {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.Policy = policyCreateFunc()
|
||||
case "fail_timeout":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.FailTimeout = dur
|
||||
case "max_fails":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
n, err := strconv.Atoi(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.MaxFails = int32(n)
|
||||
case "health_check":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.HealthCheck.Path = c.Val()
|
||||
u.HealthCheck.Interval = 30 * time.Second
|
||||
if c.NextArg() {
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.HealthCheck.Interval = dur
|
||||
}
|
||||
case "proxy_header":
|
||||
var header, value string
|
||||
if !c.Args(&header, &value) {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.proxyHeaders.Add(header, value)
|
||||
case "websocket":
|
||||
u.proxyHeaders.Add("Connection", "{>Connection}")
|
||||
u.proxyHeaders.Add("Upgrade", "{>Upgrade}")
|
||||
case "without":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.WithoutPathPrefix = c.Val()
|
||||
case "except":
|
||||
ignoredPaths := c.RemainingArgs()
|
||||
if len(ignoredPaths) == 0 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.IgnoredSubPaths = ignoredPaths
|
||||
default:
|
||||
return c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *staticUpstream) healthCheck() {
|
||||
for _, host := range u.Hosts {
|
||||
hostURL := host.Name + u.HealthCheck.Path
|
||||
if r, err := http.Get(hostURL); err == nil {
|
||||
io.Copy(ioutil.Discard, r.Body)
|
||||
r.Body.Close()
|
||||
host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
|
||||
} else {
|
||||
host.Unhealthy = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
|
||||
ticker := time.NewTicker(u.HealthCheck.Interval)
|
||||
u.healthCheck()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
u.healthCheck()
|
||||
case <-stop:
|
||||
// TODO: the library should provide a stop channel and global
|
||||
// waitgroup to allow goroutines started by plugins a chance
|
||||
// to clean themselves up.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *staticUpstream) Select() *UpstreamHost {
|
||||
pool := u.Hosts
|
||||
if len(pool) == 1 {
|
||||
if pool[0].Down() {
|
||||
return nil
|
||||
}
|
||||
return pool[0]
|
||||
}
|
||||
allDown := true
|
||||
for _, host := range pool {
|
||||
if !host.Down() {
|
||||
allDown = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allDown {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.Policy == nil {
|
||||
return (&Random{}).Select(pool)
|
||||
}
|
||||
return u.Policy.Select(pool)
|
||||
}
|
||||
|
||||
func (u *staticUpstream) IsAllowedPath(requestPath string) bool {
|
||||
for _, ignoredSubPath := range u.IgnoredSubPaths {
|
||||
if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
83
middleware/proxy/upstream_test.go
Normal file
83
middleware/proxy/upstream_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHealthCheck(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
Hosts: testPool(),
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
upstream.healthCheck()
|
||||
if upstream.Hosts[0].Down() {
|
||||
t.Error("Expected first host in testpool to not fail healthcheck.")
|
||||
}
|
||||
if !upstream.Hosts[1].Down() {
|
||||
t.Error("Expected second host in testpool to fail healthcheck.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelect(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
Hosts: testPool()[:3],
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
upstream.Hosts[0].Unhealthy = true
|
||||
upstream.Hosts[1].Unhealthy = true
|
||||
upstream.Hosts[2].Unhealthy = true
|
||||
if h := upstream.Select(); h != nil {
|
||||
t.Error("Expected select to return nil as all host are down")
|
||||
}
|
||||
upstream.Hosts[2].Unhealthy = false
|
||||
if h := upstream.Select(); h == nil {
|
||||
t.Error("Expected select to not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterPolicy(t *testing.T) {
|
||||
name := "custom"
|
||||
customPolicy := &customPolicy{}
|
||||
RegisterPolicy(name, func() Policy { return customPolicy })
|
||||
if _, ok := supportedPolicies[name]; !ok {
|
||||
t.Error("Expected supportedPolicies to have a custom policy.")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAllowedPaths(t *testing.T) {
|
||||
upstream := &staticUpstream{
|
||||
from: "/proxy",
|
||||
IgnoredSubPaths: []string{"/download", "/static"},
|
||||
}
|
||||
tests := []struct {
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
{"/proxy", true},
|
||||
{"/proxy/dl", true},
|
||||
{"/proxy/download", false},
|
||||
{"/proxy/download/static", false},
|
||||
{"/proxy/static", false},
|
||||
{"/proxy/static/download", false},
|
||||
{"/proxy/something/download", true},
|
||||
{"/proxy/something/static", true},
|
||||
{"/proxy//static", false},
|
||||
{"/proxy//static//download", false},
|
||||
{"/proxy//download", false},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
isAllowed := upstream.IsAllowedPath(test.url)
|
||||
if test.expected != isAllowed {
|
||||
t.Errorf("Test %d: expected %v found %v", i+1, test.expected, isAllowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
70
middleware/recorder.go
Normal file
70
middleware/recorder.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ResponseRecorder is a type of ResponseWriter that captures
|
||||
// the rcode code written to it and also the size of the message
|
||||
// written in the response. A rcode code does not have
|
||||
// to be written, however, in which case 0 must be assumed.
|
||||
// It is best to have the constructor initialize this type
|
||||
// with that default status code.
|
||||
type ResponseRecorder struct {
|
||||
dns.ResponseWriter
|
||||
rcode int
|
||||
size int
|
||||
start time.Time
|
||||
}
|
||||
|
||||
// NewResponseRecorder makes and returns a new responseRecorder,
|
||||
// which captures the DNS rcode from the ResponseWriter
|
||||
// and also the length of the response message written through it.
|
||||
func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder {
|
||||
return &ResponseRecorder{
|
||||
ResponseWriter: w,
|
||||
rcode: 0,
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMsg records the status code and calls the
|
||||
// underlying ResponseWriter's WriteMsg method.
|
||||
func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error {
|
||||
r.rcode = res.Rcode
|
||||
r.size = res.Len()
|
||||
return r.ResponseWriter.WriteMsg(res)
|
||||
}
|
||||
|
||||
// Write is a wrapper that records the size of the message that gets written.
|
||||
func (r *ResponseRecorder) Write(buf []byte) (int, error) {
|
||||
n, err := r.ResponseWriter.Write(buf)
|
||||
if err == nil {
|
||||
r.size += n
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Size returns the size.
|
||||
func (r *ResponseRecorder) Size() int {
|
||||
return r.size
|
||||
}
|
||||
|
||||
// Rcode returns the rcode.
|
||||
func (r *ResponseRecorder) Rcode() int {
|
||||
return r.rcode
|
||||
}
|
||||
|
||||
// Start returns the start time of the ResponseRecorder.
|
||||
func (r *ResponseRecorder) Start() time.Time {
|
||||
return r.start
|
||||
}
|
||||
|
||||
// Hijack implements dns.Hijacker. It simply wraps the underlying
|
||||
// ResponseWriter's Hijack method if there is one, or returns an error.
|
||||
func (r *ResponseRecorder) Hijack() {
|
||||
r.ResponseWriter.Hijack()
|
||||
return
|
||||
}
|
||||
32
middleware/recorder_test.go
Normal file
32
middleware/recorder_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewResponseRecorder(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
if !(recordRequest.ResponseWriter == w) {
|
||||
t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
|
||||
}
|
||||
if recordRequest.status != http.StatusOK {
|
||||
t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
responseTestString := "test"
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
buf := []byte(responseTestString)
|
||||
recordRequest.Write(buf)
|
||||
if recordRequest.size != len(buf) {
|
||||
t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
|
||||
}
|
||||
if w.Body.String() != responseTestString {
|
||||
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
|
||||
}
|
||||
}
|
||||
84
middleware/reflect/reflect.go
Normal file
84
middleware/reflect/reflect.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Reflect provides middleware that reflects back some client properties.
|
||||
// This is the default middleware when Caddy is run without configuration.
|
||||
//
|
||||
// The left-most label must be `who`.
|
||||
// When queried for type A (resp. AAAA), it sends back the IPv4 (resp. v6) address.
|
||||
// In the additional section the port number and transport are shown.
|
||||
// Basic use pattern:
|
||||
//
|
||||
// dig @localhost -p 1053 who.miek.nl A
|
||||
//
|
||||
// ;; ANSWER SECTION:
|
||||
// who.miek.nl. 0 IN A 127.0.0.1
|
||||
//
|
||||
// ;; ADDITIONAL SECTION:
|
||||
// who.miek.nl. 0 IN TXT "Port: 56195 (udp)"
|
||||
package reflect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Reflect struct {
|
||||
Next middleware.Handler
|
||||
}
|
||||
|
||||
func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
context := middleware.Context{Req: r, W: w}
|
||||
|
||||
class := r.Question[0].Qclass
|
||||
qname := r.Question[0].Name
|
||||
i, ok := dns.NextLabel(qname, 0)
|
||||
|
||||
if strings.ToLower(qname[:i]) != who || ok {
|
||||
err := context.ErrorMessage(dns.RcodeFormatError)
|
||||
w.WriteMsg(err)
|
||||
return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
|
||||
}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
answer.SetReply(r)
|
||||
answer.Compress = true
|
||||
answer.Authoritative = true
|
||||
|
||||
ip := context.IP()
|
||||
proto := context.Proto()
|
||||
port, _ := context.Port()
|
||||
family := context.Family()
|
||||
var rr dns.RR
|
||||
|
||||
switch family {
|
||||
case 1:
|
||||
rr = new(dns.A)
|
||||
rr.(*dns.A).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeA, Class: class, Ttl: 0}
|
||||
rr.(*dns.A).A = net.ParseIP(ip).To4()
|
||||
case 2:
|
||||
rr = new(dns.AAAA)
|
||||
rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeAAAA, Class: class, Ttl: 0}
|
||||
rr.(*dns.AAAA).AAAA = net.ParseIP(ip)
|
||||
}
|
||||
|
||||
t := new(dns.TXT)
|
||||
t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
|
||||
t.Txt = []string{"Port: " + port + " (" + proto + ")"}
|
||||
|
||||
switch context.Type() {
|
||||
case "TXT":
|
||||
answer.Answer = append(answer.Answer, t)
|
||||
answer.Extra = append(answer.Extra, rr)
|
||||
default:
|
||||
fallthrough
|
||||
case "AAAA", "A":
|
||||
answer.Answer = append(answer.Answer, rr)
|
||||
answer.Extra = append(answer.Extra, t)
|
||||
}
|
||||
w.WriteMsg(answer)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
const who = "who."
|
||||
1
middleware/reflect/reflect_test.go
Normal file
1
middleware/reflect/reflect_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package reflect
|
||||
98
middleware/replacer.go
Normal file
98
middleware/replacer.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Replacer is a type which can replace placeholder
|
||||
// substrings in a string with actual values from a
|
||||
// http.Request and responseRecorder. Always use
|
||||
// NewReplacer to get one of these.
|
||||
type Replacer interface {
|
||||
Replace(string) string
|
||||
Set(key, value string)
|
||||
}
|
||||
|
||||
type replacer struct {
|
||||
replacements map[string]string
|
||||
emptyValue string
|
||||
}
|
||||
|
||||
// NewReplacer makes a new replacer based on r and rr.
|
||||
// Do not create a new replacer until r and rr have all
|
||||
// the needed values, because this function copies those
|
||||
// values into the replacer. rr may be nil if it is not
|
||||
// available. emptyValue should be the string that is used
|
||||
// in place of empty string (can still be empty string).
|
||||
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
|
||||
context := Context{W: rr, Req: r}
|
||||
rep := replacer{
|
||||
replacements: map[string]string{
|
||||
"{type}": context.Type(),
|
||||
"{name}": context.Name(),
|
||||
"{class}": context.Class(),
|
||||
"{proto}": context.Proto(),
|
||||
"{when}": func() string {
|
||||
return time.Now().Format(timeFormat)
|
||||
}(),
|
||||
"{remote}": context.IP(),
|
||||
"{port}": func() string {
|
||||
p, _ := context.Port()
|
||||
return p
|
||||
}(),
|
||||
},
|
||||
emptyValue: emptyValue,
|
||||
}
|
||||
if rr != nil {
|
||||
rep.replacements["{rcode}"] = strconv.Itoa(rr.rcode)
|
||||
rep.replacements["{size}"] = strconv.Itoa(rr.size)
|
||||
rep.replacements["{latency}"] = time.Since(rr.start).String()
|
||||
}
|
||||
|
||||
return rep
|
||||
}
|
||||
|
||||
// Replace performs a replacement of values on s and returns
|
||||
// the string with the replaced values.
|
||||
func (r replacer) Replace(s string) string {
|
||||
// Header replacements - these are case-insensitive, so we can't just use strings.Replace()
|
||||
for strings.Contains(s, headerReplacer) {
|
||||
idxStart := strings.Index(s, headerReplacer)
|
||||
endOffset := idxStart + len(headerReplacer)
|
||||
idxEnd := strings.Index(s[endOffset:], "}")
|
||||
if idxEnd > -1 {
|
||||
placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1])
|
||||
replacement := r.replacements[placeholder]
|
||||
if replacement == "" {
|
||||
replacement = r.emptyValue
|
||||
}
|
||||
s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Regular replacements - these are easier because they're case-sensitive
|
||||
for placeholder, replacement := range r.replacements {
|
||||
if replacement == "" {
|
||||
replacement = r.emptyValue
|
||||
}
|
||||
s = strings.Replace(s, placeholder, replacement, -1)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Set sets key to value in the replacements map.
|
||||
func (r replacer) Set(key, value string) {
|
||||
r.replacements["{"+key+"}"] = value
|
||||
}
|
||||
|
||||
const (
|
||||
timeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
headerReplacer = "{>"
|
||||
)
|
||||
124
middleware/replacer_test.go
Normal file
124
middleware/replacer_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewReplacer(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost", reader)
|
||||
if err != nil {
|
||||
t.Fatal("Request Formation Failed\n")
|
||||
}
|
||||
replaceValues := NewReplacer(request, recordRequest, "")
|
||||
|
||||
switch v := replaceValues.(type) {
|
||||
case replacer:
|
||||
|
||||
if v.replacements["{host}"] != "localhost" {
|
||||
t.Error("Expected host to be localhost")
|
||||
}
|
||||
if v.replacements["{method}"] != "POST" {
|
||||
t.Error("Expected request method to be POST")
|
||||
}
|
||||
if v.replacements["{status}"] != "200" {
|
||||
t.Error("Expected status to be 200")
|
||||
}
|
||||
|
||||
default:
|
||||
t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplace(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost", reader)
|
||||
if err != nil {
|
||||
t.Fatal("Request Formation Failed\n")
|
||||
}
|
||||
request.Header.Set("Custom", "foobarbaz")
|
||||
request.Header.Set("ShorterVal", "1")
|
||||
repl := NewReplacer(request, recordRequest, "-")
|
||||
|
||||
if expected, actual := "This host is localhost.", repl.Replace("This host is {host}."); expected != actual {
|
||||
t.Errorf("{host} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := "This request method is POST.", repl.Replace("This request method is {method}."); expected != actual {
|
||||
t.Errorf("{method} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := "The response status is 200.", repl.Replace("The response status is {status}."); expected != actual {
|
||||
t.Errorf("{status} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := "The Custom header is foobarbaz.", repl.Replace("The Custom header is {>Custom}."); expected != actual {
|
||||
t.Errorf("{>Custom} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test header case-insensitivity
|
||||
if expected, actual := "The cUsToM header is foobarbaz...", repl.Replace("The cUsToM header is {>cUsToM}..."); expected != actual {
|
||||
t.Errorf("{>cUsToM} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test non-existent header/value
|
||||
if expected, actual := "The Non-Existent header is -.", repl.Replace("The Non-Existent header is {>Non-Existent}."); expected != actual {
|
||||
t.Errorf("{>Non-Existent} replacement: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test bad placeholder
|
||||
if expected, actual := "Bad {host placeholder...", repl.Replace("Bad {host placeholder..."); expected != actual {
|
||||
t.Errorf("bad placeholder: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test bad header placeholder
|
||||
if expected, actual := "Bad {>Custom placeholder", repl.Replace("Bad {>Custom placeholder"); expected != actual {
|
||||
t.Errorf("bad header placeholder: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test bad header placeholder with valid one later
|
||||
if expected, actual := "Bad -", repl.Replace("Bad {>Custom placeholder {>ShorterVal}"); expected != actual {
|
||||
t.Errorf("bad header placeholders: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test shorter header value with multiple placeholders
|
||||
if expected, actual := "Short value 1 then foobarbaz.", repl.Replace("Short value {>ShorterVal} then {>Custom}."); expected != actual {
|
||||
t.Errorf("short value: expected '%s', got '%s'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost", reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Request Formation Failed \n")
|
||||
}
|
||||
repl := NewReplacer(request, recordRequest, "")
|
||||
|
||||
repl.Set("host", "getcaddy.com")
|
||||
repl.Set("method", "GET")
|
||||
repl.Set("status", "201")
|
||||
repl.Set("variable", "value")
|
||||
|
||||
if repl.Replace("This host is {host}") != "This host is getcaddy.com" {
|
||||
t.Error("Expected host replacement failed")
|
||||
}
|
||||
if repl.Replace("This request method is {method}") != "This request method is GET" {
|
||||
t.Error("Expected method replacement failed")
|
||||
}
|
||||
if repl.Replace("The response status is {status}") != "The response status is 201" {
|
||||
t.Error("Expected status replacement failed")
|
||||
}
|
||||
if repl.Replace("The value of variable is {variable}") != "The value of variable is value" {
|
||||
t.Error("Expected variable replacement failed")
|
||||
}
|
||||
}
|
||||
130
middleware/rewrite/condition.go
Normal file
130
middleware/rewrite/condition.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Operators
|
||||
const (
|
||||
Is = "is"
|
||||
Not = "not"
|
||||
Has = "has"
|
||||
NotHas = "not_has"
|
||||
StartsWith = "starts_with"
|
||||
EndsWith = "ends_with"
|
||||
Match = "match"
|
||||
NotMatch = "not_match"
|
||||
)
|
||||
|
||||
func operatorError(operator string) error {
|
||||
return fmt.Errorf("Invalid operator %v", operator)
|
||||
}
|
||||
|
||||
func newReplacer(r *dns.Msg) middleware.Replacer {
|
||||
return middleware.NewReplacer(r, nil, "")
|
||||
}
|
||||
|
||||
// condition is a rewrite condition.
|
||||
type condition func(string, string) bool
|
||||
|
||||
var conditions = map[string]condition{
|
||||
Is: isFunc,
|
||||
Not: notFunc,
|
||||
Has: hasFunc,
|
||||
NotHas: notHasFunc,
|
||||
StartsWith: startsWithFunc,
|
||||
EndsWith: endsWithFunc,
|
||||
Match: matchFunc,
|
||||
NotMatch: notMatchFunc,
|
||||
}
|
||||
|
||||
// isFunc is condition for Is operator.
|
||||
// It checks for equality.
|
||||
func isFunc(a, b string) bool {
|
||||
return a == b
|
||||
}
|
||||
|
||||
// notFunc is condition for Not operator.
|
||||
// It checks for inequality.
|
||||
func notFunc(a, b string) bool {
|
||||
return a != b
|
||||
}
|
||||
|
||||
// hasFunc is condition for Has operator.
|
||||
// It checks if b is a substring of a.
|
||||
func hasFunc(a, b string) bool {
|
||||
return strings.Contains(a, b)
|
||||
}
|
||||
|
||||
// notHasFunc is condition for NotHas operator.
|
||||
// It checks if b is not a substring of a.
|
||||
func notHasFunc(a, b string) bool {
|
||||
return !strings.Contains(a, b)
|
||||
}
|
||||
|
||||
// startsWithFunc is condition for StartsWith operator.
|
||||
// It checks if b is a prefix of a.
|
||||
func startsWithFunc(a, b string) bool {
|
||||
return strings.HasPrefix(a, b)
|
||||
}
|
||||
|
||||
// endsWithFunc is condition for EndsWith operator.
|
||||
// It checks if b is a suffix of a.
|
||||
func endsWithFunc(a, b string) bool {
|
||||
return strings.HasSuffix(a, b)
|
||||
}
|
||||
|
||||
// matchFunc is condition for Match operator.
|
||||
// It does regexp matching of a against pattern in b
|
||||
// and returns if they match.
|
||||
func matchFunc(a, b string) bool {
|
||||
matched, _ := regexp.MatchString(b, a)
|
||||
return matched
|
||||
}
|
||||
|
||||
// notMatchFunc is condition for NotMatch operator.
|
||||
// It does regexp matching of a against pattern in b
|
||||
// and returns if they do not match.
|
||||
func notMatchFunc(a, b string) bool {
|
||||
matched, _ := regexp.MatchString(b, a)
|
||||
return !matched
|
||||
}
|
||||
|
||||
// If is statement for a rewrite condition.
|
||||
type If struct {
|
||||
A string
|
||||
Operator string
|
||||
B string
|
||||
}
|
||||
|
||||
// True returns true if the condition is true and false otherwise.
|
||||
// If r is not nil, it replaces placeholders before comparison.
|
||||
func (i If) True(r *dns.Msg) bool {
|
||||
if c, ok := conditions[i.Operator]; ok {
|
||||
a, b := i.A, i.B
|
||||
if r != nil {
|
||||
replacer := newReplacer(r)
|
||||
a = replacer.Replace(i.A)
|
||||
b = replacer.Replace(i.B)
|
||||
}
|
||||
return c(a, b)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NewIf creates a new If condition.
|
||||
func NewIf(a, operator, b string) (If, error) {
|
||||
if _, ok := conditions[operator]; !ok {
|
||||
return If{}, operatorError(operator)
|
||||
}
|
||||
return If{
|
||||
A: a,
|
||||
Operator: operator,
|
||||
B: b,
|
||||
}, nil
|
||||
}
|
||||
106
middleware/rewrite/condition_test.go
Normal file
106
middleware/rewrite/condition_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
condition string
|
||||
isTrue bool
|
||||
}{
|
||||
{"a is b", false},
|
||||
{"a is a", true},
|
||||
{"a not b", true},
|
||||
{"a not a", false},
|
||||
{"a has a", true},
|
||||
{"a has b", false},
|
||||
{"ba has b", true},
|
||||
{"bab has b", true},
|
||||
{"bab has bb", false},
|
||||
{"a not_has a", false},
|
||||
{"a not_has b", true},
|
||||
{"ba not_has b", false},
|
||||
{"bab not_has b", false},
|
||||
{"bab not_has bb", true},
|
||||
{"bab starts_with bb", false},
|
||||
{"bab starts_with ba", true},
|
||||
{"bab starts_with bab", true},
|
||||
{"bab ends_with bb", false},
|
||||
{"bab ends_with bab", true},
|
||||
{"bab ends_with ab", true},
|
||||
{"a match *", false},
|
||||
{"a match a", true},
|
||||
{"a match .*", true},
|
||||
{"a match a.*", true},
|
||||
{"a match b.*", false},
|
||||
{"ba match b.*", true},
|
||||
{"ba match b[a-z]", true},
|
||||
{"b0 match b[a-z]", false},
|
||||
{"b0a match b[a-z]", false},
|
||||
{"b0a match b[a-z]+", false},
|
||||
{"b0a match b[a-z0-9]+", true},
|
||||
{"a not_match *", true},
|
||||
{"a not_match a", false},
|
||||
{"a not_match .*", false},
|
||||
{"a not_match a.*", false},
|
||||
{"a not_match b.*", true},
|
||||
{"ba not_match b.*", false},
|
||||
{"ba not_match b[a-z]", false},
|
||||
{"b0 not_match b[a-z]", true},
|
||||
{"b0a not_match b[a-z]", true},
|
||||
{"b0a not_match b[a-z]+", true},
|
||||
{"b0a not_match b[a-z0-9]+", false},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
str := strings.Fields(test.condition)
|
||||
ifCond, err := NewIf(str[0], str[1], str[2])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
isTrue := ifCond.True(nil)
|
||||
if isTrue != test.isTrue {
|
||||
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
|
||||
}
|
||||
}
|
||||
|
||||
invalidOperators := []string{"ss", "and", "if"}
|
||||
for _, op := range invalidOperators {
|
||||
_, err := NewIf("a", op, "b")
|
||||
if err == nil {
|
||||
t.Errorf("Invalid operator %v used, expected error.", op)
|
||||
}
|
||||
}
|
||||
|
||||
replaceTests := []struct {
|
||||
url string
|
||||
condition string
|
||||
isTrue bool
|
||||
}{
|
||||
{"/home", "{uri} match /home", true},
|
||||
{"/hom", "{uri} match /home", false},
|
||||
{"/hom", "{uri} starts_with /home", false},
|
||||
{"/hom", "{uri} starts_with /h", true},
|
||||
{"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
|
||||
{"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
|
||||
}
|
||||
|
||||
for i, test := range replaceTests {
|
||||
r, err := http.NewRequest("GET", test.url, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
str := strings.Fields(test.condition)
|
||||
ifCond, err := NewIf(str[0], str[1], str[2])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
isTrue := ifCond.True(r)
|
||||
if isTrue != test.isTrue {
|
||||
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
|
||||
}
|
||||
}
|
||||
}
|
||||
38
middleware/rewrite/reverter.go
Normal file
38
middleware/rewrite/reverter.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package rewrite
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// ResponseRevert reverses the operations done on the question section of a packet.
|
||||
// This is need because the client will otherwise disregards the response, i.e.
|
||||
// dig will complain with ';; Question section mismatch: got miek.nl/HINFO/IN'
|
||||
type ResponseReverter struct {
|
||||
dns.ResponseWriter
|
||||
original dns.Question
|
||||
}
|
||||
|
||||
func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter {
|
||||
return &ResponseReverter{
|
||||
ResponseWriter: w,
|
||||
original: r.Question[0],
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMsg records the status code and calls the
|
||||
// underlying ResponseWriter's WriteMsg method.
|
||||
func (r *ResponseReverter) WriteMsg(res *dns.Msg) error {
|
||||
res.Question[0] = r.original
|
||||
return r.ResponseWriter.WriteMsg(res)
|
||||
}
|
||||
|
||||
// Write is a wrapper that records the size of the message that gets written.
|
||||
func (r *ResponseReverter) Write(buf []byte) (int, error) {
|
||||
n, err := r.ResponseWriter.Write(buf)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Hijack implements dns.Hijacker. It simply wraps the underlying
|
||||
// ResponseWriter's Hijack method if there is one, or returns an error.
|
||||
func (r *ResponseReverter) Hijack() {
|
||||
r.ResponseWriter.Hijack()
|
||||
return
|
||||
}
|
||||
223
middleware/rewrite/rewrite.go
Normal file
223
middleware/rewrite/rewrite.go
Normal file
@@ -0,0 +1,223 @@
|
||||
// Package rewrite is middleware for rewriting requests internally to
|
||||
// something different.
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Result is the result of a rewrite
|
||||
type Result int
|
||||
|
||||
const (
|
||||
// RewriteIgnored is returned when rewrite is not done on request.
|
||||
RewriteIgnored Result = iota
|
||||
// RewriteDone is returned when rewrite is done on request.
|
||||
RewriteDone
|
||||
// RewriteStatus is returned when rewrite is not needed and status code should be set
|
||||
// for the request.
|
||||
RewriteStatus
|
||||
)
|
||||
|
||||
// Rewrite is middleware to rewrite requests internally before being handled.
|
||||
type Rewrite struct {
|
||||
Next middleware.Handler
|
||||
Rules []Rule
|
||||
}
|
||||
|
||||
// ServeHTTP implements the middleware.Handler interface.
|
||||
func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
wr := NewResponseReverter(w, r)
|
||||
for _, rule := range rw.Rules {
|
||||
switch result := rule.Rewrite(r); result {
|
||||
case RewriteDone:
|
||||
return rw.Next.ServeDNS(wr, r)
|
||||
case RewriteIgnored:
|
||||
break
|
||||
case RewriteStatus:
|
||||
// only valid for complex rules.
|
||||
// if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 {
|
||||
// return cRule.Status, nil
|
||||
// }
|
||||
}
|
||||
}
|
||||
return rw.Next.ServeDNS(w, r)
|
||||
}
|
||||
|
||||
// Rule describes an internal location rewrite rule.
|
||||
type Rule interface {
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
Rewrite(*dns.Msg) Result
|
||||
}
|
||||
|
||||
// SimpleRule is a simple rewrite rule. If the From and To look like a type
|
||||
// the type of the request is rewritten, otherwise the name is.
|
||||
// Note: TSIG signed requests will be invalid.
|
||||
type SimpleRule struct {
|
||||
From, To string
|
||||
fromType, toType uint16
|
||||
}
|
||||
|
||||
// NewSimpleRule creates a new Simple Rule
|
||||
func NewSimpleRule(from, to string) SimpleRule {
|
||||
tpf := dns.StringToType[from]
|
||||
tpt := dns.StringToType[to]
|
||||
|
||||
return SimpleRule{From: from, To: to, fromType: tpf, toType: tpt}
|
||||
}
|
||||
|
||||
// Rewrite rewrites the the current request.
|
||||
func (s SimpleRule) Rewrite(r *dns.Msg) Result {
|
||||
if s.fromType > 0 && s.toType > 0 {
|
||||
if r.Question[0].Qtype == s.fromType {
|
||||
r.Question[0].Qtype = s.toType
|
||||
return RewriteDone
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// if the question name matches the full name, or subset rewrite that
|
||||
// s.Question[0].Name
|
||||
return RewriteIgnored
|
||||
}
|
||||
|
||||
/*
|
||||
// ComplexRule is a rewrite rule based on a regular expression
|
||||
type ComplexRule struct {
|
||||
// Path base. Request to this path and subpaths will be rewritten
|
||||
Base string
|
||||
|
||||
// Path to rewrite to
|
||||
To string
|
||||
|
||||
// If set, neither performs rewrite nor proceeds
|
||||
// with request. Only returns code.
|
||||
Status int
|
||||
|
||||
// Extensions to filter by
|
||||
Exts []string
|
||||
|
||||
// Rewrite conditions
|
||||
Ifs []If
|
||||
|
||||
*regexp.Regexp
|
||||
}
|
||||
|
||||
// NewComplexRule creates a new RegexpRule. It returns an error if regexp
|
||||
// pattern (pattern) or extensions (ext) are invalid.
|
||||
func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
|
||||
// validate regexp if present
|
||||
var r *regexp.Regexp
|
||||
if pattern != "" {
|
||||
var err error
|
||||
r, err = regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// validate extensions if present
|
||||
for _, v := range ext {
|
||||
if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
|
||||
// check if no extension is specified
|
||||
if v != "/" && v != "!/" {
|
||||
return nil, fmt.Errorf("invalid extension %v", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ComplexRule{
|
||||
Base: base,
|
||||
To: to,
|
||||
Status: status,
|
||||
Exts: ext,
|
||||
Ifs: ifs,
|
||||
Regexp: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) {
|
||||
rPath := req.URL.Path
|
||||
replacer := newReplacer(req)
|
||||
|
||||
// validate base
|
||||
if !middleware.Path(rPath).Matches(r.Base) {
|
||||
return
|
||||
}
|
||||
|
||||
// validate extensions
|
||||
if !r.matchExt(rPath) {
|
||||
return
|
||||
}
|
||||
|
||||
// validate regexp if present
|
||||
if r.Regexp != nil {
|
||||
// include trailing slash in regexp if present
|
||||
start := len(r.Base)
|
||||
if strings.HasSuffix(r.Base, "/") {
|
||||
start--
|
||||
}
|
||||
|
||||
matches := r.FindStringSubmatch(rPath[start:])
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
// no match
|
||||
return
|
||||
default:
|
||||
// set regexp match variables {1}, {2} ...
|
||||
for i := 1; i < len(matches); i++ {
|
||||
replacer.Set(fmt.Sprint(i), matches[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate rewrite conditions
|
||||
for _, i := range r.Ifs {
|
||||
if !i.True(req) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// if status is present, stop rewrite and return it.
|
||||
if r.Status != 0 {
|
||||
return RewriteStatus
|
||||
}
|
||||
|
||||
// attempt rewrite
|
||||
return To(fs, req, r.To, replacer)
|
||||
}
|
||||
|
||||
// matchExt matches rPath against registered file extensions.
|
||||
// Returns true if a match is found and false otherwise.
|
||||
func (r *ComplexRule) matchExt(rPath string) bool {
|
||||
f := filepath.Base(rPath)
|
||||
ext := path.Ext(f)
|
||||
if ext == "" {
|
||||
ext = "/"
|
||||
}
|
||||
|
||||
mustUse := false
|
||||
for _, v := range r.Exts {
|
||||
use := true
|
||||
if v[0] == '!' {
|
||||
use = false
|
||||
v = v[1:]
|
||||
}
|
||||
|
||||
if use {
|
||||
mustUse = true
|
||||
}
|
||||
|
||||
if ext == v {
|
||||
return use
|
||||
}
|
||||
}
|
||||
|
||||
if mustUse {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
*/
|
||||
159
middleware/rewrite/rewrite_test.go
Normal file
159
middleware/rewrite/rewrite_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
)
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
rw := Rewrite{
|
||||
Next: middleware.HandlerFunc(urlPrinter),
|
||||
Rules: []Rule{
|
||||
NewSimpleRule("/from", "/to"),
|
||||
NewSimpleRule("/a", "/b"),
|
||||
NewSimpleRule("/b", "/b{uri}"),
|
||||
},
|
||||
FileSys: http.Dir("."),
|
||||
}
|
||||
|
||||
regexps := [][]string{
|
||||
{"/reg/", ".*", "/to", ""},
|
||||
{"/r/", "[a-z]+", "/toaz", "!.html|"},
|
||||
{"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""},
|
||||
{"/ab/", "ab", "/ab?{query}", ".txt|"},
|
||||
{"/ab/", "ab", "/ab?type=html&{query}", ".html|"},
|
||||
{"/abc/", "ab", "/abc/{file}", ".html|"},
|
||||
{"/abcd/", "ab", "/a/{dir}/{file}", ".html|"},
|
||||
{"/abcde/", "ab", "/a#{fragment}", ".html|"},
|
||||
{"/ab/", `.*\.jpg`, "/ajpg", ""},
|
||||
{"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""},
|
||||
{"/reg2grp", `(.*)`, "/{1}", ""},
|
||||
{"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""},
|
||||
}
|
||||
|
||||
for _, regexpRule := range regexps {
|
||||
var ext []string
|
||||
if s := strings.Split(regexpRule[3], "|"); len(s) > 1 {
|
||||
ext = s[:len(s)-1]
|
||||
}
|
||||
rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rw.Rules = append(rw.Rules, rule)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
from string
|
||||
expectedTo string
|
||||
}{
|
||||
{"/from", "/to"},
|
||||
{"/a", "/b"},
|
||||
{"/b", "/b/b"},
|
||||
{"/aa", "/aa"},
|
||||
{"/", "/"},
|
||||
{"/a?foo=bar", "/b?foo=bar"},
|
||||
{"/asdf?foo=bar", "/asdf?foo=bar"},
|
||||
{"/foo#bar", "/foo#bar"},
|
||||
{"/a#foo", "/b#foo"},
|
||||
{"/reg/foo", "/to"},
|
||||
{"/re", "/re"},
|
||||
{"/r/", "/r/"},
|
||||
{"/r/123", "/r/123"},
|
||||
{"/r/a123", "/toaz"},
|
||||
{"/r/abcz", "/toaz"},
|
||||
{"/r/z", "/toaz"},
|
||||
{"/r/z.html", "/r/z.html"},
|
||||
{"/r/z.js", "/toaz"},
|
||||
{"/url/asAB", "/to/url/asAB"},
|
||||
{"/url/aBsAB", "/url/aBsAB"},
|
||||
{"/url/a00sAB", "/to/url/a00sAB"},
|
||||
{"/url/a0z0sAB", "/to/url/a0z0sAB"},
|
||||
{"/ab/aa", "/ab/aa"},
|
||||
{"/ab/ab", "/ab/ab"},
|
||||
{"/ab/ab.txt", "/ab"},
|
||||
{"/ab/ab.txt?name=name", "/ab?name=name"},
|
||||
{"/ab/ab.html?name=name", "/ab?type=html&name=name"},
|
||||
{"/abc/ab.html", "/abc/ab.html"},
|
||||
{"/abcd/abcd.html", "/a/abcd/abcd.html"},
|
||||
{"/abcde/abcde.html", "/a"},
|
||||
{"/abcde/abcde.html#1234", "/a#1234"},
|
||||
{"/ab/ab.jpg", "/ajpg"},
|
||||
{"/reggrp/ad/12", "/a12"},
|
||||
{"/reggrp/ad/124a", "/a124/a"},
|
||||
{"/reggrp/ad/124abc", "/a124/abc"},
|
||||
{"/reg2grp/ad/124abc", "/ad/124abc"},
|
||||
{"/reg3grp/ad/aa/66", "/adaa66"},
|
||||
{"/reg3grp/ad612/n1n/ab", "/ad612n1nab"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
rw.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Body.String() != test.expectedTo {
|
||||
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'",
|
||||
i, test.expectedTo, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
statusTests := []struct {
|
||||
status int
|
||||
base string
|
||||
to string
|
||||
regexp string
|
||||
statusExpected bool
|
||||
}{
|
||||
{400, "/status", "", "", true},
|
||||
{400, "/ignore", "", "", false},
|
||||
{400, "/", "", "^/ignore", false},
|
||||
{400, "/", "", "(.*)", true},
|
||||
{400, "/status", "", "", true},
|
||||
}
|
||||
|
||||
for i, s := range statusTests {
|
||||
urlPath := fmt.Sprintf("/status%d", i)
|
||||
rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: No error expected for rule but found %v", i, err)
|
||||
}
|
||||
rw.Rules = []Rule{rule}
|
||||
req, err := http.NewRequest("GET", urlPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
code, err := rw.ServeHTTP(rec, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: No error expected for handler but found %v", i, err)
|
||||
}
|
||||
if s.statusExpected {
|
||||
if rec.Body.String() != "" {
|
||||
t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String())
|
||||
}
|
||||
if code != s.status {
|
||||
t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code)
|
||||
}
|
||||
} else {
|
||||
if code != 0 {
|
||||
t.Errorf("Test %d: Expected no status code found %d", i, code)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
fmt.Fprintf(w, r.URL.String())
|
||||
return 0, nil
|
||||
}
|
||||
0
middleware/rewrite/testdata/testdir/empty
vendored
Normal file
0
middleware/rewrite/testdata/testdir/empty
vendored
Normal file
1
middleware/rewrite/testdata/testfile
vendored
Normal file
1
middleware/rewrite/testdata/testfile
vendored
Normal file
@@ -0,0 +1 @@
|
||||
empty
|
||||
27
middleware/roller.go
Normal file
27
middleware/roller.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
// LogRoller implements a middleware that provides a rolling logger.
|
||||
type LogRoller struct {
|
||||
Filename string
|
||||
MaxSize int
|
||||
MaxAge int
|
||||
MaxBackups int
|
||||
LocalTime bool
|
||||
}
|
||||
|
||||
// GetLogWriter returns an io.Writer that writes to a rolling logger.
|
||||
func (l LogRoller) GetLogWriter() io.Writer {
|
||||
return &lumberjack.Logger{
|
||||
Filename: l.Filename,
|
||||
MaxSize: l.MaxSize,
|
||||
MaxAge: l.MaxAge,
|
||||
MaxBackups: l.MaxBackups,
|
||||
LocalTime: l.LocalTime,
|
||||
}
|
||||
}
|
||||
21
middleware/zone.go
Normal file
21
middleware/zone.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import "strings"
|
||||
|
||||
type Zones []string
|
||||
|
||||
// Matches checks to see if other matches p.
|
||||
// The match will return the most specific zones
|
||||
// that matches other. The empty string signals a not found
|
||||
// condition.
|
||||
func (z Zones) Matches(qname string) string {
|
||||
zone := ""
|
||||
for _, zname := range z {
|
||||
if strings.HasSuffix(qname, zname) {
|
||||
if len(zname) > len(zone) {
|
||||
zone = zname
|
||||
}
|
||||
}
|
||||
}
|
||||
return zone
|
||||
}
|
||||
Reference in New Issue
Block a user