Skip to content

Commit

Permalink
Refactor errors
Browse files Browse the repository at this point in the history
Wrap all errors to implement net.Error.
Caller of Read/Write can check err.(net.Error).Temporary() to determine
whether the error is fatal or not, as with raw net.UDPConn.
  • Loading branch information
at-wat committed Mar 5, 2020
1 parent 3b0a286 commit 08e9c93
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 102 deletions.
32 changes: 18 additions & 14 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,20 +243,23 @@ func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con

// Read reads data from the connection.
func (c *Conn) Read(p []byte) (n int, err error) {
select {
case <-c.readDeadline.Done():
return 0, context.DeadlineExceeded
default:
if c.isConnectionClosed() {
return 0, io.EOF
}

if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
}

select {
case <-c.readDeadline.Done():
return 0, errDeadlineExceeded
default:
}

for {
select {
case <-c.readDeadline.Done():
return 0, context.DeadlineExceeded
return 0, errDeadlineExceeded
case <-c.closed.Done():
return 0, io.EOF
case out, ok := <-c.decrypted:
Expand All @@ -279,18 +282,19 @@ func (c *Conn) Read(p []byte) (n int, err error) {

// Write writes len(p) bytes from p to the DTLS connection
func (c *Conn) Write(p []byte) (int, error) {
if c.isConnectionClosed() {
return 0, ErrConnClosed
}

select {
case <-c.writeDeadline.Done():
return 0, context.DeadlineExceeded
return 0, errDeadlineExceeded
default:
}

if !c.isHandshakeCompletedSuccessfully() {
return 0, errHandshakeInProgress
}
if c.isConnectionClosed() {
return 0, ErrConnClosed
}

return len(p), c.writePackets(c.writeDeadline, []*packet{
{
Expand Down Expand Up @@ -405,7 +409,7 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {

for _, compactedRawPackets := range compactedRawPackets {
if _, err := c.nextConn.Write(ctx, compactedRawPackets); err != nil {
return err
return netError(err)
}
}

Expand Down Expand Up @@ -553,7 +557,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
b := *bufptr
i, err := c.nextConn.Read(ctx, b)
if err != nil {
return err
return netError(err)
}

pkts, err := unpackDatagram(b[:i])
Expand Down Expand Up @@ -793,7 +797,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
}
case error:
switch err {
case context.DeadlineExceeded, context.Canceled, io.EOF:
case errDeadlineExceeded, context.DeadlineExceeded, context.Canceled, io.EOF:
default:
if c.isHandshakeCompletedSuccessfully() {
// Keep read loop and pass the read error to Read()
Expand Down Expand Up @@ -841,7 +845,7 @@ func (c *Conn) translateHandshakeCtxError(err error) error {
return nil
}
return err
case context.DeadlineExceeded:
case errDeadlineExceeded, context.DeadlineExceeded:
return errHandshakeTimeout
}
return err
Expand Down
59 changes: 59 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
Expand Down Expand Up @@ -85,6 +86,64 @@ func TestRoutineLeakOnClose(t *testing.T) {
// inboundLoop routine should not be leaked.
}

func TestReadWriteDeadline(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(5 * time.Second)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

ca, cb, err := pipeMemory()
if err != nil {
t.Fatal(err)
}

if err := ca.SetDeadline(time.Unix(0, 1)); err != nil {
t.Fatal(err)
}
_, werr := ca.Write(make([]byte, 100))
if e, ok := werr.(net.Error); ok {
if !e.Timeout() {
t.Error("Deadline exceeded Write must return Timeout error")
}
if !e.Temporary() {
t.Error("Deadline exceeded Write must return Temporary error")
}
} else {
t.Error("Write must return net.Error error")
}
_, rerr := ca.Read(make([]byte, 100))
if e, ok := rerr.(net.Error); ok {
if !e.Timeout() {
t.Error("Deadline exceeded Read must return Timeout error")
}
if !e.Temporary() {
t.Error("Deadline exceeded Read must return Temporary error")
}
} else {
t.Error("Read must return net.Error error")
}
if err := ca.SetDeadline(time.Time{}); err != nil {
t.Error(err)
}

if err := ca.Close(); err != nil {
t.Error(err)
}
if err := cb.Close(); err != nil {
t.Error(err)
}

if _, err := ca.Write(make([]byte, 100)); err != ErrConnClosed {
t.Errorf("Write must return %v after close, got %v", ErrConnClosed, err)
}
if _, err := ca.Read(make([]byte, 100)); err != io.EOF {
t.Errorf("Read must return %v after close, got %v", io.EOF, err)
}
}

func pipeMemory() (*Conn, *Conn, error) {
// In memory pipe
ca, cb := dpipe.Pipe()
Expand Down
Loading

0 comments on commit 08e9c93

Please sign in to comment.