mirror of
https://github.com/coredns/coredns.git
synced 2025-11-09 21:42:17 -05:00
First commit
This commit is contained in:
234
core/https/certificates.go
Normal file
234
core/https/certificates.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
// certCache stores certificates in memory,
|
||||
// keying certificates by name.
|
||||
var certCache = make(map[string]Certificate)
|
||||
var certCacheMu sync.RWMutex
|
||||
|
||||
// Certificate is a tls.Certificate with associated metadata tacked on.
|
||||
// Even if the metadata can be obtained by parsing the certificate,
|
||||
// we can be more efficient by extracting the metadata once so it's
|
||||
// just there, ready to use.
|
||||
type Certificate struct {
|
||||
tls.Certificate
|
||||
|
||||
// Names is the list of names this certificate is written for.
|
||||
// The first is the CommonName (if any), the rest are SAN.
|
||||
Names []string
|
||||
|
||||
// NotAfter is when the certificate expires.
|
||||
NotAfter time.Time
|
||||
|
||||
// Managed certificates are certificates that Caddy is managing,
|
||||
// as opposed to the user specifying a certificate and key file
|
||||
// or directory and managing the certificate resources themselves.
|
||||
Managed bool
|
||||
|
||||
// OnDemand certificates are obtained or loaded on-demand during TLS
|
||||
// handshakes (as opposed to preloaded certificates, which are loaded
|
||||
// at startup). If OnDemand is true, Managed must necessarily be true.
|
||||
// OnDemand certificates are maintained in the background just like
|
||||
// preloaded ones, however, if an OnDemand certificate fails to renew,
|
||||
// it is removed from the in-memory cache.
|
||||
OnDemand bool
|
||||
|
||||
// OCSP contains the certificate's parsed OCSP response.
|
||||
OCSP *ocsp.Response
|
||||
}
|
||||
|
||||
// getCertificate gets a certificate that matches name (a server name)
|
||||
// from the in-memory cache. If there is no exact match for name, it
|
||||
// will be checked against names of the form '*.example.com' (wildcard
|
||||
// certificates) according to RFC 6125. If a match is found, matched will
|
||||
// be true. If no matches are found, matched will be false and a default
|
||||
// certificate will be returned with defaulted set to true. If no default
|
||||
// certificate is set, defaulted will be set to false.
|
||||
//
|
||||
// The logic in this function is adapted from the Go standard library,
|
||||
// which is by the Go Authors.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
|
||||
var ok bool
|
||||
|
||||
// Not going to trim trailing dots here since RFC 3546 says,
|
||||
// "The hostname is represented ... without a trailing dot."
|
||||
// Just normalize to lowercase.
|
||||
name = strings.ToLower(name)
|
||||
|
||||
certCacheMu.RLock()
|
||||
defer certCacheMu.RUnlock()
|
||||
|
||||
// exact match? great, let's use it
|
||||
if cert, ok = certCache[name]; ok {
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
|
||||
// try replacing labels in the name with wildcards until we get a match
|
||||
labels := strings.Split(name, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if cert, ok = certCache[candidate]; ok {
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// if nothing matches, use the default certificate or bust
|
||||
cert, defaulted = certCache[""]
|
||||
return
|
||||
}
|
||||
|
||||
// cacheManagedCertificate loads the certificate for domain into the
|
||||
// cache, flagging it as Managed and, if onDemand is true, as OnDemand
|
||||
// (meaning that it was obtained or loaded during a TLS handshake).
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) {
|
||||
cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
|
||||
if err != nil {
|
||||
return cert, err
|
||||
}
|
||||
cert.Managed = true
|
||||
cert.OnDemand = onDemand
|
||||
cacheCertificate(cert)
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
|
||||
// and keyFile, which must be in PEM format. It stores the certificate in
|
||||
// memory. The Managed and OnDemand flags of the certificate will be set to
|
||||
// false.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
|
||||
cert, err := makeCertificateFromDisk(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheCertificate(cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
// cacheUnmanagedCertificatePEMBytes makes a certificate out of the PEM bytes
|
||||
// of the certificate and key, then caches it in memory.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
|
||||
cert, err := makeCertificate(certBytes, keyBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheCertificate(cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeCertificateFromDisk makes a Certificate by loading the
|
||||
// certificate and key files. It fills out all the fields in
|
||||
// the certificate except for the Managed and OnDemand flags.
|
||||
// (It is up to the caller to set those.)
|
||||
func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
|
||||
certPEMBlock, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
keyPEMBlock, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
return makeCertificate(certPEMBlock, keyPEMBlock)
|
||||
}
|
||||
|
||||
// makeCertificate turns a certificate PEM bundle and a key PEM block into
|
||||
// a Certificate, with OCSP and other relevant metadata tagged with it,
|
||||
// except for the OnDemand and Managed flags. It is up to the caller to
|
||||
// set those properties.
|
||||
func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
||||
var cert Certificate
|
||||
|
||||
// Convert to a tls.Certificate
|
||||
tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
||||
if err != nil {
|
||||
return cert, err
|
||||
}
|
||||
if len(tlsCert.Certificate) == 0 {
|
||||
return cert, errors.New("certificate is empty")
|
||||
}
|
||||
|
||||
// Parse leaf certificate and extract relevant metadata
|
||||
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
if err != nil {
|
||||
return cert, err
|
||||
}
|
||||
if leaf.Subject.CommonName != "" {
|
||||
cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
|
||||
}
|
||||
for _, name := range leaf.DNSNames {
|
||||
if name != leaf.Subject.CommonName {
|
||||
cert.Names = append(cert.Names, strings.ToLower(name))
|
||||
}
|
||||
}
|
||||
cert.NotAfter = leaf.NotAfter
|
||||
|
||||
// Staple OCSP
|
||||
ocspBytes, ocspResp, err := acme.GetOCSPForCert(certPEMBlock)
|
||||
if err != nil {
|
||||
// An error here is not a problem because a certificate may simply
|
||||
// not contain a link to an OCSP server. But we should log it anyway.
|
||||
log.Printf("[WARNING] No OCSP stapling for %v: %v", cert.Names, err)
|
||||
} else if ocspResp.Status == ocsp.Good {
|
||||
tlsCert.OCSPStaple = ocspBytes
|
||||
cert.OCSP = ocspResp
|
||||
}
|
||||
|
||||
cert.Certificate = tlsCert
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// cacheCertificate adds cert to the in-memory cache. If the cache is
|
||||
// empty, cert will be used as the default certificate. If the cache is
|
||||
// full, random entries are deleted until there is room to map all the
|
||||
// names on the certificate.
|
||||
//
|
||||
// This certificate will be keyed to the names in cert.Names. Any name
|
||||
// that is already a key in the cache will be replaced with this cert.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func cacheCertificate(cert Certificate) {
|
||||
certCacheMu.Lock()
|
||||
if _, ok := certCache[""]; !ok {
|
||||
// use as default
|
||||
cert.Names = append(cert.Names, "")
|
||||
certCache[""] = cert
|
||||
}
|
||||
for len(certCache)+len(cert.Names) > 10000 {
|
||||
// for simplicity, just remove random elements
|
||||
for key := range certCache {
|
||||
if key == "" { // ... but not the default cert
|
||||
continue
|
||||
}
|
||||
delete(certCache, key)
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, name := range cert.Names {
|
||||
certCache[name] = cert
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
}
|
||||
59
core/https/certificates_test.go
Normal file
59
core/https/certificates_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package https
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestUnexportedGetCertificate(t *testing.T) {
|
||||
defer func() { certCache = make(map[string]Certificate) }()
|
||||
|
||||
// When cache is empty
|
||||
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
|
||||
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
|
||||
}
|
||||
|
||||
// When cache has one certificate in it (also is default)
|
||||
defaultCert := Certificate{Names: []string{"example.com", ""}}
|
||||
certCache[""] = defaultCert
|
||||
certCache["example.com"] = defaultCert
|
||||
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
|
||||
// When retrieving wildcard certificate
|
||||
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
|
||||
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
|
||||
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
|
||||
// When no certificate matches, the default is returned
|
||||
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
|
||||
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||
} else if cert.Names[0] != "example.com" {
|
||||
t.Errorf("Expected default cert, got: %v", cert)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCertificate(t *testing.T) {
|
||||
defer func() { certCache = make(map[string]Certificate) }()
|
||||
|
||||
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
|
||||
if _, ok := certCache["example.com"]; !ok {
|
||||
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
|
||||
}
|
||||
if _, ok := certCache["sub.example.com"]; !ok {
|
||||
t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't")
|
||||
}
|
||||
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
|
||||
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
|
||||
}
|
||||
|
||||
cacheCertificate(Certificate{Names: []string{"example2.com"}})
|
||||
if _, ok := certCache["example2.com"]; !ok {
|
||||
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
|
||||
}
|
||||
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
|
||||
t.Error("Expected second cert to NOT be cached as default, but it was")
|
||||
}
|
||||
}
|
||||
215
core/https/client.go
Normal file
215
core/https/client.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/server"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
// acmeMu ensures that only one ACME challenge occurs at a time.
|
||||
var acmeMu sync.Mutex
|
||||
|
||||
// ACMEClient is an acme.Client with custom state attached.
|
||||
type ACMEClient struct {
|
||||
*acme.Client
|
||||
AllowPrompts bool // if false, we assume AlternatePort must be used
|
||||
}
|
||||
|
||||
// NewACMEClient creates a new ACMEClient given an email and whether
|
||||
// prompting the user is allowed. Clients should not be kept and
|
||||
// re-used over long periods of time, but immediate re-use is more
|
||||
// efficient than re-creating on every iteration.
|
||||
var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) {
|
||||
// Look up or create the LE user account
|
||||
leUser, err := getUser(email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The client facilitates our communication with the CA server.
|
||||
client, err := acme.NewClient(CAUrl, &leUser, KeyType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If not registered, the user must register an account with the CA
|
||||
// and agree to terms
|
||||
if leUser.Registration == nil {
|
||||
reg, err := client.Register()
|
||||
if err != nil {
|
||||
return nil, errors.New("registration error: " + err.Error())
|
||||
}
|
||||
leUser.Registration = reg
|
||||
|
||||
if allowPrompts { // can't prompt a user who isn't there
|
||||
if !Agreed && reg.TosURL == "" {
|
||||
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
|
||||
}
|
||||
if !Agreed && reg.TosURL == "" {
|
||||
return nil, errors.New("user must agree to terms")
|
||||
}
|
||||
}
|
||||
|
||||
err = client.AgreeToTOS()
|
||||
if err != nil {
|
||||
saveUser(leUser) // Might as well try, right?
|
||||
return nil, errors.New("error agreeing to terms: " + err.Error())
|
||||
}
|
||||
|
||||
// save user to the file system
|
||||
err = saveUser(leUser)
|
||||
if err != nil {
|
||||
return nil, errors.New("could not save user: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return &ACMEClient{
|
||||
Client: client,
|
||||
AllowPrompts: allowPrompts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewACMEClientGetEmail creates a new ACMEClient and gets an email
|
||||
// address at the same time (a server config is required, since it
|
||||
// may contain an email address in it).
|
||||
func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) {
|
||||
return NewACMEClient(getEmail(config, allowPrompts), allowPrompts)
|
||||
}
|
||||
|
||||
// Configure configures c according to bindHost, which is the host (not
|
||||
// whole address) to bind the listener to in solving the http and tls-sni
|
||||
// challenges.
|
||||
func (c *ACMEClient) Configure(bindHost string) {
|
||||
// If we allow prompts, operator must be present. In our case,
|
||||
// that is synonymous with saying the server is not already
|
||||
// started. So if the user is still there, we don't use
|
||||
// AlternatePort because we don't need to proxy the challenges.
|
||||
// Conversely, if the operator is not there, the server has
|
||||
// already started and we need to proxy the challenge.
|
||||
if c.AllowPrompts {
|
||||
// Operator is present; server is not already listening
|
||||
c.SetHTTPAddress(net.JoinHostPort(bindHost, ""))
|
||||
c.SetTLSAddress(net.JoinHostPort(bindHost, ""))
|
||||
//c.ExcludeChallenges([]acme.Challenge{acme.DNS01})
|
||||
} else {
|
||||
// Operator is not present; server is started, so proxy challenges
|
||||
c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort))
|
||||
c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort))
|
||||
//c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
|
||||
}
|
||||
c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS...
|
||||
}
|
||||
|
||||
// Obtain obtains a single certificate for names. It stores the certificate
|
||||
// on the disk if successful.
|
||||
func (c *ACMEClient) Obtain(names []string) error {
|
||||
Attempts:
|
||||
for attempts := 0; attempts < 2; attempts++ {
|
||||
acmeMu.Lock()
|
||||
certificate, failures := c.ObtainCertificate(names, true, nil)
|
||||
acmeMu.Unlock()
|
||||
if len(failures) > 0 {
|
||||
// Error - try to fix it or report it to the user and abort
|
||||
var errMsg string // we'll combine all the failures into a single error message
|
||||
var promptedForAgreement bool // only prompt user for agreement at most once
|
||||
|
||||
for errDomain, obtainErr := range failures {
|
||||
// TODO: Double-check, will obtainErr ever be nil?
|
||||
if tosErr, ok := obtainErr.(acme.TOSError); ok {
|
||||
// Terms of Service agreement error; we can probably deal with this
|
||||
if !Agreed && !promptedForAgreement && c.AllowPrompts {
|
||||
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
|
||||
promptedForAgreement = true
|
||||
}
|
||||
if Agreed || !c.AllowPrompts {
|
||||
err := c.AgreeToTOS()
|
||||
if err != nil {
|
||||
return errors.New("error agreeing to updated terms: " + err.Error())
|
||||
}
|
||||
continue Attempts
|
||||
}
|
||||
}
|
||||
|
||||
// If user did not agree or it was any other kind of error, just append to the list of errors
|
||||
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
|
||||
}
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Success - immediately save the certificate resource
|
||||
err := saveCertResource(certificate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error saving assets for %v: %v", names, err)
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Renew renews the managed certificate for name. Right now our storage
|
||||
// mechanism only supports one name per certificate, so this function only
|
||||
// accepts one domain as input. It can be easily modified to support SAN
|
||||
// certificates if, one day, they become desperately needed enough that our
|
||||
// storage mechanism is upgraded to be more complex to support SAN certs.
|
||||
//
|
||||
// Anyway, this function is safe for concurrent use.
|
||||
func (c *ACMEClient) Renew(name string) error {
|
||||
// Prepare for renewal (load PEM cert, key, and meta)
|
||||
certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var certMeta acme.CertificateResource
|
||||
err = json.Unmarshal(metaBytes, &certMeta)
|
||||
certMeta.Certificate = certBytes
|
||||
certMeta.PrivateKey = keyBytes
|
||||
|
||||
// Perform renewal and retry if necessary, but not too many times.
|
||||
var newCertMeta acme.CertificateResource
|
||||
var success bool
|
||||
for attempts := 0; attempts < 2; attempts++ {
|
||||
acmeMu.Lock()
|
||||
newCertMeta, err = c.RenewCertificate(certMeta, true)
|
||||
acmeMu.Unlock()
|
||||
if err == nil {
|
||||
success = true
|
||||
break
|
||||
}
|
||||
|
||||
// If the legal terms changed and need to be agreed to again,
|
||||
// we can handle that.
|
||||
if _, ok := err.(acme.TOSError); ok {
|
||||
err := c.AgreeToTOS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For any other kind of error, wait 10s and try again.
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
|
||||
if !success {
|
||||
return errors.New("too many renewal attempts; last error: " + err.Error())
|
||||
}
|
||||
|
||||
return saveCertResource(newCertMeta)
|
||||
}
|
||||
57
core/https/crypto.go
Normal file
57
core/https/crypto.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file.
|
||||
func loadPrivateKey(file string) (crypto.PrivateKey, error) {
|
||||
keyBytes, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyBlock, _ := pem.Decode(keyBytes)
|
||||
|
||||
switch keyBlock.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||
}
|
||||
|
||||
return nil, errors.New("unknown private key type")
|
||||
}
|
||||
|
||||
// savePrivateKey saves a PEM-encoded ECC/RSA private key to file.
|
||||
func savePrivateKey(key crypto.PrivateKey, file string) error {
|
||||
var pemType string
|
||||
var keyBytes []byte
|
||||
switch key := key.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
var err error
|
||||
pemType = "EC"
|
||||
keyBytes, err = x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case *rsa.PrivateKey:
|
||||
pemType = "RSA"
|
||||
keyBytes = x509.MarshalPKCS1PrivateKey(key)
|
||||
}
|
||||
|
||||
pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes}
|
||||
keyOut, err := os.Create(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyOut.Chmod(0600)
|
||||
defer keyOut.Close()
|
||||
return pem.Encode(keyOut, &pemKey)
|
||||
}
|
||||
111
core/https/crypto_test.go
Normal file
111
core/https/crypto_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
|
||||
keyFile := "test.key"
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test save
|
||||
err = savePrivateKey(privateKey, keyFile)
|
||||
if err != nil {
|
||||
t.Fatal("error saving private key:", err)
|
||||
}
|
||||
|
||||
// it doesn't make sense to test file permission on windows
|
||||
if runtime.GOOS != "windows" {
|
||||
// get info of the key file
|
||||
info, err := os.Stat(keyFile)
|
||||
if err != nil {
|
||||
t.Fatal("error stating private key:", err)
|
||||
}
|
||||
// verify permission of key file is correct
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Error("Expected key file to have permission 0600, but it wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
// test load
|
||||
loadedKey, err := loadPrivateKey(keyFile)
|
||||
if err != nil {
|
||||
t.Error("error loading private key:", err)
|
||||
}
|
||||
|
||||
// verify loaded key is correct
|
||||
if !PrivateKeysSame(privateKey, loadedKey) {
|
||||
t.Error("Expected key bytes to be the same, but they weren't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndLoadECCPrivateKey(t *testing.T) {
|
||||
keyFile := "test.key"
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test save
|
||||
err = savePrivateKey(privateKey, keyFile)
|
||||
if err != nil {
|
||||
t.Fatal("error saving private key:", err)
|
||||
}
|
||||
|
||||
// it doesn't make sense to test file permission on windows
|
||||
if runtime.GOOS != "windows" {
|
||||
// get info of the key file
|
||||
info, err := os.Stat(keyFile)
|
||||
if err != nil {
|
||||
t.Fatal("error stating private key:", err)
|
||||
}
|
||||
// verify permission of key file is correct
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Error("Expected key file to have permission 0600, but it wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
// test load
|
||||
loadedKey, err := loadPrivateKey(keyFile)
|
||||
if err != nil {
|
||||
t.Error("error loading private key:", err)
|
||||
}
|
||||
|
||||
// verify loaded key is correct
|
||||
if !PrivateKeysSame(privateKey, loadedKey) {
|
||||
t.Error("Expected key bytes to be the same, but they weren't")
|
||||
}
|
||||
}
|
||||
|
||||
// PrivateKeysSame compares the bytes of a and b and returns true if they are the same.
|
||||
func PrivateKeysSame(a, b crypto.PrivateKey) bool {
|
||||
return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b))
|
||||
}
|
||||
|
||||
// PrivateKeyBytes returns the bytes of DER-encoded key.
|
||||
func PrivateKeyBytes(key crypto.PrivateKey) []byte {
|
||||
var keyBytes []byte
|
||||
switch key := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
keyBytes = x509.MarshalPKCS1PrivateKey(key)
|
||||
case *ecdsa.PrivateKey:
|
||||
keyBytes, _ = x509.MarshalECPrivateKey(key)
|
||||
}
|
||||
return keyBytes
|
||||
}
|
||||
42
core/https/handler.go
Normal file
42
core/https/handler.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const challengeBasePath = "/.well-known/acme-challenge"
|
||||
|
||||
// RequestCallback proxies challenge requests to ACME client if the
|
||||
// request path starts with challengeBasePath. It returns true if it
|
||||
// handled the request and no more needs to be done; it returns false
|
||||
// if this call was a no-op and the request still needs handling.
|
||||
func RequestCallback(w http.ResponseWriter, r *http.Request) bool {
|
||||
if strings.HasPrefix(r.URL.Path, challengeBasePath) {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
log.Printf("[ERROR] ACME proxy handler: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(upstream)
|
||||
proxy.Transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs
|
||||
}
|
||||
proxy.ServeHTTP(w, r)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
63
core/https/handler_test.go
Normal file
63
core/https/handler_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestCallbackNoOp(t *testing.T) {
|
||||
// try base paths that aren't handled by this handler
|
||||
for _, url := range []string{
|
||||
"http://localhost/",
|
||||
"http://localhost/foo.html",
|
||||
"http://localhost/.git",
|
||||
"http://localhost/.well-known/",
|
||||
"http://localhost/.well-known/acme-challenging",
|
||||
} {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not craft request, got error: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
if RequestCallback(rw, req) {
|
||||
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCallbackSuccess(t *testing.T) {
|
||||
expectedPath := challengeBasePath + "/asdf"
|
||||
|
||||
// Set up fake acme handler backend to make sure proxying succeeds
|
||||
var proxySuccess bool
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxySuccess = true
|
||||
if r.URL.Path != expectedPath {
|
||||
t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
|
||||
// Custom listener that uses the port we expect
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to start test server listener: %v", err)
|
||||
}
|
||||
ts.Listener = ln
|
||||
|
||||
// Start our engines and run the test
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not craft request, got error: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
RequestCallback(rw, req)
|
||||
|
||||
if !proxySuccess {
|
||||
t.Fatal("Expected request to be proxied, but it wasn't")
|
||||
}
|
||||
}
|
||||
320
core/https/handshake.go
Normal file
320
core/https/handshake.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/server"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
// GetCertificate gets a certificate to satisfy clientHello as long as
|
||||
// the certificate is already cached in memory. It will not be loaded
|
||||
// from disk or obtained from the CA during the handshake.
|
||||
//
|
||||
// This function is safe for use as a tls.Config.GetCertificate callback.
|
||||
func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := getCertDuringHandshake(clientHello.ServerName, false, false)
|
||||
return &cert.Certificate, err
|
||||
}
|
||||
|
||||
// GetOrObtainCertificate will get a certificate to satisfy clientHello, even
|
||||
// if that means obtaining a new certificate from a CA during the handshake.
|
||||
// It first checks the in-memory cache, then accesses disk, then accesses the
|
||||
// network if it must. An obtained certificate will be stored on disk and
|
||||
// cached in memory.
|
||||
//
|
||||
// This function is safe for use as a tls.Config.GetCertificate callback.
|
||||
func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := getCertDuringHandshake(clientHello.ServerName, true, true)
|
||||
return &cert.Certificate, err
|
||||
}
|
||||
|
||||
// getCertDuringHandshake will get a certificate for name. It first tries
|
||||
// the in-memory cache. If no certificate for name is in the cache and if
|
||||
// loadIfNecessary == true, it goes to disk to load it into the cache and
|
||||
// serve it. If it's not on disk and if obtainIfNecessary == true, the
|
||||
// certificate will be obtained from the CA, cached, and served. If
|
||||
// obtainIfNecessary is true, then loadIfNecessary must also be set to true.
|
||||
// An error will be returned if and only if no certificate is available.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
||||
// First check our in-memory cache to see if we've already loaded it
|
||||
cert, matched, defaulted := getCertificate(name)
|
||||
if matched {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
if loadIfNecessary {
|
||||
// Then check to see if we have one on disk
|
||||
loadedCert, err := cacheManagedCertificate(name, true)
|
||||
if err == nil {
|
||||
loadedCert, err = handshakeMaintenance(name, loadedCert)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
|
||||
}
|
||||
return loadedCert, nil
|
||||
}
|
||||
|
||||
if obtainIfNecessary {
|
||||
// By this point, we need to ask the CA for a certificate
|
||||
|
||||
name = strings.ToLower(name)
|
||||
|
||||
// Make sure aren't over any applicable limits
|
||||
err := checkLimitsForObtainingNewCerts(name)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
|
||||
// Name has to qualify for a certificate
|
||||
if !HostQualifies(name) {
|
||||
return cert, errors.New("hostname '" + name + "' does not qualify for certificate")
|
||||
}
|
||||
|
||||
// Obtain certificate from the CA
|
||||
return obtainOnDemandCertificate(name)
|
||||
}
|
||||
}
|
||||
|
||||
if defaulted {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
return Certificate{}, errors.New("no certificate for " + name)
|
||||
}
|
||||
|
||||
// checkLimitsForObtainingNewCerts checks to see if name can be issued right
|
||||
// now according to mitigating factors we keep track of and preferences the
|
||||
// user has set. If a non-nil error is returned, do not issue a new certificate
|
||||
// for name.
|
||||
func checkLimitsForObtainingNewCerts(name string) error {
|
||||
// User can set hard limit for number of certs for the process to issue
|
||||
if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue {
|
||||
return fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue)
|
||||
}
|
||||
|
||||
// Make sure name hasn't failed a challenge recently
|
||||
failedIssuanceMu.RLock()
|
||||
when, ok := failedIssuance[name]
|
||||
failedIssuanceMu.RUnlock()
|
||||
if ok {
|
||||
return fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String())
|
||||
}
|
||||
|
||||
// Make sure, if we've issued a few certificates already, that we haven't
|
||||
// issued any recently
|
||||
lastIssueTimeMu.Lock()
|
||||
since := time.Since(lastIssueTime)
|
||||
lastIssueTimeMu.Unlock()
|
||||
if atomic.LoadInt32(OnDemandIssuedCount) >= 10 && since < 10*time.Minute {
|
||||
return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
|
||||
}
|
||||
|
||||
// 👍Good to go
|
||||
return nil
|
||||
}
|
||||
|
||||
// obtainOnDemandCertificate obtains a certificate for name for the given
|
||||
// name. If another goroutine has already started obtaining a cert for
|
||||
// name, it will wait and use what the other goroutine obtained.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func obtainOnDemandCertificate(name string) (Certificate, error) {
|
||||
// We must protect this process from happening concurrently, so synchronize.
|
||||
obtainCertWaitChansMu.Lock()
|
||||
wait, ok := obtainCertWaitChans[name]
|
||||
if ok {
|
||||
// lucky us -- another goroutine is already obtaining the certificate.
|
||||
// wait for it to finish obtaining the cert and then we'll use it.
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
<-wait
|
||||
return getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// looks like it's up to us to do all the work and obtain the cert
|
||||
wait = make(chan struct{})
|
||||
obtainCertWaitChans[name] = wait
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
|
||||
// Unblock waiters and delete waitgroup when we return
|
||||
defer func() {
|
||||
obtainCertWaitChansMu.Lock()
|
||||
close(wait)
|
||||
delete(obtainCertWaitChans, name)
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
}()
|
||||
|
||||
log.Printf("[INFO] Obtaining new certificate for %s", name)
|
||||
|
||||
// obtain cert
|
||||
client, err := NewACMEClientGetEmail(server.Config{}, false)
|
||||
if err != nil {
|
||||
return Certificate{}, errors.New("error creating client: " + err.Error())
|
||||
}
|
||||
client.Configure("") // TODO: which BindHost?
|
||||
err = client.Obtain([]string{name})
|
||||
if err != nil {
|
||||
// Failed to solve challenge, so don't allow another on-demand
|
||||
// issue for this name to be attempted for a little while.
|
||||
failedIssuanceMu.Lock()
|
||||
failedIssuance[name] = time.Now()
|
||||
go func(name string) {
|
||||
time.Sleep(5 * time.Minute)
|
||||
failedIssuanceMu.Lock()
|
||||
delete(failedIssuance, name)
|
||||
failedIssuanceMu.Unlock()
|
||||
}(name)
|
||||
failedIssuanceMu.Unlock()
|
||||
return Certificate{}, err
|
||||
}
|
||||
|
||||
// Success - update counters and stuff
|
||||
atomic.AddInt32(OnDemandIssuedCount, 1)
|
||||
lastIssueTimeMu.Lock()
|
||||
lastIssueTime = time.Now()
|
||||
lastIssueTimeMu.Unlock()
|
||||
|
||||
// The certificate is already on disk; now just start over to load it and serve it
|
||||
return getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// handshakeMaintenance performs a check on cert for expiration and OCSP
|
||||
// validity.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
|
||||
// Check cert expiration
|
||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < renewDurationBefore {
|
||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
|
||||
return renewDynamicCertificate(name)
|
||||
}
|
||||
|
||||
// Check OCSP staple validity
|
||||
if cert.OCSP != nil {
|
||||
refreshTime := cert.OCSP.ThisUpdate.Add(cert.OCSP.NextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
|
||||
if time.Now().After(refreshTime) {
|
||||
err := stapleOCSP(&cert, nil)
|
||||
if err != nil {
|
||||
// An error with OCSP stapling is not the end of the world, and in fact, is
|
||||
// quite common considering not all certs have issuer URLs that support it.
|
||||
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
|
||||
}
|
||||
certCacheMu.Lock()
|
||||
certCache[name] = cert
|
||||
certCacheMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// renewDynamicCertificate renews currentCert using the clientHello. It returns the
|
||||
// certificate to use and an error, if any. currentCert may be returned even if an
|
||||
// error occurs, since we perform renewals before they expire and it may still be
|
||||
// usable. name should already be lower-cased before calling this function.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func renewDynamicCertificate(name string) (Certificate, error) {
|
||||
obtainCertWaitChansMu.Lock()
|
||||
wait, ok := obtainCertWaitChans[name]
|
||||
if ok {
|
||||
// lucky us -- another goroutine is already renewing the certificate.
|
||||
// wait for it to finish, then we'll use the new one.
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
<-wait
|
||||
return getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// looks like it's up to us to do all the work and renew the cert
|
||||
wait = make(chan struct{})
|
||||
obtainCertWaitChans[name] = wait
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
|
||||
// unblock waiters and delete waitgroup when we return
|
||||
defer func() {
|
||||
obtainCertWaitChansMu.Lock()
|
||||
close(wait)
|
||||
delete(obtainCertWaitChans, name)
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
}()
|
||||
|
||||
log.Printf("[INFO] Renewing certificate for %s", name)
|
||||
|
||||
client, err := NewACMEClientGetEmail(server.Config{}, false)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
client.Configure("") // TODO: Bind address of relevant listener, yuck
|
||||
err = client.Renew(name)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
|
||||
return getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// stapleOCSP staples OCSP information to cert for hostname name.
|
||||
// If you have it handy, you should pass in the PEM-encoded certificate
|
||||
// bundle; otherwise the DER-encoded cert will have to be PEM-encoded.
|
||||
// If you don't have the PEM blocks handy, just pass in nil.
|
||||
//
|
||||
// Errors here are not necessarily fatal, it could just be that the
|
||||
// certificate doesn't have an issuer URL.
|
||||
func stapleOCSP(cert *Certificate, pemBundle []byte) error {
|
||||
if pemBundle == nil {
|
||||
// The function in the acme package that gets OCSP requires a PEM-encoded cert
|
||||
bundle := new(bytes.Buffer)
|
||||
for _, derBytes := range cert.Certificate.Certificate {
|
||||
pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
}
|
||||
pemBundle = bundle.Bytes()
|
||||
}
|
||||
|
||||
ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert.Certificate.OCSPStaple = ocspBytes
|
||||
cert.OCSP = ocspResp
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
|
||||
var obtainCertWaitChans = make(map[string]chan struct{})
|
||||
var obtainCertWaitChansMu sync.Mutex
|
||||
|
||||
// OnDemandIssuedCount is the number of certificates that have been issued
|
||||
// on-demand by this process. It is only safe to modify this count atomically.
|
||||
// If it reaches onDemandMaxIssue, on-demand issuances will fail.
|
||||
var OnDemandIssuedCount = new(int32)
|
||||
|
||||
// onDemandMaxIssue is set based on max_certs in tls config. It specifies the
|
||||
// maximum number of certificates that can be issued.
|
||||
// TODO: This applies globally, but we should probably make a server-specific
|
||||
// way to keep track of these limits and counts, since it's specified in the
|
||||
// Caddyfile...
|
||||
var onDemandMaxIssue int32
|
||||
|
||||
// failedIssuance is a set of names that we recently failed to get a
|
||||
// certificate for from the ACME CA. They are removed after some time.
|
||||
// When a name is in this map, do not issue a certificate for it on-demand.
|
||||
var failedIssuance = make(map[string]time.Time)
|
||||
var failedIssuanceMu sync.RWMutex
|
||||
|
||||
// lastIssueTime records when we last obtained a certificate successfully.
|
||||
// If this value is recent, do not make any on-demand certificate requests.
|
||||
var lastIssueTime time.Time
|
||||
var lastIssueTimeMu sync.Mutex
|
||||
54
core/https/handshake_test.go
Normal file
54
core/https/handshake_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetCertificate(t *testing.T) {
|
||||
defer func() { certCache = make(map[string]Certificate) }()
|
||||
|
||||
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
||||
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
||||
helloNoSNI := &tls.ClientHelloInfo{}
|
||||
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
|
||||
|
||||
// When cache is empty
|
||||
if cert, err := GetCertificate(hello); err == nil {
|
||||
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
|
||||
}
|
||||
if cert, err := GetCertificate(helloNoSNI); err == nil {
|
||||
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
|
||||
}
|
||||
|
||||
// When cache has one certificate in it (also is default)
|
||||
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
||||
certCache[""] = defaultCert
|
||||
certCache["example.com"] = defaultCert
|
||||
if cert, err := GetCertificate(hello); err != nil {
|
||||
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
|
||||
}
|
||||
if cert, err := GetCertificate(helloNoSNI); err != nil {
|
||||
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
|
||||
}
|
||||
|
||||
// When retrieving wildcard certificate
|
||||
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
|
||||
if cert, err := GetCertificate(helloSub); err != nil {
|
||||
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
|
||||
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
|
||||
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
|
||||
}
|
||||
|
||||
// When no certificate matches, the default is returned
|
||||
if cert, err := GetCertificate(helloNoMatch); err != nil {
|
||||
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Expected default cert with no matches, got: %v", cert)
|
||||
}
|
||||
}
|
||||
358
core/https/https.go
Normal file
358
core/https/https.go
Normal file
@@ -0,0 +1,358 @@
|
||||
// Package https facilitates the management of TLS assets and integrates
|
||||
// Let's Encrypt functionality into Caddy with first-class support for
|
||||
// creating and renewing certificates automatically. It is designed to
|
||||
// configure sites for HTTPS by default.
|
||||
package https
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/coredns/server"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
// Activate sets up TLS for each server config in configs
|
||||
// as needed; this consists of acquiring and maintaining
|
||||
// certificates and keys for qualifying configs and enabling
|
||||
// OCSP stapling for all TLS-enabled configs.
|
||||
//
|
||||
// This function may prompt the user to provide an email
|
||||
// address if none is available through other means. It
|
||||
// prefers the email address specified in the config, but
|
||||
// if that is not available it will check the command line
|
||||
// argument. If absent, it will use the most recent email
|
||||
// address from last time. If there isn't one, the user
|
||||
// will be prompted and shown SA link.
|
||||
//
|
||||
// Also note that calling this function activates asset
|
||||
// management automatically, which keeps certificates
|
||||
// renewed and OCSP stapling updated.
|
||||
//
|
||||
// Activate returns the updated list of configs, since
|
||||
// some may have been appended, for example, to redirect
|
||||
// plaintext HTTP requests to their HTTPS counterpart.
|
||||
// This function only appends; it does not splice.
|
||||
func Activate(configs []server.Config) ([]server.Config, error) {
|
||||
// just in case previous caller forgot...
|
||||
Deactivate()
|
||||
|
||||
// pre-screen each config and earmark the ones that qualify for managed TLS
|
||||
MarkQualified(configs)
|
||||
|
||||
// place certificates and keys on disk
|
||||
err := ObtainCerts(configs, true, false)
|
||||
if err != nil {
|
||||
return configs, err
|
||||
}
|
||||
|
||||
// update TLS configurations
|
||||
err = EnableTLS(configs, true)
|
||||
if err != nil {
|
||||
return configs, err
|
||||
}
|
||||
|
||||
// renew all relevant certificates that need renewal. this is important
|
||||
// to do right away for a couple reasons, mainly because each restart,
|
||||
// the renewal ticker is reset, so if restarts happen more often than
|
||||
// the ticker interval, renewals would never happen. but doing
|
||||
// it right away at start guarantees that renewals aren't missed.
|
||||
err = renewManagedCertificates(true)
|
||||
if err != nil {
|
||||
return configs, err
|
||||
}
|
||||
|
||||
// keep certificates renewed and OCSP stapling updated
|
||||
go maintainAssets(stopChan)
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
// Deactivate cleans up long-term, in-memory resources
|
||||
// allocated by calling Activate(). Essentially, it stops
|
||||
// the asset maintainer from running, meaning that certificates
|
||||
// will not be renewed, OCSP staples will not be updated, etc.
|
||||
func Deactivate() (err error) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
err = errors.New("already deactivated")
|
||||
}
|
||||
}()
|
||||
close(stopChan)
|
||||
stopChan = make(chan struct{})
|
||||
return
|
||||
}
|
||||
|
||||
// MarkQualified scans each config and, if it qualifies for managed
|
||||
// TLS, it sets the Managed field of the TLSConfig to true.
|
||||
func MarkQualified(configs []server.Config) {
|
||||
for i := 0; i < len(configs); i++ {
|
||||
if ConfigQualifies(configs[i]) {
|
||||
configs[i].TLS.Managed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ObtainCerts obtains certificates for all these configs as long as a
|
||||
// certificate does not already exist on disk. It does not modify the
|
||||
// configs at all; it only obtains and stores certificates and keys to
|
||||
// the disk. If allowPrompts is true, the user may be shown a prompt.
|
||||
// If proxyACME is true, the ACME challenges will be proxied to our alt port.
|
||||
func ObtainCerts(configs []server.Config, allowPrompts, proxyACME bool) error {
|
||||
// We group configs by email so we don't make the same clients over and
|
||||
// over. This has the potential to prompt the user for an email, but we
|
||||
// prevent that by assuming that if we already have a listener that can
|
||||
// proxy ACME challenge requests, then the server is already running and
|
||||
// the operator is no longer present.
|
||||
groupedConfigs := groupConfigsByEmail(configs, allowPrompts)
|
||||
|
||||
for email, group := range groupedConfigs {
|
||||
// Wait as long as we can before creating the client, because it
|
||||
// may not be needed, for example, if we already have what we
|
||||
// need on disk. Creating a client involves the network and
|
||||
// potentially prompting the user, etc., so only do if necessary.
|
||||
var client *ACMEClient
|
||||
|
||||
for _, cfg := range group {
|
||||
if !HostQualifies(cfg.Host) || existingCertAndKey(cfg.Host) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Now we definitely do need a client
|
||||
if client == nil {
|
||||
var err error
|
||||
client, err = NewACMEClient(email, allowPrompts)
|
||||
if err != nil {
|
||||
return errors.New("error creating client: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// c.Configure assumes that allowPrompts == !proxyACME,
|
||||
// but that's not always true. For example, a restart where
|
||||
// the user isn't present and we're not listening on port 80.
|
||||
// TODO: This could probably be refactored better.
|
||||
if proxyACME {
|
||||
client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, AlternatePort))
|
||||
client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, AlternatePort))
|
||||
client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
|
||||
} else {
|
||||
client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, ""))
|
||||
client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, ""))
|
||||
client.ExcludeChallenges([]acme.Challenge{acme.DNS01})
|
||||
}
|
||||
|
||||
err := client.Obtain([]string{cfg.Host})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// groupConfigsByEmail groups configs by the email address to be used by an
|
||||
// ACME client. It only groups configs that have TLS enabled and that are
|
||||
// marked as Managed. If userPresent is true, the operator MAY be prompted
|
||||
// for an email address.
|
||||
func groupConfigsByEmail(configs []server.Config, userPresent bool) map[string][]server.Config {
|
||||
initMap := make(map[string][]server.Config)
|
||||
for _, cfg := range configs {
|
||||
if !cfg.TLS.Managed {
|
||||
continue
|
||||
}
|
||||
leEmail := getEmail(cfg, userPresent)
|
||||
initMap[leEmail] = append(initMap[leEmail], cfg)
|
||||
}
|
||||
return initMap
|
||||
}
|
||||
|
||||
// EnableTLS configures each config to use TLS according to default settings.
|
||||
// It will only change configs that are marked as managed, and assumes that
|
||||
// certificates and keys are already on disk. If loadCertificates is true,
|
||||
// the certificates will be loaded from disk into the cache for this process
|
||||
// to use. If false, TLS will still be enabled and configured with default
|
||||
// settings, but no certificates will be parsed loaded into the cache, and
|
||||
// the returned error value will always be nil.
|
||||
func EnableTLS(configs []server.Config, loadCertificates bool) error {
|
||||
for i := 0; i < len(configs); i++ {
|
||||
if !configs[i].TLS.Managed {
|
||||
continue
|
||||
}
|
||||
configs[i].TLS.Enabled = true
|
||||
if loadCertificates && HostQualifies(configs[i].Host) {
|
||||
_, err := cacheManagedCertificate(configs[i].Host, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
setDefaultTLSParams(&configs[i])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hostHasOtherPort returns true if there is another config in the list with the same
|
||||
// hostname that has port otherPort, or false otherwise. All the configs are checked
|
||||
// against the hostname of allConfigs[thisConfigIdx].
|
||||
func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort string) bool {
|
||||
for i, otherCfg := range allConfigs {
|
||||
if i == thisConfigIdx {
|
||||
continue // has to be a config OTHER than the one we're comparing against
|
||||
}
|
||||
if otherCfg.Host == allConfigs[thisConfigIdx].Host && otherCfg.Port == otherPort {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConfigQualifies returns true if cfg qualifies for
|
||||
// fully managed TLS (but not on-demand TLS, which is
|
||||
// not considered here). It does NOT check to see if a
|
||||
// cert and key already exist for the config. If the
|
||||
// config does qualify, you should set cfg.TLS.Managed
|
||||
// to true and check that instead, because the process of
|
||||
// setting up the config may make it look like it
|
||||
// doesn't qualify even though it originally did.
|
||||
func ConfigQualifies(cfg server.Config) bool {
|
||||
return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key
|
||||
|
||||
// user can force-disable automatic HTTPS for this host
|
||||
cfg.Port != "80" &&
|
||||
cfg.TLS.LetsEncryptEmail != "off" &&
|
||||
|
||||
// we get can't certs for some kinds of hostnames, but
|
||||
// on-demand TLS allows empty hostnames at startup
|
||||
(HostQualifies(cfg.Host) || cfg.TLS.OnDemand)
|
||||
}
|
||||
|
||||
// HostQualifies returns true if the hostname alone
|
||||
// appears eligible for automatic HTTPS. For example,
|
||||
// localhost, empty hostname, and IP addresses are
|
||||
// not eligible because we cannot obtain certificates
|
||||
// for those names.
|
||||
func HostQualifies(hostname string) bool {
|
||||
return hostname != "localhost" && // localhost is ineligible
|
||||
|
||||
// hostname must not be empty
|
||||
strings.TrimSpace(hostname) != "" &&
|
||||
|
||||
// cannot be an IP address, see
|
||||
// https://community.letsencrypt.org/t/certificate-for-static-ip/84/2?u=mholt
|
||||
// (also trim [] from either end, since that special case can sneak through
|
||||
// for IPv6 addresses using the -host flag and with empty/no Caddyfile)
|
||||
net.ParseIP(strings.Trim(hostname, "[]")) == nil
|
||||
}
|
||||
|
||||
// existingCertAndKey returns true if the host has a certificate
|
||||
// and private key in storage already, false otherwise.
|
||||
func existingCertAndKey(host string) bool {
|
||||
_, err := os.Stat(storage.SiteCertFile(host))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_, err = os.Stat(storage.SiteKeyFile(host))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// saveCertResource saves the certificate resource to disk. This
|
||||
// includes the certificate file itself, the private key, and the
|
||||
// metadata file.
|
||||
func saveCertResource(cert acme.CertificateResource) error {
|
||||
err := os.MkdirAll(storage.Site(cert.Domain), 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save cert
|
||||
err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save private key
|
||||
err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save cert metadata
|
||||
jsonBytes, err := json.MarshalIndent(&cert, "", "\t")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke revokes the certificate for host via ACME protocol.
|
||||
func Revoke(host string) error {
|
||||
if !existingCertAndKey(host) {
|
||||
return errors.New("no certificate and key for " + host)
|
||||
}
|
||||
|
||||
email := getEmail(server.Config{Host: host}, true)
|
||||
if email == "" {
|
||||
return errors.New("email is required to revoke")
|
||||
}
|
||||
|
||||
client, err := NewACMEClient(email, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certFile := storage.SiteCertFile(host)
|
||||
certBytes, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = client.RevokeCertificate(certBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Remove(certFile)
|
||||
if err != nil {
|
||||
return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultEmail represents the Let's Encrypt account email to use if none provided
|
||||
DefaultEmail string
|
||||
|
||||
// Agreed indicates whether user has agreed to the Let's Encrypt SA
|
||||
Agreed bool
|
||||
|
||||
// CAUrl represents the base URL to the CA's ACME endpoint
|
||||
CAUrl string
|
||||
)
|
||||
|
||||
// AlternatePort is the port on which the acme client will open a
|
||||
// listener and solve the CA's challenges. If this alternate port
|
||||
// is used instead of the default port (80 or 443), then the
|
||||
// default port for the challenge must be forwarded to this one.
|
||||
const AlternatePort = "5033"
|
||||
|
||||
// KeyType is the type to use for new keys.
|
||||
// This shouldn't need to change except for in tests;
|
||||
// the size can be drastically reduced for speed.
|
||||
var KeyType = acme.EC384
|
||||
|
||||
// stopChan is used to signal the maintenance goroutine
|
||||
// to terminate.
|
||||
var stopChan chan struct{}
|
||||
332
core/https/https_test.go
Normal file
332
core/https/https_test.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/coredns/middleware/redirect"
|
||||
"github.com/miekg/coredns/server"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
func TestHostQualifies(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
host string
|
||||
expect bool
|
||||
}{
|
||||
{"localhost", false},
|
||||
{"127.0.0.1", false},
|
||||
{"127.0.1.5", false},
|
||||
{"::1", false},
|
||||
{"[::1]", false},
|
||||
{"[::]", false},
|
||||
{"::", false},
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{"0.0.0.0", false},
|
||||
{"192.168.1.3", false},
|
||||
{"10.0.2.1", false},
|
||||
{"169.112.53.4", false},
|
||||
{"foobar.com", true},
|
||||
{"sub.foobar.com", true},
|
||||
} {
|
||||
if HostQualifies(test.host) && !test.expect {
|
||||
t.Errorf("Test %d: Expected '%s' to NOT qualify, but it did", i, test.host)
|
||||
}
|
||||
if !HostQualifies(test.host) && test.expect {
|
||||
t.Errorf("Test %d: Expected '%s' to qualify, but it did NOT", i, test.host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigQualifies(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
cfg server.Config
|
||||
expect bool
|
||||
}{
|
||||
{server.Config{Host: ""}, false},
|
||||
{server.Config{Host: "localhost"}, false},
|
||||
{server.Config{Host: "123.44.3.21"}, false},
|
||||
{server.Config{Host: "example.com"}, true},
|
||||
{server.Config{Host: "example.com", TLS: server.TLSConfig{Manual: true}}, false},
|
||||
{server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false},
|
||||
{server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true},
|
||||
{server.Config{Host: "example.com", Scheme: "http"}, false},
|
||||
{server.Config{Host: "example.com", Port: "80"}, false},
|
||||
{server.Config{Host: "example.com", Port: "1234"}, true},
|
||||
{server.Config{Host: "example.com", Scheme: "https"}, true},
|
||||
{server.Config{Host: "example.com", Port: "80", Scheme: "https"}, false},
|
||||
} {
|
||||
if test.expect && !ConfigQualifies(test.cfg) {
|
||||
t.Errorf("Test %d: Expected config to qualify, but it did NOT: %#v", i, test.cfg)
|
||||
}
|
||||
if !test.expect && ConfigQualifies(test.cfg) {
|
||||
t.Errorf("Test %d: Expected config to NOT qualify, but it did: %#v", i, test.cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirPlaintextHost(t *testing.T) {
|
||||
cfg := redirPlaintextHost(server.Config{
|
||||
Host: "example.com",
|
||||
BindHost: "93.184.216.34",
|
||||
Port: "1234",
|
||||
})
|
||||
|
||||
// Check host and port
|
||||
if actual, expected := cfg.Host, "example.com"; actual != expected {
|
||||
t.Errorf("Expected redir config to have host %s but got %s", expected, actual)
|
||||
}
|
||||
if actual, expected := cfg.BindHost, "93.184.216.34"; actual != expected {
|
||||
t.Errorf("Expected redir config to have bindhost %s but got %s", expected, actual)
|
||||
}
|
||||
if actual, expected := cfg.Port, "80"; actual != expected {
|
||||
t.Errorf("Expected redir config to have port '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Make sure redirect handler is set up properly
|
||||
if cfg.Middleware == nil || len(cfg.Middleware) != 1 {
|
||||
t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.Middleware)
|
||||
}
|
||||
|
||||
handler, ok := cfg.Middleware[0](nil).(redirect.Redirect)
|
||||
if !ok {
|
||||
t.Fatalf("Expected a redirect.Redirect middleware, but got: %#v", handler)
|
||||
}
|
||||
if len(handler.Rules) != 1 {
|
||||
t.Fatalf("Expected one redirect rule, got: %#v", handler.Rules)
|
||||
}
|
||||
|
||||
// Check redirect rule for correctness
|
||||
if actual, expected := handler.Rules[0].FromScheme, "http"; actual != expected {
|
||||
t.Errorf("Expected redirect rule to be from scheme '%s' but is actually from '%s'", expected, actual)
|
||||
}
|
||||
if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected {
|
||||
t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual)
|
||||
}
|
||||
if actual, expected := handler.Rules[0].To, "https://{host}:1234{uri}"; actual != expected {
|
||||
t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual)
|
||||
}
|
||||
if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected {
|
||||
t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual)
|
||||
}
|
||||
|
||||
// browsers can infer a default port from scheme, so make sure the port
|
||||
// doesn't get added in explicitly for default ports like 443 for https.
|
||||
cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"})
|
||||
handler, ok = cfg.Middleware[0](nil).(redirect.Redirect)
|
||||
if actual, expected := handler.Rules[0].To, "https://{host}{uri}"; actual != expected {
|
||||
t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCertResource(t *testing.T) {
|
||||
storage = Storage("./le_test_save")
|
||||
defer func() {
|
||||
err := os.RemoveAll(string(storage))
|
||||
if err != nil {
|
||||
t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err)
|
||||
}
|
||||
}()
|
||||
|
||||
domain := "example.com"
|
||||
certContents := "certificate"
|
||||
keyContents := "private key"
|
||||
metaContents := `{
|
||||
"domain": "example.com",
|
||||
"certUrl": "https://example.com/cert",
|
||||
"certStableUrl": "https://example.com/cert/stable"
|
||||
}`
|
||||
|
||||
cert := acme.CertificateResource{
|
||||
Domain: domain,
|
||||
CertURL: "https://example.com/cert",
|
||||
CertStableURL: "https://example.com/cert/stable",
|
||||
PrivateKey: []byte(keyContents),
|
||||
Certificate: []byte(certContents),
|
||||
}
|
||||
|
||||
err := saveCertResource(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
certFile, err := ioutil.ReadFile(storage.SiteCertFile(domain))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error reading certificate file, got: %v", err)
|
||||
}
|
||||
if string(certFile) != certContents {
|
||||
t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(certFile))
|
||||
}
|
||||
|
||||
keyFile, err := ioutil.ReadFile(storage.SiteKeyFile(domain))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error reading private key file, got: %v", err)
|
||||
}
|
||||
if string(keyFile) != keyContents {
|
||||
t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(keyFile))
|
||||
}
|
||||
|
||||
metaFile, err := ioutil.ReadFile(storage.SiteMetaFile(domain))
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error reading meta file, got: %v", err)
|
||||
}
|
||||
if string(metaFile) != metaContents {
|
||||
t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(metaFile))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingCertAndKey(t *testing.T) {
|
||||
storage = Storage("./le_test_existing")
|
||||
defer func() {
|
||||
err := os.RemoveAll(string(storage))
|
||||
if err != nil {
|
||||
t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err)
|
||||
}
|
||||
}()
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
if existingCertAndKey(domain) {
|
||||
t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain)
|
||||
}
|
||||
|
||||
err := saveCertResource(acme.CertificateResource{
|
||||
Domain: domain,
|
||||
PrivateKey: []byte("key"),
|
||||
Certificate: []byte("cert"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if !existingCertAndKey(domain) {
|
||||
t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostHasOtherPort(t *testing.T) {
|
||||
configs := []server.Config{
|
||||
{Host: "example.com", Port: "80"},
|
||||
{Host: "sub1.example.com", Port: "80"},
|
||||
{Host: "sub1.example.com", Port: "443"},
|
||||
}
|
||||
|
||||
if hostHasOtherPort(configs, 0, "80") {
|
||||
t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`)
|
||||
}
|
||||
if hostHasOtherPort(configs, 0, "443") {
|
||||
t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`)
|
||||
}
|
||||
if !hostHasOtherPort(configs, 1, "443") {
|
||||
t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakePlaintextRedirects(t *testing.T) {
|
||||
configs := []server.Config{
|
||||
// Happy path = standard redirect from 80 to 443
|
||||
{Host: "example.com", TLS: server.TLSConfig{Managed: true}},
|
||||
|
||||
// Host on port 80 already defined; don't change it (no redirect)
|
||||
{Host: "sub1.example.com", Port: "80", Scheme: "http"},
|
||||
{Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}},
|
||||
|
||||
// Redirect from port 80 to port 5000 in this case
|
||||
{Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}},
|
||||
|
||||
// Can redirect from 80 to either 443 or 5001, but choose 443
|
||||
{Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}},
|
||||
{Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}},
|
||||
}
|
||||
|
||||
result := MakePlaintextRedirects(configs)
|
||||
expectedRedirCount := 3
|
||||
|
||||
if len(result) != len(configs)+expectedRedirCount {
|
||||
t.Errorf("Expected %d redirect(s) to be added, but got %d",
|
||||
expectedRedirCount, len(result)-len(configs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableTLS(t *testing.T) {
|
||||
configs := []server.Config{
|
||||
{Host: "example.com", TLS: server.TLSConfig{Managed: true}},
|
||||
{}, // not managed - no changes!
|
||||
}
|
||||
|
||||
EnableTLS(configs, false)
|
||||
|
||||
if !configs[0].TLS.Enabled {
|
||||
t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false")
|
||||
}
|
||||
if configs[1].TLS.Enabled {
|
||||
t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupConfigsByEmail(t *testing.T) {
|
||||
if groupConfigsByEmail([]server.Config{}, false) == nil {
|
||||
t.Errorf("With empty input, returned map was nil, but expected non-nil map")
|
||||
}
|
||||
|
||||
configs := []server.Config{
|
||||
{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
|
||||
{Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}},
|
||||
{Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
|
||||
{Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}},
|
||||
{Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
|
||||
{Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed
|
||||
}
|
||||
DefaultEmail = "test@example.com"
|
||||
|
||||
groups := groupConfigsByEmail(configs, true)
|
||||
|
||||
if groups == nil {
|
||||
t.Fatalf("Returned map was nil, but expected values")
|
||||
}
|
||||
|
||||
if len(groups) != 2 {
|
||||
t.Errorf("Expected 2 groups, got %d: %#v", len(groups), groups)
|
||||
}
|
||||
if len(groups["foo@bar"]) != 2 {
|
||||
t.Errorf("Expected 2 configs for foo@bar, got %d: %#v", len(groups["foobar"]), groups["foobar"])
|
||||
}
|
||||
if len(groups[DefaultEmail]) != 3 {
|
||||
t.Errorf("Expected 3 configs for %s, got %d: %#v", DefaultEmail, len(groups["foobar"]), groups["foobar"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkQualified(t *testing.T) {
|
||||
// TODO: TestConfigQualifies and this test share the same config list...
|
||||
configs := []server.Config{
|
||||
{Host: ""},
|
||||
{Host: "localhost"},
|
||||
{Host: "123.44.3.21"},
|
||||
{Host: "example.com"},
|
||||
{Host: "example.com", TLS: server.TLSConfig{Manual: true}},
|
||||
{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}},
|
||||
{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}},
|
||||
{Host: "example.com", Scheme: "http"},
|
||||
{Host: "example.com", Port: "80"},
|
||||
{Host: "example.com", Port: "1234"},
|
||||
{Host: "example.com", Scheme: "https"},
|
||||
{Host: "example.com", Port: "80", Scheme: "https"},
|
||||
}
|
||||
expectedManagedCount := 4
|
||||
|
||||
MarkQualified(configs)
|
||||
|
||||
count := 0
|
||||
for _, cfg := range configs {
|
||||
if cfg.TLS.Managed {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
if count != expectedManagedCount {
|
||||
t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
|
||||
}
|
||||
}
|
||||
211
core/https/maintain.go
Normal file
211
core/https/maintain.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/server"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
const (
|
||||
// RenewInterval is how often to check certificates for renewal.
|
||||
RenewInterval = 12 * time.Hour
|
||||
|
||||
// OCSPInterval is how often to check if OCSP stapling needs updating.
|
||||
OCSPInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// maintainAssets is a permanently-blocking function
|
||||
// that loops indefinitely and, on a regular schedule, checks
|
||||
// certificates for expiration and initiates a renewal of certs
|
||||
// that are expiring soon. It also updates OCSP stapling and
|
||||
// performs other maintenance of assets.
|
||||
//
|
||||
// You must pass in the channel which you'll close when
|
||||
// maintenance should stop, to allow this goroutine to clean up
|
||||
// after itself and unblock.
|
||||
func maintainAssets(stopChan chan struct{}) {
|
||||
renewalTicker := time.NewTicker(RenewInterval)
|
||||
ocspTicker := time.NewTicker(OCSPInterval)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-renewalTicker.C:
|
||||
log.Println("[INFO] Scanning for expiring certificates")
|
||||
renewManagedCertificates(false)
|
||||
log.Println("[INFO] Done checking certificates")
|
||||
case <-ocspTicker.C:
|
||||
log.Println("[INFO] Scanning for stale OCSP staples")
|
||||
updateOCSPStaples()
|
||||
log.Println("[INFO] Done checking OCSP staples")
|
||||
case <-stopChan:
|
||||
renewalTicker.Stop()
|
||||
ocspTicker.Stop()
|
||||
log.Println("[INFO] Stopped background maintenance routine")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func renewManagedCertificates(allowPrompts bool) (err error) {
|
||||
var renewed, deleted []Certificate
|
||||
var client *ACMEClient
|
||||
visitedNames := make(map[string]struct{})
|
||||
|
||||
certCacheMu.RLock()
|
||||
for name, cert := range certCache {
|
||||
if !cert.Managed {
|
||||
continue
|
||||
}
|
||||
|
||||
// the list of names on this cert should never be empty...
|
||||
if cert.Names == nil || len(cert.Names) == 0 {
|
||||
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names)
|
||||
deleted = append(deleted, cert)
|
||||
continue
|
||||
}
|
||||
|
||||
// skip names whose certificate we've already renewed
|
||||
if _, ok := visitedNames[name]; ok {
|
||||
continue
|
||||
}
|
||||
for _, name := range cert.Names {
|
||||
visitedNames[name] = struct{}{}
|
||||
}
|
||||
|
||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < renewDurationBefore {
|
||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
|
||||
|
||||
if client == nil {
|
||||
client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.Configure("") // TODO: Bind address of relevant listener, yuck
|
||||
}
|
||||
|
||||
err := client.Renew(cert.Names[0]) // managed certs better have only one name
|
||||
if err != nil {
|
||||
if client.AllowPrompts && timeLeft < 0 {
|
||||
// Certificate renewal failed, the operator is present, and the certificate
|
||||
// is already expired; we should stop immediately and return the error. Note
|
||||
// that we used to do this any time a renewal failed at startup. However,
|
||||
// after discussion in https://github.com/miekg/coredns/issues/642 we decided to
|
||||
// only stop startup if the certificate is expired. We still log the error
|
||||
// otherwise.
|
||||
certCacheMu.RUnlock()
|
||||
return err
|
||||
}
|
||||
log.Printf("[ERROR] %v", err)
|
||||
if cert.OnDemand {
|
||||
deleted = append(deleted, cert)
|
||||
}
|
||||
} else {
|
||||
renewed = append(renewed, cert)
|
||||
}
|
||||
}
|
||||
}
|
||||
certCacheMu.RUnlock()
|
||||
|
||||
// Apply changes to the cache
|
||||
for _, cert := range renewed {
|
||||
_, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand)
|
||||
if err != nil {
|
||||
if client.AllowPrompts {
|
||||
return err // operator is present, so report error immediately
|
||||
}
|
||||
log.Printf("[ERROR] %v", err)
|
||||
}
|
||||
}
|
||||
for _, cert := range deleted {
|
||||
certCacheMu.Lock()
|
||||
for _, name := range cert.Names {
|
||||
delete(certCache, name)
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateOCSPStaples() {
|
||||
// Create a temporary place to store updates
|
||||
// until we release the potentially long-lived
|
||||
// read lock and use a short-lived write lock.
|
||||
type ocspUpdate struct {
|
||||
rawBytes []byte
|
||||
parsed *ocsp.Response
|
||||
}
|
||||
updated := make(map[string]ocspUpdate)
|
||||
|
||||
// A single SAN certificate maps to multiple names, so we use this
|
||||
// set to make sure we don't waste cycles checking OCSP for the same
|
||||
// certificate multiple times.
|
||||
visited := make(map[string]struct{})
|
||||
|
||||
certCacheMu.RLock()
|
||||
for name, cert := range certCache {
|
||||
// skip this certificate if we've already visited it,
|
||||
// and if not, mark all the names as visited
|
||||
if _, ok := visited[name]; ok {
|
||||
continue
|
||||
}
|
||||
for _, n := range cert.Names {
|
||||
visited[n] = struct{}{}
|
||||
}
|
||||
|
||||
// no point in updating OCSP for expired certificates
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
continue
|
||||
}
|
||||
|
||||
var lastNextUpdate time.Time
|
||||
if cert.OCSP != nil {
|
||||
// start checking OCSP staple about halfway through validity period for good measure
|
||||
lastNextUpdate = cert.OCSP.NextUpdate
|
||||
refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
|
||||
|
||||
// since OCSP is already stapled, we need only check if we're in that "refresh window"
|
||||
if time.Now().Before(refreshTime) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
err := stapleOCSP(&cert, nil)
|
||||
if err != nil {
|
||||
if cert.OCSP != nil {
|
||||
// if it was no staple before, that's fine, otherwise we should log the error
|
||||
log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// By this point, we've obtained the latest OCSP response.
|
||||
// If there was no staple before, or if the response is updated, make
|
||||
// sure we apply the update to all names on the certificate.
|
||||
if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate {
|
||||
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
|
||||
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
|
||||
for _, n := range cert.Names {
|
||||
updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
|
||||
}
|
||||
}
|
||||
}
|
||||
certCacheMu.RUnlock()
|
||||
|
||||
// This write lock should be brief since we have all the info we need now.
|
||||
certCacheMu.Lock()
|
||||
for name, update := range updated {
|
||||
cert := certCache[name]
|
||||
cert.OCSP = update.parsed
|
||||
cert.Certificate.OCSPStaple = update.rawBytes
|
||||
certCache[name] = cert
|
||||
}
|
||||
certCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// renewDurationBefore is how long before expiration to renew certificates.
|
||||
const renewDurationBefore = (24 * time.Hour) * 30
|
||||
321
core/https/setup.go
Normal file
321
core/https/setup.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/coredns/core/setup"
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/coredns/server"
|
||||
)
|
||||
|
||||
// Setup sets up the TLS configuration and installs certificates that
|
||||
// are specified by the user in the config file. All the automatic HTTPS
|
||||
// stuff comes later outside of this function.
|
||||
func Setup(c *setup.Controller) (middleware.Middleware, error) {
|
||||
if c.Port == "80" {
|
||||
c.TLS.Enabled = false
|
||||
log.Printf("[WARNING] TLS disabled for %s.", c.Address())
|
||||
return nil, nil
|
||||
}
|
||||
c.TLS.Enabled = true
|
||||
|
||||
// TODO(miek): disabled for now
|
||||
return nil, nil
|
||||
|
||||
for c.Next() {
|
||||
var certificateFile, keyFile, loadDir, maxCerts string
|
||||
|
||||
args := c.RemainingArgs()
|
||||
switch len(args) {
|
||||
case 1:
|
||||
c.TLS.LetsEncryptEmail = args[0]
|
||||
|
||||
// user can force-disable managed TLS this way
|
||||
if c.TLS.LetsEncryptEmail == "off" {
|
||||
c.TLS.Enabled = false
|
||||
return nil, nil
|
||||
}
|
||||
case 2:
|
||||
certificateFile = args[0]
|
||||
keyFile = args[1]
|
||||
c.TLS.Manual = true
|
||||
}
|
||||
|
||||
// Optional block with extra parameters
|
||||
var hadBlock bool
|
||||
for c.NextBlock() {
|
||||
hadBlock = true
|
||||
switch c.Val() {
|
||||
case "protocols":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 2 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
value, ok := supportedProtocols[strings.ToLower(args[0])]
|
||||
if !ok {
|
||||
return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val())
|
||||
}
|
||||
c.TLS.ProtocolMinVersion = value
|
||||
value, ok = supportedProtocols[strings.ToLower(args[1])]
|
||||
if !ok {
|
||||
return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val())
|
||||
}
|
||||
c.TLS.ProtocolMaxVersion = value
|
||||
case "ciphers":
|
||||
for c.NextArg() {
|
||||
value, ok := supportedCiphersMap[strings.ToUpper(c.Val())]
|
||||
if !ok {
|
||||
return nil, c.Errf("Wrong cipher name or cipher not supported '%s'", c.Val())
|
||||
}
|
||||
c.TLS.Ciphers = append(c.TLS.Ciphers, value)
|
||||
}
|
||||
case "clients":
|
||||
c.TLS.ClientCerts = c.RemainingArgs()
|
||||
if len(c.TLS.ClientCerts) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
case "load":
|
||||
c.Args(&loadDir)
|
||||
c.TLS.Manual = true
|
||||
case "max_certs":
|
||||
c.Args(&maxCerts)
|
||||
c.TLS.OnDemand = true
|
||||
default:
|
||||
return nil, c.Errf("Unknown keyword '%s'", c.Val())
|
||||
}
|
||||
}
|
||||
|
||||
// tls requires at least one argument if a block is not opened
|
||||
if len(args) == 0 && !hadBlock {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
|
||||
// set certificate limit if on-demand TLS is enabled
|
||||
if maxCerts != "" {
|
||||
maxCertsNum, err := strconv.Atoi(maxCerts)
|
||||
if err != nil || maxCertsNum < 1 {
|
||||
return nil, c.Err("max_certs must be a positive integer")
|
||||
}
|
||||
if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: We have to do this because it is global; should be per-server or per-vhost...
|
||||
onDemandMaxIssue = int32(maxCertsNum)
|
||||
}
|
||||
}
|
||||
|
||||
// don't try to load certificates unless we're supposed to
|
||||
if !c.TLS.Enabled || !c.TLS.Manual {
|
||||
continue
|
||||
}
|
||||
|
||||
// load a single certificate and key, if specified
|
||||
if certificateFile != "" && keyFile != "" {
|
||||
err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, c.Errf("Unable to load certificate and key files for %s: %v", c.Host, err)
|
||||
}
|
||||
log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile)
|
||||
}
|
||||
|
||||
// load a directory of certificates, if specified
|
||||
if loadDir != "" {
|
||||
err := loadCertsInDir(c, loadDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setDefaultTLSParams(c.Config)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// loadCertsInDir loads all the certificates/keys in dir, as long as
|
||||
// the file ends with .pem. This method of loading certificates is
|
||||
// modeled after haproxy, which expects the certificate and key to
|
||||
// be bundled into the same file:
|
||||
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
|
||||
//
|
||||
// This function may write to the log as it walks the directory tree.
|
||||
func loadCertsInDir(c *setup.Controller, dir string) error {
|
||||
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
|
||||
certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer)
|
||||
var foundKey bool // use only the first key in the file
|
||||
|
||||
bundle, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
// Decode next block so we can see what type it is
|
||||
var derBlock *pem.Block
|
||||
derBlock, bundle = pem.Decode(bundle)
|
||||
if derBlock == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if derBlock.Type == "CERTIFICATE" {
|
||||
// Re-encode certificate as PEM, appending to certificate chain
|
||||
pem.Encode(certBuilder, derBlock)
|
||||
} else if derBlock.Type == "EC PARAMETERS" {
|
||||
// EC keys generated from openssl can be composed of two blocks:
|
||||
// parameters and key (parameter block should come first)
|
||||
if !foundKey {
|
||||
// Encode parameters
|
||||
pem.Encode(keyBuilder, derBlock)
|
||||
|
||||
// Key must immediately follow
|
||||
derBlock, bundle = pem.Decode(bundle)
|
||||
if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" {
|
||||
return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path)
|
||||
}
|
||||
pem.Encode(keyBuilder, derBlock)
|
||||
foundKey = true
|
||||
}
|
||||
} else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") {
|
||||
// RSA key
|
||||
if !foundKey {
|
||||
pem.Encode(keyBuilder, derBlock)
|
||||
foundKey = true
|
||||
}
|
||||
} else {
|
||||
return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type)
|
||||
}
|
||||
}
|
||||
|
||||
certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes()
|
||||
if len(certPEMBytes) == 0 {
|
||||
return c.Errf("%s: failed to parse PEM data", path)
|
||||
}
|
||||
if len(keyPEMBytes) == 0 {
|
||||
return c.Errf("%s: no private key block found", path)
|
||||
}
|
||||
|
||||
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
|
||||
if err != nil {
|
||||
return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err)
|
||||
}
|
||||
log.Printf("[INFO] Successfully loaded TLS assets from %s", path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// setDefaultTLSParams sets the default TLS cipher suites, protocol versions,
|
||||
// and server preferences of a server.Config if they were not previously set
|
||||
// (it does not overwrite; only fills in missing values). It will also set the
|
||||
// port to 443 if not already set, TLS is enabled, TLS is manual, and the host
|
||||
// does not equal localhost.
|
||||
func setDefaultTLSParams(c *server.Config) {
|
||||
// If no ciphers provided, use default list
|
||||
if len(c.TLS.Ciphers) == 0 {
|
||||
c.TLS.Ciphers = defaultCiphers
|
||||
}
|
||||
|
||||
// Not a cipher suite, but still important for mitigating protocol downgrade attacks
|
||||
// (prepend since having it at end breaks http2 due to non-h2-approved suites before it)
|
||||
c.TLS.Ciphers = append([]uint16{tls.TLS_FALLBACK_SCSV}, c.TLS.Ciphers...)
|
||||
|
||||
// Set default protocol min and max versions - must balance compatibility and security
|
||||
if c.TLS.ProtocolMinVersion == 0 {
|
||||
c.TLS.ProtocolMinVersion = tls.VersionTLS10
|
||||
}
|
||||
if c.TLS.ProtocolMaxVersion == 0 {
|
||||
c.TLS.ProtocolMaxVersion = tls.VersionTLS12
|
||||
}
|
||||
|
||||
// Prefer server cipher suites
|
||||
c.TLS.PreferServerCipherSuites = true
|
||||
|
||||
// Default TLS port is 443; only use if port is not manually specified,
|
||||
// TLS is enabled, and the host is not localhost
|
||||
if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" {
|
||||
c.Port = "443"
|
||||
}
|
||||
}
|
||||
|
||||
// Map of supported protocols.
|
||||
// SSLv3 will be not supported in future release.
|
||||
// HTTP/2 only supports TLS 1.2 and higher.
|
||||
var supportedProtocols = map[string]uint16{
|
||||
"ssl3.0": tls.VersionSSL30,
|
||||
"tls1.0": tls.VersionTLS10,
|
||||
"tls1.1": tls.VersionTLS11,
|
||||
"tls1.2": tls.VersionTLS12,
|
||||
}
|
||||
|
||||
// Map of supported ciphers, used only for parsing config.
|
||||
//
|
||||
// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites,
|
||||
// including all but two of the suites below (the two GCM suites).
|
||||
// See https://http2.github.io/http2-spec/#BadCipherSuites
|
||||
//
|
||||
// TLS_FALLBACK_SCSV is not in this list because we manually ensure
|
||||
// it is always added (even though it is not technically a cipher suite).
|
||||
//
|
||||
// This map, like any map, is NOT ORDERED. Do not range over this map.
|
||||
var supportedCiphersMap = map[string]uint16{
|
||||
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
"ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
"ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
"ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
"RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
"RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
"ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
"RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
}
|
||||
|
||||
// List of supported cipher suites in descending order of preference.
|
||||
// Ordering is very important! Getting the wrong order will break
|
||||
// mainstream clients, especially with HTTP/2.
|
||||
//
|
||||
// Note that TLS_FALLBACK_SCSV is not in this list since it is always
|
||||
// added manually.
|
||||
var supportedCiphers = []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
}
|
||||
|
||||
// List of all the ciphers we want to use by default
|
||||
var defaultCiphers = []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
}
|
||||
232
core/https/setup_test.go
Normal file
232
core/https/setup_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/coredns/core/setup"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Write test certificates to disk before tests, and clean up
|
||||
// when we're done.
|
||||
err := ioutil.WriteFile(certFile, testCert, 0644)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
err = ioutil.WriteFile(keyFile, testKey, 0644)
|
||||
if err != nil {
|
||||
os.Remove(certFile)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
result := m.Run()
|
||||
|
||||
os.Remove(certFile)
|
||||
os.Remove(keyFile)
|
||||
os.Exit(result)
|
||||
}
|
||||
|
||||
func TestSetupParseBasic(t *testing.T) {
|
||||
c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``)
|
||||
|
||||
_, err := Setup(c)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
|
||||
// Basic checks
|
||||
if !c.TLS.Manual {
|
||||
t.Error("Expected TLS Manual=true, but was false")
|
||||
}
|
||||
if !c.TLS.Enabled {
|
||||
t.Error("Expected TLS Enabled=true, but was false")
|
||||
}
|
||||
|
||||
// Security defaults
|
||||
if c.TLS.ProtocolMinVersion != tls.VersionTLS10 {
|
||||
t.Errorf("Expected 'tls1.0 (0x0301)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion)
|
||||
}
|
||||
if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 {
|
||||
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", c.TLS.ProtocolMaxVersion)
|
||||
}
|
||||
|
||||
// Cipher checks
|
||||
expectedCiphers := []uint16{
|
||||
tls.TLS_FALLBACK_SCSV,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
}
|
||||
|
||||
// Ensure count is correct (plus one for TLS_FALLBACK_SCSV)
|
||||
if len(c.TLS.Ciphers) != len(expectedCiphers) {
|
||||
t.Errorf("Expected %v Ciphers (including TLS_FALLBACK_SCSV), got %v",
|
||||
len(expectedCiphers), len(c.TLS.Ciphers))
|
||||
}
|
||||
|
||||
// Ensure ordering is correct
|
||||
for i, actual := range c.TLS.Ciphers {
|
||||
if actual != expectedCiphers[i] {
|
||||
t.Errorf("Expected cipher in position %d to be %0x, got %0x", i, expectedCiphers[i], actual)
|
||||
}
|
||||
}
|
||||
|
||||
if !c.TLS.PreferServerCipherSuites {
|
||||
t.Error("Expected PreferServerCipherSuites = true, but was false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupParseIncompleteParams(t *testing.T) {
|
||||
// Using tls without args is an error because it's unnecessary.
|
||||
c := setup.NewTestController(`tls`)
|
||||
_, err := Setup(c)
|
||||
if err == nil {
|
||||
t.Error("Expected an error, but didn't get one")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupParseWithOptionalParams(t *testing.T) {
|
||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
protocols ssl3.0 tls1.2
|
||||
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
|
||||
}`
|
||||
c := setup.NewTestController(params)
|
||||
|
||||
_, err := Setup(c)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
|
||||
if c.TLS.ProtocolMinVersion != tls.VersionSSL30 {
|
||||
t.Errorf("Expected 'ssl3.0 (0x0300)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion)
|
||||
}
|
||||
|
||||
if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 {
|
||||
t.Errorf("Expected 'tls1.2 (0x0302)' as ProtocolMaxVersion, got %#v", c.TLS.ProtocolMaxVersion)
|
||||
}
|
||||
|
||||
if len(c.TLS.Ciphers)-1 != 3 {
|
||||
t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupDefaultWithOptionalParams(t *testing.T) {
|
||||
params := `tls {
|
||||
ciphers RSA-3DES-EDE-CBC-SHA
|
||||
}`
|
||||
c := setup.NewTestController(params)
|
||||
|
||||
_, err := Setup(c)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
if len(c.TLS.Ciphers)-1 != 1 {
|
||||
t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If we allow this... but probably not a good idea.
|
||||
// func TestSetupDisableHTTPRedirect(t *testing.T) {
|
||||
// c := NewTestController(`tls {
|
||||
// allow_http
|
||||
// }`)
|
||||
// _, err := TLS(c)
|
||||
// if err != nil {
|
||||
// t.Errorf("Expected no error, but got %v", err)
|
||||
// }
|
||||
// if !c.TLS.DisableHTTPRedir {
|
||||
// t.Error("Expected HTTP redirect to be disabled, but it wasn't")
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
||||
// Test protocols wrong params
|
||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
protocols ssl tls
|
||||
}`
|
||||
c := setup.NewTestController(params)
|
||||
_, err := Setup(c)
|
||||
if err == nil {
|
||||
t.Errorf("Expected errors, but no error returned")
|
||||
}
|
||||
|
||||
// Test ciphers wrong params
|
||||
params = `tls ` + certFile + ` ` + keyFile + ` {
|
||||
ciphers not-valid-cipher
|
||||
}`
|
||||
c = setup.NewTestController(params)
|
||||
_, err = Setup(c)
|
||||
if err == nil {
|
||||
t.Errorf("Expected errors, but no error returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupParseWithClientAuth(t *testing.T) {
|
||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients client_ca.crt client2_ca.crt
|
||||
}`
|
||||
c := setup.NewTestController(params)
|
||||
_, err := Setup(c)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
|
||||
if count := len(c.TLS.ClientCerts); count != 2 {
|
||||
t.Fatalf("Expected two client certs, had %d", count)
|
||||
}
|
||||
if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" {
|
||||
t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual)
|
||||
}
|
||||
if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" {
|
||||
t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual)
|
||||
}
|
||||
|
||||
// Test missing client cert file
|
||||
params = `tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients
|
||||
}`
|
||||
c = setup.NewTestController(params)
|
||||
_, err = Setup(c)
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, but no error returned")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
certFile = "test_cert.pem"
|
||||
keyFile = "test_key.pem"
|
||||
)
|
||||
|
||||
var testCert = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIBkjCCATmgAwIBAgIJANfFCBcABL6LMAkGByqGSM49BAEwFDESMBAGA1UEAxMJ
|
||||
bG9jYWxob3N0MB4XDTE2MDIxMDIyMjAyNFoXDTE4MDIwOTIyMjAyNFowFDESMBAG
|
||||
A1UEAxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs22MtnG7
|
||||
9K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDLSiVQvFZ6lUszTlczNxVk
|
||||
pEfqrM6xAupB7qN1MHMwHQYDVR0OBBYEFHxYDvAxUwL4XrjPev6qZ/BiLDs5MEQG
|
||||
A1UdIwQ9MDuAFHxYDvAxUwL4XrjPev6qZ/BiLDs5oRikFjAUMRIwEAYDVQQDEwls
|
||||
b2NhbGhvc3SCCQDXxQgXAAS+izAMBgNVHRMEBTADAQH/MAkGByqGSM49BAEDSAAw
|
||||
RQIgRvBqbyJM2JCJqhA1FmcoZjeMocmhxQHTt1c+1N2wFUgCIQDtvrivbBPA688N
|
||||
Qh3sMeAKNKPsx5NxYdoWuu9KWcKz9A==
|
||||
-----END CERTIFICATE-----
|
||||
`)
|
||||
|
||||
var testKey = []byte(`-----BEGIN EC PARAMETERS-----
|
||||
BggqhkjOPQMBBw==
|
||||
-----END EC PARAMETERS-----
|
||||
-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIGLtRmwzYVcrH3J0BnzYbGPdWVF10i9p6mxkA4+b2fURoAoGCCqGSM49
|
||||
AwEHoUQDQgAEs22MtnG79K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDL
|
||||
SiVQvFZ6lUszTlczNxVkpEfqrM6xAupB7g==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`)
|
||||
94
core/https/storage.go
Normal file
94
core/https/storage.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/coredns/core/assets"
|
||||
)
|
||||
|
||||
// storage is used to get file paths in a consistent,
|
||||
// cross-platform way for persisting Let's Encrypt assets
|
||||
// on the file system.
|
||||
var storage = Storage(filepath.Join(assets.Path(), "letsencrypt"))
|
||||
|
||||
// Storage is a root directory and facilitates
|
||||
// forming file paths derived from it.
|
||||
type Storage string
|
||||
|
||||
// Sites gets the directory that stores site certificate and keys.
|
||||
func (s Storage) Sites() string {
|
||||
return filepath.Join(string(s), "sites")
|
||||
}
|
||||
|
||||
// Site returns the path to the folder containing assets for domain.
|
||||
func (s Storage) Site(domain string) string {
|
||||
return filepath.Join(s.Sites(), domain)
|
||||
}
|
||||
|
||||
// SiteCertFile returns the path to the certificate file for domain.
|
||||
func (s Storage) SiteCertFile(domain string) string {
|
||||
return filepath.Join(s.Site(domain), domain+".crt")
|
||||
}
|
||||
|
||||
// SiteKeyFile returns the path to domain's private key file.
|
||||
func (s Storage) SiteKeyFile(domain string) string {
|
||||
return filepath.Join(s.Site(domain), domain+".key")
|
||||
}
|
||||
|
||||
// SiteMetaFile returns the path to the domain's asset metadata file.
|
||||
func (s Storage) SiteMetaFile(domain string) string {
|
||||
return filepath.Join(s.Site(domain), domain+".json")
|
||||
}
|
||||
|
||||
// Users gets the directory that stores account folders.
|
||||
func (s Storage) Users() string {
|
||||
return filepath.Join(string(s), "users")
|
||||
}
|
||||
|
||||
// User gets the account folder for the user with email.
|
||||
func (s Storage) User(email string) string {
|
||||
if email == "" {
|
||||
email = emptyEmail
|
||||
}
|
||||
return filepath.Join(s.Users(), email)
|
||||
}
|
||||
|
||||
// UserRegFile gets the path to the registration file for
|
||||
// the user with the given email address.
|
||||
func (s Storage) UserRegFile(email string) string {
|
||||
if email == "" {
|
||||
email = emptyEmail
|
||||
}
|
||||
fileName := emailUsername(email)
|
||||
if fileName == "" {
|
||||
fileName = "registration"
|
||||
}
|
||||
return filepath.Join(s.User(email), fileName+".json")
|
||||
}
|
||||
|
||||
// UserKeyFile gets the path to the private key file for
|
||||
// the user with the given email address.
|
||||
func (s Storage) UserKeyFile(email string) string {
|
||||
if email == "" {
|
||||
email = emptyEmail
|
||||
}
|
||||
fileName := emailUsername(email)
|
||||
if fileName == "" {
|
||||
fileName = "private"
|
||||
}
|
||||
return filepath.Join(s.User(email), fileName+".key")
|
||||
}
|
||||
|
||||
// emailUsername returns the username portion of an
|
||||
// email address (part before '@') or the original
|
||||
// input if it can't find the "@" symbol.
|
||||
func emailUsername(email string) string {
|
||||
at := strings.Index(email, "@")
|
||||
if at == -1 {
|
||||
return email
|
||||
} else if at == 0 {
|
||||
return email[1:]
|
||||
}
|
||||
return email[:at]
|
||||
}
|
||||
88
core/https/storage_test.go
Normal file
88
core/https/storage_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStorage(t *testing.T) {
|
||||
storage = Storage("./le_test")
|
||||
|
||||
if expected, actual := filepath.Join("le_test", "sites"), storage.Sites(); actual != expected {
|
||||
t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("test.com"); actual != expected {
|
||||
t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected {
|
||||
t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected {
|
||||
t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected {
|
||||
t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "users"), storage.Users(); actual != expected {
|
||||
t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("me@example.com"); actual != expected {
|
||||
t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected {
|
||||
t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected {
|
||||
t.Errorf("Expected UserKeyFile() to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// Test with empty emails
|
||||
if expected, actual := filepath.Join("le_test", "users", emptyEmail), storage.User(emptyEmail); actual != expected {
|
||||
t.Errorf("Expected User(\"\") to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected {
|
||||
t.Errorf("Expected UserRegFile(\"\") to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected {
|
||||
t.Errorf("Expected UserKeyFile(\"\") to return '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailUsername(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input, expect string
|
||||
}{
|
||||
{
|
||||
input: "username@example.com",
|
||||
expect: "username",
|
||||
},
|
||||
{
|
||||
input: "plus+addressing@example.com",
|
||||
expect: "plus+addressing",
|
||||
},
|
||||
{
|
||||
input: "me+plus-addressing@example.com",
|
||||
expect: "me+plus-addressing",
|
||||
},
|
||||
{
|
||||
input: "not-an-email",
|
||||
expect: "not-an-email",
|
||||
},
|
||||
{
|
||||
input: "@foobar.com",
|
||||
expect: "foobar.com",
|
||||
},
|
||||
{
|
||||
input: emptyEmail,
|
||||
expect: emptyEmail,
|
||||
},
|
||||
{
|
||||
input: "",
|
||||
expect: "",
|
||||
},
|
||||
} {
|
||||
if actual := emailUsername(test.input); actual != test.expect {
|
||||
t.Errorf("Test %d: Expected username to be '%s' but was '%s'", i, test.expect, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
200
core/https/user.go
Normal file
200
core/https/user.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/coredns/server"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
// User represents a Let's Encrypt user account.
|
||||
type User struct {
|
||||
Email string
|
||||
Registration *acme.RegistrationResource
|
||||
key crypto.PrivateKey
|
||||
}
|
||||
|
||||
// GetEmail gets u's email.
|
||||
func (u User) GetEmail() string {
|
||||
return u.Email
|
||||
}
|
||||
|
||||
// GetRegistration gets u's registration resource.
|
||||
func (u User) GetRegistration() *acme.RegistrationResource {
|
||||
return u.Registration
|
||||
}
|
||||
|
||||
// GetPrivateKey gets u's private key.
|
||||
func (u User) GetPrivateKey() crypto.PrivateKey {
|
||||
return u.key
|
||||
}
|
||||
|
||||
// getUser loads the user with the given email from disk.
|
||||
// If the user does not exist, it will create a new one,
|
||||
// but it does NOT save new users to the disk or register
|
||||
// them via ACME. It does NOT prompt the user.
|
||||
func getUser(email string) (User, error) {
|
||||
var user User
|
||||
|
||||
// open user file
|
||||
regFile, err := os.Open(storage.UserRegFile(email))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// create a new user
|
||||
return newUser(email)
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
defer regFile.Close()
|
||||
|
||||
// load user information
|
||||
err = json.NewDecoder(regFile).Decode(&user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
// load their private key
|
||||
user.key, err = loadPrivateKey(storage.UserKeyFile(email))
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// saveUser persists a user's key and account registration
|
||||
// to the file system. It does NOT register the user via ACME
|
||||
// or prompt the user.
|
||||
func saveUser(user User) error {
|
||||
// make user account folder
|
||||
err := os.MkdirAll(storage.User(user.Email), 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// save private key file
|
||||
err = savePrivateKey(user.key, storage.UserKeyFile(user.Email))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// save registration file
|
||||
jsonBytes, err := json.MarshalIndent(&user, "", "\t")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(storage.UserRegFile(user.Email), jsonBytes, 0600)
|
||||
}
|
||||
|
||||
// newUser creates a new User for the given email address
|
||||
// with a new private key. This function does NOT save the
|
||||
// user to disk or register it via ACME. If you want to use
|
||||
// a user account that might already exist, call getUser
|
||||
// instead. It does NOT prompt the user.
|
||||
func newUser(email string) (User, error) {
|
||||
user := User{Email: email}
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
return user, errors.New("error generating private key: " + err.Error())
|
||||
}
|
||||
user.key = privateKey
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// getEmail does everything it can to obtain an email
|
||||
// address from the user to use for TLS for cfg. If it
|
||||
// cannot get an email address, it returns empty string.
|
||||
// (It will warn the user of the consequences of an
|
||||
// empty email.) This function MAY prompt the user for
|
||||
// input. If userPresent is false, the operator will
|
||||
// NOT be prompted and an empty email may be returned.
|
||||
func getEmail(cfg server.Config, userPresent bool) string {
|
||||
// First try the tls directive from the Caddyfile
|
||||
leEmail := cfg.TLS.LetsEncryptEmail
|
||||
if leEmail == "" {
|
||||
// Then try memory (command line flag or typed by user previously)
|
||||
leEmail = DefaultEmail
|
||||
}
|
||||
if leEmail == "" {
|
||||
// Then try to get most recent user email ~/.caddy/users file
|
||||
userDirs, err := ioutil.ReadDir(storage.Users())
|
||||
if err == nil {
|
||||
var mostRecent os.FileInfo
|
||||
for _, dir := range userDirs {
|
||||
if !dir.IsDir() {
|
||||
continue
|
||||
}
|
||||
if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) {
|
||||
leEmail = dir.Name()
|
||||
DefaultEmail = leEmail // save for next time
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if leEmail == "" && userPresent {
|
||||
// Alas, we must bother the user and ask for an email address;
|
||||
// if they proceed they also agree to the SA.
|
||||
reader := bufio.NewReader(stdin)
|
||||
fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.")
|
||||
fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:")
|
||||
fmt.Println(" " + saURL) // TODO: Show current SA link
|
||||
fmt.Println("Please enter your email address so you can recover your account if needed.")
|
||||
fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.")
|
||||
fmt.Print("Email address: ")
|
||||
var err error
|
||||
leEmail, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
leEmail = strings.TrimSpace(leEmail)
|
||||
DefaultEmail = leEmail
|
||||
Agreed = true
|
||||
}
|
||||
return leEmail
|
||||
}
|
||||
|
||||
// promptUserAgreement prompts the user to agree to the agreement
|
||||
// at agreementURL via stdin. If the agreement has changed, then pass
|
||||
// true as the second argument. If this is the user's first time
|
||||
// agreeing, pass false. It returns whether the user agreed or not.
|
||||
func promptUserAgreement(agreementURL string, changed bool) bool {
|
||||
if changed {
|
||||
fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL)
|
||||
fmt.Print("Do you agree to the new terms? (y/n): ")
|
||||
} else {
|
||||
fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL)
|
||||
fmt.Print("Do you agree to the terms? (y/n): ")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(stdin)
|
||||
answer, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
answer = strings.ToLower(strings.TrimSpace(answer))
|
||||
|
||||
return answer == "y" || answer == "yes"
|
||||
}
|
||||
|
||||
// stdin is used to read the user's input if prompted;
|
||||
// this is changed by tests during tests.
|
||||
var stdin = io.ReadWriter(os.Stdin)
|
||||
|
||||
// The name of the folder for accounts where the email
|
||||
// address was not provided; default 'username' if you will.
|
||||
const emptyEmail = "default"
|
||||
|
||||
// TODO: Use latest
|
||||
const saURL = "https://letsencrypt.org/documents/LE-SA-v1.0.1-July-27-2015.pdf"
|
||||
196
core/https/user_test.go
Normal file
196
core/https/user_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/server"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 128)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not generate test private key: %v", err)
|
||||
}
|
||||
u := User{
|
||||
Email: "me@mine.com",
|
||||
Registration: new(acme.RegistrationResource),
|
||||
key: privateKey,
|
||||
}
|
||||
|
||||
if expected, actual := "me@mine.com", u.GetEmail(); actual != expected {
|
||||
t.Errorf("Expected email '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if u.GetRegistration() == nil {
|
||||
t.Error("Expected a registration resource, but got nil")
|
||||
}
|
||||
if expected, actual := privateKey, u.GetPrivateKey(); actual != expected {
|
||||
t.Errorf("Expected the private key at address %p but got one at %p instead ", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
email := "me@foobar.com"
|
||||
user, err := newUser(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user: %v", err)
|
||||
}
|
||||
if user.key == nil {
|
||||
t.Error("Private key is nil")
|
||||
}
|
||||
if user.Email != email {
|
||||
t.Errorf("Expected email to be %s, but was %s", email, user.Email)
|
||||
}
|
||||
if user.Registration != nil {
|
||||
t.Error("New user already has a registration resource; it shouldn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveUser(t *testing.T) {
|
||||
storage = Storage("./testdata")
|
||||
defer os.RemoveAll(string(storage))
|
||||
|
||||
email := "me@foobar.com"
|
||||
user, err := newUser(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user: %v", err)
|
||||
}
|
||||
|
||||
err = saveUser(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Error saving user: %v", err)
|
||||
}
|
||||
_, err = os.Stat(storage.UserRegFile(email))
|
||||
if err != nil {
|
||||
t.Errorf("Cannot access user registration file, error: %v", err)
|
||||
}
|
||||
_, err = os.Stat(storage.UserKeyFile(email))
|
||||
if err != nil {
|
||||
t.Errorf("Cannot access user private key file, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserDoesNotAlreadyExist(t *testing.T) {
|
||||
storage = Storage("./testdata")
|
||||
defer os.RemoveAll(string(storage))
|
||||
|
||||
user, err := getUser("user_does_not_exist@foobar.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting user: %v", err)
|
||||
}
|
||||
|
||||
if user.key == nil {
|
||||
t.Error("Expected user to have a private key, but it was nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserAlreadyExists(t *testing.T) {
|
||||
storage = Storage("./testdata")
|
||||
defer os.RemoveAll(string(storage))
|
||||
|
||||
email := "me@foobar.com"
|
||||
|
||||
// Set up test
|
||||
user, err := newUser(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user: %v", err)
|
||||
}
|
||||
err = saveUser(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Error saving user: %v", err)
|
||||
}
|
||||
|
||||
// Expect to load user from disk
|
||||
user2, err := getUser(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting user: %v", err)
|
||||
}
|
||||
|
||||
// Assert keys are the same
|
||||
if !PrivateKeysSame(user.key, user2.key) {
|
||||
t.Error("Expected private key to be the same after loading, but it wasn't")
|
||||
}
|
||||
|
||||
// Assert emails are the same
|
||||
if user.Email != user2.Email {
|
||||
t.Errorf("Expected emails to be equal, but was '%s' before and '%s' after loading", user.Email, user2.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEmail(t *testing.T) {
|
||||
// let's not clutter up the output
|
||||
origStdout := os.Stdout
|
||||
os.Stdout = nil
|
||||
defer func() { os.Stdout = origStdout }()
|
||||
|
||||
storage = Storage("./testdata")
|
||||
defer os.RemoveAll(string(storage))
|
||||
DefaultEmail = "test2@foo.com"
|
||||
|
||||
// Test1: Use email in config
|
||||
config := server.Config{
|
||||
TLS: server.TLSConfig{
|
||||
LetsEncryptEmail: "test1@foo.com",
|
||||
},
|
||||
}
|
||||
actual := getEmail(config, true)
|
||||
if actual != "test1@foo.com" {
|
||||
t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual)
|
||||
}
|
||||
|
||||
// Test2: Use default email from flag (or user previously typing it)
|
||||
actual = getEmail(server.Config{}, true)
|
||||
if actual != DefaultEmail {
|
||||
t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual)
|
||||
}
|
||||
|
||||
// Test3: Get input from user
|
||||
DefaultEmail = ""
|
||||
stdin = new(bytes.Buffer)
|
||||
_, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
|
||||
if err != nil {
|
||||
t.Fatalf("Could not simulate user input, error: %v", err)
|
||||
}
|
||||
actual = getEmail(server.Config{}, true)
|
||||
if actual != "test3@foo.com" {
|
||||
t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
|
||||
}
|
||||
|
||||
// Test4: Get most recent email from before
|
||||
DefaultEmail = ""
|
||||
for i, eml := range []string{
|
||||
"test4-3@foo.com",
|
||||
"test4-2@foo.com",
|
||||
"test4-1@foo.com",
|
||||
} {
|
||||
u, err := newUser(eml)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user %d: %v", i, err)
|
||||
}
|
||||
err = saveUser(u)
|
||||
if err != nil {
|
||||
t.Fatalf("Error saving user %d: %v", i, err)
|
||||
}
|
||||
|
||||
// Change modified time so they're all different, so the test becomes deterministic
|
||||
f, err := os.Stat(storage.User(eml))
|
||||
if err != nil {
|
||||
t.Fatalf("Could not access user folder for '%s': %v", eml, err)
|
||||
}
|
||||
chTime := f.ModTime().Add(-(time.Duration(i) * time.Second))
|
||||
if err := os.Chtimes(storage.User(eml), chTime, chTime); err != nil {
|
||||
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
|
||||
}
|
||||
}
|
||||
actual = getEmail(server.Config{}, true)
|
||||
if actual != "test4-3@foo.com" {
|
||||
t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user