| 
									
										
										
										
											2019-09-04 23:43:45 +08:00
										 |  |  | package acl
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import (
 | 
					
						
							|  |  |  | 	"net"
 | 
					
						
							|  |  |  | 	"strings"
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-24 18:14:41 +02:00
										 |  |  | 	"github.com/coredns/caddy"
 | 
					
						
							| 
									
										
										
										
											2019-09-04 23:43:45 +08:00
										 |  |  | 	"github.com/coredns/coredns/core/dnsserver"
 | 
					
						
							|  |  |  | 	"github.com/coredns/coredns/plugin"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/infobloxopen/go-trees/iptree"
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-31 20:03:18 +02:00
										 |  |  | const pluginName = "acl"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func init() { plugin.Register(pluginName, setup) }
 | 
					
						
							| 
									
										
										
										
											2019-09-04 23:43:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func newDefaultFilter() *iptree.Tree {
 | 
					
						
							|  |  |  | 	defaultFilter := iptree.NewTree()
 | 
					
						
							|  |  |  | 	_, IPv4All, _ := net.ParseCIDR("0.0.0.0/0")
 | 
					
						
							|  |  |  | 	_, IPv6All, _ := net.ParseCIDR("::/0")
 | 
					
						
							|  |  |  | 	defaultFilter.InplaceInsertNet(IPv4All, struct{}{})
 | 
					
						
							|  |  |  | 	defaultFilter.InplaceInsertNet(IPv6All, struct{}{})
 | 
					
						
							|  |  |  | 	return defaultFilter
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func setup(c *caddy.Controller) error {
 | 
					
						
							|  |  |  | 	a, err := parse(c)
 | 
					
						
							|  |  |  | 	if err != nil {
 | 
					
						
							| 
									
										
										
										
											2020-03-31 20:03:18 +02:00
										 |  |  | 		return plugin.Error(pluginName, err)
 | 
					
						
							| 
									
										
										
										
											2019-09-04 23:43:45 +08:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
 | 
					
						
							|  |  |  | 		a.Next = next
 | 
					
						
							|  |  |  | 		return a
 | 
					
						
							|  |  |  | 	})
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func parse(c *caddy.Controller) (ACL, error) {
 | 
					
						
							|  |  |  | 	a := ACL{}
 | 
					
						
							|  |  |  | 	for c.Next() {
 | 
					
						
							|  |  |  | 		r := rule{}
 | 
					
						
							| 
									
										
										
										
											2021-05-17 22:19:54 +02:00
										 |  |  | 		args := c.RemainingArgs()
 | 
					
						
							|  |  |  | 		r.zones = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys)
 | 
					
						
							| 
									
										
										
										
											2019-09-04 23:43:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 		for c.NextBlock() {
 | 
					
						
							|  |  |  | 			p := policy{}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			action := strings.ToLower(c.Val())
 | 
					
						
							|  |  |  | 			if action == "allow" {
 | 
					
						
							|  |  |  | 				p.action = actionAllow
 | 
					
						
							|  |  |  | 			} else if action == "block" {
 | 
					
						
							|  |  |  | 				p.action = actionBlock
 | 
					
						
							| 
									
										
										
										
											2021-02-01 09:52:23 -05:00
										 |  |  | 			} else if action == "filter" {
 | 
					
						
							|  |  |  | 				p.action = actionFilter
 | 
					
						
							| 
									
										
										
										
											2019-09-04 23:43:45 +08:00
										 |  |  | 			} else {
 | 
					
						
							| 
									
										
										
										
											2021-02-01 09:52:23 -05:00
										 |  |  | 				return a, c.Errf("unexpected token %q; expect 'allow', 'block', or 'filter'", c.Val())
 | 
					
						
							| 
									
										
										
										
											2019-09-04 23:43:45 +08:00
										 |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			p.qtypes = make(map[uint16]struct{})
 | 
					
						
							|  |  |  | 			p.filter = iptree.NewTree()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			hasTypeSection := false
 | 
					
						
							|  |  |  | 			hasNetSection := false
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			remainingTokens := c.RemainingArgs()
 | 
					
						
							|  |  |  | 			for len(remainingTokens) > 0 {
 | 
					
						
							|  |  |  | 				if !isPreservedIdentifier(remainingTokens[0]) {
 | 
					
						
							|  |  |  | 					return a, c.Errf("unexpected token %q; expect 'type | net'", remainingTokens[0])
 | 
					
						
							|  |  |  | 				}
 | 
					
						
							|  |  |  | 				section := strings.ToLower(remainingTokens[0])
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				i := 1
 | 
					
						
							|  |  |  | 				var tokens []string
 | 
					
						
							|  |  |  | 				for ; i < len(remainingTokens) && !isPreservedIdentifier(remainingTokens[i]); i++ {
 | 
					
						
							|  |  |  | 					tokens = append(tokens, remainingTokens[i])
 | 
					
						
							|  |  |  | 				}
 | 
					
						
							|  |  |  | 				remainingTokens = remainingTokens[i:]
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				if len(tokens) == 0 {
 | 
					
						
							|  |  |  | 					return a, c.Errf("no token specified in %q section", section)
 | 
					
						
							|  |  |  | 				}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				switch section {
 | 
					
						
							|  |  |  | 				case "type":
 | 
					
						
							|  |  |  | 					hasTypeSection = true
 | 
					
						
							|  |  |  | 					for _, token := range tokens {
 | 
					
						
							|  |  |  | 						if token == "*" {
 | 
					
						
							|  |  |  | 							p.qtypes[dns.TypeNone] = struct{}{}
 | 
					
						
							|  |  |  | 							break
 | 
					
						
							|  |  |  | 						}
 | 
					
						
							|  |  |  | 						qtype, ok := dns.StringToType[token]
 | 
					
						
							|  |  |  | 						if !ok {
 | 
					
						
							|  |  |  | 							return a, c.Errf("unexpected token %q; expect legal QTYPE", token)
 | 
					
						
							|  |  |  | 						}
 | 
					
						
							|  |  |  | 						p.qtypes[qtype] = struct{}{}
 | 
					
						
							|  |  |  | 					}
 | 
					
						
							|  |  |  | 				case "net":
 | 
					
						
							|  |  |  | 					hasNetSection = true
 | 
					
						
							|  |  |  | 					for _, token := range tokens {
 | 
					
						
							|  |  |  | 						if token == "*" {
 | 
					
						
							|  |  |  | 							p.filter = newDefaultFilter()
 | 
					
						
							|  |  |  | 							break
 | 
					
						
							|  |  |  | 						}
 | 
					
						
							|  |  |  | 						token = normalize(token)
 | 
					
						
							|  |  |  | 						_, source, err := net.ParseCIDR(token)
 | 
					
						
							|  |  |  | 						if err != nil {
 | 
					
						
							|  |  |  | 							return a, c.Errf("illegal CIDR notation %q", token)
 | 
					
						
							|  |  |  | 						}
 | 
					
						
							|  |  |  | 						p.filter.InplaceInsertNet(source, struct{}{})
 | 
					
						
							|  |  |  | 					}
 | 
					
						
							|  |  |  | 				default:
 | 
					
						
							|  |  |  | 					return a, c.Errf("unexpected token %q; expect 'type | net'", section)
 | 
					
						
							|  |  |  | 				}
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// optional `type` section means all record types.
 | 
					
						
							|  |  |  | 			if !hasTypeSection {
 | 
					
						
							|  |  |  | 				p.qtypes[dns.TypeNone] = struct{}{}
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// optional `net` means all ip addresses.
 | 
					
						
							|  |  |  | 			if !hasNetSection {
 | 
					
						
							|  |  |  | 				p.filter = newDefaultFilter()
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			r.policies = append(r.policies, p)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 		a.Rules = append(a.Rules, r)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return a, nil
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func isPreservedIdentifier(token string) bool {
 | 
					
						
							|  |  |  | 	identifier := strings.ToLower(token)
 | 
					
						
							|  |  |  | 	return identifier == "type" || identifier == "net"
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // normalize appends '/32' for any single IPv4 address and '/128' for IPv6.
 | 
					
						
							|  |  |  | func normalize(rawNet string) string {
 | 
					
						
							|  |  |  | 	if idx := strings.IndexAny(rawNet, "/"); idx >= 0 {
 | 
					
						
							|  |  |  | 		return rawNet
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if idx := strings.IndexAny(rawNet, ":"); idx >= 0 {
 | 
					
						
							|  |  |  | 		return rawNet + "/128"
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return rawNet + "/32"
 | 
					
						
							|  |  |  | }
 |