mirror of
https://github.com/coredns/coredns.git
synced 2026-02-12 02:13:10 -05:00
feat(proxyproto): add proxy protocol support (#7738)
Signed-off-by: Adphi <philippe.adrien.nousse@gmail.com>
This commit is contained in:
136
plugin/pkg/proxyproto/proxyproto.go
Normal file
136
plugin/pkg/proxyproto/proxyproto.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package proxyproto
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
clog "github.com/coredns/coredns/plugin/pkg/log"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
)
|
||||
|
||||
var (
|
||||
_ net.PacketConn = (*PacketConn)(nil)
|
||||
_ net.Addr = (*Addr)(nil)
|
||||
)
|
||||
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
ConnPolicy proxyproto.ConnPolicyFunc
|
||||
ValidateHeader proxyproto.Validator
|
||||
ReadHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
for {
|
||||
n, addr, err = c.PacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return n, addr, err
|
||||
}
|
||||
n, addr, err = c.readFrom(p[:n], addr)
|
||||
if err != nil {
|
||||
// drop invalid packet as returning error would cause the ReadFrom caller to exit
|
||||
// which could result in DoS if an attacker sends intentional invalid packets
|
||||
clog.Warningf("dropping invalid Proxy Protocol packet from %s: %v", addr.String(), err)
|
||||
continue
|
||||
}
|
||||
return n, addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if pa, ok := addr.(*Addr); ok {
|
||||
addr = pa.u
|
||||
}
|
||||
return c.PacketConn.WriteTo(p, addr)
|
||||
}
|
||||
|
||||
func (c *PacketConn) readFrom(p []byte, addr net.Addr) (_ int, _ net.Addr, err error) {
|
||||
var policy proxyproto.Policy
|
||||
if c.ConnPolicy != nil {
|
||||
policy, err = c.ConnPolicy(proxyproto.ConnPolicyOptions{
|
||||
Upstream: addr,
|
||||
Downstream: c.LocalAddr(),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("applying Proxy Protocol connection policy: %w", err)
|
||||
}
|
||||
}
|
||||
if policy == proxyproto.SKIP {
|
||||
return len(p), addr, nil
|
||||
}
|
||||
header, payload, err := parseProxyProtocol(p)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if header != nil && c.ValidateHeader != nil {
|
||||
if err := c.ValidateHeader(header); err != nil {
|
||||
return 0, nil, fmt.Errorf("validating Proxy Protocol header: %w", err)
|
||||
}
|
||||
}
|
||||
switch policy {
|
||||
case proxyproto.REJECT:
|
||||
if header != nil {
|
||||
return 0, nil, errors.New("connection rejected by Proxy Protocol connection policy")
|
||||
}
|
||||
case proxyproto.REQUIRE:
|
||||
if header == nil {
|
||||
return 0, nil, errors.New("PROXY Protocol header required but not present")
|
||||
}
|
||||
fallthrough
|
||||
case proxyproto.USE:
|
||||
if header != nil {
|
||||
srcAddr, _, _ := header.UDPAddrs()
|
||||
addr = &Addr{u: addr, r: srcAddr}
|
||||
}
|
||||
default:
|
||||
}
|
||||
copy(p, payload)
|
||||
return len(payload), addr, nil
|
||||
}
|
||||
|
||||
type Addr struct {
|
||||
u net.Addr
|
||||
r net.Addr
|
||||
}
|
||||
|
||||
func (a *Addr) Network() string {
|
||||
return a.u.Network()
|
||||
}
|
||||
|
||||
func (a *Addr) String() string {
|
||||
return a.r.String()
|
||||
}
|
||||
|
||||
func parseProxyProtocol(packet []byte) (*proxyproto.Header, []byte, error) {
|
||||
reader := bufio.NewReader(bytes.NewReader(packet))
|
||||
|
||||
header, err := proxyproto.Read(reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, proxyproto.ErrNoProxyProtocol) {
|
||||
return nil, packet, nil
|
||||
}
|
||||
return nil, nil, fmt.Errorf("parsing Proxy Protocol header (packet size: %d): %w", len(packet), err)
|
||||
}
|
||||
|
||||
if header.Version != 2 {
|
||||
return nil, nil, fmt.Errorf("unsupported Proxy Protocol version %d (only v2 supported for UDP)", header.Version)
|
||||
}
|
||||
|
||||
_, _, ok := header.UDPAddrs()
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("PROXY Protocol header is not UDP type (transport protocol: 0x%x)", header.TransportProtocol)
|
||||
}
|
||||
|
||||
headerLen := len(packet) - reader.Buffered()
|
||||
if headerLen < 0 || headerLen > len(packet) {
|
||||
return nil, nil, fmt.Errorf("invalid header length: %d", headerLen)
|
||||
}
|
||||
|
||||
payload := packet[headerLen:]
|
||||
return header, payload, nil
|
||||
}
|
||||
Reference in New Issue
Block a user