Skip to content

Commit

Permalink
Update handshaker to handle CID extension
Browse files Browse the repository at this point in the history
Updates handshaker to handle negotiating CIDs. Local connection ID is
only set if the local party generates one and the remote indicates
support. Remote connection id is only set if remote generates one and
connection IDs are supported locally

Signed-off-by: Daniel Mangum <[email protected]>
  • Loading branch information
hasheddan committed Aug 9, 2023
1 parent 8922879 commit e5420de
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 3 deletions.
12 changes: 12 additions & 0 deletions flight0handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,21 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak
state.serverName = e.ServerName // remote server name
case *extension.ALPN:
state.peerSupportedProtocols = e.ProtocolNameList
case *extension.ConnectionID:
// Only set connection ID to be sent if server supports connection
// IDs.
if cfg.connectionIDGenerator != nil {
state.remoteConnectionID = e.CID
}
}
}

// If the client doesn't support connection IDs, the server should not
// expect one to be sent.
if state.remoteConnectionID == nil {
state.localConnectionID = nil
}

if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS
}
Expand Down
8 changes: 8 additions & 0 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha
}
}

// If we have a connection ID generator, use it. The CID may be zero length,
// in which case we are just requesting that the server send us a CID to
// use.
if cfg.connectionIDGenerator != nil {
state.localConnectionID = cfg.connectionIDGenerator()
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
}

return []*packet{
{
record: &recordlayer.RecordLayer{
Expand Down
18 changes: 18 additions & 0 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,20 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error?
}
state.NegotiatedProtocol = e.ProtocolNameList[0]
case *extension.ConnectionID:
// Only set connection ID to be sent if client supports connection
// IDs.
if cfg.connectionIDGenerator != nil {
state.remoteConnectionID = e.CID
}
}
}
// If the server doesn't support connection IDs, the client should not
// expect one to be sent.
if state.remoteConnectionID == nil {
state.localConnectionID = nil
}

if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
}
Expand Down Expand Up @@ -268,6 +280,12 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
}

// If we sent a connection ID on the first ClientHello, send it on the
// second.
if state.localConnectionID != nil {
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
}

return []*packet{
{
record: &recordlayer.RecordLayer{
Expand Down
11 changes: 10 additions & 1 deletion flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
return flight6, nil, nil
}

func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit
extensions := []extension.Extension{&extension.RenegotiationInfo{
RenegotiatedConnection: 0,
}}
Expand Down Expand Up @@ -250,6 +250,15 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
state.NegotiatedProtocol = selectedProto
}

// If we have a connection ID generator, we are willing to use connection
// IDs. We already know whether the client supports connection IDs from
// parsing the ClientHello, so avoid setting local connection ID if the
// client won't send it.
if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil {
state.localConnectionID = cfg.connectionIDGenerator()
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
}

var pkts []*packet
cipherSuiteID := uint16(state.cipherSuite.ID())

Expand Down
1 change: 1 addition & 0 deletions flight5handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han
},
},
},
shouldWrapCID: len(state.remoteConnectionID) > 0,
shouldEncrypt: true,
resetLocalSequenceNumber: true,
})
Expand Down
1 change: 1 addition & 0 deletions flight6handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ func flight6Generate(_ flightConn, state *State, cache *handshakeCache, cfg *han
},
},
},
shouldWrapCID: len(state.remoteConnectionID) > 0,
shouldEncrypt: true,
resetLocalSequenceNumber: true,
},
Expand Down
2 changes: 1 addition & 1 deletion fragment_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (f *fragmentBuffer) push(buf []byte) (bool, error) {
return false, nil
}

for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) {
for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) {
if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
return false, err
}
Expand Down
1 change: 1 addition & 0 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ type handshakeConfig struct {
customCipherSuites func() []CipherSuite
ellipticCurves []elliptic.Curve
insecureSkipHelloVerify bool
connectionIDGenerator func() []byte

onFlightState func(flightVal, handshakeState)
log logging.LeveledLogger
Expand Down
2 changes: 1 addition & 1 deletion handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ func (c *flightTestConn) writePackets(_ context.Context, pkts []*packet) error {
return err
}

c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
c.handshakeCache.push(handshakeRaw[recordlayer.FixedHeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)

content, err := h.Message.Marshal()
if err != nil {
Expand Down

0 comments on commit e5420de

Please sign in to comment.