From 6bf60d266c0a11670c6f52420a0d73cd77abeb36 Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Tue, 3 Mar 2020 01:40:31 +0900 Subject: [PATCH] Add context wrapper for Conn.Read Add wrapper to cancel Read by context. Note that underlying Conn must support SetDeadline. --- internal/net/connctx/connctx.go | 156 +++++++++++++++++++++ internal/net/connctx/connctx_test.go | 197 +++++++++++++++++++++++++++ 2 files changed, 353 insertions(+) create mode 100644 internal/net/connctx/connctx.go create mode 100644 internal/net/connctx/connctx_test.go diff --git a/internal/net/connctx/connctx.go b/internal/net/connctx/connctx.go new file mode 100644 index 000000000..3503b42cc --- /dev/null +++ b/internal/net/connctx/connctx.go @@ -0,0 +1,156 @@ +// Package connctx wraps net.Conn using context.Context. +package connctx + +import ( + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// ErrClosing is returned on Write to closed connection. +var ErrClosing = errors.New("use of closed network connection") + +// ConnCtx is a wrapper of net.Conn using context.Context. +type ConnCtx interface { + Read(context.Context, []byte) (int, error) + Write(context.Context, []byte) (int, error) + Close() error + LocalAddr() net.Addr + RemoteAddr() net.Addr + Conn() net.Conn +} + +type connCtx struct { + nextConn net.Conn + closed chan struct{} + closeOnce sync.Once + readMu sync.Mutex + writeMu sync.Mutex +} + +var veryOld = time.Unix(0, 1) + +// New creates a new ConnCtx wrapping given net.Conn. +func New(conn net.Conn) ConnCtx { + c := &connCtx{ + nextConn: conn, + closed: make(chan struct{}), + } + return c +} + +func (c *connCtx) Read(ctx context.Context, b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + select { + case <-c.closed: + return 0, io.EOF + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := c.nextConn.SetReadDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + return + } + <-done + if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, err := c.nextConn.Read(b) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2 := errSetDeadline.Load(); err == nil && err2 != nil { + err = err2.(error) + } + return n, err +} + +func (c *connCtx) Write(ctx context.Context, b []byte) (int, error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + select { + case <-c.closed: + return 0, ErrClosing + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + select { + case <-ctx.Done(): + // context canceled + if err := c.nextConn.SetWriteDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + return + } + <-done + if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + wg.Done() + }() + + n, err := c.nextConn.Write(b) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2 := errSetDeadline.Load(); err == nil && err2 != nil { + err = err2.(error) + } + return n, err +} + +func (c *connCtx) Close() error { + err := c.nextConn.Close() + c.closeOnce.Do(func() { + c.writeMu.Lock() + c.readMu.Lock() + close(c.closed) + c.readMu.Unlock() + c.writeMu.Unlock() + }) + return err +} + +func (c *connCtx) LocalAddr() net.Addr { + return c.nextConn.LocalAddr() +} + +func (c *connCtx) RemoteAddr() net.Addr { + return c.nextConn.LocalAddr() +} + +func (c *connCtx) Conn() net.Conn { + return c.nextConn +} diff --git a/internal/net/connctx/connctx_test.go b/internal/net/connctx/connctx_test.go new file mode 100644 index 000000000..cd906abc9 --- /dev/null +++ b/internal/net/connctx/connctx_test.go @@ -0,0 +1,197 @@ +package connctx + +import ( + "bytes" + "context" + "io" + "net" + "testing" + "time" +) + +func TestRead(t *testing.T) { + ca, cb := net.Pipe() + defer func() { + _ = ca.Close() + }() + + data := []byte{0x01, 0x02, 0xFF} + chErr := make(chan error) + + go func() { + _, err := cb.Write(data) + chErr <- err + }() + + c := New(ca) + b := make([]byte, 100) + n, err := c.Read(context.Background(), b) + if err != nil { + t.Fatal(err) + } + if n != len(data) { + t.Errorf("Wrong data length, expected %d, got %d", len(data), n) + } + if !bytes.Equal(data, b[:n]) { + t.Errorf("Wrong data, expected %v, got %v", data, b) + } + + err = <-chErr + if err != nil { + t.Fatal(err) + } +} + +func TestReadTImeout(t *testing.T) { + ca, _ := net.Pipe() + defer func() { + _ = ca.Close() + }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + c := New(ca) + b := make([]byte, 100) + n, err := c.Read(ctx, b) + if err == nil { + t.Error("Read unexpectedly successed") + } + if n != 0 { + t.Errorf("Wrong data length, expected %d, got %d", 0, n) + } +} + +func TestReadCancel(t *testing.T) { + ca, _ := net.Pipe() + defer func() { + _ = ca.Close() + }() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + c := New(ca) + b := make([]byte, 100) + n, err := c.Read(ctx, b) + if err == nil { + t.Error("Read unexpectedly successed") + } + if n != 0 { + t.Errorf("Wrong data length, expected %d, got %d", 0, n) + } +} + +func TestReadClosed(t *testing.T) { + ca, _ := net.Pipe() + + c := New(ca) + _ = c.Close() + + b := make([]byte, 100) + n, err := c.Read(context.Background(), b) + if err != io.EOF { + t.Errorf("Expected error '%v', got '%v'", io.EOF, err) + } + if n != 0 { + t.Errorf("Wrong data length, expected %d, got %d", 0, n) + } +} + +func TestWrite(t *testing.T) { + ca, cb := net.Pipe() + defer func() { + _ = ca.Close() + }() + + chErr := make(chan error) + chRead := make(chan []byte) + + go func() { + b := make([]byte, 100) + n, err := cb.Read(b) + chErr <- err + chRead <- b[:n] + }() + + c := New(ca) + data := []byte{0x01, 0x02, 0xFF} + n, err := c.Write(context.Background(), data) + if err != nil { + t.Fatal(err) + } + if n != len(data) { + t.Errorf("Wrong data length, expected %d, got %d", len(data), n) + } + + err = <-chErr + b := <-chRead + if !bytes.Equal(data, b) { + t.Errorf("Wrong data, expected %v, got %v", data, b) + } + if err != nil { + t.Fatal(err) + } +} + +func TestWriteTimeout(t *testing.T) { + ca, _ := net.Pipe() + defer func() { + _ = ca.Close() + }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + c := New(ca) + b := make([]byte, 100) + n, err := c.Write(ctx, b) + if err == nil { + t.Error("Write unexpectedly successed") + } + if n != 0 { + t.Errorf("Wrong data length, expected %d, got %d", 0, n) + } +} + +func TestWriteCancel(t *testing.T) { + ca, _ := net.Pipe() + defer func() { + _ = ca.Close() + }() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + c := New(ca) + b := make([]byte, 100) + n, err := c.Write(ctx, b) + if err == nil { + t.Error("Write unexpectedly successed") + } + if n != 0 { + t.Errorf("Wrong data length, expected %d, got %d", 0, n) + } +} + +func TestWriteClosed(t *testing.T) { + ca, _ := net.Pipe() + + c := New(ca) + _ = c.Close() + + b := make([]byte, 100) + n, err := c.Write(context.Background(), b) + if err != ErrClosing { + t.Errorf("Expected error '%v', got '%v'", ErrClosing, err) + } + if n != 0 { + t.Errorf("Wrong data length, expected %d, got %d", 0, n) + } +}