-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add wrapper to cancel Read by context. Note that underlying Conn must support SetDeadline.
- Loading branch information
Showing
2 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
// Package connctx wraps net.Conn using context.Context. | ||
package connctx | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"io" | ||
"net" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
) | ||
|
||
// ErrClosing is returned on Write to closed connection. | ||
var ErrClosing = errors.New("use of closed network connection") | ||
|
||
// ConnCtx is a wrapper of net.Conn using context.Context. | ||
type ConnCtx interface { | ||
Read(context.Context, []byte) (int, error) | ||
Write(context.Context, []byte) (int, error) | ||
Close() error | ||
LocalAddr() net.Addr | ||
RemoteAddr() net.Addr | ||
Conn() net.Conn | ||
} | ||
|
||
type connCtx struct { | ||
nextConn net.Conn | ||
closed chan struct{} | ||
closeOnce sync.Once | ||
readMu sync.Mutex | ||
writeMu sync.Mutex | ||
} | ||
|
||
var veryOld = time.Unix(0, 1) | ||
|
||
// New creates a new ConnCtx wrapping given net.Conn. | ||
func New(conn net.Conn) ConnCtx { | ||
c := &connCtx{ | ||
nextConn: conn, | ||
closed: make(chan struct{}), | ||
} | ||
return c | ||
} | ||
|
||
func (c *connCtx) Read(ctx context.Context, b []byte) (int, error) { | ||
c.readMu.Lock() | ||
defer c.readMu.Unlock() | ||
|
||
select { | ||
case <-c.closed: | ||
return 0, io.EOF | ||
default: | ||
} | ||
|
||
done := make(chan struct{}) | ||
var wg sync.WaitGroup | ||
var errSetDeadline atomic.Value | ||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
select { | ||
case <-ctx.Done(): | ||
// context canceled | ||
if err := c.nextConn.SetReadDeadline(veryOld); err != nil { | ||
errSetDeadline.Store(err) | ||
return | ||
} | ||
<-done | ||
if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil { | ||
errSetDeadline.Store(err) | ||
} | ||
case <-done: | ||
} | ||
}() | ||
|
||
n, err := c.nextConn.Read(b) | ||
|
||
close(done) | ||
wg.Wait() | ||
if e := ctx.Err(); e != nil && n == 0 { | ||
err = e | ||
} | ||
if err2 := errSetDeadline.Load(); err == nil && err2 != nil { | ||
err = err2.(error) | ||
} | ||
return n, err | ||
} | ||
|
||
func (c *connCtx) Write(ctx context.Context, b []byte) (int, error) { | ||
c.writeMu.Lock() | ||
defer c.writeMu.Unlock() | ||
|
||
select { | ||
case <-c.closed: | ||
return 0, ErrClosing | ||
default: | ||
} | ||
|
||
done := make(chan struct{}) | ||
var wg sync.WaitGroup | ||
var errSetDeadline atomic.Value | ||
wg.Add(1) | ||
go func() { | ||
select { | ||
case <-ctx.Done(): | ||
// context canceled | ||
if err := c.nextConn.SetWriteDeadline(veryOld); err != nil { | ||
errSetDeadline.Store(err) | ||
return | ||
} | ||
<-done | ||
if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil { | ||
errSetDeadline.Store(err) | ||
} | ||
case <-done: | ||
} | ||
wg.Done() | ||
}() | ||
|
||
n, err := c.nextConn.Write(b) | ||
|
||
close(done) | ||
wg.Wait() | ||
if e := ctx.Err(); e != nil && n == 0 { | ||
err = e | ||
} | ||
if err2 := errSetDeadline.Load(); err == nil && err2 != nil { | ||
err = err2.(error) | ||
} | ||
return n, err | ||
} | ||
|
||
func (c *connCtx) Close() error { | ||
err := c.nextConn.Close() | ||
c.closeOnce.Do(func() { | ||
c.writeMu.Lock() | ||
c.readMu.Lock() | ||
close(c.closed) | ||
c.readMu.Unlock() | ||
c.writeMu.Unlock() | ||
}) | ||
return err | ||
} | ||
|
||
func (c *connCtx) LocalAddr() net.Addr { | ||
return c.nextConn.LocalAddr() | ||
} | ||
|
||
func (c *connCtx) RemoteAddr() net.Addr { | ||
return c.nextConn.LocalAddr() | ||
} | ||
|
||
func (c *connCtx) Conn() net.Conn { | ||
return c.nextConn | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
package connctx | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"io" | ||
"net" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestRead(t *testing.T) { | ||
ca, cb := net.Pipe() | ||
defer func() { | ||
_ = ca.Close() | ||
}() | ||
|
||
data := []byte{0x01, 0x02, 0xFF} | ||
chErr := make(chan error) | ||
|
||
go func() { | ||
_, err := cb.Write(data) | ||
chErr <- err | ||
}() | ||
|
||
c := New(ca) | ||
b := make([]byte, 100) | ||
n, err := c.Read(context.Background(), b) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if n != len(data) { | ||
t.Errorf("Wrong data length, expected %d, got %d", len(data), n) | ||
} | ||
if !bytes.Equal(data, b[:n]) { | ||
t.Errorf("Wrong data, expected %v, got %v", data, b) | ||
} | ||
|
||
err = <-chErr | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
} | ||
|
||
func TestReadTImeout(t *testing.T) { | ||
ca, _ := net.Pipe() | ||
defer func() { | ||
_ = ca.Close() | ||
}() | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) | ||
defer cancel() | ||
|
||
c := New(ca) | ||
b := make([]byte, 100) | ||
n, err := c.Read(ctx, b) | ||
if err == nil { | ||
t.Error("Read unexpectedly successed") | ||
} | ||
if n != 0 { | ||
t.Errorf("Wrong data length, expected %d, got %d", 0, n) | ||
} | ||
} | ||
|
||
func TestReadCancel(t *testing.T) { | ||
ca, _ := net.Pipe() | ||
defer func() { | ||
_ = ca.Close() | ||
}() | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
go func() { | ||
time.Sleep(10 * time.Millisecond) | ||
cancel() | ||
}() | ||
|
||
c := New(ca) | ||
b := make([]byte, 100) | ||
n, err := c.Read(ctx, b) | ||
if err == nil { | ||
t.Error("Read unexpectedly successed") | ||
} | ||
if n != 0 { | ||
t.Errorf("Wrong data length, expected %d, got %d", 0, n) | ||
} | ||
} | ||
|
||
func TestReadClosed(t *testing.T) { | ||
ca, _ := net.Pipe() | ||
|
||
c := New(ca) | ||
_ = c.Close() | ||
|
||
b := make([]byte, 100) | ||
n, err := c.Read(context.Background(), b) | ||
if err != io.EOF { | ||
t.Errorf("Expected error '%v', got '%v'", io.EOF, err) | ||
} | ||
if n != 0 { | ||
t.Errorf("Wrong data length, expected %d, got %d", 0, n) | ||
} | ||
} | ||
|
||
func TestWrite(t *testing.T) { | ||
ca, cb := net.Pipe() | ||
defer func() { | ||
_ = ca.Close() | ||
}() | ||
|
||
chErr := make(chan error) | ||
chRead := make(chan []byte) | ||
|
||
go func() { | ||
b := make([]byte, 100) | ||
n, err := cb.Read(b) | ||
chErr <- err | ||
chRead <- b[:n] | ||
}() | ||
|
||
c := New(ca) | ||
data := []byte{0x01, 0x02, 0xFF} | ||
n, err := c.Write(context.Background(), data) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if n != len(data) { | ||
t.Errorf("Wrong data length, expected %d, got %d", len(data), n) | ||
} | ||
|
||
err = <-chErr | ||
b := <-chRead | ||
if !bytes.Equal(data, b) { | ||
t.Errorf("Wrong data, expected %v, got %v", data, b) | ||
} | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
} | ||
|
||
func TestWriteTimeout(t *testing.T) { | ||
ca, _ := net.Pipe() | ||
defer func() { | ||
_ = ca.Close() | ||
}() | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) | ||
defer cancel() | ||
|
||
c := New(ca) | ||
b := make([]byte, 100) | ||
n, err := c.Write(ctx, b) | ||
if err == nil { | ||
t.Error("Write unexpectedly successed") | ||
} | ||
if n != 0 { | ||
t.Errorf("Wrong data length, expected %d, got %d", 0, n) | ||
} | ||
} | ||
|
||
func TestWriteCancel(t *testing.T) { | ||
ca, _ := net.Pipe() | ||
defer func() { | ||
_ = ca.Close() | ||
}() | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
go func() { | ||
time.Sleep(10 * time.Millisecond) | ||
cancel() | ||
}() | ||
|
||
c := New(ca) | ||
b := make([]byte, 100) | ||
n, err := c.Write(ctx, b) | ||
if err == nil { | ||
t.Error("Write unexpectedly successed") | ||
} | ||
if n != 0 { | ||
t.Errorf("Wrong data length, expected %d, got %d", 0, n) | ||
} | ||
} | ||
|
||
func TestWriteClosed(t *testing.T) { | ||
ca, _ := net.Pipe() | ||
|
||
c := New(ca) | ||
_ = c.Close() | ||
|
||
b := make([]byte, 100) | ||
n, err := c.Write(context.Background(), b) | ||
if err != ErrClosing { | ||
t.Errorf("Expected error '%v', got '%v'", ErrClosing, err) | ||
} | ||
if n != 0 { | ||
t.Errorf("Wrong data length, expected %d, got %d", 0, n) | ||
} | ||
} |