Skip to content

Commit

Permalink
refactor(tcp): use SelectAll for driving listener streams (#3361)
Browse files Browse the repository at this point in the history
The PR optimizes polling of the listeners in the TCP transport by using `futures::SelectAll` instead of storing them in a queue and polling manually.

Resolves #2781.
  • Loading branch information
vnermolaev authored Jan 30, 2023
1 parent 47c1d5a commit c15e651
Showing 1 changed file with 118 additions and 140 deletions.
258 changes: 118 additions & 140 deletions transports/tcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub use provider::tokio;
use futures::{
future::{self, Ready},
prelude::*,
stream::SelectAll,
};
use futures_timer::Delay;
use if_watch::IfEvent;
Expand All @@ -55,7 +56,7 @@ use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener},
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
task::{Context, Poll, Waker},
time::Duration,
};

Expand Down Expand Up @@ -312,7 +313,7 @@ where
/// All the active listeners.
/// The [`ListenStream`] struct contains a stream that we want to be pinned. Since the `VecDeque`
/// can be resized, the only way is to use a `Pin<Box<>>`.
listeners: VecDeque<Pin<Box<ListenStream<T>>>>,
listeners: SelectAll<ListenStream<T>>,
/// Pending transport events to return from [`libp2p_core::Transport::poll`].
pending_events:
VecDeque<TransportEvent<<Self as libp2p_core::Transport>::ListenerUpgrade, io::Error>>,
Expand Down Expand Up @@ -419,7 +420,7 @@ where
Transport {
port_reuse,
config,
listeners: VecDeque::new(),
listeners: SelectAll::new(),
pending_events: VecDeque::new(),
}
}
Expand Down Expand Up @@ -447,18 +448,13 @@ where
let listener = self
.do_listen(id, socket_addr)
.map_err(TransportError::Other)?;
self.listeners.push_back(Box::pin(listener));
self.listeners.push(listener);
Ok(id)
}

fn remove_listener(&mut self, id: ListenerId) -> bool {
if let Some(index) = self.listeners.iter().position(|l| l.listener_id == id) {
self.listeners.remove(index);
self.pending_events
.push_back(TransportEvent::ListenerClosed {
listener_id: id,
reason: Ok(()),
});
if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) {
listener.close(Ok(()));
true
} else {
false
Expand Down Expand Up @@ -548,96 +544,14 @@ where
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(event);
}
// We remove each element from `listeners` one by one and add them back.
let mut remaining = self.listeners.len();
while let Some(mut listener) = self.listeners.pop_back() {
match TryStream::try_poll_next(listener.as_mut(), cx) {
Poll::Pending => {
self.listeners.push_front(listener);
remaining -= 1;
if remaining == 0 {
break;
}
}
Poll::Ready(Some(Ok(TcpListenerEvent::Upgrade {
upgrade,
local_addr,
remote_addr,
}))) => {
let id = listener.listener_id;
self.listeners.push_front(listener);
return Poll::Ready(TransportEvent::Incoming {
listener_id: id,
upgrade,
local_addr,
send_back_addr: remote_addr,
});
}
Poll::Ready(Some(Ok(TcpListenerEvent::NewAddress(a)))) => {
let id = listener.listener_id;
self.listeners.push_front(listener);
return Poll::Ready(TransportEvent::NewAddress {
listener_id: id,
listen_addr: a,
});
}
Poll::Ready(Some(Ok(TcpListenerEvent::AddressExpired(a)))) => {
let id = listener.listener_id;
self.listeners.push_front(listener);
return Poll::Ready(TransportEvent::AddressExpired {
listener_id: id,
listen_addr: a,
});
}
Poll::Ready(Some(Ok(TcpListenerEvent::Error(error)))) => {
let id = listener.listener_id;
self.listeners.push_front(listener);
return Poll::Ready(TransportEvent::ListenerError {
listener_id: id,
error,
});
}
Poll::Ready(None) => {
return Poll::Ready(TransportEvent::ListenerClosed {
listener_id: listener.listener_id,
reason: Ok(()),
});
}
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(TransportEvent::ListenerClosed {
listener_id: listener.listener_id,
reason: Err(err),
});
}
}

match self.listeners.poll_next_unpin(cx) {
Poll::Ready(Some(transport_event)) => Poll::Ready(transport_event),
_ => Poll::Pending,
}
Poll::Pending
}
}

