Skip to content

Commit

Permalink
Merge pull request #291 from pracucci/add-custom-dialer-option-to-htt…
Browse files Browse the repository at this point in the history
…p-client

Added functional option to allow to customize DialContext() in HTTP client
  • Loading branch information
roidelapluie authored Apr 21, 2021
2 parents 4240322 + 6a9c79c commit 3b362f5
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 19 deletions.
81 changes: 70 additions & 11 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package config

import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
Expand All @@ -38,6 +40,12 @@ var DefaultHTTPClientConfig = HTTPClientConfig{
FollowRedirects: true,
}

// defaultHTTPClientOptions holds the default HTTP client options.
var defaultHTTPClientOptions = httpClientOptions{
keepAlivesEnabled: true,
http2Enabled: true,
}

type closeIdler interface {
CloseIdleConnections()
}
Expand Down Expand Up @@ -194,15 +202,50 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error {
return unmarshal((*plain)(a))
}

// DialContextFunc defines the signature of the DialContext() function implemented
// by net.Dialer.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)

type httpClientOptions struct {
dialContextFunc DialContextFunc
keepAlivesEnabled bool
http2Enabled bool
}

// HTTPClientOption defines an option that can be applied to the HTTP client.
type HTTPClientOption func(options *httpClientOptions)

// WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`.
func WithDialContextFunc(fn DialContextFunc) HTTPClientOption {
return func(opts *httpClientOptions) {
opts.dialContextFunc = fn
}
}

// WithKeepAlivesDisabled allows to disable HTTP keepalive.
func WithKeepAlivesDisabled() HTTPClientOption {
return func(opts *httpClientOptions) {
opts.keepAlivesEnabled = false
}
}

// WithHTTP2Disabled allows to disable HTTP2.
func WithHTTP2Disabled() HTTPClientOption {
return func(opts *httpClientOptions) {
opts.http2Enabled = false
}
}

// NewClient returns a http.Client using the specified http.RoundTripper.
func newClient(rt http.RoundTripper) *http.Client {
return &http.Client{Transport: rt}
}

// NewClientFromConfig returns a new HTTP client configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (*http.Client, error) {
rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2)
// given config.HTTPClientConfig and config.HTTPClientOption.
// The name is used as go-conntrack metric label.
func NewClientFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (*http.Client, error) {
rt, err := NewRoundTripperFromConfig(cfg, name, optFuncs...)
if err != nil {
return nil, err
}
Expand All @@ -216,29 +259,45 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, e
}

// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (http.RoundTripper, error) {
// given config.HTTPClientConfig and config.HTTPClientOption.
// The name is used as go-conntrack metric label.
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) {
opts := defaultHTTPClientOptions
for _, f := range optFuncs {
f(&opts)
}

var dialContext func(ctx context.Context, network, addr string) (net.Conn, error)

if opts.dialContextFunc != nil {
dialContext = conntrack.NewDialContextFunc(
conntrack.DialWithDialContextFunc((func(context.Context, string, string) (net.Conn, error))(opts.dialContextFunc)),
conntrack.DialWithTracing(),
conntrack.DialWithName(name))
} else {
dialContext = conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name))
}

newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
// The only timeout we care about is the configured scrape timeout.
// It is applied on request. So we leave out any timings here.
var rt http.RoundTripper = &http.Transport{
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
MaxIdleConns: 20000,
MaxIdleConnsPerHost: 1000, // see https://github.com/golang/go/issues/13801
DisableKeepAlives: disableKeepAlives,
DisableKeepAlives: !opts.keepAlivesEnabled,
TLSClientConfig: tlsConfig,
DisableCompression: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
IdleConnTimeout: 5 * time.Minute,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name),
),
DialContext: dialContext,
}
if enableHTTP2 {
if opts.http2Enabled {
// HTTP/2 support is golang has many problematic cornercases where
// dead connections would be kept and used in connection pools.
// https://github.com/golang/go/issues/32388
Expand Down
37 changes: 29 additions & 8 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package config

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -50,6 +53,7 @@ const (
MissingKey = "missing/secret.key"

ExpectedMessage = "I'm here to serve you!!!"
ExpectedError = "expected error"
AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo"
AuthorizationCredentialsFile = "testdata/bearer.token"
AuthorizationType = "APIKEY"
Expand Down Expand Up @@ -350,7 +354,7 @@ func TestNewClientFromConfig(t *testing.T) {
if err != nil {
t.Fatal(err.Error())
}
client, err := NewClientFromConfig(validConfig.clientConfig, "test", false, true)
client, err := NewClientFromConfig(validConfig.clientConfig, "test")
if err != nil {
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
continue
Expand Down Expand Up @@ -400,7 +404,7 @@ func TestNewClientFromInvalidConfig(t *testing.T) {
}

for _, invalidConfig := range newClientInvalidConfig {
client, err := NewClientFromConfig(invalidConfig.clientConfig, "test", false, true)
client, err := NewClientFromConfig(invalidConfig.clientConfig, "test")
if client != nil {
t.Errorf("A client instance was returned instead of nil using this config: %+v", invalidConfig.clientConfig)
}
Expand All @@ -413,6 +417,23 @@ func TestNewClientFromInvalidConfig(t *testing.T) {
}
}

func TestCustomDialContextFunc(t *testing.T) {
dialFn := func(_ context.Context, _, _ string) (net.Conn, error) {
return nil, errors.New(ExpectedError)
}

cfg := HTTPClientConfig{}
client, err := NewClientFromConfig(cfg, "test", WithDialContextFunc(dialFn))
if err != nil {
t.Fatalf("Can't create a client from this config: %+v", cfg)
}

_, err = client.Get("http://localhost")
if err == nil || !strings.Contains(err.Error(), ExpectedError) {
t.Errorf("Expected error %q but got %q", ExpectedError, err)
}
}

func TestMissingBearerAuthFile(t *testing.T) {
cfg := HTTPClientConfig{
BearerTokenFile: MissingBearerTokenFile,
Expand All @@ -439,7 +460,7 @@ func TestMissingBearerAuthFile(t *testing.T) {
}
defer testServer.Close()

client, err := NewClientFromConfig(cfg, "test", false, true)
client, err := NewClientFromConfig(cfg, "test")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -637,7 +658,7 @@ func TestBasicAuthNoPassword(t *testing.T) {
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
client, err := NewClientFromConfig(*cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand All @@ -663,7 +684,7 @@ func TestBasicAuthNoUsername(t *testing.T) {
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
client, err := NewClientFromConfig(*cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand All @@ -689,7 +710,7 @@ func TestBasicAuthPasswordFile(t *testing.T) {
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
client, err := NewClientFromConfig(*cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand Down Expand Up @@ -840,7 +861,7 @@ func TestTLSRoundTripper(t *testing.T) {
writeCertificate(bs, tc.cert, cert)
writeCertificate(bs, tc.key, key)
if c == nil {
c, err = NewClientFromConfig(cfg, "test", false, true)
c, err = NewClientFromConfig(cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand Down Expand Up @@ -912,7 +933,7 @@ func TestTLSRoundTripperRaces(t *testing.T) {
writeCertificate(bs, TLSCAChainPath, ca)
writeCertificate(bs, ClientCertificatePath, cert)
writeCertificate(bs, ClientKeyNoPassPath, key)
c, err = NewClientFromConfig(cfg, "test", false, true)
c, err = NewClientFromConfig(cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand Down

0 comments on commit 3b362f5

Please sign in to comment.