Revert pkg/nonwriter changes (#1829)

The DoH work (#1619) made changes to pkg/nonwriter.Writer that in
hindsight were not backwards compatible; it added override for the
LocalAddr() and RemoteAddr(). Instead of rolling back that PR, this PR
reverts those changes and creates a DoHWriter for use in the
https-server.go side of things.

This was only caught in the integration test making this hard to catch,
so we add a upstream_file_test.go that tries (doesn't work yet) to test
this in the unit tests as well. Esp. helpful when 'git bisecting'.

Fixes #1826
This commit is contained in:
Miek Gieben
2018-05-23 13:50:27 +01:00
committed by Chris O'Haver
parent 49891d2103
commit 0f74281a53
4 changed files with 81 additions and 19 deletions

View File

@@ -4,8 +4,10 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"github.com/coredns/coredns/plugin/pkg/nonwriter"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -54,3 +56,19 @@ func base64ToMsg(b64 string) (*dns.Msg, error) {
} }
var b64Enc = base64.RawURLEncoding var b64Enc = base64.RawURLEncoding
// DoHWriter is a nonwriter.Writer that adds more specific LocalAddr and RemoteAddr methods.
type DoHWriter struct {
nonwriter.Writer
// raddr is the remote's address. This can be optionally set.
raddr net.Addr
// laddr is our address. This can be optionally set.
laddr net.Addr
}
// RemoteAddr returns the remote address.
func (d *DoHWriter) RemoteAddr() net.Addr { return d.raddr }
// LocalAddr returns the local address.
func (d *DoHWriter) LocalAddr() net.Addr { return d.laddr }

View File

@@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/coredns/coredns/plugin/pkg/nonwriter"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -119,12 +118,10 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// Create a non-writer with the correct addresses in it. // Create a DoHWriter with the correct addresses in it.
dw := &nonwriter.Writer{Laddr: s.listenAddr}
h, p, _ := net.SplitHostPort(r.RemoteAddr) h, p, _ := net.SplitHostPort(r.RemoteAddr)
po, _ := strconv.Atoi(p) port, _ := strconv.Atoi(p)
ip := net.ParseIP(h) dw := &DoHWriter{laddr: s.listenAddr, raddr: &net.TCPAddr{IP: net.ParseIP(h), Port: port}}
dw.Raddr = &net.TCPAddr{IP: ip, Port: po}
// We just call the normal chain handler - all error handling is done there. // We just call the normal chain handler - all error handling is done there.
// We should expect a packet to be returned that we can send to the client. // We should expect a packet to be returned that we can send to the client.

View File

@@ -2,8 +2,6 @@
package nonwriter package nonwriter
import ( import (
"net"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -11,11 +9,6 @@ import (
type Writer struct { type Writer struct {
dns.ResponseWriter dns.ResponseWriter
Msg *dns.Msg Msg *dns.Msg
// Raddr is the remote's address. This can be optionally set.
Raddr net.Addr
// Laddr is our address. This can be optionally set.
Laddr net.Addr
} }
// New makes and returns a new NonWriter. // New makes and returns a new NonWriter.
@@ -26,9 +19,3 @@ func (w *Writer) WriteMsg(res *dns.Msg) error {
w.Msg = res w.Msg = res
return nil return nil
} }
// RemoteAddr returns the remote address.
func (w *Writer) RemoteAddr() net.Addr { return w.Raddr }
// LocalAddr returns the local address.
func (w *Writer) LocalAddr() net.Addr { return w.Laddr }

View File

@@ -0,0 +1,60 @@
package test
import (
"testing"
"github.com/miekg/dns"
)
// TODO(miek): this test needs to be fleshed out.
func TestFileUpstream(t *testing.T) {
name, rm, err := TempFile(".", `$ORIGIN example.org.
@ 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. (
2017042745 ; serial
7200 ; refresh (2 hours)
3600 ; retry (1 hour)
1209600 ; expire (2 weeks)
3600 ; minimum (1 hour)
)
3600 IN NS a.iana-servers.net.
3600 IN NS b.iana-servers.net.
www 3600 IN CNAME www.example.net.
`)
if err != nil {
t.Fatalf("Failed to create zone: %s", err)
}
defer rm()
// Corefile with for example without proxy section.
corefile := `example.org:0 {
file ` + name + ` {
upstream
}
hosts {
10.0.0.1 www.example.net.
fallthrough
}
}
`
i, udp, _, err := CoreDNSServerAndPorts(corefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer i.Stop()
m := new(dns.Msg)
m.SetQuestion("www.example.org.", dns.TypeA)
m.SetEdns0(4096, true)
r, err := dns.Exchange(m, udp)
if err != nil {
t.Fatalf("Could not exchange msg: %s", err)
}
if r.Rcode == dns.RcodeServerFailure {
t.Fatalf("Rcode should not be dns.RcodeServerFailure")
}
t.Logf("%s", r)
}