Add optional TLS support to /metrics endpoint (#7255)

* Use exporter-toolkit to enable optional TLS encryption on /metrics endpoint

Signed-off-by: peppi-lotta <peppi-lotta.saari@est.tech>

* Implement startup listener to signal server readiness

Signed-off-by: peppi-lotta <peppi-lotta.saari@est.tech>

---------

Signed-off-by: peppi-lotta <peppi-lotta.saari@est.tech>
This commit is contained in:
Peppi-Lotta
2026-03-12 22:49:00 +02:00
committed by GitHub
parent a8c802e1b3
commit 7ff001dca7
13 changed files with 553 additions and 8 deletions

View File

@@ -3,18 +3,22 @@ package metrics
import (
"context"
"log/slog"
"net"
"net/http"
"os"
"sync"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/prometheus/exporter-toolkit/web"
)
// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics .
@@ -34,6 +38,8 @@ type Metrics struct {
zoneMu sync.RWMutex
plugins map[string]struct{} // all available plugins, used to determine which plugin made the client write
tlsConfigPath string
}
// New returns a new instance of Metrics with the given address.
@@ -83,6 +89,32 @@ func (m *Metrics) ZoneNames() []string {
return s
}
// startupListener wraps a net.Listener to detect when Accept() is first called
type startupListener struct {
net.Listener
readyOnce sync.Once
ready chan struct{}
}
func newStartupListener(l net.Listener) *startupListener {
return &startupListener{
Listener: l,
ready: make(chan struct{}),
}
}
func (sl *startupListener) Accept() (net.Conn, error) {
// Signal ready on first Accept() call (server is running)
sl.readyOnce.Do(func() {
close(sl.ready)
})
return sl.Listener.Accept()
}
func (sl *startupListener) Ready() <-chan struct{} {
return sl.ready
}
// OnStartup sets up the metrics on startup.
func (m *Metrics) OnStartup() error {
ln, err := reuseport.Listen("tcp", m.Addr)
@@ -91,7 +123,9 @@ func (m *Metrics) OnStartup() error {
return err
}
m.ln = ln
startupListener := newStartupListener(ln)
m.ln = startupListener
m.lnSetup = true
m.mux = http.NewServeMux()
@@ -99,6 +133,7 @@ func (m *Metrics) OnStartup() error {
// creating some helper variables to avoid data races on m.srv and m.ln
server := &http.Server{
Addr: m.Addr,
Handler: m.mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
@@ -106,10 +141,53 @@ func (m *Metrics) OnStartup() error {
}
m.srv = server
if m.tlsConfigPath == "" {
go func() {
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
log.Errorf("Failed to start HTTP metrics server: %s", err)
}
}()
ListenAddr = ln.Addr().String() // For tests.
return nil
}
// Check TLS config file existence
if _, err := os.Stat(m.tlsConfigPath); os.IsNotExist(err) {
log.Errorf("TLS config file does not exist: %s", m.tlsConfigPath)
return err
}
// Create web config for ListenAndServe
webConfig := &web.FlagConfig{
WebListenAddresses: &[]string{m.Addr},
WebSystemdSocket: new(bool), // false by default
WebConfigFile: &m.tlsConfigPath,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
// Create channels for synchronization
startUpErr := make(chan error, 1)
go func() {
server.Serve(ln)
// Try to start the server and report result if there an error.
// web.Serve() never returns nil, it always returns a non-nil error and
// it doesn't retun anything if server starts successfully.
// startupListener handles capturing succesful startup.
err := web.Serve(m.ln, server, webConfig, logger)
if err != nil && err != http.ErrServerClosed {
log.Errorf("Failed to start HTTPS metrics server: %v", err)
startUpErr <- err
}
}()
// Wait for startup errors
select {
case err := <-startUpErr:
return err
case <-startupListener.Ready():
log.Infof("Server is ready and accepting connections")
}
ListenAddr = ln.Addr().String() // For tests.
return nil
}

View File

@@ -2,9 +2,18 @@ package metrics
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"os"
"testing"
"time"
@@ -17,6 +26,388 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"
)
const (
serverCertFile = "test_data/server.crt"
serverKeyFile = "test_data/server.key"
clientCertFile = "test_data/client_selfsigned.crt"
clientKeyFile = "test_data/client_selfsigned.key"
tlsCaChainFile = "test_data/tls-ca-chain.pem"
)
func createTestCertFiles(t *testing.T) error {
t.Helper()
// Generate CA certificate
caCert, caKey, err := generateCA()
if err != nil {
t.Fatalf("Failed to generate CA certificate: %v", err)
return err
}
// Generate server certificate signed by CA
cert, key, err := generateCert(caCert, caKey)
if err != nil {
t.Fatalf("Failed to generate server certificate: %v", err)
return err
}
// Generate client CA certificate
clientCaCert, clientCaKey, err := generateCA()
if err != nil {
t.Fatalf("Failed to generate client CA certificate: %v", err)
return err
}
// Generate client certificate signed by CA
clientCert, clientKey, err := generateCert(clientCaCert, clientCaKey)
if err != nil {
t.Fatalf("Failed to generate client certificate: %v", err)
return err
}
// Create ca chain file
caChain := append(caCert, clientCaCert...)
// Write certificates to temporary files
err = writeFile(t, string(cert), serverCertFile)
if err != nil {
t.Fatalf("Failed to write server certificate: %v", err)
return err
}
err = writeFile(t, string(key), serverKeyFile)
if err != nil {
t.Fatalf("Failed to write server key: %v", err)
return err
}
err = writeFile(t, string(clientCert), clientCertFile)
if err != nil {
t.Fatalf("Failed to write client certificate: %v", err)
return err
}
err = writeFile(t, string(clientKey), clientKeyFile)
if err != nil {
t.Fatalf("Failed to write client key: %v", err)
return err
}
err = writeFile(t, string(caChain), tlsCaChainFile)
if err != nil {
t.Fatalf("Failed to write CA certificate: %v", err)
return err
}
return nil
}
func generateCA() ([]byte, []byte, error) {
ca := &x509.Certificate{
SerialNumber: big.NewInt(2023),
Subject: pkix.Name{
Organization: []string{"Test CA"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}
caPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})
caPrivKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
})
return caPEM, caPrivKeyPEM, nil
}
func generateCert(caCertPEM, caKeyPEM []byte) ([]byte, []byte, error) {
caCertBlock, _ := pem.Decode(caCertPEM)
caCert, err := x509.ParseCertificate(caCertBlock.Bytes)
if err != nil {
return nil, nil, err
}
caKeyBlock, _ := pem.Decode(caKeyPEM)
caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes)
if err != nil {
return nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: big.NewInt(2023),
Subject: pkix.Name{
Organization: []string{"Test Server"},
},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, caCert, &certPrivKey.PublicKey, caKey)
if err != nil {
return nil, nil, err
}
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
certPrivKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
return certPEM, certPrivKeyPEM, nil
}
func cleanupTestCertFiles() {
os.Remove(serverCertFile)
os.Remove(serverKeyFile)
os.Remove(clientCertFile)
os.Remove(clientKeyFile)
os.Remove(tlsCaChainFile)
}
func writeFile(t *testing.T, content, path string) error {
t.Helper()
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
return err
}
return nil
}
func getTLSClient(clientCertName bool) *http.Client {
cert, err := os.ReadFile(tlsCaChainFile)
if err != nil {
panic("Unable to start TLS client. Check cert path")
}
var clientCertficate tls.Certificate
if clientCertName {
clientCertficate, err = tls.LoadX509KeyPair(
clientCertFile,
clientKeyFile,
)
if err != nil {
panic(fmt.Sprintf("failed to load client certificate: %v", err))
}
}
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: func() *x509.CertPool {
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(cert)
return caCertPool
}(),
GetClientCertificate: func(req *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientCertficate, nil
},
},
},
}
return client
}
func TestMetricsTLS(t *testing.T) {
err := createTestCertFiles(t)
if err != nil {
t.Fatalf("Failed to create test certificate files: %v", err)
}
defer cleanupTestCertFiles()
tests := []struct {
name string
tlsConfigPath string
UseTLSClient bool
clientCertificate bool
caFile string
expectStartupError bool
expectRequestError bool
}{
{
name: "No TLS config: starts a HTTP server, connect successfully with default client",
tlsConfigPath: "",
},
{
name: "No TLS config: starts HTTP server, connection fails with TLS client",
tlsConfigPath: "",
UseTLSClient: true,
expectRequestError: true,
},
{
name: "Empty TLS config: starts a HTTP server",
tlsConfigPath: "test_data/configs/empty.yml",
},
{
name: "Valid TLS config, no client cert, successful connection with TLS client",
tlsConfigPath: "test_data/configs/valid_verifyclientcertifgiven.yml",
UseTLSClient: true,
},
{
name: `Valid TLS config, connection fails with default client`,
tlsConfigPath: "test_data/configs/valid_verifyclientcertifgiven.yml",
expectRequestError: true,
},
{
name: `Valid TLS config with RequireAnyClientCert, connection succeeds with TLS client presenting (valid) certificate`,
tlsConfigPath: "test_data/configs/valid_requireanyclientcert.yml",
UseTLSClient: true,
clientCertificate: true,
},
{
name: "Wrong path to TLS config file fails to start server",
tlsConfigPath: "test_data/configs/this-does-not-exist.yml",
UseTLSClient: true,
expectStartupError: true,
},
{
name: `TLS config hasinvalid structure, fails to start server`,
tlsConfigPath: "test_data/configs/junk.yml",
UseTLSClient: true,
expectStartupError: true,
},
{
name: "Missing key file, fails to start server",
tlsConfigPath: "test_data/configs/keyPath_empty.yml",
UseTLSClient: true,
expectStartupError: true,
},
{
name: "Missing cert file, fails to start server",
tlsConfigPath: "test_data/configs/certPath_empty.yml",
UseTLSClient: true,
expectStartupError: true,
},
{
name: "Wrong key file path, fails to start server",
tlsConfigPath: "test_data/configs/keyPath_invalid.yml",
UseTLSClient: true,
expectStartupError: true,
},
{
name: "Wrong cert file path, fails to start server",
tlsConfigPath: "test_data/configs/certPath_invalid.yml",
UseTLSClient: true,
expectStartupError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
met := New("localhost:0")
met.tlsConfigPath = tt.tlsConfigPath
// Start server
err := met.OnStartup()
if tt.expectStartupError {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
t.Fatalf("Failed to start metrics handler: %s", err)
}
defer met.OnFinalShutdown()
// Wait for server to be ready
select {
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for server to start")
case <-func() chan struct{} {
ch := make(chan struct{})
go func() {
for {
conn, err := net.DialTimeout("tcp", ListenAddr, 100*time.Millisecond)
if err == nil {
conn.Close()
close(ch)
return
}
time.Sleep(100 * time.Millisecond)
}
}()
return ch
}():
}
// Create appropriate client and protocol
var client *http.Client
var protocol string
if tt.UseTLSClient {
client = getTLSClient(tt.clientCertificate)
protocol = "https"
} else {
client = http.DefaultClient
protocol = "http"
}
// Try multiple times to account for server startup time
var resp *http.Response
var err2 error
for i := range 10 {
url := fmt.Sprintf("%s://%s/metrics", protocol, ListenAddr)
t.Logf("Attempt %d: Connecting to %s", i+1, url)
resp, err2 = client.Get(url)
if err2 == nil {
t.Logf("Successfully connected to metrics server")
break
}
t.Logf("Connection attempt failed: %v", err2)
time.Sleep(200 * time.Millisecond)
}
if err2 != nil {
if tt.expectRequestError {
return
}
t.Fatalf("Failed to connect to metrics server: %v", err2)
}
if resp != nil {
defer resp.Body.Close()
}
if tt.expectRequestError {
// If we expect a request error but got a response, check if it's a bad status code
// which indicates the connection succeeded but the request was invalid (e.g., HTTP to HTTPS server)
if resp.StatusCode == http.StatusBadRequest {
// Got expected error response
return
}
// Got unexpected response status
t.Fatalf("Expected request error with status %d but got response with status %d", http.StatusBadRequest, resp.StatusCode)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
})
}
}
func TestMetrics(t *testing.T) {
met := New("localhost:0")
if err := met.OnStartup(); err != nil {

View File

@@ -9,12 +9,10 @@ import (
"github.com/coredns/coredns/coremain"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/uniq"
)
var (
log = clog.NewWithPlugin("prometheus")
u = uniq.New()
registry = newReg()
)
@@ -97,6 +95,27 @@ func parse(c *caddy.Controller) (*Metrics, error) {
default:
return met, c.ArgErr()
}
// Parse TLS block if present
for c.NextBlock() {
switch c.Val() {
case "tls":
if met.tlsConfigPath != "" {
return nil, c.Err("tls block already specified")
}
// Get cert and key files as positional arguments
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
tlsCfgPath := args[0]
met.tlsConfigPath = tlsCfgPath
default:
return nil, c.Errf("unknown option: %s", c.Val())
}
}
}
return met, nil
}

