Skip to content

Commit

Permalink
Make Conn fully net.Conn compatible
Browse files Browse the repository at this point in the history
* Implement Read/WriteDeadline
* Return remained data on Read after Close
* Test using nettest.TestConn
  • Loading branch information
at-wat committed Feb 10, 2020
1 parent 2cc05a1 commit 34883a7
Show file tree
Hide file tree
Showing 11 changed files with 420 additions and 91 deletions.
87 changes: 53 additions & 34 deletions conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dtls

import (
"context"
"crypto/rand"
"crypto/tls"
"crypto/x509"
Expand All @@ -13,6 +14,7 @@ import (
"time"

"github.com/pion/dtls/v2/internal/closer"
"github.com/pion/dtls/v2/internal/net/deadline"
"github.com/pion/logging"
)

Expand Down Expand Up @@ -88,6 +90,9 @@ type Conn struct {
handshakeErr *atomicError // Error if one occurred during handshake
readErr *atomicError // Error if one occurred in inboundLoop

readDeadline *deadline.Deadline
writeDeadline *deadline.Deadline

log logging.LeveledLogger
}

Expand Down Expand Up @@ -166,6 +171,9 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage
log: logger,
handshakeErr: &atomicError{},
readErr: &atomicError{},

readDeadline: deadline.New(),
writeDeadline: deadline.New(),
}

// Use host from conn address when serverName is not provided
Expand Down Expand Up @@ -255,33 +263,26 @@ func Server(conn net.Conn, config *Config) (*Conn, error) {

// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
checkConnStatus := func() error {
if err := c.handshakeErr.load(); err != nil {
return err
}
if c.connectionClosed.Err() != nil {
return io.EOF
}
if err := c.readErr.load(); err != nil {
return err
var out []byte
var ok bool
select {
case out, ok = <-c.decrypted:
// inboundLoop has closed but error has not been set yet
if !ok {
if err := c.handshakeErr.load(); err != nil {
return 0, err
}
if c.connectionClosed.Err() != nil {
return 0, io.EOF
}
if err := c.readErr.load(); err != nil {
return 0, err
}
return 0, io.EOF
}

return nil
}

if err := checkConnStatus(); err != nil {
return 0, err
}
out, ok := <-c.decrypted
if err := checkConnStatus(); err != nil {
return 0, err
case <-c.readDeadline.Done():
return 0, context.DeadlineExceeded
}

// inboundLoop has closed but error has not been set yet
if !ok {
return 0, io.EOF
}

if len(p) < len(out) {
return 0, errBufferTooSmall
}
Expand All @@ -291,8 +292,11 @@ func (c *Conn) Read(p []byte) (n int, err error) {

// Write writes len(p) bytes from p to the DTLS connection
func (c *Conn) Write(p []byte) (int, error) {
c.lock.Lock()
defer c.lock.Unlock()
select {
case <-c.writeDeadline.Done():
return 0, context.DeadlineExceeded
default:
}

if err := c.handshakeErr.load(); err != nil {
return 0, err
Expand All @@ -304,6 +308,9 @@ func (c *Conn) Write(p []byte) (int, error) {
return 0, errHandshakeInProgress
}

c.lock.Lock()
defer c.lock.Unlock()

if err := c.bufferPacket(&packet{
record: &recordLayer{
recordLayerHeader: recordLayerHeader{
Expand Down Expand Up @@ -658,6 +665,13 @@ func (c *Conn) handleIncomingPacket(buf []byte) (*alert, error) {
return &alert{alertLevelFatal, alertUnexpectedMessage}, fmt.Errorf("ApplicationData with epoch of 0")
}

if err := c.handshakeErr.load(); err != nil {
return nil, err
}
if c.connectionClosed.Err() != nil {
return nil, io.EOF
}

select {
case c.decrypted <- content.data:
case <-c.connectionClosed.Done():
Expand Down Expand Up @@ -772,27 +786,32 @@ func (c *Conn) getRemoteEpoch() uint16 {
return c.state.remoteEpoch.Load().(uint16)
}

// LocalAddr is a stub
// LocalAddr implements net.Conn.LocalAddr
func (c *Conn) LocalAddr() net.Addr {
return c.nextConn.LocalAddr()
}

// RemoteAddr is a stub
// RemoteAddr implements net.Conn.RemoteAddr
func (c *Conn) RemoteAddr() net.Addr {
return c.nextConn.RemoteAddr()
}

// SetDeadline is a stub
// SetDeadline implements net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error {
return c.nextConn.SetDeadline(t)
c.readDeadline.Set(t)
return c.SetWriteDeadline(t)
}

// SetReadDeadline is a stub
// SetReadDeadline implements net.Conn.SetReadDeadline
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.nextConn.SetReadDeadline(t)
c.readDeadline.Set(t)
// Read deadline is fully managed by this layer.
// Don't set read deadline to underlying connection.
return nil
}

// SetWriteDeadline is a stub
// SetWriteDeadline implements net.Conn.SetWriteDeadline
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline.Set(t)
return c.nextConn.SetWriteDeadline(t)
}
18 changes: 0 additions & 18 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@ import (
"github.com/pion/transport/test"
)

// Seems to strict for out implementation at this point
// func TestNetTest(t *testing.T) {
// lim := test.TimeOut(time.Minute*1 + time.Second*10)
// defer lim.Stop()
//
// nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
// c1, c2, err = pipeMemory()
// if err != nil {
// return nil, nil, nil, err
// }
// stop = func() {
// c1.Close()
// c2.Close()
// }
// return
// })
// }

func TestStressDuplex(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ require (
github.com/pion/logging v0.2.2
github.com/pion/transport v0.8.10
golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3
)

go 1.13
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d h1:9FCpayM9Egr1baVnV1SX0H87m+XB0B8S0hAMi99X/3U=
golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
77 changes: 77 additions & 0 deletions internal/net/deadline/deadline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Package deadline provides deadline timer used to implement
// net.Conn compatible connection
package deadline

import (
"sync"
"time"
)

// Deadline signals updatable deadline timer.
type Deadline struct {
exceeded chan struct{}
stop chan struct{}
stopped chan bool
mu sync.RWMutex
}

// New creates new deadline timer.
func New() *Deadline {
d := &Deadline{
exceeded: make(chan struct{}),
stop: make(chan struct{}),
stopped: make(chan bool, 1),
}
d.stopped <- true
return d
}

// Set new deadline. Zero value means no deadline.
func (d *Deadline) Set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()

close(d.stop)

select {
case <-d.exceeded:
d.exceeded = make(chan struct{})
default:
stopped := <-d.stopped
if !stopped {
d.exceeded = make(chan struct{})
}
}
d.stop = make(chan struct{})
d.stopped = make(chan bool, 1)

if t.IsZero() {
d.stopped <- true
return
}

if dur := time.Until(t); dur > 0 {
exceeded := d.exceeded
stopped := d.stopped
go func() {
select {
case <-time.After(dur):
close(exceeded)
stopped <- false
case <-d.stop:
stopped <- true
}
}()
return
}

close(d.exceeded)
d.stopped <- false
}

// Done receives deadline signal.
func (d *Deadline) Done() <-chan struct{} {
d.mu.RLock()
defer d.mu.RUnlock()
return d.exceeded
}
Loading

0 comments on commit 34883a7

Please sign in to comment.