From 82373d76c4a5a50b4d5c26a37398c76d306232c1 Mon Sep 17 00:00:00 2001 From: Joe Turki Date: Sat, 11 Jan 2025 03:31:42 -0600 Subject: [PATCH] Upgrade golangci-lint, more linters --- .golangci.yml | 49 +-- client.go | 55 +++- client_test.go | 39 ++- examples/lt-cred-generator/main.go | 1 + examples/stun-only-server/main.go | 4 +- examples/turn-client/tcp-alloc/main.go | 6 +- examples/turn-client/tcp/main.go | 4 +- examples/turn-client/udp/main.go | 4 +- .../add-software-attribute/main.go | 25 +- examples/turn-server/log/main.go | 13 +- .../turn-server/lt-cred-turn-rest/main.go | 10 +- examples/turn-server/lt-cred/main.go | 10 +- examples/turn-server/perm-filter/main.go | 15 +- examples/turn-server/port-range/main.go | 15 +- .../turn-server/simple-multithreaded/main.go | 7 +- examples/turn-server/simple/main.go | 11 +- examples/turn-server/tcp/main.go | 5 +- examples/turn-server/tls/main.go | 9 +- internal/allocation/allocation.go | 69 +++-- internal/allocation/allocation_manager.go | 54 ++-- .../allocation/allocation_manager_test.go | 44 ++- internal/allocation/allocation_test.go | 140 +++++---- internal/allocation/channel_bind.go | 2 +- internal/allocation/five_tuple.go | 15 +- internal/allocation/permission.go | 2 +- internal/client/allocation.go | 10 +- internal/client/binding.go | 8 +- internal/client/binding_test.go | 40 +-- internal/client/client.go | 2 +- internal/client/client_test.go | 2 + internal/client/periodic_timer.go | 8 +- internal/client/periodic_timer_test.go | 8 +- internal/client/permission.go | 5 +- internal/client/permission_test.go | 20 +- internal/client/tcp_alloc.go | 39 ++- internal/client/tcp_conn.go | 2 +- internal/client/tcp_conn_test.go | 14 +- internal/client/transaction.go | 37 ++- internal/client/trylock.go | 1 + internal/client/trylock_test.go | 2 + internal/client/udp_conn.go | 101 +++--- internal/client/udp_conn_test.go | 2 +- internal/ipnet/util.go | 5 +- internal/proto/addr.go | 2 + internal/proto/chandata.go | 22 +- internal/proto/chandata_test.go | 26 +- internal/proto/chann.go | 3 +- internal/proto/chann_test.go | 26 +- internal/proto/connection_id.go | 4 +- internal/proto/data.go | 4 +- internal/proto/data_test.go | 18 +- internal/proto/dontfrag.go | 7 +- internal/proto/dontfrag_test.go | 12 +- internal/proto/evenport.go | 5 +- internal/proto/evenport_test.go | 10 +- internal/proto/fuzz_test.go | 16 +- internal/proto/lifetime.go | 8 +- internal/proto/lifetime_test.go | 20 +- internal/proto/peeraddr.go | 4 +- internal/proto/proto_test.go | 3 + internal/proto/relayedaddr.go | 4 +- internal/proto/reqfamily.go | 2 + internal/proto/reqfamily_test.go | 32 +- internal/proto/reqtrans.go | 4 +- internal/proto/reqtrans_test.go | 40 +-- internal/proto/rsrvtoken.go | 4 +- internal/proto/rsrvtoken_test.go | 22 +- internal/server/nonce.go | 12 +- internal/server/server.go | 54 +++- internal/server/stun.go | 10 +- internal/server/turn.go | 293 +++++++++++------- internal/server/turn_test.go | 24 +- internal/server/util.go | 68 ++-- lt_cred.go | 52 +++- relay_address_generator_none.go | 16 +- relay_address_generator_range.go | 18 +- relay_address_generator_static.go | 15 +- server.go | 53 ++-- server_config.go | 19 +- server_test.go | 38 ++- stun_conn.go | 53 ++-- 81 files changed, 1168 insertions(+), 769 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index a3235bec..50211be0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,17 +25,32 @@ linters-settings: - ^os.Exit$ - ^panic$ - ^print(ln)?$ + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. @@ -44,20 +59,20 @@ linters: - exportloopref # checks for pointers to enclosing loop variables - forbidigo # Forbids identifiers - forcetypeassert # finds forced type assertions + - funlen # Tool for detection of long functions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - - err113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code @@ -65,9 +80,15 @@ linters: - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes @@ -75,28 +96,21 @@ linters: - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - - containedctx # containedctx is a linter that detects struct contained context.Context field - - cyclop # checks function and package cyclomatic complexity - - funlen # Tool for detection of long functions - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - gomnd # An analyzer to detect magic numbers. + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - - lll # Reports long lines - - maintidx # maintidx measures the maintainability index of each function. - - makezero # Finds slice declarations with non-zero initial length - - nakedret # Finds naked returns in functions greater than a specified function length - - nestif # Reports deeply nested if statements - - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated @@ -104,8 +118,7 @@ linters: - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - - varnamelen # checks that the length of a variable's name matches its scope + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! diff --git a/client.go b/client.go index 6e38ae10..328324e8 100644 --- a/client.go +++ b/client.go @@ -49,7 +49,7 @@ type ClientConfig struct { LoggerFactory logging.LoggerFactory } -// Client is a STUN server client +// Client is a STUN server client. type Client struct { conn net.PacketConn // Read-only net transport.Net // Read-only @@ -72,7 +72,8 @@ type Client struct { log logging.LeveledLogger // Read-only } -// NewClient returns a new Client instance. listeningAddress is the address and port to listen on, default "0.0.0.0:0" +// NewClient returns a new Client instance. listeningAddress is the address and port to listen on, +// default "0.0.0.0:0". func NewClient(config *ClientConfig) (*Client, error) { loggerFactory := config.LoggerFactory if loggerFactory == nil { @@ -119,7 +120,7 @@ func NewClient(config *ClientConfig) (*Client, error) { log.Debugf("Resolved TURN server %s to %s", config.TURNServerAddr, turnServ) } - c := &Client{ + client := &Client{ conn: config.Conn, stunServerAddr: stunServ, turnServerAddr: turnServ, @@ -133,25 +134,25 @@ func NewClient(config *ClientConfig) (*Client, error) { log: log, } - return c, nil + return client, nil } -// TURNServerAddr return the TURN server address +// TURNServerAddr return the TURN server address. func (c *Client) TURNServerAddr() net.Addr { return c.turnServerAddr } -// STUNServerAddr return the STUN server address +// STUNServerAddr return the STUN server address. func (c *Client) STUNServerAddr() net.Addr { return c.stunServerAddr } -// Username returns username +// Username returns username. func (c *Client) Username() stun.Username { return c.username } -// Realm return realm +// Realm return realm. func (c *Client) Realm() stun.Realm { return c.realm } @@ -175,12 +176,14 @@ func (c *Client) Listen() error { n, from, err := c.conn.ReadFrom(buf) if err != nil { c.log.Debugf("Failed to read: %s. Exiting loop", err) + break } _, err = c.HandleInbound(buf[:n], from) if err != nil { c.log.Debugf("Failed to handle inbound message: %s. Exiting loop", err) + break } } @@ -191,7 +194,7 @@ func (c *Client) Listen() error { return nil } -// Close closes this client +// Close closes this client. func (c *Client) Close() { c.mutexTrMap.Lock() defer c.mutexTrMap.Unlock() @@ -201,7 +204,7 @@ func (c *Client) Close() { // TransactionID & Base64: https://play.golang.org/p/EEgmJDI971P -// SendBindingRequestTo sends a new STUN request to the given transport address +// SendBindingRequestTo sends a new STUN request to the given transport address. func (c *Client) SendBindingRequestTo(to net.Addr) (net.Addr, error) { attrs := []stun.Setter{stun.TransactionID, stun.BindingRequest} if len(c.software) > 0 { @@ -228,15 +231,21 @@ func (c *Client) SendBindingRequestTo(to net.Addr) (net.Addr, error) { }, nil } -// SendBindingRequest sends a new STUN request to the STUN server +// SendBindingRequest sends a new STUN request to the STUN server. func (c *Client) SendBindingRequest() (net.Addr, error) { if c.stunServerAddr == nil { return nil, errSTUNServerAddressNotSet } + return c.SendBindingRequestTo(c.stunServerAddr) } -func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddress, proto.Lifetime, stun.Nonce, error) { +func (c *Client) sendAllocateRequest(protocol proto.Protocol) ( // nolint:cyclop,funlen + proto.RelayedAddress, + proto.Lifetime, + stun.Nonce, + error, +) { var relayed proto.RelayedAddress var lifetime proto.Lifetime var nonce stun.Nonce @@ -295,6 +304,7 @@ func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddr if err = code.GetFrom(res); err == nil { return relayed, lifetime, nonce, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113 } + return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:goerr113 } @@ -307,10 +317,11 @@ func (c *Client) sendAllocateRequest(protocol proto.Protocol) (proto.RelayedAddr if err := lifetime.GetFrom(res); err != nil { return relayed, lifetime, nonce, err } + return relayed, lifetime, nonce, nil } -// Allocate sends a TURN allocation request to the given transport address +// Allocate sends a TURN allocation request to the given transport address. func (c *Client) Allocate() (net.PacketConn, error) { if err := c.allocTryLock.Lock(); err != nil { return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) @@ -403,10 +414,11 @@ func (c *Client) CreatePermission(addrs ...net.Addr) error { return err } } + return nil } -// PerformTransaction performs STUN transaction +// PerformTransaction performs STUN transaction. func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult bool) (client.TransactionResult, error, ) { @@ -442,11 +454,12 @@ func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult if res.Err != nil { return res, res.Err } + return res, nil } // OnDeallocated is called when de-allocation of relay address has been complete. -// (Called by UDPConn) +// (Called by UDPConn). func (c *Client) OnDeallocated(net.Addr) { c.setRelayedUDPConn(nil) c.setTCPAllocation(nil) @@ -494,7 +507,7 @@ func (c *Client) HandleInbound(data []byte, from net.Addr) (bool, error) { return false, nil } -func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { +func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { // nolint:cyclop,funlen raw := make([]byte, len(data)) copy(raw, data) @@ -507,7 +520,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { return fmt.Errorf("%w : %s", errUnexpectedSTUNRequestMessage, msg.String()) } - if msg.Type.Class == stun.ClassIndication { + if msg.Type.Class == stun.ClassIndication { // nolint:nestif switch msg.Type.Method { case stun.MethodData: var peerAddr proto.PeerAddress @@ -529,6 +542,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { relayedConn := c.relayedUDPConn() if relayedConn == nil { c.log.Debug("No relayed conn allocated") + return nil // Silently discard } relayedConn.HandleInbound(data, from) @@ -553,6 +567,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { allocation := c.getTCPAllocation() if allocation == nil { c.log.Debug("No TCP allocation exists") + return nil // Silently discard } @@ -560,6 +575,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { default: c.log.Debug("Received unsupported STUN method") } + return nil } @@ -576,6 +592,7 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { c.mutexTrMap.Unlock() // Silently discard c.log.Debugf("No transaction for %s", msg) + return nil } @@ -607,6 +624,7 @@ func (c *Client) handleChannelData(data []byte) error { relayedConn := c.relayedUDPConn() if relayedConn == nil { c.log.Debug("No relayed conn allocated") + return nil // Silently discard } @@ -618,6 +636,7 @@ func (c *Client) handleChannelData(data []byte) error { c.log.Tracef("Channel data received from %s (ch=%d)", addr.String(), int(chData.Number)) relayedConn.HandleInbound(chData.Data, addr) + return nil } @@ -638,6 +657,7 @@ func (c *Client) onRtxTimeout(trKey string, nRtx int) { }) { c.log.Debug("No listener for transaction") } + return } @@ -651,6 +671,7 @@ func (c *Client) onRtxTimeout(trKey string, nRtx int) { }) { c.log.Debug("No listener for transaction") } + return } tr.StartRtxTimer(c.onRtxTimeout) diff --git a/client_test.go b/client_test.go index 0627b95e..924e5767 100644 --- a/client_test.go +++ b/client_test.go @@ -18,11 +18,17 @@ import ( "github.com/stretchr/testify/require" ) -func buildMsg(transactionID [stun.TransactionIDSize]byte, msgType stun.MessageType, additional ...stun.Setter) []stun.Setter { +func buildMsg( + transactionID [stun.TransactionIDSize]byte, + msgType stun.MessageType, + additional ...stun.Setter, +) []stun.Setter { return append([]stun.Setter{&stun.Message{TransactionID: transactionID}, msgType}, additional...) } func createListeningTestClient(t *testing.T, loggerFactory logging.LoggerFactory) (*Client, net.PacketConn, bool) { + t.Helper() + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") assert.NoError(t, err) @@ -37,7 +43,12 @@ func createListeningTestClient(t *testing.T, loggerFactory logging.LoggerFactory return c, conn, true } -func createListeningTestClientWithSTUNServ(t *testing.T, loggerFactory logging.LoggerFactory) (*Client, net.PacketConn, bool) { +func createListeningTestClientWithSTUNServ(t *testing.T, loggerFactory logging.LoggerFactory) ( // nolint:lll + *Client, net.PacketConn, + bool, +) { + t.Helper() + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") assert.NoError(t, err) @@ -54,30 +65,30 @@ func createListeningTestClientWithSTUNServ(t *testing.T, loggerFactory logging.L return c, conn, true } -func TestClientWithSTUN(t *testing.T) { +func TestClientWithSTUN(t *testing.T) { // nolint:funlen loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("SendBindingRequest", func(t *testing.T) { - c, pc, ok := createListeningTestClientWithSTUNServ(t, loggerFactory) + client, pc, ok := createListeningTestClientWithSTUNServ(t, loggerFactory) if !ok { return } - defer c.Close() + defer client.Close() - resp, err := c.SendBindingRequest() + resp, err := client.SendBindingRequest() assert.NoError(t, err, "should succeed") log.Debugf("mapped-addr: %s", resp) - assert.Equal(t, 0, c.trMap.Size(), "should be no transaction left") + assert.Equal(t, 0, client.trMap.Size(), "should be no transaction left") assert.NoError(t, pc.Close()) }) t.Run("SendBindingRequestTo Parallel", func(t *testing.T) { - c, pc, ok := createListeningTestClient(t, loggerFactory) + client, pc, ok := createListeningTestClient(t, loggerFactory) if !ok { return } - defer c.Close() + defer client.Close() // Simple channel fo go routine start signaling started := make(chan struct{}) @@ -90,14 +101,14 @@ func TestClientWithSTUN(t *testing.T) { // stun1.l.google.com:19302, more at https://gist.github.com/zziuni/3741933#file-stuns-L5 go func() { close(started) - _, err1 = c.SendBindingRequestTo(to) + _, err1 = client.SendBindingRequestTo(to) close(finished) }() // Block until go routine is started to make two almost parallel requests <-started - if _, err = c.SendBindingRequestTo(to); err != nil { + if _, err = client.SendBindingRequestTo(to); err != nil { t.Fatal(err) } @@ -136,7 +147,7 @@ func TestClientWithSTUN(t *testing.T) { // Create an allocation, and then delete all nonces // The subsequent Write on the allocation will cause a CreatePermission -// which will be forced to handle a stale nonce response +// which will be forced to handle a stale nonce response. func TestClientNonceExpiration(t *testing.T) { udpListener, err := net.ListenPacket("udp4", "0.0.0.0:3478") assert.NoError(t, err) @@ -186,8 +197,8 @@ func TestClientNonceExpiration(t *testing.T) { assert.NoError(t, server.Close()) } -// Create a TCP-based allocation and verify allocation can be created -func TestTCPClient(t *testing.T) { +// Create a TCP-based allocation and verify allocation can be created. +func TestTCPClient(t *testing.T) { // nolint:funlen // Setup server tcpListener, err := net.Listen("tcp4", "0.0.0.0:13478") //nolint: gosec require.NoError(t, err) diff --git a/examples/lt-cred-generator/main.go b/examples/lt-cred-generator/main.go index fc5ea527..997d1c0e 100644 --- a/examples/lt-cred-generator/main.go +++ b/examples/lt-cred-generator/main.go @@ -25,6 +25,7 @@ func main() { if showHelp != nil && *showHelp { log.Println("Usage:") log.Println("$ lt-cred-generator | xargs go run examples/turn-client/udp/main.go -host localhost -ping=true -user=") + return } diff --git a/examples/stun-only-server/main.go b/examples/stun-only-server/main.go index da10e582..fe4b7d14 100644 --- a/examples/stun-only-server/main.go +++ b/examples/stun-only-server/main.go @@ -33,7 +33,7 @@ func main() { log.Panicf("Failed to create STUN server listener: %s", err) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ // PacketConnConfigs is a list of UDP Listeners and the configuration around them PacketConnConfigs: []turn.PacketConnConfig{ { @@ -50,7 +50,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-client/tcp-alloc/main.go b/examples/turn-client/tcp-alloc/main.go index 1fee091d..ea3c0da3 100644 --- a/examples/turn-client/tcp-alloc/main.go +++ b/examples/turn-client/tcp-alloc/main.go @@ -18,7 +18,7 @@ import ( func setupSignalingChannel(addrCh chan string, signaling bool, relayAddr string) { addr := "127.0.0.1:5000" - if signaling { + if signaling { // nolint:nestif go func() { listener, err := net.Listen("tcp", addr) if err != nil { @@ -61,7 +61,7 @@ func setupSignalingChannel(addrCh chan string, signaling bool, relayAddr string) } } -func main() { +func main() { // nolint:funlen,cyclop host := flag.String("host", "", "TURN Server name.") port := flag.Int("port", 3478, "Listening port.") user := flag.String("user", "", "A pair of username and password (e.g. \"user=pass\")") @@ -147,7 +147,7 @@ func main() { buf := make([]byte, 4096) var n int - if *signaling { + if *signaling { // nolint:nestif conn, err := allocation.DialTCP("tcp", nil, peerAddr) if err != nil { log.Panicf("Failed to dial: %s", err) diff --git a/examples/turn-client/tcp/main.go b/examples/turn-client/tcp/main.go index f7a7ec1d..80a55bbd 100644 --- a/examples/turn-client/tcp/main.go +++ b/examples/turn-client/tcp/main.go @@ -16,7 +16,7 @@ import ( "github.com/pion/turn/v4" ) -func main() { +func main() { // nolint:funlen,cyclop host := flag.String("host", "", "TURN Server name.") port := flag.Int("port", 3478, "Listening port.") user := flag.String("user", "", "A pair of username and password (e.g. \"user=pass\")") @@ -92,7 +92,7 @@ func main() { } } -func doPingTest(client *turn.Client, relayConn net.PacketConn) error { +func doPingTest(client *turn.Client, relayConn net.PacketConn) error { // nolint:cyclop,funlen // Send BindingRequest to learn our external IP mappedAddr, err := client.SendBindingRequest() if err != nil { diff --git a/examples/turn-client/udp/main.go b/examples/turn-client/udp/main.go index d53ea8f7..46f8c68a 100644 --- a/examples/turn-client/udp/main.go +++ b/examples/turn-client/udp/main.go @@ -16,7 +16,7 @@ import ( "github.com/pion/turn/v4" ) -func main() { +func main() { // nolint:funlen,cyclop host := flag.String("host", "", "TURN Server name.") port := flag.Int("port", 3478, "Listening port.") user := flag.String("user", "", "A pair of username and password (e.g. \"user=pass\")") @@ -96,7 +96,7 @@ func main() { } } -func doPingTest(client *turn.Client, relayConn net.PacketConn) error { +func doPingTest(client *turn.Client, relayConn net.PacketConn) error { // nolint:cyclop,funlen // Send BindingRequest to learn our external IP mappedAddr, err := client.SendBindingRequest() if err != nil { diff --git a/examples/turn-server/add-software-attribute/main.go b/examples/turn-server/add-software-attribute/main.go index bffeb1e9..344cdf2b 100644 --- a/examples/turn-server/add-software-attribute/main.go +++ b/examples/turn-server/add-software-attribute/main.go @@ -19,15 +19,15 @@ import ( "github.com/pion/turn/v4" ) -// attributeAdder wraps a PacketConn and appends the SOFTWARE attribute to STUN packets -// This pattern could be used to capture/inspect/modify data as well +// attributeAdder wraps a PacketConn and appends the SOFTWARE attribute to STUN packets. +// This pattern could be used to capture/inspect/modify data as well. type attributeAdder struct { net.PacketConn } -func (s *attributeAdder) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if stun.IsMessage(p) { - m := &stun.Message{Raw: p} +func (s *attributeAdder) WriteTo(payload []byte, addr net.Addr) (n int, err error) { + if stun.IsMessage(payload) { + m := &stun.Message{Raw: payload} if err = m.Decode(); err != nil { return } @@ -37,10 +37,10 @@ func (s *attributeAdder) WriteTo(p []byte, addr net.Addr) (n int, err error) { } m.Encode() - p = m.Raw + payload = m.Raw } - return s.PacketConn.WriteTo(p, addr) + return s.PacketConn.WriteTo(payload, addr) } func main() { @@ -71,7 +71,7 @@ func main() { usersMap[kv[1]] = turn.GenerateAuthKey(kv[1], *realm, kv[2]) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -80,6 +80,7 @@ func main() { if key, ok := usersMap[username]; ok { return key, true } + return nil, false }, // PacketConnConfigs is a list of UDP Listeners and the configuration around them @@ -87,8 +88,10 @@ func main() { { PacketConn: &attributeAdder{udpListener}, RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(*publicIP), // Claim that we are listening on IP passed by user (This should be your Public IP) - Address: "0.0.0.0", // But actually be listening on every interface + // Claim that we are listening on IP passed by user (This should be your Public IP) + RelayAddress: net.ParseIP(*publicIP), + // But actually be listening on every interface + Address: "0.0.0.0", }, }, }, @@ -102,7 +105,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-server/log/main.go b/examples/turn-server/log/main.go index fc974e07..f9212a83 100644 --- a/examples/turn-server/log/main.go +++ b/examples/turn-server/log/main.go @@ -20,7 +20,7 @@ import ( ) // stunLogger wraps a PacketConn and prints incoming/outgoing STUN packets -// This pattern could be used to capture/inspect/modify data as well +// This pattern could be used to capture/inspect/modify data as well. type stunLogger struct { net.PacketConn } @@ -79,7 +79,7 @@ func main() { usersMap[kv[1]] = turn.GenerateAuthKey(kv[1], *realm, kv[2]) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -88,6 +88,7 @@ func main() { if key, ok := usersMap[username]; ok { return key, true } + return nil, false }, // PacketConnConfigs is a list of UDP Listeners and the configuration around them @@ -95,8 +96,10 @@ func main() { { PacketConn: &stunLogger{udpListener}, RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(*publicIP), // Claim that we are listening on IP passed by user (This should be your Public IP) - Address: "0.0.0.0", // But actually be listening on every interface + // Claim that we are listening on IP passed by user (This should be your Public IP) + RelayAddress: net.ParseIP(*publicIP), + // But actually be listening on every interface + Address: "0.0.0.0", }, }, }, @@ -110,7 +113,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-server/lt-cred-turn-rest/main.go b/examples/turn-server/lt-cred-turn-rest/main.go index 655cdbc3..9025e0db 100644 --- a/examples/turn-server/lt-cred-turn-rest/main.go +++ b/examples/turn-server/lt-cred-turn-rest/main.go @@ -43,7 +43,7 @@ func main() { // and process them yourself. logger := logging.NewDefaultLeveledLoggerForScope("lt-creds", logging.LogLevelTrace, os.Stdout) - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, AuthHandler: turn.LongTermTURNRESTAuthHandler(*authSecret, logger), // PacketConnConfigs is a list of UDP Listeners and the configuration around them @@ -51,8 +51,10 @@ func main() { { PacketConn: udpListener, RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(*publicIP), // Claim that we are listening on IP passed by user (This should be your Public IP) - Address: "0.0.0.0", // But actually be listening on every interface + // Claim that we are listening on IP passed by user (This should be your Public IP). + RelayAddress: net.ParseIP(*publicIP), + // But actually be listening on every interface. + Address: "0.0.0.0", }, }, }, @@ -66,7 +68,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-server/lt-cred/main.go b/examples/turn-server/lt-cred/main.go index b05d2aa3..d62b3646 100644 --- a/examples/turn-server/lt-cred/main.go +++ b/examples/turn-server/lt-cred/main.go @@ -46,7 +46,7 @@ func main() { // and process them yourself. logger := logging.NewDefaultLeveledLoggerForScope("lt-creds", logging.LogLevelTrace, os.Stdout) - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -57,8 +57,10 @@ func main() { { PacketConn: udpListener, RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(*publicIP), // Claim that we are listening on IP passed by user (This should be your Public IP) - Address: "0.0.0.0", // But actually be listening on every interface + // Claim that we are listening on IP passed by user (This should be your Public IP). + RelayAddress: net.ParseIP(*publicIP), + // But actually be listening on every interface. + Address: "0.0.0.0", }, }, }, @@ -72,7 +74,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-server/perm-filter/main.go b/examples/turn-server/perm-filter/main.go index ab2c24b8..2554594a 100644 --- a/examples/turn-server/perm-filter/main.go +++ b/examples/turn-server/perm-filter/main.go @@ -21,7 +21,7 @@ import ( "github.com/pion/turn/v4" ) -func main() { +func main() { // nolint:funlen publicIP := flag.String("public-ip", "", "IP Address that TURN can be contacted by.") port := flag.Int("port", 3478, "Listening port.") users := flag.String("users", "", "List of username and password (e.g. \"user=pass,user=pass\")") @@ -49,7 +49,7 @@ func main() { usersMap[kv[1]] = turn.GenerateAuthKey(kv[1], *realm, kv[2]) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -58,6 +58,7 @@ func main() { if key, ok := usersMap[username]; ok { return key, true } + return nil, false }, // PacketConnConfigs is a list of UDP Listeners and the configuration around them @@ -65,8 +66,10 @@ func main() { { PacketConn: udpListener, RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(*publicIP), // Claim that we are listening on IP passed by user (This should be your Public IP) - Address: "0.0.0.0", // But actually be listening on every interface + // Claim that we are listening on IP passed by user (This should be your Public IP) + RelayAddress: net.ParseIP(*publicIP), + // But actually be listening on every interface + Address: "0.0.0.0", }, // allow peer connections only to the client's own (host or server-reflexive) IP PermissionHandler: func(clientAddr net.Addr, peerIP net.IP) bool { @@ -74,11 +77,13 @@ func main() { if clientIP[0] != peerIP.String() { log.Printf("Blocking request from client IP %s to peer %s", clientIP[0], peerIP.String()) + return false } log.Printf("Admitting request from client IP %s to peer %s", clientIP[0], peerIP.String()) + return true }, }, @@ -93,7 +98,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-server/port-range/main.go b/examples/turn-server/port-range/main.go index 9ed536ed..28864487 100644 --- a/examples/turn-server/port-range/main.go +++ b/examples/turn-server/port-range/main.go @@ -46,7 +46,7 @@ func main() { usersMap[kv[1]] = turn.GenerateAuthKey(kv[1], *realm, kv[2]) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -55,6 +55,7 @@ func main() { if key, ok := usersMap[username]; ok { return key, true } + return nil, false }, // PacketConnConfigs is a list of UDP Listeners and the configuration around them @@ -62,10 +63,12 @@ func main() { { PacketConn: udpListener, RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{ - RelayAddress: net.ParseIP(*publicIP), // Claim that we are listening on IP passed by user (This should be your Public IP) - Address: "0.0.0.0", // But actually be listening on every interface - MinPort: 50000, - MaxPort: 55000, + // Claim that we are listening on IP passed by user (This should be your Public IP) + RelayAddress: net.ParseIP(*publicIP), + // But actually be listening on every interface + Address: "0.0.0.0", + MinPort: 50000, + MaxPort: 55000, }, }, }, @@ -79,7 +82,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-server/simple-multithreaded/main.go b/examples/turn-server/simple-multithreaded/main.go index fb89473c..dd477afd 100644 --- a/examples/turn-server/simple-multithreaded/main.go +++ b/examples/turn-server/simple-multithreaded/main.go @@ -21,7 +21,7 @@ import ( "golang.org/x/sys/unix" ) -func main() { +func main() { // nolint:funlen,cyclop publicIP := flag.String("public-ip", "", "IP Address that TURN can be contacted by.") port := flag.Int("port", 3478, "Listening port.") users := flag.String("users", "", "List of username and password (e.g. \"user=pass,user=pass\")") @@ -85,7 +85,7 @@ func main() { log.Printf("Server %d listening on %s\n", i, conn.LocalAddr().String()) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -94,6 +94,7 @@ func main() { if key, ok := usersMap[username]; ok { return key, true } + return nil, false }, // PacketConnConfigs is a list of UDP Listeners and the configuration around them @@ -108,7 +109,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panicf("Failed to close TURN server: %s", err) } } diff --git a/examples/turn-server/simple/main.go b/examples/turn-server/simple/main.go index 294b76a5..bf3ea524 100644 --- a/examples/turn-server/simple/main.go +++ b/examples/turn-server/simple/main.go @@ -45,7 +45,7 @@ func main() { usersMap[kv[1]] = turn.GenerateAuthKey(kv[1], *realm, kv[2]) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -54,6 +54,7 @@ func main() { if key, ok := usersMap[username]; ok { return key, true } + return nil, false }, // PacketConnConfigs is a list of UDP Listeners and the configuration around them @@ -61,8 +62,10 @@ func main() { { PacketConn: udpListener, RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ - RelayAddress: net.ParseIP(*publicIP), // Claim that we are listening on IP passed by user (This should be your Public IP) - Address: "0.0.0.0", // But actually be listening on every interface + // Claim that we are listening on IP passed by user (This should be your Public IP) + RelayAddress: net.ParseIP(*publicIP), + // But actually be listening on every interface + Address: "0.0.0.0", }, }, }, @@ -76,7 +79,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-server/tcp/main.go b/examples/turn-server/tcp/main.go index ef3d147c..7663836b 100644 --- a/examples/turn-server/tcp/main.go +++ b/examples/turn-server/tcp/main.go @@ -45,7 +45,7 @@ func main() { usersMap[kv[1]] = turn.GenerateAuthKey(kv[1], *realm, kv[2]) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -54,6 +54,7 @@ func main() { if key, ok := usersMap[username]; ok { return key, true } + return nil, false }, // ListenerConfig is a list of Listeners and the configuration around them @@ -76,7 +77,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/examples/turn-server/tls/main.go b/examples/turn-server/tls/main.go index 5d9d971c..76870eb3 100644 --- a/examples/turn-server/tls/main.go +++ b/examples/turn-server/tls/main.go @@ -18,7 +18,7 @@ import ( "github.com/pion/turn/v4" ) -func main() { +func main() { // nolint:funlen publicIP := flag.String("public-ip", "", "IP Address that TURN can be contacted by.") port := flag.Int("port", 5349, "Listening port.") users := flag.String("users", "", "List of username and password (e.g. \"user=pass,user=pass\")") @@ -36,6 +36,7 @@ func main() { cer, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Println(err) + return } @@ -48,6 +49,7 @@ func main() { }) if err != nil { log.Println(err) + return } @@ -58,7 +60,7 @@ func main() { usersMap[kv[1]] = turn.GenerateAuthKey(kv[1], *realm, kv[2]) } - s, err := turn.NewServer(turn.ServerConfig{ + server, err := turn.NewServer(turn.ServerConfig{ Realm: *realm, // Set AuthHandler callback // This is called every time a user tries to authenticate with the TURN server @@ -67,6 +69,7 @@ func main() { if key, ok := usersMap[username]; ok { return key, true } + return nil, false }, // ListenerConfig is a list of Listeners and the configuration around them @@ -89,7 +92,7 @@ func main() { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - if err = s.Close(); err != nil { + if err = server.Close(); err != nil { log.Panic(err) } } diff --git a/internal/allocation/allocation.go b/internal/allocation/allocation.go index 5b5ff369..9e2d5352 100644 --- a/internal/allocation/allocation.go +++ b/internal/allocation/allocation.go @@ -22,7 +22,7 @@ type allocationResponse struct { } // Allocation is tied to a FiveTuple and relays traffic -// use CreateAllocation and GetAllocation to operate +// use CreateAllocation and GetAllocation to operate. type Allocation struct { RelayAddr net.Addr Protocol Protocol @@ -55,7 +55,7 @@ func NewAllocation(turnSocket net.PacketConn, fiveTuple *FiveTuple, log logging. } } -// GetPermission gets the Permission from the allocation +// GetPermission gets the Permission from the allocation. func (a *Allocation) GetPermission(addr net.Addr) *Permission { a.permissionsLock.RLock() defer a.permissionsLock.RUnlock() @@ -63,9 +63,9 @@ func (a *Allocation) GetPermission(addr net.Addr) *Permission { return a.permissions[ipnet.FingerprintAddr(addr)] } -// AddPermission adds a new permission to the allocation -func (a *Allocation) AddPermission(p *Permission) { - fingerprint := ipnet.FingerprintAddr(p.Addr) +// AddPermission adds a new permission to the allocation. +func (a *Allocation) AddPermission(perms *Permission) { + fingerprint := ipnet.FingerprintAddr(perms.Addr) a.permissionsLock.RLock() existedPermission, ok := a.permissions[fingerprint] @@ -73,18 +73,19 @@ func (a *Allocation) AddPermission(p *Permission) { if ok { existedPermission.refresh(permissionTimeout) + return } - p.allocation = a + perms.allocation = a a.permissionsLock.Lock() - a.permissions[fingerprint] = p + a.permissions[fingerprint] = perms a.permissionsLock.Unlock() - p.start(permissionTimeout) + perms.start(permissionTimeout) } -// RemovePermission removes the net.Addr's fingerprint from the allocation's permissions +// RemovePermission removes the net.Addr's fingerprint from the allocation's permissions. func (a *Allocation) RemovePermission(addr net.Addr) { a.permissionsLock.Lock() defer a.permissionsLock.Unlock() @@ -92,13 +93,13 @@ func (a *Allocation) RemovePermission(addr net.Addr) { } // AddChannelBind adds a new ChannelBind to the allocation, it also updates the -// permissions needed for this ChannelBind -func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) error { +// permissions needed for this ChannelBind. +func (a *Allocation) AddChannelBind(chanBind *ChannelBind, lifetime time.Duration) error { // Check that this channel id isn't bound to another transport address, and // that this transport address isn't bound to another channel number. - channelByNumber := a.GetChannelByNumber(c.Number) + channelByNumber := a.GetChannelByNumber(chanBind.Number) - if channelByNumber != a.GetChannelByAddr(c.Peer) { + if channelByNumber != a.GetChannelByAddr(chanBind.Peer) { return errSameChannelDifferentPeer } @@ -107,12 +108,12 @@ func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) erro a.channelBindingsLock.Lock() defer a.channelBindingsLock.Unlock() - c.allocation = a - a.channelBindings = append(a.channelBindings, c) - c.start(lifetime) + chanBind.allocation = a + a.channelBindings = append(a.channelBindings, chanBind) + chanBind.start(lifetime) // Channel binds also refresh permissions. - a.AddPermission(NewPermission(c.Peer, a.log)) + a.AddPermission(NewPermission(chanBind.Peer, a.log)) } else { channelByNumber.refresh(lifetime) @@ -123,7 +124,7 @@ func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) erro return nil } -// RemoveChannelBind removes the ChannelBind from this allocation by id +// RemoveChannelBind removes the ChannelBind from this allocation by id. func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool { a.channelBindingsLock.Lock() defer a.channelBindingsLock.Unlock() @@ -131,6 +132,7 @@ func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool { for i := len(a.channelBindings) - 1; i >= 0; i-- { if a.channelBindings[i].Number == number { a.channelBindings = append(a.channelBindings[:i], a.channelBindings[i+1:]...) + return true } } @@ -138,7 +140,7 @@ func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool { return false } -// GetChannelByNumber gets the ChannelBind from this allocation by id +// GetChannelByNumber gets the ChannelBind from this allocation by id. func (a *Allocation) GetChannelByNumber(number proto.ChannelNumber) *ChannelBind { a.channelBindingsLock.RLock() defer a.channelBindingsLock.RUnlock() @@ -147,10 +149,11 @@ func (a *Allocation) GetChannelByNumber(number proto.ChannelNumber) *ChannelBind return cb } } + return nil } -// GetChannelByAddr gets the ChannelBind from this allocation by net.Addr +// GetChannelByAddr gets the ChannelBind from this allocation by net.Addr. func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind { a.channelBindingsLock.RLock() defer a.channelBindingsLock.RUnlock() @@ -159,17 +162,18 @@ func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind { return cb } } + return nil } -// Refresh updates the allocations lifetime +// Refresh updates the allocations lifetime. func (a *Allocation) Refresh(lifetime time.Duration) { if !a.lifetimeTimer.Reset(lifetime) { a.log.Errorf("Failed to reset allocation timer for %v", a.fiveTuple) } } -// SetResponseCache cache allocation response for retransmit allocation request +// SetResponseCache cache allocation response for retransmit allocation request. func (a *Allocation) SetResponseCache(transactionID [stun.TransactionIDSize]byte, attrs []stun.Setter) { a.responseCache.Store(&allocationResponse{ transactionID: transactionID, @@ -177,15 +181,16 @@ func (a *Allocation) SetResponseCache(transactionID [stun.TransactionIDSize]byte }) } -// GetResponseCache return response cache for retransmit allocation request +// GetResponseCache return response cache for retransmit allocation request. func (a *Allocation) GetResponseCache() (id [stun.TransactionIDSize]byte, attrs []stun.Setter) { if res, ok := a.responseCache.Load().(*allocationResponse); ok && res != nil { id, attrs = res.transactionID, res.responseAttrs } + return } -// Close closes the allocation +// Close closes the allocation. func (a *Allocation) Close() error { select { case <-a.closed: @@ -233,13 +238,14 @@ func (a *Allocation) Close() error { const rtpMTU = 1600 -func (a *Allocation) packetHandler(m *Manager) { +func (a *Allocation) packetHandler(manager *Manager) { buffer := make([]byte, rtpMTU) for { n, srcAddr, err := a.RelaySocket.ReadFrom(buffer) if err != nil { - m.DeleteAllocation(a.fiveTuple) + manager.DeleteAllocation(a.fiveTuple) + return } @@ -248,7 +254,7 @@ func (a *Allocation) packetHandler(m *Manager) { n, srcAddr) - if channel := a.GetChannelByAddr(srcAddr); channel != nil { + if channel := a.GetChannelByAddr(srcAddr); channel != nil { // nolint:nestif channelData := &proto.ChannelData{ Data: buffer[:n], Number: channel.Number, @@ -262,15 +268,22 @@ func (a *Allocation) packetHandler(m *Manager) { udpAddr, ok := srcAddr.(*net.UDPAddr) if !ok { a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err) + return } peerAddressAttr := proto.PeerAddress{IP: udpAddr.IP, Port: udpAddr.Port} dataAttr := proto.Data(buffer[:n]) - msg, err := stun.Build(stun.TransactionID, stun.NewType(stun.MethodData, stun.ClassIndication), peerAddressAttr, dataAttr) + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodData, stun.ClassIndication), + peerAddressAttr, + dataAttr, + ) if err != nil { a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err) + return } a.log.Debugf("Relaying message from %s to client at %s", diff --git a/internal/allocation/allocation_manager.go b/internal/allocation/allocation_manager.go index 2b765921..a3b011f4 100644 --- a/internal/allocation/allocation_manager.go +++ b/internal/allocation/allocation_manager.go @@ -25,7 +25,7 @@ type reservation struct { port int } -// Manager is used to hold active allocations +// Manager is used to hold active allocations. type Manager struct { lock sync.RWMutex log logging.LeveledLogger @@ -58,21 +58,23 @@ func NewManager(config ManagerConfig) (*Manager, error) { }, nil } -// GetAllocation fetches the allocation matching the passed FiveTuple +// GetAllocation fetches the allocation matching the passed FiveTuple. func (m *Manager) GetAllocation(fiveTuple *FiveTuple) *Allocation { m.lock.RLock() defer m.lock.RUnlock() + return m.allocations[fiveTuple.Fingerprint()] } -// AllocationCount returns the number of existing allocations +// AllocationCount returns the number of existing allocations. func (m *Manager) AllocationCount() int { m.lock.RLock() defer m.lock.RUnlock() + return len(m.allocations) } -// Close closes the manager and closes all allocations it manages +// Close closes the manager and closes all allocations it manages. func (m *Manager) Close() error { m.lock.Lock() defer m.lock.Unlock() @@ -82,11 +84,17 @@ func (m *Manager) Close() error { return err } } + return nil } -// CreateAllocation creates a new allocation and starts relaying -func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration) (*Allocation, error) { +// CreateAllocation creates a new allocation and starts relaying. +func (m *Manager) CreateAllocation( + fiveTuple *FiveTuple, + turnSocket net.PacketConn, + requestedPort int, + lifetime time.Duration, +) (*Allocation, error) { switch { case fiveTuple == nil: return nil, errNilFiveTuple @@ -100,34 +108,35 @@ func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketCo return nil, errLifetimeZero } - if a := m.GetAllocation(fiveTuple); a != nil { + if alloc := m.GetAllocation(fiveTuple); alloc != nil { return nil, fmt.Errorf("%w: %v", errDupeFiveTuple, fiveTuple) } - a := NewAllocation(turnSocket, fiveTuple, m.log) + alloc := NewAllocation(turnSocket, fiveTuple, m.log) conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort) if err != nil { return nil, err } - a.RelaySocket = conn - a.RelayAddr = relayAddr + alloc.RelaySocket = conn + alloc.RelayAddr = relayAddr - m.log.Debugf("Listening on relay address: %s", a.RelayAddr) + m.log.Debugf("Listening on relay address: %s", alloc.RelayAddr) - a.lifetimeTimer = time.AfterFunc(lifetime, func() { - m.DeleteAllocation(a.fiveTuple) + alloc.lifetimeTimer = time.AfterFunc(lifetime, func() { + m.DeleteAllocation(alloc.fiveTuple) }) m.lock.Lock() - m.allocations[fiveTuple.Fingerprint()] = a + m.allocations[fiveTuple.Fingerprint()] = alloc m.lock.Unlock() - go a.packetHandler(m) - return a, nil + go alloc.packetHandler(m) + + return alloc, nil } -// DeleteAllocation removes an allocation +// DeleteAllocation removes an allocation. func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) { fingerprint := fiveTuple.Fingerprint() @@ -145,7 +154,7 @@ func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) { } } -// CreateReservation stores the reservation for the token+port +// CreateReservation stores the reservation for the token+port. func (m *Manager) CreateReservation(reservationToken string, port int) { time.AfterFunc(30*time.Second, func() { m.lock.Lock() @@ -153,6 +162,7 @@ func (m *Manager) CreateReservation(reservationToken string, port int) { for i := len(m.reservations) - 1; i >= 0; i-- { if m.reservations[i].token == reservationToken { m.reservations = append(m.reservations[:i], m.reservations[i+1:]...) + return } } @@ -166,7 +176,7 @@ func (m *Manager) CreateReservation(reservationToken string, port int) { m.lock.Unlock() } -// GetReservation returns the port for a given reservation if it exists +// GetReservation returns the port for a given reservation if it exists. func (m *Manager) GetReservation(reservationToken string) (int, bool) { m.lock.RLock() defer m.lock.RUnlock() @@ -176,10 +186,11 @@ func (m *Manager) GetReservation(reservationToken string) (int, bool) { return r.port, true } } + return 0, false } -// GetRandomEvenPort returns a random un-allocated udp4 port +// GetRandomEvenPort returns a random un-allocated udp4 port. func (m *Manager) GetRandomEvenPort() (int, error) { for i := 0; i < 128; i++ { conn, addr, err := m.allocatePacketConn("udp4", 0) @@ -199,11 +210,12 @@ func (m *Manager) GetRandomEvenPort() (int, error) { return udpAddr.Port, nil } } + return 0, errFailedToAllocateEvenPort } // GrantPermission handles permission requests by calling the permission handler callback -// associated with the TURN server listener socket +// associated with the TURN server listener socket. func (m *Manager) GrantPermission(sourceAddr net.Addr, peerIP net.IP) error { // No permission handler: open if m.permissionHandler == nil { diff --git a/internal/allocation/allocation_manager_test.go b/internal/allocation/allocation_manager_test.go index 014d85d3..5a68cda2 100644 --- a/internal/allocation/allocation_manager_test.go +++ b/internal/allocation/allocation_manager_test.go @@ -47,8 +47,10 @@ func TestManager(t *testing.T) { } } -// Test invalid Allocation creations +// Test invalid Allocation creations. func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn) { + t.Helper() + m, err := newTestManager() assert.NoError(t, err) @@ -63,8 +65,10 @@ func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn) { } } -// Test valid Allocation creations +// Test valid Allocation creations. func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { + t.Helper() + m, err := newTestManager() assert.NoError(t, err) @@ -78,8 +82,10 @@ func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { } } -// Test that two allocations can't be created with the same FiveTuple +// Test that two allocations can't be created with the same FiveTuple. func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.PacketConn) { + t.Helper() + m, err := newTestManager() assert.NoError(t, err) @@ -94,26 +100,30 @@ func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.Pack } func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() + t.Helper() + + manager, err := newTestManager() assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := manager.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } - if a := m.GetAllocation(fiveTuple); a == nil { + if a := manager.GetAllocation(fiveTuple); a == nil { t.Errorf("Failed to get allocation right after creation") } - m.DeleteAllocation(fiveTuple) - if a := m.GetAllocation(fiveTuple); a != nil { + manager.DeleteAllocation(fiveTuple) + if a := manager.GetAllocation(fiveTuple); a != nil { t.Errorf("Get allocation with %v should be nil after delete", fiveTuple) } } -// Test that allocation should be closed if timeout +// Test that allocation should be closed if timeout. func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { + t.Helper() + m, err := newTestManager() assert.NoError(t, err) @@ -140,22 +150,24 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { } } -// Test for manager close +// Test for manager close. func subTestManagerClose(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() + t.Helper() + + manager, err := newTestManager() assert.NoError(t, err) allocations := make([]*Allocation, 2) - a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second) + a1, _ := manager.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second) allocations[0] = a1 - a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute) + a2, _ := manager.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute) allocations[1] = a2 // Make a1 timeout time.Sleep(2 * time.Second) - if err := m.Close(); err != nil { + if err := manager.Close(); err != nil { t.Errorf("Manager close with error: %v", err) } @@ -189,15 +201,19 @@ func newTestManager() (*Manager, error) { }, AllocateConn: func(string, int) (net.Conn, net.Addr, error) { return nil, nil, nil }, } + return NewManager(config) } func isClose(conn io.Closer) bool { closeErr := conn.Close() + return closeErr != nil && strings.Contains(closeErr.Error(), "use of closed network connection") } func subTestGetRandomEvenPort(t *testing.T, _ net.PacketConn) { + t.Helper() + m, err := newTestManager() assert.NoError(t, err) diff --git a/internal/allocation/allocation_test.go b/internal/allocation/allocation_test.go index 49269d68..74068394 100644 --- a/internal/allocation/allocation_test.go +++ b/internal/allocation/allocation_test.go @@ -46,7 +46,9 @@ func TestAllocation(t *testing.T) { } func subTestGetPermission(t *testing.T) { - a := NewAllocation(nil, nil, nil) + t.Helper() + + alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -63,32 +65,34 @@ func subTestGetPermission(t *testing.T) { t.Fatalf("failed to resolve: %s", err) } - p := &Permission{ + perms := &Permission{ Addr: addr, } - p2 := &Permission{ + perms2 := &Permission{ Addr: addr2, } - p3 := &Permission{ + perms3 := &Permission{ Addr: addr3, } - a.AddPermission(p) - a.AddPermission(p2) - a.AddPermission(p3) + alloc.AddPermission(perms) + alloc.AddPermission(perms2) + alloc.AddPermission(perms3) - foundP1 := a.GetPermission(addr) - assert.Equal(t, p, foundP1, "Should keep the first one.") + foundP1 := alloc.GetPermission(addr) + assert.Equal(t, perms, foundP1, "Should keep the first one.") - foundP2 := a.GetPermission(addr2) - assert.Equal(t, p, foundP2, "Second one should be ignored.") + foundP2 := alloc.GetPermission(addr2) + assert.Equal(t, perms, foundP2, "Second one should be ignored.") - foundP3 := a.GetPermission(addr3) - assert.Equal(t, p3, foundP3, "Permission with another IP should be found") + foundP3 := alloc.GetPermission(addr3) + assert.Equal(t, perms3, foundP3, "Permission with another IP should be found") } func subTestAddPermission(t *testing.T) { - a := NewAllocation(nil, nil, nil) + t.Helper() + + alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -99,15 +103,17 @@ func subTestAddPermission(t *testing.T) { Addr: addr, } - a.AddPermission(p) - assert.Equal(t, a, p.allocation, "Permission's allocation should be the adder.") + alloc.AddPermission(p) + assert.Equal(t, alloc, p.allocation, "Permission's allocation should be the adder.") - foundPermission := a.GetPermission(p.Addr) + foundPermission := alloc.GetPermission(p.Addr) assert.Equal(t, p, foundPermission) } func subTestRemovePermission(t *testing.T) { - a := NewAllocation(nil, nil, nil) + t.Helper() + + alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -118,19 +124,21 @@ func subTestRemovePermission(t *testing.T) { Addr: addr, } - a.AddPermission(p) + alloc.AddPermission(p) - foundPermission := a.GetPermission(p.Addr) + foundPermission := alloc.GetPermission(p.Addr) assert.Equal(t, p, foundPermission, "Got permission is not same as the the added.") - a.RemovePermission(p.Addr) + alloc.RemovePermission(p.Addr) - foundPermission = a.GetPermission(p.Addr) + foundPermission = alloc.GetPermission(p.Addr) assert.Nil(t, foundPermission, "Got permission should be nil after removed.") } func subTestAddChannelBind(t *testing.T) { - a := NewAllocation(nil, nil, nil) + t.Helper() + + alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -139,22 +147,24 @@ func subTestAddChannelBind(t *testing.T) { c := NewChannelBind(proto.MinChannelNumber, addr, nil) - err = a.AddChannelBind(c, proto.DefaultLifetime) + err = alloc.AddChannelBind(c, proto.DefaultLifetime) assert.Nil(t, err, "should succeed") - assert.Equal(t, a, c.allocation, "allocation should be the caller.") + assert.Equal(t, alloc, c.allocation, "allocation should be the caller.") c2 := NewChannelBind(proto.MinChannelNumber+1, addr, nil) - err = a.AddChannelBind(c2, proto.DefaultLifetime) + err = alloc.AddChannelBind(c2, proto.DefaultLifetime) assert.NotNil(t, err, "should failed with conflicted peer address") addr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:3479") c3 := NewChannelBind(proto.MinChannelNumber, addr2, nil) - err = a.AddChannelBind(c3, proto.DefaultLifetime) + err = alloc.AddChannelBind(c3, proto.DefaultLifetime) assert.NotNil(t, err, "should fail with conflicted number.") } func subTestGetChannelByNumber(t *testing.T) { - a := NewAllocation(nil, nil, nil) + t.Helper() + + alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -163,17 +173,19 @@ func subTestGetChannelByNumber(t *testing.T) { c := NewChannelBind(proto.MinChannelNumber, addr, nil) - _ = a.AddChannelBind(c, proto.DefaultLifetime) + _ = alloc.AddChannelBind(c, proto.DefaultLifetime) - existChannel := a.GetChannelByNumber(c.Number) + existChannel := alloc.GetChannelByNumber(c.Number) assert.Equal(t, c, existChannel) - notExistChannel := a.GetChannelByNumber(proto.MinChannelNumber + 1) + notExistChannel := alloc.GetChannelByNumber(proto.MinChannelNumber + 1) assert.Nil(t, notExistChannel, "should be nil for not existed channel.") } func subTestGetChannelByAddr(t *testing.T) { - a := NewAllocation(nil, nil, nil) + t.Helper() + + alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -182,18 +194,20 @@ func subTestGetChannelByAddr(t *testing.T) { c := NewChannelBind(proto.MinChannelNumber, addr, nil) - _ = a.AddChannelBind(c, proto.DefaultLifetime) + _ = alloc.AddChannelBind(c, proto.DefaultLifetime) - existChannel := a.GetChannelByAddr(c.Peer) + existChannel := alloc.GetChannelByAddr(c.Peer) assert.Equal(t, c, existChannel) addr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:3479") - notExistChannel := a.GetChannelByAddr(addr2) + notExistChannel := alloc.GetChannelByAddr(addr2) assert.Nil(t, notExistChannel, "should be nil for not existed channel.") } func subTestRemoveChannelBind(t *testing.T) { - a := NewAllocation(nil, nil, nil) + t.Helper() + + alloc := NewAllocation(nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -202,33 +216,37 @@ func subTestRemoveChannelBind(t *testing.T) { c := NewChannelBind(proto.MinChannelNumber, addr, nil) - _ = a.AddChannelBind(c, proto.DefaultLifetime) + _ = alloc.AddChannelBind(c, proto.DefaultLifetime) - a.RemoveChannelBind(c.Number) + alloc.RemoveChannelBind(c.Number) - channelByNumber := a.GetChannelByNumber(c.Number) + channelByNumber := alloc.GetChannelByNumber(c.Number) assert.Nil(t, channelByNumber) - channelByAddr := a.GetChannelByAddr(c.Peer) + channelByAddr := alloc.GetChannelByAddr(c.Peer) assert.Nil(t, channelByAddr) } func subTestAllocationRefresh(t *testing.T) { - a := NewAllocation(nil, nil, nil) + t.Helper() + + alloc := NewAllocation(nil, nil, nil) var wg sync.WaitGroup wg.Add(1) - a.lifetimeTimer = time.AfterFunc(proto.DefaultLifetime, func() { + alloc.lifetimeTimer = time.AfterFunc(proto.DefaultLifetime, func() { wg.Done() }) - a.Refresh(0) + alloc.Refresh(0) wg.Wait() // LifetimeTimer has expired - assert.False(t, a.lifetimeTimer.Stop()) + assert.False(t, alloc.lifetimeTimer.Stop()) } func subTestAllocationClose(t *testing.T) { + t.Helper() + network := "udp" l, err := net.ListenPacket(network, "0.0.0.0:0") @@ -236,10 +254,10 @@ func subTestAllocationClose(t *testing.T) { panic(err) } - a := NewAllocation(nil, nil, nil) - a.RelaySocket = l + alloc := NewAllocation(nil, nil, nil) + alloc.RelaySocket = l // Add mock lifetimeTimer - a.lifetimeTimer = time.AfterFunc(proto.DefaultLifetime, func() {}) + alloc.lifetimeTimer = time.AfterFunc(proto.DefaultLifetime, func() {}) // Add channel addr, err := net.ResolveUDPAddr(network, "127.0.0.1:3478") @@ -248,20 +266,22 @@ func subTestAllocationClose(t *testing.T) { } c := NewChannelBind(proto.MinChannelNumber, addr, nil) - _ = a.AddChannelBind(c, proto.DefaultLifetime) + _ = alloc.AddChannelBind(c, proto.DefaultLifetime) // Add permission - a.AddPermission(NewPermission(addr, nil)) + alloc.AddPermission(NewPermission(addr, nil)) - err = a.Close() + err = alloc.Close() assert.Nil(t, err, "should succeed") - assert.True(t, isClose(a.RelaySocket), "should be closed") + assert.True(t, isClose(alloc.RelaySocket), "should be closed") } -func subTestPacketHandler(t *testing.T) { +func subTestPacketHandler(t *testing.T) { // nolint:funlen + t.Helper() + network := "udp" - m, _ := newTestManager() + manager, _ := newTestManager() // TURN server initialization turnSocket, err := net.ListenPacket(network, "127.0.0.1:0") @@ -289,7 +309,7 @@ func subTestPacketHandler(t *testing.T) { } }() - a, err := m.CreateAllocation(&FiveTuple{ + alloc, err := manager.CreateAllocation(&FiveTuple{ SrcAddr: clientListener.LocalAddr(), DstAddr: turnSocket.LocalAddr(), }, turnSocket, 0, proto.DefaultLifetime) @@ -307,12 +327,12 @@ func subTestPacketHandler(t *testing.T) { } // Add permission with peer1 address - a.AddPermission(NewPermission(peerListener1.LocalAddr(), m.log)) + alloc.AddPermission(NewPermission(peerListener1.LocalAddr(), manager.log)) // Add channel with min channel number and peer2 address - channelBind := NewChannelBind(proto.MinChannelNumber, peerListener2.LocalAddr(), m.log) - _ = a.AddChannelBind(channelBind, proto.DefaultLifetime) + channelBind := NewChannelBind(proto.MinChannelNumber, peerListener2.LocalAddr(), manager.log) + _ = alloc.AddChannelBind(channelBind, proto.DefaultLifetime) - _, port, _ := ipnet.AddrIPPort(a.RelaySocket.LocalAddr()) + _, port, _ := ipnet.AddrIPPort(alloc.RelaySocket.LocalAddr()) relayAddrWithHostStr := fmt.Sprintf("127.0.0.1:%d", port) relayAddrWithHost, _ := net.ResolveUDPAddr(network, relayAddrWithHostStr) @@ -350,13 +370,15 @@ func subTestPacketHandler(t *testing.T) { assert.Equal(t, targetText2, string(channelData.Data), "get data doesn't equal the target text.") // Listeners close - _ = m.Close() + _ = manager.Close() _ = clientListener.Close() _ = peerListener1.Close() _ = peerListener2.Close() } func subTestResponseCache(t *testing.T) { + t.Helper() + a := NewAllocation(nil, nil, nil) transactionID := [stun.TransactionIDSize]byte{1, 2, 3} responseAttrs := []stun.Setter{ diff --git a/internal/allocation/channel_bind.go b/internal/allocation/channel_bind.go index 19b18431..6ad9b46f 100644 --- a/internal/allocation/channel_bind.go +++ b/internal/allocation/channel_bind.go @@ -22,7 +22,7 @@ type ChannelBind struct { log logging.LeveledLogger } -// NewChannelBind creates a new ChannelBind +// NewChannelBind creates a new ChannelBind. func NewChannelBind(number proto.ChannelNumber, peer net.Addr, log logging.LeveledLogger) *ChannelBind { return &ChannelBind{ Number: number, diff --git a/internal/allocation/five_tuple.go b/internal/allocation/five_tuple.go index 6d812caf..14761611 100644 --- a/internal/allocation/five_tuple.go +++ b/internal/allocation/five_tuple.go @@ -7,10 +7,10 @@ import ( "net" ) -// Protocol is an enum for relay protocol +// Protocol is an enum for relay protocol. type Protocol uint8 -// Network protocols for relay +// Network protocols for relay. const ( UDP Protocol = iota TCP @@ -27,19 +27,19 @@ type FiveTuple struct { SrcAddr, DstAddr net.Addr } -// Equal asserts if two FiveTuples are equal +// Equal asserts if two FiveTuples are equal. func (f *FiveTuple) Equal(b *FiveTuple) bool { return f.Fingerprint() == b.Fingerprint() } -// FiveTupleFingerprint is a comparable representation of a FiveTuple +// FiveTupleFingerprint is a comparable representation of a FiveTuple. type FiveTupleFingerprint struct { srcIP, dstIP [16]byte srcPort, dstPort uint16 protocol Protocol } -// Fingerprint is the identity of a FiveTuple +// Fingerprint is the identity of a FiveTuple. func (f *FiveTuple) Fingerprint() (fp FiveTupleFingerprint) { srcIP, srcPort := netAddrIPAndPort(f.SrcAddr) copy(fp.srcIP[:], srcIP) @@ -48,15 +48,16 @@ func (f *FiveTuple) Fingerprint() (fp FiveTupleFingerprint) { copy(fp.dstIP[:], dstIP) fp.dstPort = dstPort fp.protocol = f.Protocol + return } func netAddrIPAndPort(addr net.Addr) (net.IP, uint16) { switch a := addr.(type) { case *net.UDPAddr: - return a.IP.To16(), uint16(a.Port) + return a.IP.To16(), uint16(a.Port) // nolint:gosec // G115 case *net.TCPAddr: - return a.IP.To16(), uint16(a.Port) + return a.IP.To16(), uint16(a.Port) // nolint:gosec // G115 default: return nil, 0 } diff --git a/internal/allocation/permission.go b/internal/allocation/permission.go index a774f2ce..7b02adc3 100644 --- a/internal/allocation/permission.go +++ b/internal/allocation/permission.go @@ -22,7 +22,7 @@ type Permission struct { log logging.LeveledLogger } -// NewPermission create a new Permission +// NewPermission create a new Permission. func NewPermission(addr net.Addr, log logging.LeveledLogger) *Permission { return &Permission{ Addr: addr, diff --git a/internal/client/allocation.go b/internal/client/allocation.go index a7f74636..ed64eb87 100644 --- a/internal/client/allocation.go +++ b/internal/client/allocation.go @@ -16,7 +16,7 @@ import ( "github.com/pion/turn/v4/internal/proto" ) -// AllocationConfig is a set of configuration params use by NewUDPConn and NewTCPAllocation +// AllocationConfig is a set of configuration params use by NewUDPConn and NewTCPAllocation. type AllocationConfig struct { Client Client RelayedAddr net.Addr @@ -82,6 +82,7 @@ func (a *allocation) refreshAllocation(lifetime time.Duration, dontWait bool) er if dontWait { a.log.Debug("Refresh request sent") + return nil } @@ -93,10 +94,13 @@ func (a *allocation) refreshAllocation(lifetime time.Duration, dontWait bool) er if err = code.GetFrom(res); err == nil { if code.Code == stun.CodeStaleNonce { a.setNonceFromMsg(res) + return errTryAgain } + return err } + return fmt.Errorf("%s", res.Type) //nolint:goerr113 } @@ -108,6 +112,7 @@ func (a *allocation) refreshAllocation(lifetime time.Duration, dontWait bool) er a.setLifetime(updatedLifetime.Duration) a.log.Debugf("Updated lifetime: %d seconds", int(a.lifetime().Seconds())) + return nil } @@ -115,6 +120,7 @@ func (a *allocation) refreshPermissions() error { addrs := a.permMap.addrs() if len(addrs) == 0 { a.log.Debug("No permission to refresh") + return nil } if err := a.CreatePermissions(addrs...); err != nil { @@ -122,9 +128,11 @@ func (a *allocation) refreshPermissions() error { return errTryAgain } a.log.Errorf("Fail to refresh permissions: %s", err) + return err } a.log.Debug("Refresh permissions successful") + return nil } diff --git a/internal/client/binding.go b/internal/client/binding.go index a0217476..1203cceb 100644 --- a/internal/client/binding.go +++ b/internal/client/binding.go @@ -61,7 +61,7 @@ func (b *binding) refreshedAt() time.Time { return b._refreshedAt } -// Thread-safe binding map +// Thread-safe binding map. type bindingManager struct { chanMap map[uint16]*binding addrMap map[string]*binding @@ -84,6 +84,7 @@ func (mgr *bindingManager) assignChannelNumber() uint16 { } else { mgr.next++ } + return n } @@ -100,6 +101,7 @@ func (mgr *bindingManager) create(addr net.Addr) *binding { mgr.chanMap[b.number] = b mgr.addrMap[b.addr.String()] = b + return b } @@ -108,6 +110,7 @@ func (mgr *bindingManager) findByAddr(addr net.Addr) (*binding, bool) { defer mgr.mutex.RUnlock() b, ok := mgr.addrMap[addr.String()] + return b, ok } @@ -116,6 +119,7 @@ func (mgr *bindingManager) findByNumber(number uint16) (*binding, bool) { defer mgr.mutex.RUnlock() b, ok := mgr.chanMap[number] + return b, ok } @@ -130,6 +134,7 @@ func (mgr *bindingManager) deleteByAddr(addr net.Addr) bool { delete(mgr.addrMap, addr.String()) delete(mgr.chanMap, b.number) + return true } @@ -144,6 +149,7 @@ func (mgr *bindingManager) deleteByNumber(number uint16) bool { delete(mgr.addrMap, b.addr.String()) delete(mgr.chanMap, number) + return true } diff --git a/internal/client/binding_test.go b/internal/client/binding_test.go index 5fd982c1..4d9f18a5 100644 --- a/internal/client/binding_test.go +++ b/internal/client/binding_test.go @@ -10,56 +10,56 @@ import ( "github.com/stretchr/testify/assert" ) -func TestBindingManager(t *testing.T) { +func TestBindingManager(t *testing.T) { // nolint:funlen t.Run("number assignment", func(t *testing.T) { - m := newBindingManager() - var n uint16 + bm := newBindingManager() + var chanNum uint16 for i := uint16(0); i < 10; i++ { - n = m.assignChannelNumber() - assert.Equal(t, minChannelNumber+i, n, "should match") + chanNum = bm.assignChannelNumber() + assert.Equal(t, minChannelNumber+i, chanNum, "should match") } - m.next = uint16(0x7ff0) + bm.next = uint16(0x7ff0) for i := uint16(0); i < 16; i++ { - n = m.assignChannelNumber() - assert.Equal(t, 0x7ff0+i, n, "should match") + chanNum = bm.assignChannelNumber() + assert.Equal(t, 0x7ff0+i, chanNum, "should match") } // Back to min - n = m.assignChannelNumber() - assert.Equal(t, minChannelNumber, n, "should match") + chanNum = bm.assignChannelNumber() + assert.Equal(t, minChannelNumber, chanNum, "should match") }) t.Run("method test", func(t *testing.T) { lo := net.IPv4(127, 0, 0, 1) count := 100 - m := newBindingManager() + bm := newBindingManager() for i := 0; i < count; i++ { addr := &net.UDPAddr{IP: lo, Port: 10000 + i} - b0 := m.create(addr) - b1, ok := m.findByAddr(addr) + b0 := bm.create(addr) + b1, ok := bm.findByAddr(addr) assert.True(t, ok, "should succeed") - b2, ok := m.findByNumber(b0.number) + b2, ok := bm.findByNumber(b0.number) assert.True(t, ok, "should succeed") assert.Equal(t, b0, b1, "should match") assert.Equal(t, b0, b2, "should match") } - assert.Equal(t, count, m.size(), "should match") - assert.Equal(t, count, len(m.addrMap), "should match") + assert.Equal(t, count, bm.size(), "should match") + assert.Equal(t, count, len(bm.addrMap), "should match") for i := 0; i < count; i++ { addr := &net.UDPAddr{IP: lo, Port: 10000 + i} if i%2 == 0 { - assert.True(t, m.deleteByAddr(addr), "should return true") + assert.True(t, bm.deleteByAddr(addr), "should return true") } else { - assert.True(t, m.deleteByNumber(minChannelNumber+uint16(i)), "should return true") + assert.True(t, bm.deleteByNumber(minChannelNumber+uint16(i)), "should return true") // nolint:gosec // G115 } } - assert.Equal(t, 0, m.size(), "should match") - assert.Equal(t, 0, len(m.addrMap), "should match") + assert.Equal(t, 0, bm.size(), "should match") + assert.Equal(t, 0, len(bm.addrMap), "should match") }) t.Run("failure test", func(t *testing.T) { diff --git a/internal/client/client.go b/internal/client/client.go index 10b49a28..45041c90 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -10,7 +10,7 @@ import ( "github.com/pion/stun/v3" ) -// Client is an interface for the public turn.Client in order to break cyclic dependencies +// Client is an interface for the public turn.Client in order to break cyclic dependencies. type Client interface { WriteTo(data []byte, to net.Addr) (int, error) PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 95d68c90..03ba6d43 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -19,6 +19,7 @@ func (c *mockClient) WriteTo(data []byte, to net.Addr) (int, error) { if c.writeTo != nil { return c.writeTo(data, to) } + return 0, nil } @@ -26,6 +27,7 @@ func (c *mockClient) PerformTransaction(msg *stun.Message, to net.Addr, dontWait if c.performTransaction != nil { return c.performTransaction(msg, to, dontWait) } + return TransactionResult{}, errFake } diff --git a/internal/client/periodic_timer.go b/internal/client/periodic_timer.go index 5660bc45..9e949129 100644 --- a/internal/client/periodic_timer.go +++ b/internal/client/periodic_timer.go @@ -8,10 +8,10 @@ import ( "time" ) -// PeriodicTimerTimeoutHandler is a handler called on timeout +// PeriodicTimerTimeoutHandler is a handler called on timeout. type PeriodicTimerTimeoutHandler func(timerID int) -// PeriodicTimer is a periodic timer +// PeriodicTimer is a periodic timer. type PeriodicTimer struct { id int interval time.Duration @@ -20,7 +20,7 @@ type PeriodicTimer struct { mutex sync.RWMutex } -// NewPeriodicTimer create a new timer +// NewPeriodicTimer create a new timer. func NewPeriodicTimer(id int, timeoutHandler PeriodicTimerTimeoutHandler, interval time.Duration) *PeriodicTimer { return &PeriodicTimer{ id: id, @@ -76,7 +76,7 @@ func (t *PeriodicTimer) Stop() { } // IsRunning tests if the timer is running. -// Debug purpose only +// Debug purpose only. func (t *PeriodicTimer) IsRunning() bool { t.mutex.RLock() defer t.mutex.RUnlock() diff --git a/internal/client/periodic_timer_test.go b/internal/client/periodic_timer_test.go index d77a16b3..514bbb57 100644 --- a/internal/client/periodic_timer_test.go +++ b/internal/client/periodic_timer_test.go @@ -34,7 +34,13 @@ func TestPeriodicTimer(t *testing.T) { time.Sleep(120 * time.Millisecond) rt.Stop() assert.False(t, rt.IsRunning(), "should not be running") - assert.Equal(t, 4, int(atomic.LoadUint64(&nCbs)), "should be called 4 times (actual: %d)", atomic.LoadUint64(&nCbs)) + assert.Equal( + t, + uint64(4), + atomic.LoadUint64(&nCbs), + "should be called 4 times (actual: %d)", + atomic.LoadUint64(&nCbs), + ) }) t.Run("stop inside handler", func(t *testing.T) { diff --git a/internal/client/permission.go b/internal/client/permission.go index 0436e4ea..d6708d86 100644 --- a/internal/client/permission.go +++ b/internal/client/permission.go @@ -32,7 +32,7 @@ func (p *permission) state() permState { return permState(atomic.LoadInt32((*int32)(&p.st))) } -// Thread-safe permission map +// Thread-safe permission map. type permissionMap struct { permMap map[string]*permission mutex sync.RWMutex @@ -43,6 +43,7 @@ func (m *permissionMap) insert(addr net.Addr, p *permission) bool { defer m.mutex.Unlock() p.addr = addr m.permMap[ipnet.FingerprintAddr(addr)] = p + return true } @@ -50,6 +51,7 @@ func (m *permissionMap) find(addr net.Addr) (*permission, bool) { m.mutex.RLock() defer m.mutex.RUnlock() p, ok := m.permMap[ipnet.FingerprintAddr(addr)] + return p, ok } @@ -67,6 +69,7 @@ func (m *permissionMap) addrs() []net.Addr { for _, p := range m.permMap { addrs = append(addrs, p.addr) } + return addrs } diff --git a/internal/client/permission_test.go b/internal/client/permission_test.go index 7b65d372..5c617c96 100644 --- a/internal/client/permission_test.go +++ b/internal/client/permission_test.go @@ -22,7 +22,7 @@ func TestPermission(t *testing.T) { }) } -func TestPermissionMap(t *testing.T) { +func TestPermissionMap(t *testing.T) { // nolint:funlen t.Run("Basic operations", func(t *testing.T) { pm := newPermissionMap() assert.NotNil(t, pm) @@ -41,20 +41,20 @@ func TestPermissionMap(t *testing.T) { assert.True(t, pm.insert(tcpAddr, perm3)) assert.Equal(t, 3, len(pm.permMap)) - p, ok := pm.find(udpAddr1) + perms, ok := pm.find(udpAddr1) assert.True(t, ok) - assert.Equal(t, perm1, p) - assert.Equal(t, permStateIdle, p.st) + assert.Equal(t, perm1, perms) + assert.Equal(t, permStateIdle, perms.st) - p, ok = pm.find(udpAddr2) + perms, ok = pm.find(udpAddr2) assert.True(t, ok) - assert.Equal(t, perm2, p) - assert.Equal(t, permStatePermitted, p.st) + assert.Equal(t, perm2, perms) + assert.Equal(t, permStatePermitted, perms.st) - p, ok = pm.find(tcpAddr) + perms, ok = pm.find(tcpAddr) assert.True(t, ok) - assert.Equal(t, perm3, p) - assert.Equal(t, permStateIdle, p.st) + assert.Equal(t, perm3, perms) + assert.Equal(t, permStateIdle, perms.st) addrs := pm.addrs() ips := []net.IP{} diff --git a/internal/client/tcp_alloc.go b/internal/client/tcp_alloc.go index 82948387..5686fa20 100644 --- a/internal/client/tcp_alloc.go +++ b/internal/client/tcp_alloc.go @@ -34,9 +34,9 @@ type TCPAllocation struct { allocation } -// NewTCPAllocation creates a new instance of TCPConn +// NewTCPAllocation creates a new instance of TCPConn. func NewTCPAllocation(config *AllocationConfig) *TCPAllocation { - a := &TCPAllocation{ + alloc := &TCPAllocation{ connAttemptCh: make(chan *connectionAttempt, 10), acceptTimer: time.NewTimer(time.Duration(math.MaxInt64)), allocation: allocation{ @@ -54,31 +54,31 @@ func NewTCPAllocation(config *AllocationConfig) *TCPAllocation { }, } - a.log.Debugf("Initial lifetime: %d seconds", int(a.lifetime().Seconds())) + alloc.log.Debugf("Initial lifetime: %d seconds", int(alloc.lifetime().Seconds())) - a.refreshAllocTimer = NewPeriodicTimer( + alloc.refreshAllocTimer = NewPeriodicTimer( timerIDRefreshAlloc, - a.onRefreshTimers, - a.lifetime()/2, + alloc.onRefreshTimers, + alloc.lifetime()/2, ) - a.refreshPermsTimer = NewPeriodicTimer( + alloc.refreshPermsTimer = NewPeriodicTimer( timerIDRefreshPerms, - a.onRefreshTimers, + alloc.onRefreshTimers, permRefreshInterval, ) - if a.refreshAllocTimer.Start() { - a.log.Debug("Started refreshAllocTimer") + if alloc.refreshAllocTimer.Start() { + alloc.log.Debug("Started refreshAllocTimer") } - if a.refreshPermsTimer.Start() { - a.log.Debug("Started refreshPermsTimer") + if alloc.refreshPermsTimer.Start() { + alloc.log.Debug("Started refreshPermsTimer") } - return a + return alloc } -// Connect sends a Connect request to the turn server and returns a chosen connection ID +// Connect sends a Connect request to the turn server and returns a chosen connection ID. func (a *TCPAllocation) Connect(peer net.Addr) (proto.ConnectionID, error) { setters := []stun.Setter{ stun.TransactionID, @@ -119,6 +119,7 @@ func (a *TCPAllocation) Connect(peer net.Addr) (proto.ConnectionID, error) { } a.log.Debugf("Connect request successful (cid=%v)", cid) + return cid, nil } @@ -218,8 +219,8 @@ func (a *TCPAllocation) DialTCPWithConn(conn net.Conn, _ string, rAddr *net.TCPA return dataConn, nil } -// BindConnection associates the provided connection -func (a *TCPAllocation) BindConnection(dataConn *TCPConn, cid proto.ConnectionID) error { +// BindConnection associates the provided connection. +func (a *TCPAllocation) BindConnection(dataConn *TCPConn, cid proto.ConnectionID) error { // nolint:cyclop msg, err := stun.Build( stun.TransactionID, stun.NewType(stun.MethodConnectionBind, stun.ClassRequest), @@ -272,9 +273,11 @@ func (a *TCPAllocation) BindConnection(dataConn *TCPConn, cid proto.ConnectionID if err = code.GetFrom(res); err == nil { return fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113 } + return fmt.Errorf("%s", res.Type) //nolint:goerr113 case stun.ClassSuccessResponse: a.log.Debug("Successful connectionBind request") + return nil default: return fmt.Errorf("%w: %s", errUnexpectedSTUNRequestMessage, res.String()) @@ -347,6 +350,7 @@ func (a *TCPAllocation) SetDeadline(t time.Time) error { d = time.Until(t) } a.acceptTimer.Reset(d) + return nil } @@ -358,10 +362,11 @@ func (a *TCPAllocation) Close() error { a.refreshPermsTimer.Stop() a.client.OnDeallocated(a.relayedAddr) + return a.refreshAllocation(0, true /* dontWait=true */) } -// Addr returns the relayed address of the allocation +// Addr returns the relayed address of the allocation. func (a *TCPAllocation) Addr() net.Addr { return a.relayedAddr } diff --git a/internal/client/tcp_conn.go b/internal/client/tcp_conn.go index 990b6d42..18330bec 100644 --- a/internal/client/tcp_conn.go +++ b/internal/client/tcp_conn.go @@ -23,7 +23,7 @@ const ( var _ transport.TCPConn = (*TCPConn)(nil) // Includes type check for net.Conn // TCPConn wraps a transport.TCPConn and returns the allocations relayed -// transport address in response to TCPConn.LocalAddress() +// transport address in response to TCPConn.LocalAddress(). type TCPConn struct { transport.TCPConn remoteAddress *net.TCPAddr diff --git a/internal/client/tcp_conn_test.go b/internal/client/tcp_conn_test.go index b7ef6d8c..88e022e1 100644 --- a/internal/client/tcp_conn_test.go +++ b/internal/client/tcp_conn_test.go @@ -19,7 +19,11 @@ type dummyTCPConn struct { transport.TCPConn } -func buildMsg(transactionID [stun.TransactionIDSize]byte, msgType stun.MessageType, additional ...stun.Setter) []stun.Setter { +func buildMsg( + transactionID [stun.TransactionIDSize]byte, + msgType stun.MessageType, + additional ...stun.Setter, +) []stun.Setter { return append([]stun.Setter{&stun.Message{TransactionID: transactionID}, msgType}, additional...) } @@ -37,10 +41,11 @@ func (c dummyTCPConn) Read(b []byte) (int, error) { } copy(b, msg.Raw) + return len(msg.Raw), nil } -func TestTCPConn(t *testing.T) { +func TestTCPConn(t *testing.T) { //nolint:funlen t.Run("Connect()", func(t *testing.T) { var cid proto.ConnectionID = 5 client := &mockClient{ @@ -52,8 +57,10 @@ func TestTCPConn(t *testing.T) { cid, ) assert.NoError(t, err) + return TransactionResult{Msg: msg}, nil } + return TransactionResult{}, errFake }, } @@ -91,8 +98,10 @@ func TestTCPConn(t *testing.T) { stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}, ) assert.NoError(t, err) + return TransactionResult{Msg: msg}, nil } + return TransactionResult{}, errFake }, } @@ -171,6 +180,7 @@ func TestTCPConn(t *testing.T) { cid, ) assert.NoError(t, err) + return TransactionResult{Msg: msg}, nil }, } diff --git a/internal/client/transaction.go b/internal/client/transaction.go index b1c9105e..744df668 100644 --- a/internal/client/transaction.go +++ b/internal/client/transaction.go @@ -15,7 +15,7 @@ const ( maxRtxInterval time.Duration = 1600 * time.Millisecond ) -// TransactionResult is a bag of result values of a transaction +// TransactionResult is a bag of result values of a transaction. type TransactionResult struct { Msg *stun.Message From net.Addr @@ -23,7 +23,7 @@ type TransactionResult struct { Err error } -// TransactionConfig is a set of config params used by NewTransaction +// TransactionConfig is a set of config params used by NewTransaction. type TransactionConfig struct { Key string Raw []byte @@ -32,7 +32,7 @@ type TransactionConfig struct { IgnoreResult bool // True to throw away the result of this transaction (it will not be readable using WaitForResult) } -// Transaction represents a transaction +// Transaction represents a transaction. type Transaction struct { Key string // Read-only Raw []byte // Read-only @@ -44,7 +44,7 @@ type Transaction struct { mutex sync.RWMutex } -// NewTransaction creates a new instance of Transaction +// NewTransaction creates a new instance of Transaction. func NewTransaction(config *TransactionConfig) *Transaction { var resultCh chan TransactionResult if !config.IgnoreResult { @@ -60,7 +60,7 @@ func NewTransaction(config *TransactionConfig) *Transaction { } } -// StartRtxTimer starts the transaction timer +// StartRtxTimer starts the transaction timer. func (t *Transaction) StartRtxTimer(onTimeout func(trKey string, nRtx int)) { t.mutex.Lock() defer t.mutex.Unlock() @@ -78,7 +78,7 @@ func (t *Transaction) StartRtxTimer(onTimeout func(trKey string, nRtx int)) { }) } -// StopRtxTimer stop the transaction timer +// StopRtxTimer stop the transaction timer. func (t *Transaction) StopRtxTimer() { t.mutex.Lock() defer t.mutex.Unlock() @@ -88,7 +88,7 @@ func (t *Transaction) StopRtxTimer() { } } -// WriteResult writes the result to the result channel +// WriteResult writes the result to the result channel. func (t *Transaction) WriteResult(res TransactionResult) bool { if t.resultCh == nil { return false @@ -99,7 +99,7 @@ func (t *Transaction) WriteResult(res TransactionResult) bool { return true } -// WaitForResult waits for the transaction result +// WaitForResult waits for the transaction result. func (t *Transaction) WaitForResult() TransactionResult { if t.resultCh == nil { return TransactionResult{ @@ -111,17 +111,18 @@ func (t *Transaction) WaitForResult() TransactionResult { if !ok { result.Err = errTransactionClosed } + return result } -// Close closes the transaction +// Close closes the transaction. func (t *Transaction) Close() { if t.resultCh != nil { close(t.resultCh) } } -// Retries returns the number of retransmission it has made +// Retries returns the number of retransmission it has made. func (t *Transaction) Retries() int { t.mutex.RLock() defer t.mutex.RUnlock() @@ -129,38 +130,40 @@ func (t *Transaction) Retries() int { return t.nRtx } -// TransactionMap is a thread-safe transaction map +// TransactionMap is a thread-safe transaction map. type TransactionMap struct { trMap map[string]*Transaction mutex sync.RWMutex } -// NewTransactionMap create a new instance of the transaction map +// NewTransactionMap create a new instance of the transaction map. func NewTransactionMap() *TransactionMap { return &TransactionMap{ trMap: map[string]*Transaction{}, } } -// Insert inserts a transaction to the map +// Insert inserts a transaction to the map. func (m *TransactionMap) Insert(key string, tr *Transaction) bool { m.mutex.Lock() defer m.mutex.Unlock() m.trMap[key] = tr + return true } -// Find looks up a transaction by its key +// Find looks up a transaction by its key. func (m *TransactionMap) Find(key string) (*Transaction, bool) { m.mutex.RLock() defer m.mutex.RUnlock() tr, ok := m.trMap[key] + return tr, ok } -// Delete deletes a transaction by its key +// Delete deletes a transaction by its key. func (m *TransactionMap) Delete(key string) { m.mutex.Lock() defer m.mutex.Unlock() @@ -168,7 +171,7 @@ func (m *TransactionMap) Delete(key string) { delete(m.trMap, key) } -// CloseAndDeleteAll closes and deletes all transactions +// CloseAndDeleteAll closes and deletes all transactions. func (m *TransactionMap) CloseAndDeleteAll() { m.mutex.Lock() defer m.mutex.Unlock() @@ -179,7 +182,7 @@ func (m *TransactionMap) CloseAndDeleteAll() { } } -// Size returns the length of the transaction map +// Size returns the length of the transaction map. func (m *TransactionMap) Size() int { m.mutex.RLock() defer m.mutex.RUnlock() diff --git a/internal/client/trylock.go b/internal/client/trylock.go index 6bfe6ff3..8380b0c8 100644 --- a/internal/client/trylock.go +++ b/internal/client/trylock.go @@ -18,6 +18,7 @@ func (c *TryLock) Lock() error { if !atomic.CompareAndSwapInt32(&c.n, 0, 1) { return errDoubleLock } + return nil } diff --git a/internal/client/trylock_test.go b/internal/client/trylock_test.go index cd3d1428..98547b05 100644 --- a/internal/client/trylock_test.go +++ b/internal/client/trylock_test.go @@ -18,6 +18,7 @@ func TestTryLock(t *testing.T) { return err } defer cl.Unlock() + return nil } @@ -34,6 +35,7 @@ func TestTryLock(t *testing.T) { } defer cl.Unlock() time.Sleep(50 * time.Millisecond) + return nil } diff --git a/internal/client/udp_conn.go b/internal/client/udp_conn.go index 7bb18765..f423d767 100644 --- a/internal/client/udp_conn.go +++ b/internal/client/udp_conn.go @@ -33,7 +33,7 @@ type inboundData struct { } // UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. -// compatible with net.PacketConn and net.Conn +// compatible with net.PacketConn and net.Conn. type UDPConn struct { bindingMgr *bindingManager // Thread-safe readCh chan *inboundData // Thread-safe @@ -41,9 +41,9 @@ type UDPConn struct { allocation } -// NewUDPConn creates a new instance of UDPConn +// NewUDPConn creates a new instance of UDPConn. func NewUDPConn(config *AllocationConfig) *UDPConn { - c := &UDPConn{ + conn := &UDPConn{ bindingMgr: newBindingManager(), readCh: make(chan *inboundData, maxReadQueueSize), closeCh: make(chan struct{}), @@ -63,28 +63,28 @@ func NewUDPConn(config *AllocationConfig) *UDPConn { }, } - c.log.Debugf("Initial lifetime: %d seconds", int(c.lifetime().Seconds())) + conn.log.Debugf("Initial lifetime: %d seconds", int(conn.lifetime().Seconds())) - c.refreshAllocTimer = NewPeriodicTimer( + conn.refreshAllocTimer = NewPeriodicTimer( timerIDRefreshAlloc, - c.onRefreshTimers, - c.lifetime()/2, + conn.onRefreshTimers, + conn.lifetime()/2, ) - c.refreshPermsTimer = NewPeriodicTimer( + conn.refreshPermsTimer = NewPeriodicTimer( timerIDRefreshPerms, - c.onRefreshTimers, + conn.onRefreshTimers, permRefreshInterval, ) - if c.refreshAllocTimer.Start() { - c.log.Debugf("Started refresh allocation timer") + if conn.refreshAllocTimer.Start() { + conn.log.Debugf("Started refresh allocation timer") } - if c.refreshPermsTimer.Start() { - c.log.Debugf("Started refresh permission timer") + if conn.refreshPermsTimer.Start() { + conn.log.Debugf("Started refresh permission timer") } - return c + return conn } // ReadFrom reads a packet from the connection, @@ -105,6 +105,7 @@ func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { if n < len(ibData.data) { return 0, nil, io.ErrShortBuffer } + return n, ibData.from, nil case <-c.readTimer.C: @@ -134,19 +135,21 @@ func (a *allocation) createPermission(perm *permission, addr net.Addr) error { // Punch a hole! (this would block a bit..) if err := a.CreatePermissions(addr); err != nil { a.permMap.delete(addr) + return err } perm.setState(permStatePermitted) } + return nil } -// WriteTo writes a packet with payload p to addr. +// WriteTo writes a packet with payload to addr. // WriteTo can be made to time out and return // an Error with Timeout() == true after a fixed time limit; // see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. -func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { //nolint: gocognit +func (c *UDPConn) WriteTo(payload []byte, addr net.Addr) (int, error) { //nolint:gocognit,cyclop,funlen var err error _, ok := addr.(*net.UDPAddr) if !ok { @@ -177,31 +180,32 @@ func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { //nolint: goco } // Bind channel - b, ok := c.bindingMgr.findByAddr(addr) + bound, ok := c.bindingMgr.findByAddr(addr) if !ok { - b = c.bindingMgr.create(addr) + bound = c.bindingMgr.create(addr) } - bindSt := b.state() + bindSt := bound.state() + //nolint:nestif if bindSt == bindingStateIdle || bindSt == bindingStateRequest || bindSt == bindingStateFailed { func() { // Block only callers with the same binding until // the binding transaction has been complete - b.muBind.Lock() - defer b.muBind.Unlock() + bound.muBind.Lock() + defer bound.muBind.Unlock() // Binding state may have been changed while waiting. check again. - if b.state() == bindingStateIdle { - b.setState(bindingStateRequest) + if bound.state() == bindingStateIdle { + bound.setState(bindingStateRequest) go func() { - err2 := c.bind(b) + err2 := c.bind(bound) if err2 != nil { c.log.Warnf("Failed to bind bind(): %s", err2) - b.setState(bindingStateFailed) + bound.setState(bindingStateFailed) // Keep going... } else { - b.setState(bindingStateReady) + bound.setState(bindingStateReady) } }() } @@ -213,7 +217,7 @@ func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { //nolint: goco msg, err = stun.Build( stun.TransactionID, stun.NewType(stun.MethodSend, stun.ClassIndication), - proto.Data(p), + proto.Data(payload), peerAddr, stun.Fingerprint, ) @@ -230,31 +234,32 @@ func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { //nolint: goco // Check if the binding needs a refresh func() { - b.muBind.Lock() - defer b.muBind.Unlock() + bound.muBind.Lock() + defer bound.muBind.Unlock() - if b.state() == bindingStateReady && time.Since(b.refreshedAt()) > 5*time.Minute { - b.setState(bindingStateRefresh) + if bound.state() == bindingStateReady && time.Since(bound.refreshedAt()) > 5*time.Minute { + bound.setState(bindingStateRefresh) go func() { - err = c.bind(b) + err = c.bind(bound) if err != nil { c.log.Warnf("Failed to bind() for refresh: %s", err) - b.setState(bindingStateFailed) + bound.setState(bindingStateFailed) // Keep going... } else { - b.setRefreshedAt(time.Now()) - b.setState(bindingStateReady) + bound.setRefreshedAt(time.Now()) + bound.setState(bindingStateReady) } }() } }() // Send via ChannelData - _, err = c.sendChannelData(p, b.number) + _, err = c.sendChannelData(payload, bound.number) if err != nil { return 0, err } - return len(p), nil + + return len(payload), nil } // Close closes the connection. @@ -271,6 +276,7 @@ func (c *UDPConn) Close() error { } c.client.OnDeallocated(c.relayedAddr) + return c.refreshAllocation(0, true /* dontWait=true */) } @@ -309,6 +315,7 @@ func (c *UDPConn) SetReadDeadline(t time.Time) error { d = time.Until(t) } c.readTimer.Reset(d) + return nil } @@ -372,17 +379,20 @@ func (a *allocation) CreatePermissions(addrs ...net.Addr) error { if err = code.GetFrom(res); err == nil { if code.Code == stun.CodeStaleNonce { a.setNonceFromMsg(res) + return errTryAgain } + return fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113 } + return fmt.Errorf("%s", res.Type) //nolint:goerr113 } return nil } -// HandleInbound passes inbound data in UDPConn +// HandleInbound passes inbound data in UDPConn. func (c *UDPConn) HandleInbound(data []byte, from net.Addr) { // Copy data copied := make([]byte, len(data)) @@ -396,21 +406,22 @@ func (c *UDPConn) HandleInbound(data []byte, from net.Addr) { } // FindAddrByChannelNumber returns a peer address associated with the -// channel number on this UDPConn +// channel number on this UDPConn. func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) { b, ok := c.bindingMgr.findByNumber(chNum) if !ok { return nil, false } + return b.addr, true } -func (c *UDPConn) bind(b *binding) error { +func (c *UDPConn) bind(bound *binding) error { setters := []stun.Setter{ stun.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassRequest), - addr2PeerAddress(b.addr), - proto.ChannelNumber(b.number), + addr2PeerAddress(bound.addr), + proto.ChannelNumber(bound.number), c.username, c.realm, c.nonce(), @@ -425,7 +436,8 @@ func (c *UDPConn) bind(b *binding) error { trRes, err := c.client.PerformTransaction(msg, c.serverAddr, false) if err != nil { - c.bindingMgr.deleteByAddr(b.addr) + c.bindingMgr.deleteByAddr(bound.addr) + return err } @@ -435,7 +447,7 @@ func (c *UDPConn) bind(b *binding) error { return fmt.Errorf("unexpected response type %s", res.Type) //nolint:goerr113 } - c.log.Debugf("Channel binding successful: %s %d", b.addr, b.number) + c.log.Debugf("Channel binding successful: %s %d", bound.addr, bound.number) // Success. return nil @@ -451,5 +463,6 @@ func (c *UDPConn) sendChannelData(data []byte, chNum uint16) (int, error) { if err != nil { return 0, err } + return len(data), nil } diff --git a/internal/client/udp_conn_test.go b/internal/client/udp_conn_test.go index 889ffe4b..0b01315a 100644 --- a/internal/client/udp_conn_test.go +++ b/internal/client/udp_conn_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestUDPConn(t *testing.T) { +func TestUDPConn(t *testing.T) { //nolint:funlen t.Run("bind()", func(t *testing.T) { client := &mockClient{ performTransaction: func(*stun.Message, net.Addr, bool) (TransactionResult, error) { diff --git a/internal/ipnet/util.go b/internal/ipnet/util.go index c3a5b32f..9753ef35 100644 --- a/internal/ipnet/util.go +++ b/internal/ipnet/util.go @@ -11,7 +11,7 @@ import ( var errFailedToCastAddr = errors.New("failed to cast net.Addr to *net.UDPAddr or *net.TCPAddr") -// AddrIPPort extracts the IP and Port from a net.Addr +// AddrIPPort extracts the IP and Port from a net.Addr. func AddrIPPort(a net.Addr) (net.IP, int, error) { aUDP, ok := a.(*net.UDPAddr) if ok { @@ -27,7 +27,7 @@ func AddrIPPort(a net.Addr) (net.IP, int, error) { } // AddrEqual asserts that two net.Addrs are equal -// Currently only supports UDP but will be extended in the future to support others +// Currently only supports UDP but will be extended in the future to support others. func AddrEqual(a, b net.Addr) bool { aUDP, ok := a.(*net.UDPAddr) if !ok { @@ -51,5 +51,6 @@ func FingerprintAddr(addr net.Addr) string { case *net.TCPAddr: // Do we really need this case? return a.IP.String() } + return "" // Should never happen } diff --git a/internal/proto/addr.go b/internal/proto/addr.go index 7d4db8bf..38b3027f 100644 --- a/internal/proto/addr.go +++ b/internal/proto/addr.go @@ -28,6 +28,7 @@ func (a Addr) Equal(b Addr) bool { if a.Port != b.Port { return false } + return a.IP.Equal(b.IP) } @@ -64,5 +65,6 @@ func (t FiveTuple) Equal(b FiveTuple) bool { if !t.Server.Equal(b.Server) { return false } + return true } diff --git a/internal/proto/chandata.go b/internal/proto/chandata.go index df937119..4c9bb262 100644 --- a/internal/proto/chandata.go +++ b/internal/proto/chandata.go @@ -12,7 +12,7 @@ import ( // ChannelData represents The ChannelData Message. // -// See RFC 5766 Section 11.4 +// See RFC 5766 Section 11.4. type ChannelData struct { Data []byte // Can be sub slice of Raw Length int // Ignored while encoding, len(Data) is used @@ -20,21 +20,22 @@ type ChannelData struct { Raw []byte } -// Equal returns true if b == c. -func (c *ChannelData) Equal(b *ChannelData) bool { - if c == nil && b == nil { +// Equal returns true if compareTo == c. +func (c *ChannelData) Equal(compareTo *ChannelData) bool { + if c == nil && compareTo == nil { return true } - if c == nil || b == nil { + if c == nil || compareTo == nil { return false } - if c.Number != b.Number { + if c.Number != compareTo.Number { return false } - if len(c.Data) != len(b.Data) { + if len(c.Data) != len(compareTo.Data) { return false } - return bytes.Equal(c.Data, b.Data) + + return bytes.Equal(c.Data, compareTo.Data) } // Grow ensures that internal buffer will fit v more bytes and @@ -76,6 +77,7 @@ func nearestPaddedValueLength(l int) int { if n < l { n += padding } + return n } @@ -90,7 +92,7 @@ func (c *ChannelData) WriteHeader() { _ = c.Raw[:channelDataHeaderSize] binary.BigEndian.PutUint16(c.Raw[:channelDataNumberSize], uint16(c.Number)) binary.BigEndian.PutUint16(c.Raw[channelDataNumberSize:channelDataHeaderSize], - uint16(len(c.Data)), + uint16(len(c.Data)), // nolint:gosec // G115 ) } @@ -118,6 +120,7 @@ func (c *ChannelData) Decode() error { if int(l) > len(buf[channelDataHeaderSize:]) { return ErrBadChannelDataLength } + return nil } @@ -139,5 +142,6 @@ func IsChannelData(buf []byte) bool { // Quick check for channel number. num := binary.BigEndian.Uint16(buf[0:channelNumberSize]) + return isChannelNumberValid(num) } diff --git a/internal/proto/chandata_test.go b/internal/proto/chandata_test.go index 796ebfea..ad01f994 100644 --- a/internal/proto/chandata_test.go +++ b/internal/proto/chandata_test.go @@ -13,25 +13,25 @@ import ( ) func TestChannelData_Encode(t *testing.T) { - d := &ChannelData{ + chanData := &ChannelData{ Data: []byte{1, 2, 3, 4}, Number: MinChannelNumber + 1, } - d.Encode() + chanData.Encode() b := &ChannelData{} - b.Raw = append(b.Raw, d.Raw...) + b.Raw = append(b.Raw, chanData.Raw...) if err := b.Decode(); err != nil { t.Error(err) } - if !b.Equal(d) { + if !b.Equal(chanData) { t.Error("not equal") } - if !IsChannelData(b.Raw) || !IsChannelData(d.Raw) { + if !IsChannelData(b.Raw) || !IsChannelData(chanData.Raw) { t.Error("unexpected IsChannelData") } } -func TestChannelData_Equal(t *testing.T) { +func TestChannelData_Equal(t *testing.T) { // nolint:funlen for _, tc := range []struct { name string a, b *ChannelData @@ -235,14 +235,14 @@ func TestChromeChannelData(t *testing.T) { // All hex streams decoded to raw binary format and stored in data slice. // Decoding packets to messages. for i, packet := range data { - m := new(ChannelData) - m.Raw = packet - if err := m.Decode(); err != nil { + chanData := new(ChannelData) + chanData.Raw = packet + if err := chanData.Decode(); err != nil { t.Errorf("Packet %d: %v", i, err) } encoded := &ChannelData{ - Data: m.Data, - Number: m.Number, + Data: chanData.Data, + Number: chanData.Number, } encoded.Encode() decoded := new(ChannelData) @@ -250,11 +250,11 @@ func TestChromeChannelData(t *testing.T) { if err := decoded.Decode(); err != nil { t.Error(err) } - if !decoded.Equal(m) { + if !decoded.Equal(chanData) { t.Error("should be equal") } - messages = append(messages, m) + messages = append(messages, chanData) } if len(messages) != 2 { t.Error("unexpected message slice list") diff --git a/internal/proto/chann.go b/internal/proto/chann.go index 3aeb59f3..da017bf6 100644 --- a/internal/proto/chann.go +++ b/internal/proto/chann.go @@ -15,7 +15,7 @@ import ( // // The CHANNEL-NUMBER attribute contains the number of the channel. // -// RFC 5766 Section 14.1 +// RFC 5766 Section 14.1. type ChannelNumber uint16 // Encoded as uint16 func (n ChannelNumber) String() string { return strconv.Itoa(int(n)) } @@ -29,6 +29,7 @@ func (n ChannelNumber) AddTo(m *stun.Message) error { binary.BigEndian.PutUint16(v[:2], uint16(n)) // v[2:4] are zeroes (RFFU = 0) m.Add(stun.AttrChannelNumber, v) + return nil } diff --git a/internal/proto/chann_test.go b/internal/proto/chann_test.go index 04c85fda..ebbc1fb4 100644 --- a/internal/proto/chann_test.go +++ b/internal/proto/chann_test.go @@ -35,7 +35,7 @@ func BenchmarkChannelNumber(b *testing.B) { }) } -func TestChannelNumber(t *testing.T) { +func TestChannelNumber(t *testing.T) { // nolint:cyclop,funlen t.Run("String", func(t *testing.T) { n := ChannelNumber(112) if n.String() != "112" { @@ -43,12 +43,12 @@ func TestChannelNumber(t *testing.T) { } }) t.Run("NoAlloc", func(t *testing.T) { - m := &stun.Message{} + stunMsg := &stun.Message{} if wasAllocs(func() { // Case with ChannelNumber on stack. n := ChannelNumber(6) - n.AddTo(m) //nolint - m.Reset() + n.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } @@ -57,30 +57,30 @@ func TestChannelNumber(t *testing.T) { nP := &n if wasAllocs(func() { // On heap. - nP.AddTo(m) //nolint - m.Reset() + nP.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } }) t.Run("AddTo", func(t *testing.T) { - m := new(stun.Message) - n := ChannelNumber(6) - if err := n.AddTo(m); err != nil { + stunMsg := new(stun.Message) + chanNumber := ChannelNumber(6) + if err := chanNumber.AddTo(stunMsg); err != nil { t.Error(err) } - m.WriteHeader() + stunMsg.WriteHeader() t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(m.Raw); err != nil { + if _, err := decoded.Write(stunMsg.Raw); err != nil { t.Fatal("failed to decode message:", err) } var numDecoded ChannelNumber if err := numDecoded.GetFrom(decoded); err != nil { t.Fatal(err) } - if numDecoded != n { - t.Errorf("Decoded %d, expected %d", numDecoded, n) + if numDecoded != chanNumber { + t.Errorf("Decoded %d, expected %d", numDecoded, chanNumber) } if wasAllocs(func() { var num ChannelNumber diff --git a/internal/proto/connection_id.go b/internal/proto/connection_id.go index 568deff1..d95d9b04 100644 --- a/internal/proto/connection_id.go +++ b/internal/proto/connection_id.go @@ -14,7 +14,7 @@ import ( // The CONNECTION-ID attribute uniquely identifies a peer data // connection. It is a 32-bit unsigned integral value. // -// RFC 6062 Section 6.2.1 +// RFC 6062 Section 6.2.1. type ConnectionID uint32 const connectionIDSize = 4 // uint32: 4 bytes, 32 bits @@ -24,6 +24,7 @@ func (c ConnectionID) AddTo(m *stun.Message) error { v := make([]byte, lifetimeSize) binary.BigEndian.PutUint32(v, uint32(c)) m.Add(stun.AttrConnectionID, v) + return nil } @@ -38,5 +39,6 @@ func (c *ConnectionID) GetFrom(m *stun.Message) error { } _ = v[connectionIDSize-1] // Asserting length *(*uint32)(c) = binary.BigEndian.Uint32(v) + return nil } diff --git a/internal/proto/data.go b/internal/proto/data.go index dcba23b4..ea2f0c86 100644 --- a/internal/proto/data.go +++ b/internal/proto/data.go @@ -13,12 +13,13 @@ import "github.com/pion/stun/v3" // the UDP header if the data was been sent directly between the client // and the peer). // -// RFC 5766 Section 14.4 +// RFC 5766 Section 14.4. type Data []byte // AddTo adds DATA to message. func (d Data) AddTo(m *stun.Message) error { m.Add(stun.AttrData, d) + return nil } @@ -29,5 +30,6 @@ func (d *Data) GetFrom(m *stun.Message) error { return err } *d = v + return nil } diff --git a/internal/proto/data_test.go b/internal/proto/data_test.go index 1f18e451..de714ba7 100644 --- a/internal/proto/data_test.go +++ b/internal/proto/data_test.go @@ -34,13 +34,13 @@ func BenchmarkData(b *testing.B) { func TestData(t *testing.T) { t.Run("NoAlloc", func(t *testing.T) { - m := new(stun.Message) + stunMsg := new(stun.Message) v := []byte{1, 2, 3, 4} if wasAllocs(func() { // On stack. d := Data(v) - d.AddTo(m) //nolint - m.Reset() + d.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } @@ -48,16 +48,16 @@ func TestData(t *testing.T) { d := &Data{1, 2, 3, 4} if wasAllocs(func() { // On heap. - d.AddTo(m) //nolint - m.Reset() + d.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } }) t.Run("AddTo", func(t *testing.T) { m := new(stun.Message) - d := Data{1, 2, 33, 44, 0x13, 0xaf} - if err := d.AddTo(m); err != nil { + data := Data{1, 2, 33, 44, 0x13, 0xaf} + if err := data.AddTo(m); err != nil { t.Fatal(err) } m.WriteHeader() @@ -70,8 +70,8 @@ func TestData(t *testing.T) { if err := dataDecoded.GetFrom(decoded); err != nil { t.Fatal(err) } - if !bytes.Equal(dataDecoded, d) { - t.Error(dataDecoded, "!=", d, "(expected)") + if !bytes.Equal(dataDecoded, data) { + t.Error(dataDecoded, "!=", data, "(expected)") } if wasAllocs(func() { var dataDecoded Data diff --git a/internal/proto/dontfrag.go b/internal/proto/dontfrag.go index e4be0af6..d46ae8f1 100644 --- a/internal/proto/dontfrag.go +++ b/internal/proto/dontfrag.go @@ -8,7 +8,7 @@ import ( ) // DontFragmentAttr is a deprecated alias for DontFragment -// Deprecated: Please use DontFragment +// Deprecated: Please use DontFragment. type DontFragmentAttr = DontFragment // DontFragment represents DONT-FRAGMENT attribute. @@ -18,7 +18,7 @@ type DontFragmentAttr = DontFragment // application data onward to the peer. This attribute has no value // part and thus the attribute length field is 0. // -// RFC 5766 Section 14.8 +// RFC 5766 Section 14.8. type DontFragment struct{} const dontFragmentSize = 0 @@ -26,6 +26,7 @@ const dontFragmentSize = 0 // AddTo adds DONT-FRAGMENT attribute to message. func (DontFragment) AddTo(m *stun.Message) error { m.Add(stun.AttrDontFragment, nil) + return nil } @@ -35,11 +36,13 @@ func (d *DontFragment) GetFrom(m *stun.Message) error { if err != nil { return err } + return stun.CheckSize(stun.AttrDontFragment, len(v), dontFragmentSize) } // IsSet returns true if DONT-FRAGMENT attribute is set. func (DontFragment) IsSet(m *stun.Message) bool { _, err := m.Get(stun.AttrDontFragment) + return err == nil } diff --git a/internal/proto/dontfrag_test.go b/internal/proto/dontfrag_test.go index 477d5632..f8100d9a 100644 --- a/internal/proto/dontfrag_test.go +++ b/internal/proto/dontfrag_test.go @@ -20,21 +20,21 @@ func TestDontFragment(t *testing.T) { } }) t.Run("AddTo", func(t *testing.T) { - m := new(stun.Message) - if err := dontFrag.AddTo(m); err != nil { + stunMsg := new(stun.Message) + if err := dontFrag.AddTo(stunMsg); err != nil { t.Error(err) } - m.WriteHeader() + stunMsg.WriteHeader() t.Run("IsSet", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(m.Raw); err != nil { + if _, err := decoded.Write(stunMsg.Raw); err != nil { t.Fatal("failed to decode message:", err) } - if !dontFrag.IsSet(m) { + if !dontFrag.IsSet(stunMsg) { t.Error("should be set") } if wasAllocs(func() { - dontFrag.IsSet(m) + dontFrag.IsSet(stunMsg) }) { t.Error("unexpected allocations") } diff --git a/internal/proto/evenport.go b/internal/proto/evenport.go index 31468c75..6989ab89 100644 --- a/internal/proto/evenport.go +++ b/internal/proto/evenport.go @@ -11,7 +11,7 @@ import "github.com/pion/stun/v3" // relayed transport address be even, and (optionally) that the server // reserve the next-higher port number. // -// RFC 5766 Section 14.6 +// RFC 5766 Section 14.6. type EvenPort struct { // ReservePort means that the server is requested to reserve // the next-higher port number (on the same IP address) @@ -23,6 +23,7 @@ func (p EvenPort) String() string { if p.ReservePort { return "reserve: true" } + return "reserve: false" } @@ -39,6 +40,7 @@ func (p EvenPort) AddTo(m *stun.Message) error { v[0] = firstBitSet } m.Add(stun.AttrEvenPort, v) + return nil } @@ -54,5 +56,6 @@ func (p *EvenPort) GetFrom(m *stun.Message) error { if v[0]&firstBitSet > 0 { p.ReservePort = true } + return nil } diff --git a/internal/proto/evenport_test.go b/internal/proto/evenport_test.go index 7b09d177..2e5abcaf 100644 --- a/internal/proto/evenport_test.go +++ b/internal/proto/evenport_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestEvenPort(t *testing.T) { +func TestEvenPort(t *testing.T) { // nolint:cyclop,funlen t.Run("String", func(t *testing.T) { p := EvenPort{} if p.String() != "reserve: false" { @@ -42,10 +42,10 @@ func TestEvenPort(t *testing.T) { }) t.Run("AddTo", func(t *testing.T) { m := new(stun.Message) - p := EvenPort{ + evenPortAttr := EvenPort{ ReservePort: true, } - if err := p.AddTo(m); err != nil { + if err := evenPortAttr.AddTo(m); err != nil { t.Error(err) } m.WriteHeader() @@ -58,8 +58,8 @@ func TestEvenPort(t *testing.T) { if err := port.GetFrom(decoded); err != nil { t.Fatal(err) } - if port != p { - t.Errorf("Decoded %q, expected %q", port.String(), p.String()) + if port != evenPortAttr { + t.Errorf("Decoded %q, expected %q", port.String(), evenPortAttr.String()) } if wasAllocs(func() { port.GetFrom(decoded) //nolint diff --git a/internal/proto/fuzz_test.go b/internal/proto/fuzz_test.go index f9e5da3c..947f50ef 100644 --- a/internal/proto/fuzz_test.go +++ b/internal/proto/fuzz_test.go @@ -27,6 +27,7 @@ func (a attrs) pick(v byte) struct { t stun.AttrType } { idx := int(v) % len(a) + return a[idx] } @@ -66,6 +67,7 @@ func FuzzSetters(f *testing.F) { fmt.Println("unexpected 404") //nolint panic(err) //nolint } + return } @@ -90,10 +92,10 @@ func FuzzSetters(f *testing.F) { } func FuzzChannelData(f *testing.F) { - d := &ChannelData{} + channelData := &ChannelData{} f.Fuzz(func(_ *testing.T, data []byte) { - d.Reset() + channelData.Reset() if len(data) > channelDataHeaderSize { // Make sure the channel id is in the proper range @@ -104,18 +106,18 @@ func FuzzChannelData(f *testing.F) { } } - d.Raw = append(d.Raw, data...) - if d.Decode() != nil { + channelData.Raw = append(channelData.Raw, data...) + if channelData.Decode() != nil { return } - d.Encode() - if !d.Number.Valid() { + channelData.Encode() + if !channelData.Number.Valid() { return } d2 := &ChannelData{} - d2.Raw = d.Raw + d2.Raw = channelData.Raw if err := d2.Decode(); err != nil { panic(err) //nolint } diff --git a/internal/proto/lifetime.go b/internal/proto/lifetime.go index 6f55c7a9..37b86f81 100644 --- a/internal/proto/lifetime.go +++ b/internal/proto/lifetime.go @@ -12,7 +12,7 @@ import ( // DefaultLifetime in RFC 5766 is 10 minutes. // -// RFC 5766 Section 2.2 +// RFC 5766 Section 2.2. const DefaultLifetime = time.Minute * 10 // Lifetime represents LIFETIME attribute. @@ -23,12 +23,12 @@ const DefaultLifetime = time.Minute * 10 // unsigned integral value representing the number of seconds remaining // until expiration. // -// RFC 5766 Section 14.2 +// RFC 5766 Section 14.2. type Lifetime struct { time.Duration } -// Seconds in uint32 +// Seconds in uint32. const lifetimeSize = 4 // 4 bytes, 32 bits // AddTo adds LIFETIME to message. @@ -36,6 +36,7 @@ func (l Lifetime) AddTo(m *stun.Message) error { v := make([]byte, lifetimeSize) binary.BigEndian.PutUint32(v, uint32(l.Seconds())) m.Add(stun.AttrLifetime, v) + return nil } @@ -51,5 +52,6 @@ func (l *Lifetime) GetFrom(m *stun.Message) error { _ = v[lifetimeSize-1] // Asserting length seconds := binary.BigEndian.Uint32(v) l.Duration = time.Second * time.Duration(seconds) + return nil } diff --git a/internal/proto/lifetime_test.go b/internal/proto/lifetime_test.go index a4aba420..08afc005 100644 --- a/internal/proto/lifetime_test.go +++ b/internal/proto/lifetime_test.go @@ -36,7 +36,7 @@ func BenchmarkLifetime(b *testing.B) { }) } -func TestLifetime(t *testing.T) { +func TestLifetime(t *testing.T) { // nolint:cyclop,funlen t.Run("String", func(t *testing.T) { l := Lifetime{time.Second * 10} if l.String() != "10s" { @@ -44,14 +44,14 @@ func TestLifetime(t *testing.T) { } }) t.Run("NoAlloc", func(t *testing.T) { - m := &stun.Message{} + stunMsg := &stun.Message{} if wasAllocs(func() { // On stack. l := Lifetime{ Duration: time.Minute, } - l.AddTo(m) //nolint - m.Reset() + l.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } @@ -59,16 +59,16 @@ func TestLifetime(t *testing.T) { l := &Lifetime{time.Second} if wasAllocs(func() { // On heap. - l.AddTo(m) //nolint - m.Reset() + l.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } }) t.Run("AddTo", func(t *testing.T) { m := new(stun.Message) - l := Lifetime{time.Second * 10} - if err := l.AddTo(m); err != nil { + lifetime := Lifetime{time.Second * 10} + if err := lifetime.AddTo(m); err != nil { t.Error(err) } m.WriteHeader() @@ -81,8 +81,8 @@ func TestLifetime(t *testing.T) { if err := life.GetFrom(decoded); err != nil { t.Fatal(err) } - if life != l { - t.Errorf("Decoded %q, expected %q", life, l) + if life != lifetime { + t.Errorf("Decoded %q, expected %q", life, lifetime) } if wasAllocs(func() { life.GetFrom(decoded) //nolint diff --git a/internal/proto/peeraddr.go b/internal/proto/peeraddr.go index f4d9de7b..a919bd15 100644 --- a/internal/proto/peeraddr.go +++ b/internal/proto/peeraddr.go @@ -15,7 +15,7 @@ import ( // seen from the TURN server. (For example, the peer's server-reflexive // transport address if the peer is behind a NAT.) // -// RFC 5766 Section 14.3 +// RFC 5766 Section 14.3. type PeerAddress struct { IP net.IP Port int @@ -41,5 +41,5 @@ func (a *PeerAddress) GetFrom(m *stun.Message) error { // seen from the TURN server. (For example, the peer's server-reflexive // transport address if the peer is behind a NAT.) // -// RFC 5766 Section 14.3 +// RFC 5766 Section 14.3. type XORPeerAddress = PeerAddress diff --git a/internal/proto/proto_test.go b/internal/proto/proto_test.go index 8c2cadd3..37a6223a 100644 --- a/internal/proto/proto_test.go +++ b/internal/proto/proto_test.go @@ -18,6 +18,8 @@ func wasAllocs(f func()) bool { } func loadData(tb testing.TB, name string) []byte { + tb.Helper() + name = filepath.Join("testdata", name) f, err := os.Open(name) // #nosec if err != nil { @@ -32,5 +34,6 @@ func loadData(tb testing.TB, name string) []byte { if err != nil { tb.Fatal(err) } + return v } diff --git a/internal/proto/relayedaddr.go b/internal/proto/relayedaddr.go index c5fb2686..1d22ade9 100644 --- a/internal/proto/relayedaddr.go +++ b/internal/proto/relayedaddr.go @@ -14,7 +14,7 @@ import ( // It specifies the address and port that the server allocated to the // client. It is encoded in the same way as XOR-MAPPED-ADDRESS. // -// RFC 5766 Section 14.5 +// RFC 5766 Section 14.5. type RelayedAddress struct { IP net.IP Port int @@ -39,5 +39,5 @@ func (a *RelayedAddress) GetFrom(m *stun.Message) error { // It specifies the address and port that the server allocated to the // client. It is encoded in the same way as XOR-MAPPED-ADDRESS. // -// RFC 5766 Section 14.5 +// RFC 5766 Section 14.5. type XORRelayedAddress = RelayedAddress diff --git a/internal/proto/reqfamily.go b/internal/proto/reqfamily.go index 01016a98..d00c302f 100644 --- a/internal/proto/reqfamily.go +++ b/internal/proto/reqfamily.go @@ -32,6 +32,7 @@ func (f *RequestedAddressFamily) GetFrom(m *stun.Message) error { default: return errInvalidRequestedFamilyValue } + return nil } @@ -54,6 +55,7 @@ func (f RequestedAddressFamily) AddTo(m *stun.Message) error { // The RFFU field MUST be set to zero on transmission and MUST be // ignored on reception. It is reserved for future uses. m.Add(stun.AttrRequestedAddressFamily, v) + return nil } diff --git a/internal/proto/reqfamily_test.go b/internal/proto/reqfamily_test.go index 71ac087b..63492e3c 100644 --- a/internal/proto/reqfamily_test.go +++ b/internal/proto/reqfamily_test.go @@ -10,7 +10,7 @@ import ( "github.com/pion/stun/v3" ) -func TestRequestedAddressFamily(t *testing.T) { +func TestRequestedAddressFamily(t *testing.T) { // nolint:cyclop,funlen t.Run("String", func(t *testing.T) { if RequestedFamilyIPv4.String() != "IPv4" { t.Errorf("bad string %q, expected %q", RequestedFamilyIPv4, @@ -27,47 +27,47 @@ func TestRequestedAddressFamily(t *testing.T) { } }) t.Run("NoAlloc", func(t *testing.T) { - m := &stun.Message{} + stunMsg := &stun.Message{} if wasAllocs(func() { // On stack. r := RequestedFamilyIPv4 - r.AddTo(m) //nolint - m.Reset() + r.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } - r := new(RequestedAddressFamily) - *r = RequestedFamilyIPv4 + requestFamilyAttr := new(RequestedAddressFamily) + *requestFamilyAttr = RequestedFamilyIPv4 if wasAllocs(func() { // On heap. - r.AddTo(m) //nolint - m.Reset() + requestFamilyAttr.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } }) t.Run("AddTo", func(t *testing.T) { - m := new(stun.Message) - r := RequestedFamilyIPv4 - if err := r.AddTo(m); err != nil { + stunMsg := new(stun.Message) + requestFamilyAddr := RequestedFamilyIPv4 + if err := requestFamilyAddr.AddTo(stunMsg); err != nil { t.Error(err) } - m.WriteHeader() + stunMsg.WriteHeader() t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(m.Raw); err != nil { + if _, err := decoded.Write(stunMsg.Raw); err != nil { t.Fatal("failed to decode message:", err) } var req RequestedAddressFamily if err := req.GetFrom(decoded); err != nil { t.Fatal(err) } - if req != r { - t.Errorf("Decoded %q, expected %q", req, r) + if req != requestFamilyAddr { + t.Errorf("Decoded %q, expected %q", req, requestFamilyAddr) } if wasAllocs(func() { - r.GetFrom(decoded) //nolint + requestFamilyAddr.GetFrom(decoded) //nolint }) { t.Error("Unexpected allocations") } diff --git a/internal/proto/reqtrans.go b/internal/proto/reqtrans.go index 111dcd69..b907e384 100644 --- a/internal/proto/reqtrans.go +++ b/internal/proto/reqtrans.go @@ -36,7 +36,7 @@ func (p Protocol) String() string { // protocol for the allocated transport address. RFC 5766 only allows the use of // code point 17 (User Datagram Protocol). // -// RFC 5766 Section 14.7 +// RFC 5766 Section 14.7. type RequestedTransport struct { Protocol Protocol } @@ -55,6 +55,7 @@ func (t RequestedTransport) AddTo(m *stun.Message) error { // The RFFU field MUST be set to zero on transmission and MUST be // ignored on reception. It is reserved for future uses. m.Add(stun.AttrRequestedTransport, v) + return nil } @@ -68,5 +69,6 @@ func (t *RequestedTransport) GetFrom(m *stun.Message) error { return err } t.Protocol = Protocol(v[0]) + return nil } diff --git a/internal/proto/reqtrans_test.go b/internal/proto/reqtrans_test.go index e7416840..2a78f728 100644 --- a/internal/proto/reqtrans_test.go +++ b/internal/proto/reqtrans_test.go @@ -10,40 +10,40 @@ import ( "github.com/pion/stun/v3" ) -func TestRequestedTransport(t *testing.T) { +func TestRequestedTransport(t *testing.T) { // nolint:cyclop,funlen t.Run("String", func(t *testing.T) { - r := RequestedTransport{ + transAttr := RequestedTransport{ Protocol: ProtoUDP, } - if r.String() != "protocol: UDP" { - t.Errorf("bad string %q, expected %q", r, + if transAttr.String() != "protocol: UDP" { + t.Errorf("bad string %q, expected %q", transAttr, "protocol: UDP", ) } - r = RequestedTransport{ + transAttr = RequestedTransport{ Protocol: ProtoTCP, } - if r.String() != "protocol: TCP" { - t.Errorf("bad string %q, expected %q", r, + if transAttr.String() != "protocol: TCP" { + t.Errorf("bad string %q, expected %q", transAttr, "protocol: TCP", ) } - r.Protocol = 254 - if r.String() != "protocol: 254" { - t.Errorf("bad string %q, expected %q", r, + transAttr.Protocol = 254 + if transAttr.String() != "protocol: 254" { + t.Errorf("bad string %q, expected %q", transAttr, "protocol: 254", ) } }) t.Run("NoAlloc", func(t *testing.T) { - m := &stun.Message{} + stunMsg := &stun.Message{} if wasAllocs(func() { // On stack. r := RequestedTransport{ Protocol: ProtoUDP, } - r.AddTo(m) //nolint - m.Reset() + r.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } @@ -53,18 +53,18 @@ func TestRequestedTransport(t *testing.T) { } if wasAllocs(func() { // On heap. - r.AddTo(m) //nolint - m.Reset() + r.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } }) t.Run("AddTo", func(t *testing.T) { m := new(stun.Message) - r := RequestedTransport{ + transAttr := RequestedTransport{ Protocol: ProtoUDP, } - if err := r.AddTo(m); err != nil { + if err := transAttr.AddTo(m); err != nil { t.Error(err) } m.WriteHeader() @@ -79,11 +79,11 @@ func TestRequestedTransport(t *testing.T) { if err := req.GetFrom(decoded); err != nil { t.Fatal(err) } - if req != r { - t.Errorf("Decoded %q, expected %q", req, r) + if req != transAttr { + t.Errorf("Decoded %q, expected %q", req, transAttr) } if wasAllocs(func() { - r.GetFrom(decoded) //nolint + transAttr.GetFrom(decoded) //nolint }) { t.Error("Unexpected allocations") } diff --git a/internal/proto/rsrvtoken.go b/internal/proto/rsrvtoken.go index 9c816485..aed38861 100644 --- a/internal/proto/rsrvtoken.go +++ b/internal/proto/rsrvtoken.go @@ -14,7 +14,7 @@ import "github.com/pion/stun/v3" // attribute in a subsequent Allocate request to request the server use // that relayed transport address for the allocation. // -// RFC 5766 Section 14.9 +// RFC 5766 Section 14.9. type ReservationToken []byte const reservationTokenSize = 8 // 8 bytes @@ -25,6 +25,7 @@ func (t ReservationToken) AddTo(m *stun.Message) error { return err } m.Add(stun.AttrReservationToken, t) + return nil } @@ -38,5 +39,6 @@ func (t *ReservationToken) GetFrom(m *stun.Message) error { return err } *t = v + return nil } diff --git a/internal/proto/rsrvtoken_test.go b/internal/proto/rsrvtoken_test.go index 35475134..64fdff47 100644 --- a/internal/proto/rsrvtoken_test.go +++ b/internal/proto/rsrvtoken_test.go @@ -11,15 +11,15 @@ import ( "github.com/pion/stun/v3" ) -func TestReservationToken(t *testing.T) { +func TestReservationToken(t *testing.T) { // nolint:cyclop,funlen t.Run("NoAlloc", func(t *testing.T) { - m := &stun.Message{} + stunMsg := &stun.Message{} tok := make([]byte, 8) if wasAllocs(func() { // On stack. tk := ReservationToken(tok) - tk.AddTo(m) //nolint - m.Reset() + tk.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } @@ -27,30 +27,30 @@ func TestReservationToken(t *testing.T) { tk := make(ReservationToken, 8) if wasAllocs(func() { // On heap. - tk.AddTo(m) //nolint - m.Reset() + tk.AddTo(stunMsg) //nolint + stunMsg.Reset() }) { t.Error("Unexpected allocations") } }) t.Run("AddTo", func(t *testing.T) { - m := new(stun.Message) + stunMsg := new(stun.Message) tk := make(ReservationToken, 8) tk[2] = 33 tk[7] = 1 - if err := tk.AddTo(m); err != nil { + if err := tk.AddTo(stunMsg); err != nil { t.Error(err) } - m.WriteHeader() + stunMsg.WriteHeader() t.Run("HandleErr", func(t *testing.T) { badTk := ReservationToken{34, 45} - if !stun.IsAttrSizeInvalid(badTk.AddTo(m)) { + if !stun.IsAttrSizeInvalid(badTk.AddTo(stunMsg)) { t.Error("IsAttrSizeInvalid should be true") } }) t.Run("GetFrom", func(t *testing.T) { decoded := new(stun.Message) - if _, err := decoded.Write(m.Raw); err != nil { + if _, err := decoded.Write(stunMsg.Raw); err != nil { t.Fatal("failed to decode message:", err) } var tok ReservationToken diff --git a/internal/server/nonce.go b/internal/server/nonce.go index b3f3131e..22e466f9 100644 --- a/internal/server/nonce.go +++ b/internal/server/nonce.go @@ -19,7 +19,7 @@ const ( nonceKeyLength = 64 ) -// NewNonceHash creates a NonceHash +// NewNonceHash creates a NonceHash. func NewNonceHash() (*NonceHash, error) { key := make([]byte, nonceKeyLength) if _, err := rand.Read(key); err != nil { @@ -29,15 +29,15 @@ func NewNonceHash() (*NonceHash, error) { return &NonceHash{key}, nil } -// NonceHash is used to create and verify nonces +// NonceHash is used to create and verify nonces. type NonceHash struct { key []byte } -// Generate a nonce +// Generate a nonce. func (n *NonceHash) Generate() (string, error) { nonce := make([]byte, 8, nonceLength) - binary.BigEndian.PutUint64(nonce, uint64(time.Now().UnixMilli())) + binary.BigEndian.PutUint64(nonce, uint64(time.Now().UnixMilli())) // nolint:gosec // G115 hash := hmac.New(sha256.New, n.key) if _, err := hash.Write(nonce[:8]); err != nil { @@ -48,14 +48,14 @@ func (n *NonceHash) Generate() (string, error) { return hex.EncodeToString(nonce), nil } -// Validate checks that nonce is signed and is not expired +// Validate checks that nonce is signed and is not expired. func (n *NonceHash) Validate(nonce string) error { b, err := hex.DecodeString(nonce) if err != nil || len(b) != nonceLength { return fmt.Errorf("%w: %v", errInvalidNonce, err) //nolint:errorlint } - if ts := time.UnixMilli(int64(binary.BigEndian.Uint64(b))); time.Since(ts) > nonceLifetime { + if ts := time.UnixMilli(int64(binary.BigEndian.Uint64(b))); time.Since(ts) > nonceLifetime { // nolint:gosec // G115 return errInvalidNonce } diff --git a/internal/server/server.go b/internal/server/server.go index 253492e9..4f5c3148 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,7 @@ import ( "github.com/pion/turn/v4/internal/proto" ) -// Request contains all the state needed to process a single incoming datagram +// Request contains all the state needed to process a single incoming datagram. type Request struct { // Current Request State Conn net.PacketConn @@ -33,7 +33,7 @@ type Request struct { ChannelBindTimeout time.Duration } -// HandleRequest processes the give Request +// HandleRequest processes the give Request. func HandleRequest(r Request) error { r.Log.Debugf("Received %d bytes of udp from %s on %s", len(r.Buff), r.SrcAddr, r.Conn.LocalAddr()) @@ -44,42 +44,62 @@ func HandleRequest(r Request) error { return handleTURNPacket(r) } -func handleDataPacket(r Request) error { - r.Log.Debugf("Received DataPacket from %s", r.SrcAddr.String()) - c := proto.ChannelData{Raw: r.Buff} +func handleDataPacket(req Request) error { + req.Log.Debugf("Received DataPacket from %s", req.SrcAddr.String()) + c := proto.ChannelData{Raw: req.Buff} if err := c.Decode(); err != nil { return fmt.Errorf("%w: %v", errFailedToCreateChannelData, err) //nolint:errorlint } - err := handleChannelData(r, &c) + err := handleChannelData(req, &c) if err != nil { - err = fmt.Errorf("%w from %v: %v", errUnableToHandleChannelData, r.SrcAddr, err) //nolint:errorlint + err = fmt.Errorf("%w from %v: %v", errUnableToHandleChannelData, req.SrcAddr, err) //nolint:errorlint } return err } -func handleTURNPacket(r Request) error { - r.Log.Debug("Handling TURN packet") - m := &stun.Message{Raw: append([]byte{}, r.Buff...)} - if err := m.Decode(); err != nil { - return fmt.Errorf("%w: %v", errFailedToCreateSTUNPacket, err) //nolint:errorlint +func handleTURNPacket(req Request) error { + req.Log.Debug("Handling TURN packet") + stunMsg := &stun.Message{Raw: append([]byte{}, req.Buff...)} + if err := stunMsg.Decode(); err != nil { + // nolint:errorlint + return fmt.Errorf("%w: %v", errFailedToCreateSTUNPacket, err) } - h, err := getMessageHandler(m.Type.Class, m.Type.Method) + handler, err := getMessageHandler(stunMsg.Type.Class, stunMsg.Type.Method) if err != nil { - return fmt.Errorf("%w %v-%v from %v: %v", errUnhandledSTUNPacket, m.Type.Method, m.Type.Class, r.SrcAddr, err) //nolint:errorlint + // nolint:errorlint + return fmt.Errorf( + "%w %v-%v from %v: %v", + errUnhandledSTUNPacket, + stunMsg.Type.Method, + stunMsg.Type.Class, + req.SrcAddr, + err, + ) } - err = h(r, m) + err = handler(req, stunMsg) if err != nil { - return fmt.Errorf("%w %v-%v from %v: %v", errFailedToHandle, m.Type.Method, m.Type.Class, r.SrcAddr, err) //nolint:errorlint + // nolint:errorlint + return fmt.Errorf( + "%w %v-%v from %v: %v", + errFailedToHandle, + stunMsg.Type.Method, + stunMsg.Type.Class, + req.SrcAddr, + err, + ) } return nil } -func getMessageHandler(class stun.MessageClass, method stun.Method) (func(r Request, m *stun.Message) error, error) { +func getMessageHandler(class stun.MessageClass, method stun.Method) ( // nolint:cyclop + func(req Request, stunMsg *stun.Message) error, + error, +) { switch class { case stun.ClassIndication: switch method { diff --git a/internal/server/stun.go b/internal/server/stun.go index 393bb99d..1880bf17 100644 --- a/internal/server/stun.go +++ b/internal/server/stun.go @@ -8,18 +8,18 @@ import ( "github.com/pion/turn/v4/internal/ipnet" ) -func handleBindingRequest(r Request, m *stun.Message) error { - r.Log.Debugf("Received BindingRequest from %s", r.SrcAddr) +func handleBindingRequest(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received BindingRequest from %s", req.SrcAddr) - ip, port, err := ipnet.AddrIPPort(r.SrcAddr) + ip, port, err := ipnet.AddrIPPort(req.SrcAddr) if err != nil { return err } - attrs := buildMsg(m.TransactionID, stun.BindingSuccess, &stun.XORMappedAddress{ + attrs := buildMsg(stunMsg.TransactionID, stun.BindingSuccess, &stun.XORMappedAddress{ IP: ip, Port: port, }, stun.Fingerprint) - return buildAndSend(r.Conn, r.SrcAddr, attrs...) + return buildAndSend(req.Conn, req.SrcAddr, attrs...) } diff --git a/internal/server/turn.go b/internal/server/turn.go index 46e45ecb..e49beb95 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -17,42 +17,61 @@ import ( const runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" // See: https://tools.ietf.org/html/rfc5766#section-6.2 -func handleAllocateRequest(r Request, m *stun.Message) error { - r.Log.Debugf("Received AllocateRequest from %s", r.SrcAddr) +// . +func handleAllocateRequest(req Request, stunMsg *stun.Message) error { // nolint:cyclop,funlen + req.Log.Debugf("Received AllocateRequest from %s", req.SrcAddr) // 1. The server MUST require that the request be authenticated. This // authentication MUST be done using the long-term credential // mechanism of [https://tools.ietf.org/html/rfc5389#section-10.2.2] // unless the client and server agree to use another mechanism through // some procedure outside the scope of this document. - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodAllocate) + messageIntegrity, hasAuth, err := authenticateRequest(req, stunMsg, stun.MethodAllocate) if !hasAuth { return err } fiveTuple := &allocation.FiveTuple{ - SrcAddr: r.SrcAddr, - DstAddr: r.Conn.LocalAddr(), + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), Protocol: allocation.UDP, } requestedPort := 0 reservationToken := "" - badRequestMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}) - insufficientCapacityMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeInsufficientCapacity}) + badRequestMsg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}, + ) + insufficientCapacityMsg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeInsufficientCapacity}, + ) // 2. The server checks if the 5-tuple is currently in use by an // existing allocation. If yes, the server rejects the request with // a 437 (Allocation Mismatch) error. - if alloc := r.AllocationManager.GetAllocation(fiveTuple); alloc != nil { + if alloc := req.AllocationManager.GetAllocation(fiveTuple); alloc != nil { id, attrs := alloc.GetResponseCache() - if id != m.TransactionID { - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeAllocMismatch}) - return buildAndSendErr(r.Conn, r.SrcAddr, errRelayAlreadyAllocatedForFiveTuple, msg...) + if id != stunMsg.TransactionID { + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeAllocMismatch}, + ) + + return buildAndSendErr(req.Conn, req.SrcAddr, errRelayAlreadyAllocatedForFiveTuple, msg...) } // A retry allocation - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(attrs, messageIntegrity)...) - return buildAndSend(r.Conn, r.SrcAddr, msg...) + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), + append(attrs, messageIntegrity)..., + ) + + return buildAndSend(req.Conn, req.SrcAddr, msg...) } // 3. The server checks if the request contains a REQUESTED-TRANSPORT @@ -62,11 +81,16 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // specifies a protocol other that UDP/TCP, the server rejects the // request with a 442 (Unsupported Transport Protocol) error. var requestedTransport proto.RequestedTransport - if err = requestedTransport.GetFrom(m); err != nil { - return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + if err = requestedTransport.GetFrom(stunMsg); err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } else if requestedTransport.Protocol != proto.ProtoUDP && requestedTransport.Protocol != proto.ProtoTCP { - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeUnsupportedTransProto}) - return buildAndSendErr(r.Conn, r.SrcAddr, errUnsupportedTransportProtocol, msg...) + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeUnsupportedTransProto}, + ) + + return buildAndSendErr(req.Conn, req.SrcAddr, errUnsupportedTransportProtocol, msg...) } // 4. The request may contain a DONT-FRAGMENT attribute. If it does, @@ -74,9 +98,15 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // bit set to 1 (see Section 12), then the server treats the DONT- // FRAGMENT attribute in the Allocate request as an unknown // comprehension-required attribute. - if m.Contains(stun.AttrDontFragment) { - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeUnknownAttribute}, &stun.UnknownAttributes{stun.AttrDontFragment}) - return buildAndSendErr(r.Conn, r.SrcAddr, errNoDontFragmentSupport, msg...) + if stunMsg.Contains(stun.AttrDontFragment) { + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeUnknownAttribute}, + &stun.UnknownAttributes{stun.AttrDontFragment}, + ) + + return buildAndSendErr(req.Conn, req.SrcAddr, errNoDontFragmentSupport, msg...) } // 5. The server checks if the request contains a RESERVATION-TOKEN @@ -88,10 +118,10 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // the token is not valid for some reason, the server rejects the // request with a 508 (Insufficient Capacity) error. var reservationTokenAttr proto.ReservationToken - if err = reservationTokenAttr.GetFrom(m); err == nil { + if err = reservationTokenAttr.GetFrom(stunMsg); err == nil { var evenPort proto.EvenPort - if err = evenPort.GetFrom(m); err == nil { - return buildAndSendErr(r.Conn, r.SrcAddr, errRequestWithReservationTokenAndEvenPort, badRequestMsg...) + if err = evenPort.GetFrom(stunMsg); err == nil { + return buildAndSendErr(req.Conn, req.SrcAddr, errRequestWithReservationTokenAndEvenPort, badRequestMsg...) } } @@ -102,11 +132,11 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // server rejects the request with a 508 (Insufficient Capacity) // error. var evenPort proto.EvenPort - if err = evenPort.GetFrom(m); err == nil { + if err = evenPort.GetFrom(stunMsg); err == nil { var randomPort int - randomPort, err = r.AllocationManager.GetRandomEvenPort() + randomPort, err = req.AllocationManager.GetRandomEvenPort() if err != nil { - return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...) + return buildAndSendErr(req.Conn, req.SrcAddr, err, insufficientCapacityMsg...) } requestedPort = randomPort reservationToken, err = randutil.GenerateCryptoRandomString(8, runesAlpha) @@ -126,14 +156,14 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // with a 300 (Try Alternate) error if it wishes to redirect the // client to a different server. The use of this error code and // attribute follow the specification in [RFC5389]. - lifetimeDuration := allocationLifeTime(m) - a, err := r.AllocationManager.CreateAllocation( + lifetimeDuration := allocationLifeTime(stunMsg) + alloc, err := req.AllocationManager.CreateAllocation( fiveTuple, - r.Conn, + req.Conn, requestedPort, lifetimeDuration) if err != nil { - return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...) + return buildAndSendErr(req.Conn, req.SrcAddr, err, insufficientCapacityMsg...) } // Once the allocation is created, the server replies with a success @@ -148,14 +178,14 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // * An XOR-MAPPED-ADDRESS attribute containing the client's IP address // and port (from the 5-tuple). - srcIP, srcPort, err := ipnet.AddrIPPort(r.SrcAddr) + srcIP, srcPort, err := ipnet.AddrIPPort(req.SrcAddr) if err != nil { - return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } - relayIP, relayPort, err := ipnet.AddrIPPort(a.RelayAddr) + relayIP, relayPort, err := ipnet.AddrIPPort(alloc.RelayAddr) if err != nil { - return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } responseAttrs := []stun.Setter{ @@ -173,90 +203,105 @@ func handleAllocateRequest(r Request, m *stun.Message) error { } if reservationToken != "" { - r.AllocationManager.CreateReservation(reservationToken, relayPort) + req.AllocationManager.CreateReservation(reservationToken, relayPort) responseAttrs = append(responseAttrs, proto.ReservationToken([]byte(reservationToken))) } - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(responseAttrs, messageIntegrity)...) - a.SetResponseCache(m.TransactionID, responseAttrs) - return buildAndSend(r.Conn, r.SrcAddr, msg...) + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), + append(responseAttrs, messageIntegrity)..., + ) + alloc.SetResponseCache(stunMsg.TransactionID, responseAttrs) + + return buildAndSend(req.Conn, req.SrcAddr, msg...) } -func handleRefreshRequest(r Request, m *stun.Message) error { - r.Log.Debugf("Received RefreshRequest from %s", r.SrcAddr) +func handleRefreshRequest(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received RefreshRequest from %s", req.SrcAddr) - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodRefresh) + messageIntegrity, hasAuth, err := authenticateRequest(req, stunMsg, stun.MethodRefresh) if !hasAuth { return err } - lifetimeDuration := allocationLifeTime(m) + lifetimeDuration := allocationLifeTime(stunMsg) fiveTuple := &allocation.FiveTuple{ - SrcAddr: r.SrcAddr, - DstAddr: r.Conn.LocalAddr(), + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), Protocol: allocation.UDP, } if lifetimeDuration != 0 { - a := r.AllocationManager.GetAllocation(fiveTuple) + a := req.AllocationManager.GetAllocation(fiveTuple) if a == nil { - return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr()) + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) } a.Refresh(lifetimeDuration) } else { - r.AllocationManager.DeleteAllocation(fiveTuple) - } - - return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodRefresh, stun.ClassSuccessResponse), []stun.Setter{ - &proto.Lifetime{ - Duration: lifetimeDuration, - }, - messageIntegrity, - }...)...) + req.AllocationManager.DeleteAllocation(fiveTuple) + } + + return buildAndSend( + req.Conn, + req.SrcAddr, + buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodRefresh, stun.ClassSuccessResponse), + []stun.Setter{ + &proto.Lifetime{ + Duration: lifetimeDuration, + }, + messageIntegrity, + }..., + )..., + ) } -func handleCreatePermissionRequest(r Request, m *stun.Message) error { - r.Log.Debugf("Received CreatePermission from %s", r.SrcAddr) +func handleCreatePermissionRequest(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received CreatePermission from %s", req.SrcAddr) - a := r.AllocationManager.GetAllocation(&allocation.FiveTuple{ - SrcAddr: r.SrcAddr, - DstAddr: r.Conn.LocalAddr(), + alloc := req.AllocationManager.GetAllocation(&allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), Protocol: allocation.UDP, }) - if a == nil { - return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr()) + if alloc == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) } - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodCreatePermission) + messageIntegrity, hasAuth, err := authenticateRequest(req, stunMsg, stun.MethodCreatePermission) if !hasAuth { return err } addCount := 0 - if err := m.ForEach(stun.AttrXORPeerAddress, func(m *stun.Message) error { + if err := stunMsg.ForEach(stun.AttrXORPeerAddress, func(m *stun.Message) error { var peerAddress proto.PeerAddress if err := peerAddress.GetFrom(m); err != nil { return err } - if err := r.AllocationManager.GrantPermission(r.SrcAddr, peerAddress.IP); err != nil { - r.Log.Infof("permission denied for client %s to peer %s", r.SrcAddr, peerAddress.IP) + if err := req.AllocationManager.GrantPermission(req.SrcAddr, peerAddress.IP); err != nil { + req.Log.Infof("permission denied for client %s to peer %s", req.SrcAddr, peerAddress.IP) + return err } - r.Log.Debugf("Adding permission for %s", fmt.Sprintf("%s:%d", + req.Log.Debugf("Adding permission for %s", fmt.Sprintf("%s:%d", peerAddress.IP, peerAddress.Port)) - a.AddPermission(allocation.NewPermission( + alloc.AddPermission(allocation.NewPermission( &net.UDPAddr{ IP: peerAddress.IP, Port: peerAddress.Port, }, - r.Log, + req.Log, )) addCount++ + return nil }); err != nil { addCount = 0 @@ -267,115 +312,131 @@ func handleCreatePermissionRequest(r Request, m *stun.Message) error { respClass = stun.ClassErrorResponse } - return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodCreatePermission, respClass), []stun.Setter{messageIntegrity}...)...) + return buildAndSend( + req.Conn, + req.SrcAddr, + buildMsg(stunMsg.TransactionID, stun.NewType(stun.MethodCreatePermission, respClass), + []stun.Setter{messageIntegrity}...)..., + ) } -func handleSendIndication(r Request, m *stun.Message) error { - r.Log.Debugf("Received SendIndication from %s", r.SrcAddr) - a := r.AllocationManager.GetAllocation(&allocation.FiveTuple{ - SrcAddr: r.SrcAddr, - DstAddr: r.Conn.LocalAddr(), +func handleSendIndication(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received SendIndication from %s", req.SrcAddr) + alloc := req.AllocationManager.GetAllocation(&allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), Protocol: allocation.UDP, }) - if a == nil { - return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr()) + if alloc == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) } dataAttr := proto.Data{} - if err := dataAttr.GetFrom(m); err != nil { + if err := dataAttr.GetFrom(stunMsg); err != nil { return err } peerAddress := proto.PeerAddress{} - if err := peerAddress.GetFrom(m); err != nil { + if err := peerAddress.GetFrom(stunMsg); err != nil { return err } msgDst := &net.UDPAddr{IP: peerAddress.IP, Port: peerAddress.Port} - if perm := a.GetPermission(msgDst); perm == nil { + if perm := alloc.GetPermission(msgDst); perm == nil { return fmt.Errorf("%w: %v", errNoPermission, msgDst) } - l, err := a.RelaySocket.WriteTo(dataAttr, msgDst) + l, err := alloc.RelaySocket.WriteTo(dataAttr, msgDst) if l != len(dataAttr) { return fmt.Errorf("%w %d != %d (expected) err: %v", errShortWrite, l, len(dataAttr), err) //nolint:errorlint } + return err } -func handleChannelBindRequest(r Request, m *stun.Message) error { - r.Log.Debugf("Received ChannelBindRequest from %s", r.SrcAddr) +func handleChannelBindRequest(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received ChannelBindRequest from %s", req.SrcAddr) - a := r.AllocationManager.GetAllocation(&allocation.FiveTuple{ - SrcAddr: r.SrcAddr, - DstAddr: r.Conn.LocalAddr(), + alloc := req.AllocationManager.GetAllocation(&allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), Protocol: allocation.UDP, }) - if a == nil { - return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr()) + if alloc == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) } - badRequestMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}) + badRequestMsg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}, + ) - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodChannelBind) + messageIntegrity, hasAuth, err := authenticateRequest(req, stunMsg, stun.MethodChannelBind) if !hasAuth { return err } var channel proto.ChannelNumber - if err = channel.GetFrom(m); err != nil { - return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + if err = channel.GetFrom(stunMsg); err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } peerAddr := proto.PeerAddress{} - if err = peerAddr.GetFrom(m); err != nil { - return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + if err = peerAddr.GetFrom(stunMsg); err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } - if err = r.AllocationManager.GrantPermission(r.SrcAddr, peerAddr.IP); err != nil { - r.Log.Infof("permission denied for client %s to peer %s", r.SrcAddr, peerAddr.IP) + if err = req.AllocationManager.GrantPermission(req.SrcAddr, peerAddr.IP); err != nil { + req.Log.Infof("permission denied for client %s to peer %s", req.SrcAddr, peerAddr.IP) - unauthorizedRequestMsg := buildMsg(m.TransactionID, + unauthorizedRequestMsg := buildMsg(stunMsg.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeUnauthorized}) - return buildAndSendErr(r.Conn, r.SrcAddr, err, unauthorizedRequestMsg...) + + return buildAndSendErr(req.Conn, req.SrcAddr, err, unauthorizedRequestMsg...) } - r.Log.Debugf("Binding channel %d to %s", channel, peerAddr) - err = a.AddChannelBind(allocation.NewChannelBind( + req.Log.Debugf("Binding channel %d to %s", channel, peerAddr) + err = alloc.AddChannelBind(allocation.NewChannelBind( channel, &net.UDPAddr{IP: peerAddr.IP, Port: peerAddr.Port}, - r.Log, - ), r.ChannelBindTimeout) + req.Log, + ), req.ChannelBindTimeout) if err != nil { - return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } - return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse), []stun.Setter{messageIntegrity}...)...) + return buildAndSend( + req.Conn, + req.SrcAddr, + buildMsg(stunMsg.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse), + []stun.Setter{messageIntegrity}...)..., + ) } -func handleChannelData(r Request, c *proto.ChannelData) error { - r.Log.Debugf("Received ChannelData from %s", r.SrcAddr) +func handleChannelData(req Request, channelData *proto.ChannelData) error { + req.Log.Debugf("Received ChannelData from %s", req.SrcAddr) - a := r.AllocationManager.GetAllocation(&allocation.FiveTuple{ - SrcAddr: r.SrcAddr, - DstAddr: r.Conn.LocalAddr(), + alloc := req.AllocationManager.GetAllocation(&allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), Protocol: allocation.UDP, }) - if a == nil { - return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr()) + if alloc == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) } - channel := a.GetChannelByNumber(c.Number) + channel := alloc.GetChannelByNumber(channelData.Number) if channel == nil { - return fmt.Errorf("%w %x", errNoSuchChannelBind, uint16(c.Number)) + return fmt.Errorf("%w %x", errNoSuchChannelBind, uint16(channelData.Number)) } - l, err := a.RelaySocket.WriteTo(c.Data, channel.Peer) + l, err := alloc.RelaySocket.WriteTo(channelData.Data, channel.Peer) if err != nil { return fmt.Errorf("%w: %s", errFailedWriteSocket, err.Error()) - } else if l != len(c.Data) { - return fmt.Errorf("%w %d != %d (expected)", errShortWrite, l, len(c.Data)) + } else if l != len(channelData.Data) { + return fmt.Errorf("%w %d != %d (expected)", errShortWrite, l, len(channelData.Data)) } return nil diff --git a/internal/server/turn_test.go b/internal/server/turn_test.go index e4a3b947..31a648da 100644 --- a/internal/server/turn_test.go +++ b/internal/server/turn_test.go @@ -18,7 +18,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAllocationLifeTime(t *testing.T) { +func TestAllocationLifeTime(t *testing.T) { // nolint:funlen t.Run("Parsing", func(t *testing.T) { lifetime := proto.Lifetime{ Duration: 5 * time.Second, @@ -55,22 +55,22 @@ func TestAllocationLifeTime(t *testing.T) { }) t.Run("DeletionZeroLifetime", func(t *testing.T) { - l, err := net.ListenPacket("udp4", "0.0.0.0:0") + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") assert.NoError(t, err) defer func() { - assert.NoError(t, l.Close()) + assert.NoError(t, conn.Close()) }() logger := logging.NewDefaultLoggerFactory().NewLogger("turn") allocationManager, err := allocation.NewManager(allocation.ManagerConfig{ AllocatePacketConn: func(network string, _ int) (net.PacketConn, net.Addr, error) { - conn, listenErr := net.ListenPacket(network, "0.0.0.0:0") + con, listenErr := net.ListenPacket(network, "0.0.0.0:0") if err != nil { return nil, nil, listenErr } - return conn, conn.LocalAddr(), nil + return con, con.LocalAddr(), nil }, AllocateConn: func(string, int) (net.Conn, net.Addr, error) { return nil, nil, nil @@ -84,10 +84,10 @@ func TestAllocationLifeTime(t *testing.T) { staticKey, err := nonceHash.Generate() assert.NoError(t, err) - r := Request{ + req := Request{ AllocationManager: allocationManager, NonceHash: nonceHash, - Conn: l, + Conn: conn, SrcAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5000}, Log: logger, AuthHandler: func(string, string, net.Addr) (key []byte, ok bool) { @@ -95,12 +95,12 @@ func TestAllocationLifeTime(t *testing.T) { }, } - fiveTuple := &allocation.FiveTuple{SrcAddr: r.SrcAddr, DstAddr: r.Conn.LocalAddr(), Protocol: allocation.UDP} + fiveTuple := &allocation.FiveTuple{SrcAddr: req.SrcAddr, DstAddr: req.Conn.LocalAddr(), Protocol: allocation.UDP} - _, err = r.AllocationManager.CreateAllocation(fiveTuple, r.Conn, 0, time.Hour) + _, err = req.AllocationManager.CreateAllocation(fiveTuple, req.Conn, 0, time.Hour) assert.NoError(t, err) - assert.NotNil(t, r.AllocationManager.GetAllocation(fiveTuple)) + assert.NotNil(t, req.AllocationManager.GetAllocation(fiveTuple)) m := &stun.Message{} assert.NoError(t, (proto.Lifetime{}).AddTo(m)) @@ -109,7 +109,7 @@ func TestAllocationLifeTime(t *testing.T) { assert.NoError(t, (stun.Realm(staticKey)).AddTo(m)) assert.NoError(t, (stun.Username(staticKey)).AddTo(m)) - assert.NoError(t, handleRefreshRequest(r, m)) - assert.Nil(t, r.AllocationManager.GetAllocation(fiveTuple)) + assert.NoError(t, handleRefreshRequest(req, m)) + assert.Nil(t, req.AllocationManager.GetAllocation(fiveTuple)) }) } diff --git a/internal/server/util.go b/internal/server/util.go index 7c01d329..d3a63f33 100644 --- a/internal/server/util.go +++ b/internal/server/util.go @@ -14,7 +14,8 @@ import ( ) const ( - maximumAllocationLifetime = time.Hour // See: https://tools.ietf.org/html/rfc5766#section-6.2 defines 3600 seconds recommendation + // See: https://tools.ietf.org/html/rfc5766#section-6.2 defines 3600 seconds recommendation. + maximumAllocationLifetime = time.Hour ) func buildAndSend(conn net.PacketConn, dst net.Addr, attrs ...stun.Setter) error { @@ -30,71 +31,90 @@ func buildAndSend(conn net.PacketConn, dst net.Addr, attrs ...stun.Setter) error return err } -// Send a STUN packet and return the original error to the caller +// Send a STUN packet and return the original error to the caller. func buildAndSendErr(conn net.PacketConn, dst net.Addr, err error, attrs ...stun.Setter) error { if sendErr := buildAndSend(conn, dst, attrs...); sendErr != nil { err = fmt.Errorf("%w %v %v", errFailedToSendError, sendErr, err) //nolint:errorlint } + return err } -func buildMsg(transactionID [stun.TransactionIDSize]byte, msgType stun.MessageType, additional ...stun.Setter) []stun.Setter { +func buildMsg( + transactionID [stun.TransactionIDSize]byte, + msgType stun.MessageType, + additional ...stun.Setter, +) []stun.Setter { return append([]stun.Setter{&stun.Message{TransactionID: transactionID}, msgType}, additional...) } -func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) (stun.MessageIntegrity, bool, error) { +func authenticateRequest(req Request, stunMsg *stun.Message, callingMethod stun.Method) ( // nolint:funlen + stun.MessageIntegrity, + bool, + error, +) { respondWithNonce := func(responseCode stun.ErrorCode) (stun.MessageIntegrity, bool, error) { - nonce, err := r.NonceHash.Generate() + nonce, err := req.NonceHash.Generate() if err != nil { return nil, false, err } - return nil, false, buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, + return nil, false, buildAndSend(req.Conn, req.SrcAddr, buildMsg(stunMsg.TransactionID, stun.NewType(callingMethod, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: responseCode}, stun.NewNonce(nonce), - stun.NewRealm(r.Realm), + stun.NewRealm(req.Realm), )...) } - if !m.Contains(stun.AttrMessageIntegrity) { + if !stunMsg.Contains(stun.AttrMessageIntegrity) { return respondWithNonce(stun.CodeUnauthorized) } nonceAttr := &stun.Nonce{} usernameAttr := &stun.Username{} realmAttr := &stun.Realm{} - badRequestMsg := buildMsg(m.TransactionID, stun.NewType(callingMethod, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}) + badRequestMsg := buildMsg( + stunMsg.TransactionID, + stun.NewType(callingMethod, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}, + ) // No Auth handler is set, server is running in STUN only mode - // Respond with 400 so clients don't retry - if r.AuthHandler == nil { - sendErr := buildAndSend(r.Conn, r.SrcAddr, badRequestMsg...) + // Respond with 400 so clients don't retry. + if req.AuthHandler == nil { + sendErr := buildAndSend(req.Conn, req.SrcAddr, badRequestMsg...) + return nil, false, sendErr } - if err := nonceAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + if err := nonceAttr.GetFrom(stunMsg); err != nil { + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } - // Assert Nonce is signed and is not expired - if err := r.NonceHash.Validate(nonceAttr.String()); err != nil { + // Assert Nonce is signed and is not expired. + if err := req.NonceHash.Validate(nonceAttr.String()); err != nil { return respondWithNonce(stun.CodeStaleNonce) } - if err := realmAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) - } else if err := usernameAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + if err := realmAttr.GetFrom(stunMsg); err != nil { + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } else if err := usernameAttr.GetFrom(stunMsg); err != nil { + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } - ourKey, ok := r.AuthHandler(usernameAttr.String(), realmAttr.String(), r.SrcAddr) + ourKey, ok := req.AuthHandler(usernameAttr.String(), realmAttr.String(), req.SrcAddr) if !ok { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), badRequestMsg...) + return nil, false, buildAndSendErr( + req.Conn, + req.SrcAddr, + fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), + badRequestMsg..., + ) } - if err := stun.MessageIntegrity(ourKey).Check(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + if err := stun.MessageIntegrity(ourKey).Check(stunMsg); err != nil { + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) } return stun.MessageIntegrity(ourKey), true, nil diff --git a/lt_cred.go b/lt_cred.go index bd3197f1..33ca3e86 100644 --- a/lt_cred.go +++ b/lt_cred.go @@ -15,20 +15,26 @@ import ( //nolint:gci "github.com/pion/logging" ) -// GenerateLongTermCredentials can be used to create credentials valid for [duration] time +// GenerateLongTermCredentials can be used to create credentials valid for [duration] time. func GenerateLongTermCredentials(sharedSecret string, duration time.Duration) (string, string, error) { t := time.Now().Add(duration).Unix() username := strconv.FormatInt(t, 10) password, err := longTermCredentials(username, sharedSecret) + return username, password, err } -// GenerateLongTermTURNRESTCredentials can be used to create credentials valid for [duration] time -func GenerateLongTermTURNRESTCredentials(sharedSecret string, user string, duration time.Duration) (string, string, error) { +// GenerateLongTermTURNRESTCredentials can be used to create credentials valid for [duration] time. +func GenerateLongTermTURNRESTCredentials(sharedSecret string, user string, duration time.Duration) ( + string, + string, + error, +) { t := time.Now().Add(duration).Unix() timestamp := strconv.FormatInt(t, 10) username := timestamp + ":" + user password, err := longTermCredentials(username, sharedSecret) + return username, password, err } @@ -39,31 +45,38 @@ func longTermCredentials(username string, sharedSecret string) (string, error) { return "", err // Not sure if this will ever happen } password := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(password), nil } // NewLongTermAuthHandler returns a turn.AuthAuthHandler used with Long Term (or Time Windowed) Credentials. // See: https://datatracker.ietf.org/doc/html/rfc8489#section-9.2 -func NewLongTermAuthHandler(sharedSecret string, l logging.LeveledLogger) AuthHandler { - if l == nil { - l = logging.NewDefaultLoggerFactory().NewLogger("turn") +// . +func NewLongTermAuthHandler(sharedSecret string, logger logging.LeveledLogger) AuthHandler { + if logger == nil { + logger = logging.NewDefaultLoggerFactory().NewLogger("turn") } + return func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - l.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) + logger.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) t, err := strconv.Atoi(username) if err != nil { - l.Errorf("Invalid time-windowed username %q", username) + logger.Errorf("Invalid time-windowed username %q", username) + return nil, false } if int64(t) < time.Now().Unix() { - l.Errorf("Expired time-windowed username %q", username) + logger.Errorf("Expired time-windowed username %q", username) + return nil, false } password, err := longTermCredentials(username, sharedSecret) if err != nil { - l.Error(err.Error()) + logger.Error(err.Error()) + return nil, false } + return GenerateAuthKey(username, realm, password), true } } @@ -74,27 +87,32 @@ func NewLongTermAuthHandler(sharedSecret string, l logging.LeveledLogger) AuthHa // // The supported format of is timestamp:username, where username is an arbitrary user id and the // timestamp specifies the expiry of the credential. -func LongTermTURNRESTAuthHandler(sharedSecret string, l logging.LeveledLogger) AuthHandler { - if l == nil { - l = logging.NewDefaultLoggerFactory().NewLogger("turn") +func LongTermTURNRESTAuthHandler(sharedSecret string, logger logging.LeveledLogger) AuthHandler { + if logger == nil { + logger = logging.NewDefaultLoggerFactory().NewLogger("turn") } + return func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - l.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) + logger.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) timestamp := strings.Split(username, ":")[0] t, err := strconv.Atoi(timestamp) if err != nil { - l.Errorf("Invalid time-windowed username %q", username) + logger.Errorf("Invalid time-windowed username %q", username) + return nil, false } if int64(t) < time.Now().Unix() { - l.Errorf("Expired time-windowed username %q", username) + logger.Errorf("Expired time-windowed username %q", username) + return nil, false } password, err := longTermCredentials(username, sharedSecret) if err != nil { - l.Error(err.Error()) + logger.Error(err.Error()) + return nil, false } + return GenerateAuthKey(username, realm, password), true } } diff --git a/relay_address_generator_none.go b/relay_address_generator_none.go index b0974010..f60b1030 100644 --- a/relay_address_generator_none.go +++ b/relay_address_generator_none.go @@ -12,7 +12,7 @@ import ( "github.com/pion/transport/v3/stdnet" ) -// RelayAddressGeneratorNone returns the listener with no modifications +// RelayAddressGeneratorNone returns the listener with no modifications. type RelayAddressGeneratorNone struct { // Address is passed to Listen/ListenPacket when creating the Relay Address string @@ -20,7 +20,7 @@ type RelayAddressGeneratorNone struct { Net transport.Net } -// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured +// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured. func (r *RelayAddressGeneratorNone) Validate() error { if r.Net == nil { var err error @@ -38,8 +38,13 @@ func (r *RelayAddressGeneratorNone) Validate() error { } } -// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requestedPort int) ( + net.PacketConn, + net.Addr, + error, +) { conn, err := r.Net.ListenPacket(network, r.Address+":"+strconv.Itoa(requestedPort)) if err != nil { return nil, nil, err @@ -48,7 +53,8 @@ func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requested return conn, conn.LocalAddr(), nil } -// AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with +// AllocateConn generates a new Conn to receive traffic on and the IP/Port +// to populate the allocation response with. func (r *RelayAddressGeneratorNone) AllocateConn(string, int) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/relay_address_generator_range.go b/relay_address_generator_range.go index d87a57f9..9c334099 100644 --- a/relay_address_generator_range.go +++ b/relay_address_generator_range.go @@ -35,7 +35,7 @@ type RelayAddressGeneratorPortRange struct { Net transport.Net } -// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured +// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured. func (r *RelayAddressGeneratorPortRange) Validate() error { if r.Net == nil { var err error @@ -67,24 +67,30 @@ func (r *RelayAddressGeneratorPortRange) Validate() error { } } -// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorPortRange) AllocatePacketConn( + network string, + requestedPort int, +) (net.PacketConn, net.Addr, error) { if requestedPort != 0 { conn, err := r.Net.ListenPacket(network, fmt.Sprintf("%s:%d", r.Address, requestedPort)) if err != nil { return nil, nil, err } + relayAddr, ok := conn.LocalAddr().(*net.UDPAddr) if !ok { return nil, nil, errNilConn } relayAddr.IP = r.RelayAddress + return conn, relayAddr, nil } for try := 0; try < r.MaxRetries; try++ { - port := r.MinPort + uint16(r.Rand.Intn(int((r.MaxPort+1)-r.MinPort))) + port := int(r.MinPort) + r.Rand.Intn(int((r.MaxPort+1)-r.MinPort)) conn, err := r.Net.ListenPacket(network, fmt.Sprintf("%s:%d", r.Address, port)) if err != nil { continue @@ -96,13 +102,15 @@ func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requ } relayAddr.IP = r.RelayAddress + return conn, relayAddr, nil } return nil, nil, errMaxRetriesExceeded } -// AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with +// AllocateConn generates a new Conn to receive traffic on and the IP/Port +// to populate the allocation response with. func (r *RelayAddressGeneratorPortRange) AllocateConn(string, int) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/relay_address_generator_static.go b/relay_address_generator_static.go index 39c68777..12ea5f25 100644 --- a/relay_address_generator_static.go +++ b/relay_address_generator_static.go @@ -13,7 +13,7 @@ import ( ) // RelayAddressGeneratorStatic can be used to return static IP address each time a relay is created. -// This can be used when you have a single static IP address that you want to use +// This can be used when you have a single static IP address that you want to use. type RelayAddressGeneratorStatic struct { // RelayAddress is the IP returned to the user when the relay is created RelayAddress net.IP @@ -24,7 +24,7 @@ type RelayAddressGeneratorStatic struct { Net transport.Net } -// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured +// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured. func (r *RelayAddressGeneratorStatic) Validate() error { if r.Net == nil { var err error @@ -44,8 +44,12 @@ func (r *RelayAddressGeneratorStatic) Validate() error { } } -// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorStatic) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorStatic) AllocatePacketConn( + network string, + requestedPort int, +) (net.PacketConn, net.Addr, error) { conn, err := r.Net.ListenPacket(network, r.Address+":"+strconv.Itoa(requestedPort)) if err != nil { return nil, nil, err @@ -62,7 +66,8 @@ func (r *RelayAddressGeneratorStatic) AllocatePacketConn(network string, request return conn, relayAddr, nil } -// AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with +// AllocateConn generates a new Conn to receive traffic on and the IP/Port +// to populate the allocation response with. func (r *RelayAddressGeneratorStatic) AllocateConn(string, int) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/server.go b/server.go index 3b58938f..90fa48cf 100644 --- a/server.go +++ b/server.go @@ -20,7 +20,7 @@ const ( defaultInboundMTU = 1600 ) -// Server is an instance of the Pion TURN Server +// Server is an instance of the Pion TURN Server. type Server struct { log logging.LeveledLogger authHandler AuthHandler @@ -34,10 +34,8 @@ type Server struct { inboundMTU int } -// NewServer creates the Pion TURN server -// -//nolint:gocognit -func NewServer(config ServerConfig) (*Server, error) { +// NewServer creates the Pion TURN server. +func NewServer(config ServerConfig) (*Server, error) { // nolint:gocognit,cyclop,funlen if err := config.validate(); err != nil { return nil, err } @@ -57,7 +55,7 @@ func NewServer(config ServerConfig) (*Server, error) { return nil, err } - s := &Server{ + server := &Server{ log: loggerFactory.NewLogger("turn"), authHandler: config.AuthHandler, realm: config.Realm, @@ -68,53 +66,56 @@ func NewServer(config ServerConfig) (*Server, error) { inboundMTU: mtu, } - if s.channelBindTimeout == 0 { - s.channelBindTimeout = proto.DefaultLifetime + if server.channelBindTimeout == 0 { + server.channelBindTimeout = proto.DefaultLifetime } - for _, cfg := range s.packetConnConfigs { - am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler) + for _, cfg := range server.packetConnConfigs { + am, err := server.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler) if err != nil { return nil, fmt.Errorf("failed to create AllocationManager: %w", err) } go func(cfg PacketConnConfig, am *allocation.Manager) { - s.readLoop(cfg.PacketConn, am) + server.readLoop(cfg.PacketConn, am) if err := am.Close(); err != nil { - s.log.Errorf("Failed to close AllocationManager: %s", err) + server.log.Errorf("Failed to close AllocationManager: %s", err) } }(cfg, am) } - for _, cfg := range s.listenerConfigs { - am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler) + for _, cfg := range server.listenerConfigs { + am, err := server.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler) if err != nil { return nil, fmt.Errorf("failed to create AllocationManager: %w", err) } go func(cfg ListenerConfig, am *allocation.Manager) { - s.readListener(cfg.Listener, am) + server.readListener(cfg.Listener, am) if err := am.Close(); err != nil { - s.log.Errorf("Failed to close AllocationManager: %s", err) + server.log.Errorf("Failed to close AllocationManager: %s", err) } }(cfg, am) } - return s, nil + return server, nil } -// AllocationCount returns the number of active allocations. It can be used to drain the server before closing +// AllocationCount returns the number of active allocations. +// It can be used to drain the server before closing. func (s *Server) AllocationCount() int { allocs := 0 for _, am := range s.allocationManagers { allocs += am.AllocationCount() } + return allocs } -// Close stops the TURN Server. It cleans up any associated state and closes all connections it is managing +// Close stops the TURN Server. +// It cleans up any associated state and closes all connections it is managing. func (s *Server) Close() error { var errors []error @@ -147,6 +148,7 @@ func (s *Server) readListener(l net.Listener, am *allocation.Manager) { conn, err := l.Accept() if err != nil { s.log.Debugf("Failed to accept: %s", err) + return } @@ -179,7 +181,10 @@ func (n *nilAddressGenerator) AllocateConn(string, int) (net.Conn, net.Addr, err return nil, nil, errRelayAddressGeneratorNil } -func (s *Server) createAllocationManager(addrGenerator RelayAddressGenerator, handler PermissionHandler) (*allocation.Manager, error) { +func (s *Server) createAllocationManager( + addrGenerator RelayAddressGenerator, + handler PermissionHandler, +) (*allocation.Manager, error) { if handler == nil { handler = DefaultPermissionHandler } @@ -202,21 +207,23 @@ func (s *Server) createAllocationManager(addrGenerator RelayAddressGenerator, ha return am, err } -func (s *Server) readLoop(p net.PacketConn, allocationManager *allocation.Manager) { +func (s *Server) readLoop(conn net.PacketConn, allocationManager *allocation.Manager) { buf := make([]byte, s.inboundMTU) for { - n, addr, err := p.ReadFrom(buf) + n, addr, err := conn.ReadFrom(buf) switch { case err != nil: s.log.Debugf("Exit read loop on error: %s", err) + return case n >= s.inboundMTU: s.log.Debugf("Read bytes exceeded MTU, packet is possibly truncated") + continue } if err := server.HandleRequest(server.Request{ - Conn: p, + Conn: conn, SrcAddr: addr, Buff: buf[:n], Log: s.log, diff --git a/server_config.go b/server_config.go index eab2988e..8f276140 100644 --- a/server_config.go +++ b/server_config.go @@ -34,12 +34,13 @@ type RelayAddressGenerator interface { // of NATs that comply with [RFC4787], see https://tools.ietf.org/html/rfc5766#section-2.3. type PermissionHandler func(clientAddr net.Addr, peerIP net.IP) (ok bool) -// DefaultPermissionHandler is convince function that grants permission to all peers +// DefaultPermissionHandler is convince function that grants permission to all peers. func DefaultPermissionHandler(net.Addr, net.IP) (ok bool) { return true } -// PacketConnConfig is a single net.PacketConn to listen/write on. This will be used for UDP listeners +// PacketConnConfig is a single net.PacketConn to listen/write on. +// This will be used for UDP listeners. type PacketConnConfig struct { PacketConn net.PacketConn @@ -67,7 +68,8 @@ func (c *PacketConnConfig) validate() error { return nil } -// ListenerConfig is a single net.Listener to accept connections on. This will be used for TCP, TLS and DTLS listeners +// ListenerConfig is a single net.Listener to accept connections on. +// This will be used for TCP, TLS and DTLS listeners. type ListenerConfig struct { Listener net.Listener @@ -93,18 +95,20 @@ func (c *ListenerConfig) validate() error { return c.RelayAddressGenerator.Validate() } -// AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior +// AuthHandler is a callback used to handle incoming auth requests, +// allowing users to customize Pion TURN with custom behavior. type AuthHandler func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) -// GenerateAuthKey is a convenience function to easily generate keys in the format used by AuthHandler +// GenerateAuthKey is a convenience function to easily generate keys in the format used by AuthHandler. func GenerateAuthKey(username, realm, password string) []byte { // #nosec h := md5.New() fmt.Fprint(h, strings.Join([]string{username, realm, password}, ":")) // nolint: errcheck + return h.Sum(nil) } -// ServerConfig configures the Pion TURN Server +// ServerConfig configures the Pion TURN Server. type ServerConfig struct { // PacketConnConfigs and ListenerConfigs are a list of all the turn listeners // Each listener can have custom behavior around the creation of Relays @@ -117,7 +121,8 @@ type ServerConfig struct { // Realm sets the realm for this server Realm string - // AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior + // AuthHandler is a callback used to handle incoming auth requests, + // allowing users to customize Pion TURN with custom behavior AuthHandler AuthHandler // ChannelBindTimeout sets the lifetime of channel binding. Defaults to 10 minutes. diff --git a/server_test.go b/server_test.go index 44020db6..de6cb62d 100644 --- a/server_test.go +++ b/server_test.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestServer(t *testing.T) { +func TestServer(t *testing.T) { // nolint:funlen,maintidx lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -43,6 +43,7 @@ func TestServer(t *testing.T) { if pw, ok := credMap[username]; ok { return pw, true } + return nil, false }, PacketConnConfigs: []PacketConnConfig{ @@ -132,6 +133,7 @@ func TestServer(t *testing.T) { if pw, ok := credMap[username]; ok { return pw, true } + return nil, false }, ListenerConfigs: []ListenerConfig{ @@ -217,6 +219,7 @@ func TestServer(t *testing.T) { if pw, ok := credMap[username]; ok { return pw, true } + return nil, false }, PacketConnConfigs: []PacketConnConfig{ @@ -374,10 +377,11 @@ func (v *VNet) Close() error { if err := v.server.Close(); err != nil { return err } + return v.wan.Stop() } -func buildVNet() (*VNet, error) { +func buildVNet() (*VNet, error) { // nolint:cyclop,funlen loggerFactory := logging.NewDefaultLoggerFactory() // WAN @@ -457,6 +461,7 @@ func buildVNet() (*VNet, error) { if pw, ok := credMap[username]; ok { return pw, true } + return nil, false }, Realm: "pion.ly", @@ -552,10 +557,24 @@ func TestConsumeSingleTURNFrame(t *testing.T) { err error } cases := map[string]testCase{ - "channel data": {data: []byte{0x40, 0x01, 0x00, 0x08, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, err: nil}, - "partial data less than channel header": {data: []byte{1}, err: errIncompleteTURNFrame}, - "partial stun message": {data: []byte{0x0, 0x16, 0x02, 0xDC, 0x21, 0x12, 0xA4, 0x42, 0x0, 0x0, 0x0}, err: errIncompleteTURNFrame}, - "stun message": {data: []byte{0x0, 0x16, 0x00, 0x02, 0x21, 0x12, 0xA4, 0x42, 0xf7, 0x43, 0x81, 0xa3, 0xc9, 0xcd, 0x88, 0x89, 0x70, 0x58, 0xac, 0x73, 0x0, 0x0}}, + "channel data": { + data: []byte{0x40, 0x01, 0x00, 0x08, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + err: nil, + }, + "partial data less than channel header": { + data: []byte{1}, + err: errIncompleteTURNFrame, + }, + "partial stun message": { + data: []byte{0x0, 0x16, 0x02, 0xDC, 0x21, 0x12, 0xA4, 0x42, 0x0, 0x0, 0x0}, + err: errIncompleteTURNFrame, + }, + "stun message": { + data: []byte{ + 0x00, 0x16, 0x00, 0x02, 0x21, 0x12, 0xA4, 0x42, 0xf7, 0x43, 0x81, + 0xa3, 0xc9, 0xcd, 0x88, 0x89, 0x70, 0x58, 0xac, 0x73, 0x00, 0x00, + }, + }, } for name, cs := range cases { @@ -616,7 +635,9 @@ func TestSTUNOnly(t *testing.T) { assert.Equal(t, err.Error(), "Allocate error response (error 400: )") } -func RunBenchmarkServer(b *testing.B, clientNum int) { +func RunBenchmarkServer(b *testing.B, clientNum int) { // nolint:cyclop,funlen + b.Helper() + loggerFactory := logging.NewDefaultLoggerFactory() credMap := map[string][]byte{ "user": GenerateAuthKey("user", "pion.ly", "pass"), @@ -641,6 +662,7 @@ func RunBenchmarkServer(b *testing.B, clientNum int) { if pw, ok := credMap[username]; ok { return pw, true } + return nil, false }, PacketConnConfigs: []PacketConnConfig{{ @@ -729,7 +751,7 @@ func RunBenchmarkServer(b *testing.B, clientNum int) { } } -// BenchmarkServer will benchmark the server with multiple simultaneous client connections +// BenchmarkServer will benchmark the server with multiple simultaneous client connections. func BenchmarkServer(b *testing.B) { for i := 1; i <= 4; i++ { b.Run(fmt.Sprintf("client_num_%d", i), func(b *testing.B) { diff --git a/stun_conn.go b/stun_conn.go index 57543544..fd1f1ae6 100644 --- a/stun_conn.go +++ b/stun_conn.go @@ -20,7 +20,7 @@ var ( // STUNConn wraps a net.Conn and implements // net.PacketConn by being STUN aware and -// packetizing the stream +// packetizing the stream. type STUNConn struct { nextConn net.Conn buff []byte @@ -36,92 +36,93 @@ const ( ) // Given a buffer give the last offset of the TURN frame -// If the buffer isn't a valid STUN or ChannelData packet -// or the length doesn't match return false -func consumeSingleTURNFrame(p []byte) (int, error) { +// If the buffer isn't a valid STUN or ChannelData packet, +// or the length doesn't match return false. +func consumeSingleTURNFrame(b []byte) (int, error) { // Too short to determine if ChannelData or STUN - if len(p) < 9 { + if len(b) < 9 { return 0, errIncompleteTURNFrame } var datagramSize uint16 switch { - case stun.IsMessage(p): - datagramSize = binary.BigEndian.Uint16(p[2:4]) + stunHeaderSize - case proto.ChannelNumber(binary.BigEndian.Uint16(p[0:2])).Valid(): - datagramSize = binary.BigEndian.Uint16(p[channelDataNumberSize:channelDataHeaderSize]) + case stun.IsMessage(b): + datagramSize = binary.BigEndian.Uint16(b[2:4]) + stunHeaderSize + case proto.ChannelNumber(binary.BigEndian.Uint16(b[0:2])).Valid(): + datagramSize = binary.BigEndian.Uint16(b[channelDataNumberSize:channelDataHeaderSize]) if paddingOverflow := (datagramSize + channelDataPadding) % channelDataPadding; paddingOverflow != 0 { datagramSize = (datagramSize + channelDataPadding) - paddingOverflow } datagramSize += channelDataHeaderSize - case len(p) < stunHeaderSize: + case len(b) < stunHeaderSize: return 0, errIncompleteTURNFrame default: return 0, errInvalidTURNFrame } - if len(p) < int(datagramSize) { + if len(b) < int(datagramSize) { return 0, errIncompleteTURNFrame } return int(datagramSize), nil } -// ReadFrom implements ReadFrom from net.PacketConn -func (s *STUNConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +// ReadFrom implements ReadFrom from net.PacketConn. +func (s *STUNConn) ReadFrom(payload []byte) (n int, addr net.Addr, err error) { // First pass any buffered data from previous reads n, err = consumeSingleTURNFrame(s.buff) if errors.Is(err, errInvalidTURNFrame) { return 0, nil, err } else if err == nil { - copy(p, s.buff[:n]) + copy(payload, s.buff[:n]) s.buff = s.buff[n:] return n, s.nextConn.RemoteAddr(), nil } // Then read from the nextConn, appending to our buff - n, err = s.nextConn.Read(p) + n, err = s.nextConn.Read(payload) if err != nil { return 0, nil, err } - s.buff = append(s.buff, append([]byte{}, p[:n]...)...) - return s.ReadFrom(p) + s.buff = append(s.buff, append([]byte{}, payload[:n]...)...) + + return s.ReadFrom(payload) } -// WriteTo implements WriteTo from net.PacketConn -func (s *STUNConn) WriteTo(p []byte, _ net.Addr) (n int, err error) { - return s.nextConn.Write(p) +// WriteTo implements WriteTo from net.PacketConn. +func (s *STUNConn) WriteTo(payload []byte, _ net.Addr) (n int, err error) { + return s.nextConn.Write(payload) } -// Close implements Close from net.PacketConn +// Close implements Close from net.PacketConn. func (s *STUNConn) Close() error { return s.nextConn.Close() } -// LocalAddr implements LocalAddr from net.PacketConn +// LocalAddr implements LocalAddr from net.PacketConn. func (s *STUNConn) LocalAddr() net.Addr { return s.nextConn.LocalAddr() } -// SetDeadline implements SetDeadline from net.PacketConn +// SetDeadline implements SetDeadline from net.PacketConn. func (s *STUNConn) SetDeadline(t time.Time) error { return s.nextConn.SetDeadline(t) } -// SetReadDeadline implements SetReadDeadline from net.PacketConn +// SetReadDeadline implements SetReadDeadline from net.PacketConn. func (s *STUNConn) SetReadDeadline(t time.Time) error { return s.nextConn.SetReadDeadline(t) } -// SetWriteDeadline implements SetWriteDeadline from net.PacketConn +// SetWriteDeadline implements SetWriteDeadline from net.PacketConn. func (s *STUNConn) SetWriteDeadline(t time.Time) error { return s.nextConn.SetWriteDeadline(t) } -// NewSTUNConn creates a STUNConn +// NewSTUNConn creates a STUNConn. func NewSTUNConn(nextConn net.Conn) *STUNConn { return &STUNConn{nextConn: nextConn} }