View File

@@ -0,0 +1,3 @@
tls_server_config:
cert_file: ""
key_file: "../server.key"

View File

@@ -0,0 +1,3 @@
tls_server_config:
cert_file: "somefile"
key_file: "../server.key"

View File

@@ -0,0 +1,20 @@
hWkNKCp3fvIx3jKnsaBI
TuEjdwNS8A2vYdFbiKqr
ay3RiOtykgt4m6m3KOol
ZreGpJRGmpDSVV9cioiF
r7kDOHhHU2frvv0nLcY2
uQMQM4XgqFkCG6gFAIJZ
g99tTkrZhN9b6pkJ6J2y
rzdt729HrA2RblDGYfjs
MW7GxrBdlCnliYJGPhfr
g9kaXxMXcDwsw0C0rv0u
637ZmfRGElb6VBVOtgqn
RG0MRezjLYCJQBMUdRDE
RzO4VicAzj7asVZAT3oo
nPw267UONk7h7KBYRgch
Alj38foWqjV3heXXdahm
TrMzMgl6JIQ1x4OZB5i4
qlrXFJoeV6Pr77nuiEh9
3yE5vMnnKHm2nImEfzMG
bI01UDObHRSaoJLC0vTD
G9tlcKU883NkQ6nsxJ8Y

View File

@@ -0,0 +1,3 @@
tls_server_config:
cert_file: "../server.crt"
key_file: ""

View File

@@ -0,0 +1,3 @@
tls_server_config:
cert_file: "../server.cert"
key_file: "somefile"

View File

@@ -0,0 +1,4 @@
tls_server_config:
cert_file: "../server.crt"
key_file: "../server.key"
client_auth_type: "RequireAnyClientCert"

View File

@@ -0,0 +1,4 @@
tls_server_config:
cert_file: "../server.crt"
key_file: "../server.key"
client_auth_type: "VerifyClientCertIfGiven"