/// Event produced by a [`ListenStream`].
#[derive(Debug)]
enum TcpListenerEvent<S> {
/// The listener is listening on a new additional [`Multiaddr`].
NewAddress(Multiaddr),
/// An upgrade, consisting of the upgrade future, the listener address and the remote address.
Upgrade {
/// The upgrade.
upgrade: Ready<Result<S, io::Error>>,
/// The local address which produced this upgrade.
local_addr: Multiaddr,
/// The remote address which produced this upgrade.
remote_addr: Multiaddr,
},
/// A [`Multiaddr`] is no longer used for listening.
AddressExpired(Multiaddr),
/// A non-fatal error has happened on the listener.
///
/// This event should be generated in order to notify the user that something wrong has
/// happened. The listener, however, continues to run.
Error(io::Error),
}

/// A stream of incoming connections on one or more interfaces.
struct ListenStream<T>
where
Expand Down Expand Up @@ -669,6 +583,12 @@ where
sleep_on_error: Duration,
/// The current pause, if any.
pause: Option<Delay>,
/// Pending event to reported.
pending_event: Option<<Self as Stream>::Item>,
/// The listener can be manually closed with [`Transport::remove_listener`](libp2p_core::Transport::remove_listener).
is_closed: bool,
/// The stream must be awaken after it has been closed to deliver the last event.
close_listener_waker: Option<Waker>,
}

impl<T> ListenStream<T>
Expand All @@ -694,6 +614,9 @@ where
if_watcher,
pause: None,
sleep_on_error: Duration::from_millis(100),
pending_event: None,
is_closed: false,
close_listener_waker: None,
})
}

Expand All @@ -716,6 +639,74 @@ where
.unregister(self.listen_addr.ip(), self.listen_addr.port()),
}
}

/// Close the listener.
///
/// This will create a [`TransportEvent::ListenerClosed`] and
/// terminate the stream once the event has been reported.
fn close(&mut self, reason: Result<(), io::Error>) {
if self.is_closed {
return;
}
self.pending_event = Some(TransportEvent::ListenerClosed {
listener_id: self.listener_id,
reason,
});
self.is_closed = true;

// Wake the stream to deliver the last event.
if let Some(waker) = self.close_listener_waker.take() {
waker.wake();
}
}

/// Poll for a next If Event.
fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
let if_watcher = match self.if_watcher.as_mut() {
Some(if_watcher) => if_watcher,
None => return Poll::Pending,
};

let my_listen_addr_port = self.listen_addr.port();

while let Poll::Ready(Some(event)) = if_watcher.poll_next_unpin(cx) {
match event {
Ok(IfEvent::Up(inet)) => {
let ip = inet.addr();
if self.listen_addr.is_ipv4() == ip.is_ipv4() {
let ma = ip_to_multiaddr(ip, my_listen_addr_port);
log::debug!("New listen address: {}", ma);
self.port_reuse.register(ip, my_listen_addr_port);
return Poll::Ready(TransportEvent::NewAddress {
listener_id: self.listener_id,
listen_addr: ma,
});
}
}
Ok(IfEvent::Down(inet)) => {
let ip = inet.addr();
if self.listen_addr.is_ipv4() == ip.is_ipv4() {
let ma = ip_to_multiaddr(ip, my_listen_addr_port);
log::debug!("Expired listen address: {}", ma);
self.port_reuse.unregister(ip, my_listen_addr_port);
return Poll::Ready(TransportEvent::AddressExpired {
listener_id: self.listener_id,
listen_addr: ma,
});
}
}
Err(error) => {
self.pause = Some(Delay::new(self.sleep_on_error));
return Poll::Ready(TransportEvent::ListenerError {
listener_id: self.listener_id,
error,
});
}
}
}

Poll::Pending
}
}

