Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-91166: zero copy SelectorSocketTransport transport implementation #31871

Merged
merged 23 commits into from
Dec 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
39b538f
zero copy transport implementation
kumaraditya303 Mar 14, 2022
abd2dc3
WIP sendmsg
kumaraditya303 Mar 15, 2022
669b661
writelines implementation
kumaraditya303 Mar 15, 2022
0692952
use sysconf
kumaraditya303 Mar 15, 2022
2725334
fix tests
kumaraditya303 Mar 15, 2022
bed096d
skip test if sendmsg does not exists
kumaraditya303 Oct 24, 2022
f090e8d
📜🤖 Added by blurb_it.
blurb-it[bot] Oct 24, 2022
d6c77cd
fix check on other platforms
kumaraditya303 Oct 24, 2022
f2ee404
rename some vars
kumaraditya303 Nov 28, 2022
effab03
_HAVE_SENDMSG -> _HAS_SENDMSG
kumaraditya303 Nov 28, 2022
e1e4362
fix tests
kumaraditya303 Nov 28, 2022
bdb1bda
fix tests
kumaraditya303 Nov 28, 2022
d1fae6c
fix writelines and add comments
kumaraditya303 Nov 28, 2022
cd45016
optimize calling
kumaraditya303 Nov 28, 2022
5b962f5
use send if sendmsg does not exists in writelines
kumaraditya303 Nov 28, 2022
85d6909
Merge branch 'main' into asyncio-zero-copy
kumaraditya303 Nov 28, 2022
152b748
Update Lib/asyncio/selector_events.py
kumaraditya303 Dec 2, 2022
2c62bcb
more tests
kumaraditya303 Dec 13, 2022
9b92cff
check fatal error
kumaraditya303 Dec 13, 2022
7e05c2c
Merge branch 'main' into asyncio-zero-copy
kumaraditya303 Dec 13, 2022
97de955
Merge branch 'asyncio-zero-copy' of https://github.com/kumaraditya303…
kumaraditya303 Dec 13, 2022
57b1ba0
Merge branch 'main' of https://github.com/python/cpython into asyncio…
kumaraditya303 Dec 22, 2022
2ca3571
code review
kumaraditya303 Dec 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 75 additions & 11 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import collections
import errno
import functools
import itertools
import os
import selectors
import socket
import warnings
Expand All @@ -28,6 +30,14 @@
from . import trsock
from .log import logger

_HAS_SENDMSG = hasattr(socket.socket, 'sendmsg')

if _HAS_SENDMSG:
try:
SC_IOV_MAX = os.sysconf('SC_IOV_MAX')
except OSError:
# Fallback to send
_HAS_SENDMSG = False

