mirror of
				https://github.com/coredns/coredns.git
				synced 2025-11-03 18:53:13 -05:00 
			
		
		
		
	
		
			
	
	
		
			169 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			169 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| 
								 | 
							
								package tsig
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"bufio"
							 | 
						||
| 
								 | 
							
									"fmt"
							 | 
						||
| 
								 | 
							
									"io"
							 | 
						||
| 
								 | 
							
									"os"
							 | 
						||
| 
								 | 
							
									"strings"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									"github.com/coredns/caddy"
							 | 
						||
| 
								 | 
							
									"github.com/coredns/coredns/core/dnsserver"
							 | 
						||
| 
								 | 
							
									"github.com/coredns/coredns/plugin"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									"github.com/miekg/dns"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func init() {
							 | 
						||
| 
								 | 
							
									caddy.RegisterPlugin(pluginName, caddy.Plugin{
							 | 
						||
| 
								 | 
							
										ServerType: "dns",
							 | 
						||
| 
								 | 
							
										Action:     setup,
							 | 
						||
| 
								 | 
							
									})
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func setup(c *caddy.Controller) error {
							 | 
						||
| 
								 | 
							
									t, err := parse(c)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return plugin.Error(pluginName, c.ArgErr())
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									config := dnsserver.GetConfig(c)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									config.TsigSecret = t.secrets
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									config.AddPlugin(func(next plugin.Handler) plugin.Handler {
							 | 
						||
| 
								 | 
							
										t.Next = next
							 | 
						||
| 
								 | 
							
										return t
							 | 
						||
| 
								 | 
							
									})
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func parse(c *caddy.Controller) (*TSIGServer, error) {
							 | 
						||
| 
								 | 
							
									t := &TSIGServer{
							 | 
						||
| 
								 | 
							
										secrets: make(map[string]string),
							 | 
						||
| 
								 | 
							
										types:   defaultQTypes,
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									for i := 0; c.Next(); i++ {
							 | 
						||
| 
								 | 
							
										if i > 0 {
							 | 
						||
| 
								 | 
							
											return nil, plugin.ErrOnce
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
										t.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
							 | 
						||
| 
								 | 
							
										for c.NextBlock() {
							 | 
						||
| 
								 | 
							
											switch c.Val() {
							 | 
						||
| 
								 | 
							
											case "secret":
							 | 
						||
| 
								 | 
							
												args := c.RemainingArgs()
							 | 
						||
| 
								 | 
							
												if len(args) != 2 {
							 | 
						||
| 
								 | 
							
													return nil, c.ArgErr()
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												k := plugin.Name(args[0]).Normalize()
							 | 
						||
| 
								 | 
							
												if _, exists := t.secrets[k]; exists {
							 | 
						||
| 
								 | 
							
													return nil, fmt.Errorf("key %q redefined", k)
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												t.secrets[k] = args[1]
							 | 
						||
| 
								 | 
							
											case "secrets":
							 | 
						||
| 
								 | 
							
												args := c.RemainingArgs()
							 | 
						||
| 
								 | 
							
												if len(args) != 1 {
							 | 
						||
| 
								 | 
							
													return nil, c.ArgErr()
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												f, err := os.Open(args[0])
							 | 
						||
| 
								 | 
							
												if err != nil {
							 | 
						||
| 
								 | 
							
													return nil, err
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												secrets, err := parseKeyFile(f)
							 | 
						||
| 
								 | 
							
												if err != nil {
							 | 
						||
| 
								 | 
							
													return nil, err
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												for k, s := range secrets {
							 | 
						||
| 
								 | 
							
													if _, exists := t.secrets[k]; exists {
							 | 
						||
| 
								 | 
							
														return nil, fmt.Errorf("key %q redefined", k)
							 | 
						||
| 
								 | 
							
													}
							 | 
						||
| 
								 | 
							
													t.secrets[k] = s
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
											case "require":
							 | 
						||
| 
								 | 
							
												t.types = qTypes{}
							 | 
						||
| 
								 | 
							
												args := c.RemainingArgs()
							 | 
						||
| 
								 | 
							
												if len(args) == 0 {
							 | 
						||
| 
								 | 
							
													return nil, c.ArgErr()
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												if args[0] == "all" {
							 | 
						||
| 
								 | 
							
													t.all = true
							 | 
						||
| 
								 | 
							
													continue
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												if args[0] == "none" {
							 | 
						||
| 
								 | 
							
													continue
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												for _, str := range args {
							 | 
						||
| 
								 | 
							
													qt, ok := dns.StringToType[str]
							 | 
						||
| 
								 | 
							
													if !ok {
							 | 
						||
| 
								 | 
							
														return nil, c.Errf("unknown query type '%s'", str)
							 | 
						||
| 
								 | 
							
													}
							 | 
						||
| 
								 | 
							
													t.types[qt] = struct{}{}
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
											default:
							 | 
						||
| 
								 | 
							
												return nil, c.Errf("unknown property '%s'", c.Val())
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return t, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func parseKeyFile(f io.Reader) (map[string]string, error) {
							 | 
						||
| 
								 | 
							
									secrets := make(map[string]string)
							 | 
						||
| 
								 | 
							
									s := bufio.NewScanner(f)
							 | 
						||
| 
								 | 
							
									for s.Scan() {
							 | 
						||
| 
								 | 
							
										fields := strings.Fields(s.Text())
							 | 
						||
| 
								 | 
							
										if len(fields) == 0 {
							 | 
						||
| 
								 | 
							
											continue
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										if fields[0] != "key" {
							 | 
						||
| 
								 | 
							
											return nil, fmt.Errorf("unexpected token %q", fields[0])
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										if len(fields) < 2 {
							 | 
						||
| 
								 | 
							
											return nil, fmt.Errorf("expected key name %q", s.Text())
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										key := strings.Trim(fields[1], "\"{")
							 | 
						||
| 
								 | 
							
										if len(key) == 0 {
							 | 
						||
| 
								 | 
							
											return nil, fmt.Errorf("expected key name %q", s.Text())
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										key = plugin.Name(key).Normalize()
							 | 
						||
| 
								 | 
							
										if _, ok := secrets[key]; ok {
							 | 
						||
| 
								 | 
							
											return nil, fmt.Errorf("key %q redefined", key)
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									key:
							 | 
						||
| 
								 | 
							
										for s.Scan() {
							 | 
						||
| 
								 | 
							
											fields := strings.Fields(s.Text())
							 | 
						||
| 
								 | 
							
											if len(fields) == 0 {
							 | 
						||
| 
								 | 
							
												continue
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											switch fields[0] {
							 | 
						||
| 
								 | 
							
											case "algorithm":
							 | 
						||
| 
								 | 
							
												continue
							 | 
						||
| 
								 | 
							
											case "secret":
							 | 
						||
| 
								 | 
							
												if len(fields) < 2 {
							 | 
						||
| 
								 | 
							
													return nil, fmt.Errorf("expected secret key %q", s.Text())
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												secret := strings.Trim(fields[1], "\";")
							 | 
						||
| 
								 | 
							
												if len(secret) == 0 {
							 | 
						||
| 
								 | 
							
													return nil, fmt.Errorf("expected secret key %q", s.Text())
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												secrets[key] = secret
							 | 
						||
| 
								 | 
							
											case "}":
							 | 
						||
| 
								 | 
							
												fallthrough
							 | 
						||
| 
								 | 
							
											case "};":
							 | 
						||
| 
								 | 
							
												break key
							 | 
						||
| 
								 | 
							
											default:
							 | 
						||
| 
								 | 
							
												return nil, fmt.Errorf("unexpected token %q", fields[0])
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										if _, ok := secrets[key]; !ok {
							 | 
						||
| 
								 | 
							
											return nil, fmt.Errorf("expected secret for key %q", key)
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return secrets, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								var defaultQTypes = qTypes{}
							 |