impl<T> Drop for ListenStream<T>
Expand All @@ -733,52 +724,34 @@ where
T::Listener: Unpin,
T::Stream: Unpin,
{
type Item = Result<TcpListenerEvent<T::Stream>, io::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let me = Pin::into_inner(self);
type Item = TransportEvent<Ready<Result<T::Stream, io::Error>>, io::Error>;

if let Some(mut pause) = me.pause.take() {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
if let Some(mut pause) = self.pause.take() {
match pause.poll_unpin(cx) {
Poll::Ready(_) => {}
Poll::Pending => {
me.pause = Some(pause);
self.pause = Some(pause);
return Poll::Pending;
}
}
}

if let Some(if_watcher) = me.if_watcher.as_mut() {
while let Poll::Ready(Some(event)) = if_watcher.poll_next_unpin(cx) {
match event {
Ok(IfEvent::Up(inet)) => {
let ip = inet.addr();
if me.listen_addr.is_ipv4() == ip.is_ipv4() {
let ma = ip_to_multiaddr(ip, me.listen_addr.port());
log::debug!("New listen address: {}", ma);
me.port_reuse.register(ip, me.listen_addr.port());
return Poll::Ready(Some(Ok(TcpListenerEvent::NewAddress(ma))));
}
}
Ok(IfEvent::Down(inet)) => {
let ip = inet.addr();
if me.listen_addr.is_ipv4() == ip.is_ipv4() {
let ma = ip_to_multiaddr(ip, me.listen_addr.port());
log::debug!("Expired listen address: {}", ma);
me.port_reuse.unregister(ip, me.listen_addr.port());
return Poll::Ready(Some(Ok(TcpListenerEvent::AddressExpired(ma))));
}
}
Err(err) => {
me.pause = Some(Delay::new(me.sleep_on_error));
return Poll::Ready(Some(Ok(TcpListenerEvent::Error(err))));
}
}
}
if let Some(event) = self.pending_event.take() {
return Poll::Ready(Some(event));
}

if self.is_closed {
// Terminate the stream if the listener closed and all remaining events have been reported.
return Poll::Ready(None);
}

if let Poll::Ready(event) = self.poll_if_addr(cx) {
return Poll::Ready(Some(event));
}

// Take the pending connection from the backlog.
match T::poll_accept(&mut me.listener, cx) {
match T::poll_accept(&mut self.listener, cx) {
Poll::Ready(Ok(Incoming {
local_addr,
remote_addr,
Expand All @@ -789,20 +762,25 @@ where

log::debug!("Incoming connection from {} at {}", remote_addr, local_addr);

return Poll::Ready(Some(Ok(TcpListenerEvent::Upgrade {
return Poll::Ready(Some(TransportEvent::Incoming {
listener_id: self.listener_id,
upgrade: future::ok(stream),
local_addr,
remote_addr,
})));
send_back_addr: remote_addr,
}));
}
Poll::Ready(Err(e)) => {
Poll::Ready(Err(error)) => {
// These errors are non-fatal for the listener stream.
me.pause = Some(Delay::new(me.sleep_on_error));
return Poll::Ready(Some(Ok(TcpListenerEvent::Error(e))));
self.pause = Some(Delay::new(self.sleep_on_error));
return Poll::Ready(Some(TransportEvent::ListenerError {
listener_id: self.listener_id,
error,
}));
}
Poll::Pending => {}
};
}

self.close_listener_waker = Some(cx.waker().clone());
Poll::Pending
}
}
Expand Down Expand Up @@ -1119,7 +1097,7 @@ mod tests {
match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await {
TransportEvent::NewAddress { .. } => {
// Check that tcp and listener share the same port reuse SocketAddr
let listener = tcp.listeners.front().unwrap();
let listener = tcp.listeners.iter().next().unwrap();
let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip());
let port_reuse_listener = listener
.port_reuse
Expand Down Expand Up @@ -1188,7 +1166,7 @@ mod tests {
TransportEvent::NewAddress {
listen_addr: addr1, ..
} => {
let listener1 = tcp.listeners.front().unwrap();
let listener1 = tcp.listeners.iter().next().unwrap();
let port_reuse_tcp =
tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip());
let port_reuse_listener1 = listener1
Expand Down

0 comments on commit c15e651

Please sign in to comment.