def _test_selector_event(selector, fd, event):
# Test if the selector is monitoring 'event' events
Expand Down Expand Up @@ -757,8 +767,6 @@ class _SelectorTransport(transports._FlowControlMixin,

max_size = 256 * 1024 # Buffer size passed to recv().

_buffer_factory = bytearray # Constructs initial value for self._buffer.

# Attribute used in the destructor: it must be set even if the constructor
# is not called (see _SelectorSslTransport which may start by raising an
# exception)
Expand All @@ -783,7 +791,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
self.set_protocol(protocol)

self._server = server
self._buffer = self._buffer_factory()
self._buffer = collections.deque()
self._conn_lost = 0 # Set when call to connection_lost scheduled.
self._closing = False # Set when close() called.
if self._server is not None:
Expand Down Expand Up @@ -887,7 +895,7 @@ def _call_connection_lost(self, exc):
self._server = None

def get_write_buffer_size(self):
return len(self._buffer)
return sum(map(len, self._buffer))

def _add_reader(self, fd, callback, *args):
if self._closing:
Expand All @@ -909,7 +917,10 @@ def __init__(self, loop, sock, protocol, waiter=None,
self._eof = False
self._paused = False
self._empty_waiter = None

if _HAS_SENDMSG:
self._write_ready = self._write_sendmsg
else:
self._write_ready = self._write_send
# Disable the Nagle algorithm -- small writes will be
# sent without waiting for the TCP ACK. This generally
# decreases the latency (in some cases significantly.)
Expand Down Expand Up @@ -1066,23 +1077,68 @@ def write(self, data):
self._fatal_error(exc, 'Fatal write error on socket transport')
return
else:
data = data[n:]
data = memoryview(data)[n:]
if not data:
return
# Not all was written; register write handler.
self._loop._add_writer(self._sock_fd, self._write_ready)

# Add it to the buffer.
self._buffer.extend(data)
self._buffer.append(data)
self._maybe_pause_protocol()

def _write_ready(self):
def _get_sendmsg_buffer(self):
return itertools.islice(self._buffer, SC_IOV_MAX)

def _write_sendmsg(self):
assert self._buffer, 'Data should not be empty'
if self._conn_lost:
return
try:
nbytes = self._sock.sendmsg(self._get_sendmsg_buffer())
self._adjust_leftover_buffer(nbytes)
except (BlockingIOError, InterruptedError):
pass
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
self._loop._remove_writer(self._sock_fd)
self._buffer.clear()
self._fatal_error(exc, 'Fatal write error on socket transport')
if self._empty_waiter is not None:
self._empty_waiter.set_exception(exc)
else:
self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop._remove_writer(self._sock_fd)
if self._empty_waiter is not None:
self._empty_waiter.set_result(None)
if self._closing:
self._call_connection_lost(None)
elif self._eof:
self._sock.shutdown(socket.SHUT_WR)

def _adjust_leftover_buffer(self, nbytes: int) -> None:
buffer = self._buffer
while nbytes:
b = buffer.popleft()
b_len = len(b)
if b_len <= nbytes:
nbytes -= b_len
else:
buffer.appendleft(b[nbytes:])
break

def _write_send(self):
assert self._buffer, 'Data should not be empty'
if self._conn_lost:
return
try:
n = self._sock.send(self._buffer)
buffer = self._buffer.popleft()
n = self._sock.send(buffer)
kumaraditya303 marked this conversation as resolved.
Show resolved Hide resolved
if n != len(buffer):
# Not all data was written
self._buffer.appendleft(buffer[n:])
except (BlockingIOError, InterruptedError):
pass
except (SystemExit, KeyboardInterrupt):
Expand All @@ -1094,8 +1150,6 @@ def _write_ready(self):
if self._empty_waiter is not None:
self._empty_waiter.set_exception(exc)
else:
if n:
del self._buffer[:n]
self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop._remove_writer(self._sock_fd)
Expand All @@ -1113,6 +1167,16 @@ def write_eof(self):
if not self._buffer:
self._sock.shutdown(socket.SHUT_WR)

def writelines(self, list_of_data):
kumaraditya303 marked this conversation as resolved.
Show resolved Hide resolved
if self._eof:
raise RuntimeError('Cannot call writelines() after write_eof()')
if self._empty_waiter is not None:
raise RuntimeError('unable to writelines; sendfile is in progress')
if not list_of_data:
return
self._buffer.extend([memoryview(data) for data in list_of_data])
self._write_ready()
kumaraditya303 marked this conversation as resolved.
Show resolved Hide resolved

def can_write_eof(self):
return True

Expand Down
117 changes: 99 additions & 18 deletions Lib/test/test_asyncio/test_selector_events.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
"""Tests for selector_events.py"""

import sys
import collections
import selectors
import socket
import sys
import unittest
from asyncio import selector_events
from unittest import mock

try:
import ssl
except ImportError:
ssl = None

import asyncio
from asyncio.selector_events import BaseSelectorEventLoop
from asyncio.selector_events import _SelectorTransport
from asyncio.selector_events import _SelectorSocketTransport
from asyncio.selector_events import _SelectorDatagramTransport
from asyncio.selector_events import (BaseSelectorEventLoop,
_SelectorDatagramTransport,
_SelectorSocketTransport,
_SelectorTransport)
from test.test_asyncio import utils as test_utils


MOCK_ANY = mock.ANY


Expand All @@ -37,7 +39,10 @@ def _close_self_pipe(self):


def list_to_buffer(l=()):
return bytearray().join(l)
buffer = collections.deque()
buffer.extend((memoryview(i) for i in l))
return buffer



def close_transport(transport):
Expand Down Expand Up @@ -493,9 +498,13 @@ def setUp(self):
self.sock = mock.Mock(socket.socket)
self.sock_fd = self.sock.fileno.return_value = 7

def socket_transport(self, waiter=None):
def socket_transport(self, waiter=None, sendmsg=False):
transport = _SelectorSocketTransport(self.loop, self.sock,
self.protocol, waiter=waiter)
if sendmsg:
transport._write_ready = transport._write_sendmsg
else:
transport._write_ready = transport._write_send
self.addCleanup(close_transport, transport)
return transport

Expand Down Expand Up @@ -664,14 +673,14 @@ def test_write_memoryview(self):

def test_write_no_data(self):
transport = self.socket_transport()
transport._buffer.extend(b'data')
transport._buffer.append(memoryview(b'data'))
transport.write(b'')
self.assertFalse(self.sock.send.called)
self.assertEqual(list_to_buffer([b'data']), transport._buffer)

def test_write_buffer(self):
transport = self.socket_transport()
transport._buffer.extend(b'data1')
transport._buffer.append(b'data1')
transport.write(b'data2')
self.assertFalse(self.sock.send.called)
self.assertEqual(list_to_buffer([b'data1', b'data2']),
Expand Down Expand Up @@ -729,6 +738,77 @@ def test_write_tryagain(self):
self.loop.assert_writer(7, transport._write_ready)
self.assertEqual(list_to_buffer([b'data']), transport._buffer)

def test_write_sendmsg_no_data(self):
self.sock.sendmsg = mock.Mock()
self.sock.sendmsg.return_value = 0
transport = self.socket_transport(sendmsg=True)
transport._buffer.append(memoryview(b'data'))
transport.write(b'')
self.assertFalse(self.sock.sendmsg.called)
self.assertEqual(list_to_buffer([b'data']), transport._buffer)

@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
def test_write_sendmsg_full(self):
data = memoryview(b'data')
self.sock.sendmsg = mock.Mock()
self.sock.sendmsg.return_value = len(data)

transport = self.socket_transport(sendmsg=True)
transport._buffer.append(data)
self.loop._add_writer(7, transport._write_ready)
transport._write_ready()
self.assertTrue(self.sock.sendmsg.called)
self.assertFalse(self.loop.writers)

@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
def test_write_sendmsg_partial(self):

data = memoryview(b'data')
self.sock.sendmsg = mock.Mock()
# Sent partial data
self.sock.sendmsg.return_value = 2

transport = self.socket_transport(sendmsg=True)
transport._buffer.append(data)
self.loop._add_writer(7, transport._write_ready)
transport._write_ready()
self.assertTrue(self.sock.sendmsg.called)
self.assertTrue(self.loop.writers)
self.assertEqual(list_to_buffer([b'ta']), transport._buffer)

@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
def test_write_sendmsg_half_buffer(self):
data = [memoryview(b'data1'), memoryview(b'data2')]
self.sock.sendmsg = mock.Mock()
# Sent partial data
self.sock.sendmsg.return_value = 2

transport = self.socket_transport(sendmsg=True)
transport._buffer.extend(data)
self.loop._add_writer(7, transport._write_ready)
transport._write_ready()
kumaraditya303 marked this conversation as resolved.
Show resolved Hide resolved
self.assertTrue(self.sock.sendmsg.called)
self.assertTrue(self.loop.writers)
self.assertEqual(list_to_buffer([b'ta1', b'data2']), transport._buffer)
kumaraditya303 marked this conversation as resolved.
Show resolved Hide resolved

@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
def test_write_sendmsg_OSError(self):
data = memoryview(b'data')
self.sock.sendmsg = mock.Mock()
err = self.sock.sendmsg.side_effect = OSError()

transport = self.socket_transport(sendmsg=True)
transport._fatal_error = mock.Mock()
transport._buffer.extend(data)
# Calls _fatal_error and clears the buffer
transport._write_ready()
kumaraditya303 marked this conversation as resolved.
Show resolved Hide resolved
self.assertTrue(self.sock.sendmsg.called)
self.assertFalse(self.loop.writers)
self.assertEqual(list_to_buffer([]), transport._buffer)
transport._fatal_error.assert_called_with(
err,
'Fatal write error on socket transport')

@mock.patch('asyncio.selector_events.logger')
def test_write_exception(self, m_log):
err = self.sock.send.side_effect = OSError()
Expand Down Expand Up @@ -768,19 +848,19 @@ def test_write_ready(self):
self.sock.send.return_value = len(data)

transport = self.socket_transport()
transport._buffer.extend(data)
transport._buffer.append(data)
self.loop._add_writer(7, transport._write_ready)
transport._write_ready()
self.assertTrue(self.sock.send.called)
self.assertFalse(self.loop.writers)

def test_write_ready_closing(self):
data = b'data'
data = memoryview(b'data')
self.sock.send.return_value = len(data)

transport = self.socket_transport()
transport._closing = True
transport._buffer.extend(data)
transport._buffer.append(data)
self.loop._add_writer(7, transport._write_ready)
transport._write_ready()
self.assertTrue(self.sock.send.called)
Expand All @@ -795,11 +875,11 @@ def test_write_ready_no_data(self):
self.assertRaises(AssertionError, transport._write_ready)

def test_write_ready_partial(self):
data = b'data'
data = memoryview(b'data')
self.sock.send.return_value = 2

transport = self.socket_transport()
transport._buffer.extend(data)
transport._buffer.append(data)
self.loop._add_writer(7, transport._write_ready)
transport._write_ready()
self.loop.assert_writer(7, transport._write_ready)
Expand All @@ -810,7 +890,7 @@ def test_write_ready_partial_none(self):
self.sock.send.return_value = 0

transport = self.socket_transport()
transport._buffer.extend(data)
transport._buffer.append(data)
self.loop._add_writer(7, transport._write_ready)
transport._write_ready()
self.loop.assert_writer(7, transport._write_ready)
Expand All @@ -820,12 +900,13 @@ def test_write_ready_tryagain(self):
self.sock.send.side_effect = BlockingIOError

transport = self.socket_transport()
transport._buffer = list_to_buffer([b'data1', b'data2'])
buffer = list_to_buffer([b'data1', b'data2'])
transport._buffer = buffer
self.loop._add_writer(7, transport._write_ready)
transport._write_ready()

self.loop.assert_writer(7, transport._write_ready)
self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
self.assertEqual(buffer, transport._buffer)

def test_write_ready_exception(self):
err = self.sock.send.side_effect = OSError()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:mod:`asyncio` is optimized to avoid excessive copying when writing to socket and use :meth:`~socket.socket.sendmsg` if the platform supports it. Patch by Kumar Aditya.