diff --git a/Cargo.toml b/Cargo.toml index 76a5ab54e20..e458dbbd122 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ default = [ "websocket", "yamux", ] + autonat = ["dep:libp2p-autonat"] dcutr = ["dep:libp2p-dcutr", "libp2p-metrics?/dcutr"] deflate = ["dep:libp2p-deflate"] diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 7bea0f3f35a..a3d91a680e1 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -5,9 +5,14 @@ in favor of forcing `StreamMuxer::Substream` to implement `AsyncRead + AsyncWrite`. See [PR 2707]. - Replace `Into` bound on `StreamMuxer::Error` with `std::error::Error`. See [PR 2710]. +- Remove the concept of individual `Transport::Listener` streams from `Transport`. + Instead the `Transport` is polled directly via `Transport::poll`. The + `Transport` is now responsible for driving its listeners. See [PR 2652]. + [PR 2691]: https://github.com/libp2p/rust-libp2p/pull/2691 [PR 2707]: https://github.com/libp2p/rust-libp2p/pull/2707 [PR 2710]: https://github.com/libp2p/rust-libp2p/pull/2710 +[PR 2652]: https://github.com/libp2p/rust-libp2p/pull/2652 # 0.33.0 diff --git a/core/src/connection.rs b/core/src/connection.rs index 3a8d54d04a1..91008408fe2 100644 --- a/core/src/connection.rs +++ b/core/src/connection.rs @@ -43,25 +43,6 @@ impl std::ops::Add for ConnectionId { } } -/// The ID of a single listener. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct ListenerId(u64); - -impl ListenerId { - /// Creates a `ListenerId` from a non-negative integer. - pub fn new(id: u64) -> Self { - Self(id) - } -} - -impl std::ops::Add for ListenerId { - type Output = Self; - - fn add(self, other: u64) -> Self { - Self(self.0 + other) - } -} - /// The endpoint roles associated with a peer-to-peer communication channel. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Endpoint { diff --git a/core/src/either.rs b/core/src/either.rs index 8ce9046f6f8..bce6e05aadf 100644 --- a/core/src/either.rs +++ b/core/src/either.rs @@ -20,7 +20,7 @@ use crate::{ muxing::{StreamMuxer, StreamMuxerEvent}, - transport::{ListenerEvent, Transport, TransportError}, + transport::{ListenerId, Transport, TransportError, TransportEvent}, Multiaddr, ProtocolName, }; use futures::{ @@ -274,48 +274,6 @@ pub enum EitherOutbound { B(B::OutboundSubstream), } -/// Implements `Stream` and dispatches all method calls to either `First` or `Second`. -#[pin_project(project = EitherListenStreamProj)] -#[derive(Debug, Copy, Clone)] -#[must_use = "futures do nothing unless polled"] -pub enum EitherListenStream { - First(#[pin] A), - Second(#[pin] B), -} - -impl Stream - for EitherListenStream -where - AStream: TryStream, Error = AError>, - BStream: TryStream, Error = BError>, -{ - type Item = Result< - ListenerEvent, EitherError>, - EitherError, - >; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { - EitherListenStreamProj::First(a) => match TryStream::try_poll_next(a, cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le - .map(EitherFuture::First) - .map_err(EitherError::A)))), - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), - }, - EitherListenStreamProj::Second(a) => match TryStream::try_poll_next(a, cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le - .map(EitherFuture::Second) - .map_err(EitherError::B)))), - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::B(err)))), - }, - } - } -} - /// Implements `Future` and dispatches all method calls to either `First` or `Second`. #[pin_project(project = EitherFutureProj)] #[derive(Debug, Copy, Clone)] @@ -385,11 +343,12 @@ impl ProtocolName for EitherName { } } } - -#[derive(Debug, Copy, Clone)] +#[pin_project(project = EitherTransportProj)] +#[derive(Debug)] +#[must_use = "transports do nothing unless polled"] pub enum EitherTransport { - Left(A), - Right(B), + Left(#[pin] A), + Right(#[pin] B), } impl Transport for EitherTransport @@ -399,29 +358,54 @@ where { type Output = EitherOutput; type Error = EitherError; - type Listener = EitherListenStream; type ListenerUpgrade = EitherFuture; type Dial = EitherFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - use TransportError::*; - match self { - EitherTransport::Left(a) => match a.listen_on(addr) { - Ok(listener) => Ok(EitherListenStream::First(listener)), - Err(MultiaddrNotSupported(addr)) => Err(MultiaddrNotSupported(addr)), - Err(Other(err)) => Err(Other(EitherError::A(err))), + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + EitherTransportProj::Left(a) => match a.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(event) => Poll::Ready( + event + .map_upgrade(EitherFuture::First) + .map_err(EitherError::A), + ), }, - EitherTransport::Right(b) => match b.listen_on(addr) { - Ok(listener) => Ok(EitherListenStream::Second(listener)), - Err(MultiaddrNotSupported(addr)) => Err(MultiaddrNotSupported(addr)), - Err(Other(err)) => Err(Other(EitherError::B(err))), + EitherTransportProj::Right(b) => match b.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(event) => Poll::Ready( + event + .map_upgrade(EitherFuture::Second) + .map_err(EitherError::B), + ), }, } } + fn remove_listener(&mut self, id: ListenerId) -> bool { + match self { + EitherTransport::Left(t) => t.remove_listener(id), + EitherTransport::Right(t) => t.remove_listener(id), + } + } + + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + use TransportError::*; + match self { + EitherTransport::Left(a) => a.listen_on(addr).map_err(|e| match e { + MultiaddrNotSupported(addr) => MultiaddrNotSupported(addr), + Other(err) => Other(EitherError::A(err)), + }), + EitherTransport::Right(b) => b.listen_on(addr).map_err(|e| match e { + MultiaddrNotSupported(addr) => MultiaddrNotSupported(addr), + Other(err) => Other(EitherError::B(err)), + }), + } + } + fn dial(&mut self, addr: Multiaddr) -> Result> { use TransportError::*; match self { diff --git a/core/src/transport.rs b/core/src/transport.rs index a625e3b552c..df6e094ed80 100644 --- a/core/src/transport.rs +++ b/core/src/transport.rs @@ -25,10 +25,14 @@ //! any desired protocols. The rest of the module defines combinators for //! modifying a transport through composition with other transports or protocol upgrades. -use crate::connection::ConnectedPoint; use futures::prelude::*; use multiaddr::Multiaddr; -use std::{error::Error, fmt}; +use std::{ + error::Error, + fmt, + pin::Pin, + task::{Context, Poll}, +}; pub mod and_then; pub mod choice; @@ -42,6 +46,8 @@ pub mod upgrade; mod boxed; mod optional; +use crate::ConnectedPoint; + pub use self::boxed::Boxed; pub use self::choice::OrTransport; pub use self::memory::MemoryTransport; @@ -87,21 +93,8 @@ pub trait Transport { /// An error that occurred during connection setup. type Error: Error; - /// A stream of [`Output`](Transport::Output)s for inbound connections. - /// - /// An item should be produced whenever a connection is received at the lowest level of the - /// transport stack. The item must be a [`ListenerUpgrade`](Transport::ListenerUpgrade) future - /// that resolves to an [`Output`](Transport::Output) value once all protocol upgrades - /// have been applied. - /// - /// If this stream produces an error, it is considered fatal and the listener is killed. It - /// is possible to report non-fatal errors by producing a [`ListenerEvent::Error`]. - type Listener: Stream< - Item = Result, Self::Error>, - >; - /// A pending [`Output`](Transport::Output) for an inbound connection, - /// obtained from the [`Listener`](Transport::Listener) stream. + /// obtained from the [`Transport`] stream. /// /// After a connection has been accepted by the transport, it may need to go through /// asynchronous post-processing (i.e. protocol upgrade negotiations). Such @@ -115,22 +108,20 @@ pub trait Transport { /// obtained from [dialing](Transport::dial). type Dial: Future>; - /// Listens on the given [`Multiaddr`], producing a stream of pending, inbound connections - /// and addresses this transport is listening on (cf. [`ListenerEvent`]). + /// Listens on the given [`Multiaddr`] for inbound connections. + fn listen_on(&mut self, addr: Multiaddr) -> Result>; + + /// Remove a listener. /// - /// Returning an error from the stream is considered fatal. The listener can also report - /// non-fatal errors by producing a [`ListenerEvent::Error`]. - fn listen_on(&mut self, addr: Multiaddr) -> Result> - where - Self: Sized; + /// Return `true` if there was a listener with this Id, `false` + /// otherwise. + fn remove_listener(&mut self, id: ListenerId) -> bool; /// Dials the given [`Multiaddr`], returning a future for a pending outbound connection. /// /// If [`TransportError::MultiaddrNotSupported`] is returned, it may be desirable to /// try an alternative [`Transport`], if available. - fn dial(&mut self, addr: Multiaddr) -> Result> - where - Self: Sized; + fn dial(&mut self, addr: Multiaddr) -> Result>; /// As [`Transport::dial`] but has the local node act as a listener on the outgoing connection. /// @@ -140,9 +131,23 @@ pub trait Transport { fn dial_as_listener( &mut self, addr: Multiaddr, - ) -> Result> - where - Self: Sized; + ) -> Result>; + + /// Poll for [`TransportEvent`]s. + /// + /// A [`TransportEvent::Incoming`] should be produced whenever a connection is received at the lowest + /// level of the transport stack. The item must be a [`ListenerUpgrade`](Transport::ListenerUpgrade) + /// future that resolves to an [`Output`](Transport::Output) value once all protocol upgrades have + /// been applied. + /// + /// Transports are expected to produce [`TransportEvent::Incoming`] events only for + /// listen addresses which have previously been announced via + /// a [`TransportEvent::NewAddress`] event and which have not been invalidated by + /// an [`TransportEvent::AddressExpired`] event yet. + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>; /// Performs a transport-specific mapping of an address `observed` by /// a remote onto a local `listen` address to yield an address for @@ -152,9 +157,8 @@ pub trait Transport { /// Boxes the transport, including custom transport errors. fn boxed(self) -> boxed::Boxed where - Self: Transport + Sized + Send + 'static, + Self: Sized + Send + Unpin + 'static, Self::Dial: Send + 'static, - Self::Listener: Send + 'static, Self::ListenerUpgrade: Send + 'static, Self::Error: Send + Sync, { @@ -221,149 +225,277 @@ pub trait Transport { } } -/// Event produced by [`Transport::Listener`]s. -/// -/// Transports are expected to produce `Upgrade` events only for -/// listen addresses which have previously been announced via -/// a `NewAddress` event and which have not been invalidated by -/// an `AddressExpired` event yet. -#[derive(Clone, Debug, PartialEq)] -pub enum ListenerEvent { - /// The transport is listening on a new additional [`Multiaddr`]. - NewAddress(Multiaddr), - /// An upgrade, consisting of the upgrade future, the listener address and the remote address. - Upgrade { - /// The upgrade. +/// The ID of a single listener. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct ListenerId(u64); + +impl ListenerId { + /// Creates a new `ListenerId`. + pub fn new() -> Self { + ListenerId(rand::random()) + } +} + +impl Default for ListenerId { + fn default() -> Self { + Self::new() + } +} + +/// Event produced by [`Transport`]s. +pub enum TransportEvent { + /// A new address is being listened on. + NewAddress { + /// The listener that is listening on the new address. + listener_id: ListenerId, + /// The new address that is being listened on. + listen_addr: Multiaddr, + }, + /// An address is no longer being listened on. + AddressExpired { + /// The listener that is no longer listening on the address. + listener_id: ListenerId, + /// The new address that is being listened on. + listen_addr: Multiaddr, + }, + /// A connection is incoming on one of the listeners. + Incoming { + /// The listener that produced the upgrade. + listener_id: ListenerId, + /// The produced upgrade. upgrade: TUpgr, - /// The local address which produced this upgrade. + /// Local connection address. local_addr: Multiaddr, - /// The remote address which produced this upgrade. - remote_addr: Multiaddr, + /// Address used to send back data to the incoming client. + send_back_addr: Multiaddr, + }, + /// A listener closed. + ListenerClosed { + /// The ID of the listener that closed. + listener_id: ListenerId, + /// Reason for the closure. Contains `Ok(())` if the stream produced `None`, or `Err` + /// if the stream produced an error. + reason: Result<(), TErr>, }, - /// A [`Multiaddr`] is no longer used for listening. - AddressExpired(Multiaddr), - /// A non-fatal error has happened on the listener. + /// A listener errored. /// - /// This event should be generated in order to notify the user that something wrong has - /// happened. The listener, however, continues to run. - Error(TErr), + /// The listener will continue to be polled for new events and the event + /// is for informational purposes only. + ListenerError { + /// The ID of the listener that errored. + listener_id: ListenerId, + /// The error value. + error: TErr, + }, } -impl ListenerEvent { - /// In case this [`ListenerEvent`] is an upgrade, apply the given function - /// to the upgrade and multiaddress and produce another listener event - /// based the the function's result. - pub fn map(self, f: impl FnOnce(TUpgr) -> U) -> ListenerEvent { +impl TransportEvent { + /// In case this [`TransportEvent`] is an upgrade, apply the given function + /// to the upgrade and produce another transport event based the the function's result. + pub fn map_upgrade(self, map: impl FnOnce(TUpgr) -> U) -> TransportEvent { match self { - ListenerEvent::Upgrade { + TransportEvent::Incoming { + listener_id, upgrade, local_addr, - remote_addr, - } => ListenerEvent::Upgrade { - upgrade: f(upgrade), + send_back_addr, + } => TransportEvent::Incoming { + listener_id, + upgrade: map(upgrade), local_addr, - remote_addr, + send_back_addr, + }, + TransportEvent::NewAddress { + listen_addr, + listener_id, + } => TransportEvent::NewAddress { + listen_addr, + listener_id, + }, + TransportEvent::AddressExpired { + listen_addr, + listener_id, + } => TransportEvent::AddressExpired { + listen_addr, + listener_id, + }, + TransportEvent::ListenerError { listener_id, error } => { + TransportEvent::ListenerError { listener_id, error } + } + TransportEvent::ListenerClosed { + listener_id, + reason, + } => TransportEvent::ListenerClosed { + listener_id, + reason, }, - ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), - ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), - ListenerEvent::Error(e) => ListenerEvent::Error(e), } } - /// In case this [`ListenerEvent`] is an [`Error`](ListenerEvent::Error), - /// apply the given function to the error and produce another listener event based on the - /// function's result. - pub fn map_err(self, f: impl FnOnce(TErr) -> U) -> ListenerEvent { + /// In case this [`TransportEvent`] is an [`ListenerError`](TransportEvent::ListenerError), + /// or [`ListenerClosed`](TransportEvent::ListenerClosed) apply the given function to the + /// error and produce another transport event based on the function's result. + pub fn map_err(self, map_err: impl FnOnce(TErr) -> E) -> TransportEvent { match self { - ListenerEvent::Upgrade { + TransportEvent::Incoming { + listener_id, upgrade, local_addr, - remote_addr, - } => ListenerEvent::Upgrade { + send_back_addr, + } => TransportEvent::Incoming { + listener_id, upgrade, local_addr, - remote_addr, + send_back_addr, + }, + TransportEvent::NewAddress { + listen_addr, + listener_id, + } => TransportEvent::NewAddress { + listen_addr, + listener_id, + }, + TransportEvent::AddressExpired { + listen_addr, + listener_id, + } => TransportEvent::AddressExpired { + listen_addr, + listener_id, + }, + TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError { + listener_id, + error: map_err(error), + }, + TransportEvent::ListenerClosed { + listener_id, + reason, + } => TransportEvent::ListenerClosed { + listener_id, + reason: reason.map_err(map_err), }, - ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), - ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), - ListenerEvent::Error(e) => ListenerEvent::Error(f(e)), } } - /// Returns `true` if this is an `Upgrade` listener event. + /// Returns `true` if this is an [`Incoming`](TransportEvent::Incoming) transport event. pub fn is_upgrade(&self) -> bool { - matches!(self, ListenerEvent::Upgrade { .. }) + matches!(self, TransportEvent::Incoming { .. }) } - /// Try to turn this listener event into upgrade parts. + /// Try to turn this transport event into the upgrade parts of the + /// incoming connection. /// - /// Returns `None` if the event is not actually an upgrade, + /// Returns `None` if the event is not actually an incoming connection, /// otherwise the upgrade and the remote address. - pub fn into_upgrade(self) -> Option<(TUpgr, Multiaddr)> { - if let ListenerEvent::Upgrade { + pub fn into_incoming(self) -> Option<(TUpgr, Multiaddr)> { + if let TransportEvent::Incoming { upgrade, - remote_addr, + send_back_addr, .. } = self { - Some((upgrade, remote_addr)) + Some((upgrade, send_back_addr)) } else { None } } - /// Returns `true` if this is a `NewAddress` listener event. + /// Returns `true` if this is a [`TransportEvent::NewAddress`]. pub fn is_new_address(&self) -> bool { - matches!(self, ListenerEvent::NewAddress(_)) + matches!(self, TransportEvent::NewAddress { .. }) } - /// Try to turn this listener event into the `NewAddress` part. + /// Try to turn this transport event into the new `Multiaddr`. /// - /// Returns `None` if the event is not actually a `NewAddress`, + /// Returns `None` if the event is not actually a [`TransportEvent::NewAddress`], /// otherwise the address. pub fn into_new_address(self) -> Option { - if let ListenerEvent::NewAddress(a) = self { - Some(a) + if let TransportEvent::NewAddress { listen_addr, .. } = self { + Some(listen_addr) } else { None } } - /// Returns `true` if this is an `AddressExpired` listener event. + /// Returns `true` if this is an [`TransportEvent::AddressExpired`]. pub fn is_address_expired(&self) -> bool { - matches!(self, ListenerEvent::AddressExpired(_)) + matches!(self, TransportEvent::AddressExpired { .. }) } - /// Try to turn this listener event into the `AddressExpired` part. + /// Try to turn this transport event into the expire `Multiaddr`. /// - /// Returns `None` if the event is not actually a `AddressExpired`, + /// Returns `None` if the event is not actually a [`TransportEvent::AddressExpired`], /// otherwise the address. pub fn into_address_expired(self) -> Option { - if let ListenerEvent::AddressExpired(a) = self { - Some(a) + if let TransportEvent::AddressExpired { listen_addr, .. } = self { + Some(listen_addr) } else { None } } - /// Returns `true` if this is an `Error` listener event. - pub fn is_error(&self) -> bool { - matches!(self, ListenerEvent::Error(_)) + /// Returns `true` if this is an [`TransportEvent::ListenerError`] transport event. + pub fn is_listener_error(&self) -> bool { + matches!(self, TransportEvent::ListenerError { .. }) } - /// Try to turn this listener event into the `Error` part. + /// Try to turn this transport event into the listener error. /// - /// Returns `None` if the event is not actually a `Error`, + /// Returns `None` if the event is not actually a [`TransportEvent::ListenerError`]`, /// otherwise the error. - pub fn into_error(self) -> Option { - if let ListenerEvent::Error(err) = self { - Some(err) + pub fn into_listener_error(self) -> Option { + if let TransportEvent::ListenerError { error, .. } = self { + Some(error) } else { None } } } +impl fmt::Debug for TransportEvent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + TransportEvent::NewAddress { + listener_id, + listen_addr, + } => f + .debug_struct("TransportEvent::NewAddress") + .field("listener_id", listener_id) + .field("listen_addr", listen_addr) + .finish(), + TransportEvent::AddressExpired { + listener_id, + listen_addr, + } => f + .debug_struct("TransportEvent::AddressExpired") + .field("listener_id", listener_id) + .field("listen_addr", listen_addr) + .finish(), + TransportEvent::Incoming { + listener_id, + local_addr, + .. + } => f + .debug_struct("TransportEvent::Incoming") + .field("listener_id", listener_id) + .field("local_addr", local_addr) + .finish(), + TransportEvent::ListenerClosed { + listener_id, + reason, + } => f + .debug_struct("TransportEvent::Closed") + .field("listener_id", listener_id) + .field("reason", reason) + .finish(), + TransportEvent::ListenerError { listener_id, error } => f + .debug_struct("TransportEvent::ListenerError") + .field("listener_id", listener_id) + .field("error", error) + .finish(), + } + } +} + /// An error during [dialing][Transport::dial] or [listening][Transport::listen_on] /// on a [`Transport`]. #[derive(Debug, Clone)] diff --git a/core/src/transport/and_then.rs b/core/src/transport/and_then.rs index f73a0caf8e6..561a2f281ff 100644 --- a/core/src/transport/and_then.rs +++ b/core/src/transport/and_then.rs @@ -21,15 +21,17 @@ use crate::{ connection::{ConnectedPoint, Endpoint}, either::EitherError, - transport::{ListenerEvent, Transport, TransportError}, + transport::{ListenerId, Transport, TransportError, TransportEvent}, }; use futures::{future::Either, prelude::*}; use multiaddr::Multiaddr; use std::{error, marker::PhantomPinned, pin::Pin, task::Context, task::Poll}; -/// See the `Transport::and_then` method. +/// See the [`Transport::and_then`] method. +#[pin_project::pin_project] #[derive(Debug, Clone)] pub struct AndThen { + #[pin] transport: T, fun: C, } @@ -49,27 +51,17 @@ where { type Output = O; type Error = EitherError; - type Listener = AndThenStream; type ListenerUpgrade = AndThenFuture; type Dial = AndThenFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let listener = self - .transport + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.transport .listen_on(addr) - .map_err(|err| err.map(EitherError::A))?; - // Try to negotiate the protocol. - // Note that failing to negotiate a protocol will never produce a future with an error. - // Instead the `stream` will produce `Ok(Err(...))`. - // `stream` can only produce an `Err` if `listening_stream` produces an `Err`. - let stream = AndThenStream { - stream: listener, - fun: self.fun.clone(), - }; - Ok(stream) + .map_err(|err| err.map(EitherError::A)) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -116,68 +108,40 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.address_translation(server, observed) } -} -/// Custom `Stream` to avoid boxing. -/// -/// Applies a function to every stream item. -#[pin_project::pin_project] -#[derive(Debug, Clone)] -pub struct AndThenStream { - #[pin] - stream: TListener, - fun: TMap, -} - -impl Stream - for AndThenStream -where - TListener: TryStream, Error = TTransErr>, - TListUpgr: TryFuture, - TMap: FnOnce(TTransOut, ConnectedPoint) -> TMapOut + Clone, - TMapOut: TryFuture, -{ - type Item = Result< - ListenerEvent< - AndThenFuture, - EitherError, - >, - EitherError, - >; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let this = self.project(); - match TryStream::try_poll_next(this.stream, cx) { - Poll::Ready(Some(Ok(event))) => { - let event = match event { - ListenerEvent::Upgrade { - upgrade, - local_addr, - remote_addr, - } => { - let point = ConnectedPoint::Listener { - local_addr: local_addr.clone(), - send_back_addr: remote_addr.clone(), - }; - ListenerEvent::Upgrade { - upgrade: AndThenFuture { - inner: Either::Left(Box::pin(upgrade)), - args: Some((this.fun.clone(), point)), - _marker: PhantomPinned, - }, - local_addr, - remote_addr, - } - } - ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), - ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), - ListenerEvent::Error(e) => ListenerEvent::Error(EitherError::A(e)), + match this.transport.poll(cx) { + Poll::Ready(TransportEvent::Incoming { + listener_id, + upgrade, + local_addr, + send_back_addr, + }) => { + let point = ConnectedPoint::Listener { + local_addr: local_addr.clone(), + send_back_addr: send_back_addr.clone(), }; - - Poll::Ready(Some(Ok(event))) + Poll::Ready(TransportEvent::Incoming { + listener_id, + upgrade: AndThenFuture { + inner: Either::Left(Box::pin(upgrade)), + args: Some((this.fun.clone(), point)), + _marker: PhantomPinned, + }, + local_addr, + send_back_addr, + }) + } + Poll::Ready(other) => { + let mapped = other + .map_upgrade(|_upgrade| unreachable!("case already matched")) + .map_err(EitherError::A); + Poll::Ready(mapped) } - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), - Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } diff --git a/core/src/transport/boxed.rs b/core/src/transport/boxed.rs index 8a804aa40be..b2560c4a662 100644 --- a/core/src/transport/boxed.rs +++ b/core/src/transport/boxed.rs @@ -18,18 +18,22 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{ListenerEvent, Transport, TransportError}; -use futures::prelude::*; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; +use futures::{prelude::*, stream::FusedStream}; use multiaddr::Multiaddr; -use std::{error::Error, fmt, io, pin::Pin}; +use std::{ + error::Error, + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; /// Creates a new [`Boxed`] transport from the given transport. pub fn boxed(transport: T) -> Boxed where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + Sync, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, { Boxed { @@ -41,19 +45,22 @@ where /// and `ListenerUpgrade` futures are `Box`ed and only the `Output` /// and `Error` types are captured in type variables. pub struct Boxed { - inner: Box + Send>, + inner: Box + Send + Unpin>, } type Dial = Pin> + Send>>; -type Listener = - Pin, io::Error>>> + Send>>; type ListenerUpgrade = Pin> + Send>>; trait Abstract { - fn listen_on(&mut self, addr: Multiaddr) -> Result, TransportError>; + fn listen_on(&mut self, addr: Multiaddr) -> Result>; + fn remove_listener(&mut self, id: ListenerId) -> bool; fn dial(&mut self, addr: Multiaddr) -> Result, TransportError>; fn dial_as_listener(&mut self, addr: Multiaddr) -> Result, TransportError>; fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option; + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, io::Error>>; } impl Abstract for T @@ -61,22 +68,14 @@ where T: Transport + 'static, T::Error: Send + Sync, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, { - fn listen_on(&mut self, addr: Multiaddr) -> Result, TransportError> { - let listener = Transport::listen_on(self, addr).map_err(|e| e.map(box_err))?; - let fut = listener - .map_ok(|event| { - event - .map(|upgrade| { - let up = upgrade.map_err(box_err); - Box::pin(up) as ListenerUpgrade - }) - .map_err(box_err) - }) - .map_err(box_err); - Ok(Box::pin(fut)) + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + Transport::listen_on(self, addr).map_err(|e| e.map(box_err)) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + Transport::remove_listener(self, id) } fn dial(&mut self, addr: Multiaddr) -> Result, TransportError> { @@ -96,6 +95,20 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { Transport::address_translation(self, server, observed) } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, io::Error>> { + self.poll(cx).map(|event| { + event + .map_upgrade(|upgrade| { + let up = upgrade.map_err(box_err); + Box::pin(up) as ListenerUpgrade + }) + .map_err(box_err) + }) + } } impl fmt::Debug for Boxed { @@ -107,17 +120,17 @@ impl fmt::Debug for Boxed { impl Transport for Boxed { type Output = O; type Error = io::Error; - type Listener = Listener; type ListenerUpgrade = ListenerUpgrade; type Dial = Dial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { self.inner.listen_on(addr) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.remove_listener(id) + } + fn dial(&mut self, addr: Multiaddr) -> Result> { self.inner.dial(addr) } @@ -132,6 +145,27 @@ impl Transport for Boxed { fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.inner.address_translation(server, observed) } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.inner.as_mut()).poll(cx) + } +} + +impl Stream for Boxed { + type Item = TransportEvent, io::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Transport::poll(self, cx).map(Some) + } +} + +impl FusedStream for Boxed { + fn is_terminated(&self) -> bool { + false + } } fn box_err(e: E) -> io::Error { diff --git a/core/src/transport/choice.rs b/core/src/transport/choice.rs index f1d21cfa30c..17528c1d4a8 100644 --- a/core/src/transport/choice.rs +++ b/core/src/transport/choice.rs @@ -18,13 +18,15 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::either::{EitherError, EitherFuture, EitherListenStream, EitherOutput}; -use crate::transport::{Transport, TransportError}; +use crate::either::{EitherError, EitherFuture, EitherOutput}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use multiaddr::Multiaddr; +use std::{pin::Pin, task::Context, task::Poll}; /// Struct returned by `or_transport()`. #[derive(Debug, Copy, Clone)] -pub struct OrTransport(A, B); +#[pin_project::pin_project] +pub struct OrTransport(#[pin] A, #[pin] B); impl OrTransport { pub fn new(a: A, b: B) -> OrTransport { @@ -39,33 +41,27 @@ where { type Output = EitherOutput; type Error = EitherError; - type Listener = EitherListenStream; type ListenerUpgrade = EitherFuture; type Dial = EitherFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let addr = match self.0.listen_on(addr) { - Ok(listener) => return Ok(EitherListenStream::First(listener)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => { - return Err(TransportError::Other(EitherError::A(err))) - } + res => return res.map_err(|err| err.map(EitherError::A)), }; let addr = match self.1.listen_on(addr) { - Ok(listener) => return Ok(EitherListenStream::Second(listener)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => { - return Err(TransportError::Other(EitherError::B(err))) - } + res => return res.map_err(|err| err.map(EitherError::B)), }; Err(TransportError::MultiaddrNotSupported(addr)) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.0.remove_listener(id) || self.1.remove_listener(id) + } + fn dial(&mut self, addr: Multiaddr) -> Result> { let addr = match self.0.dial(addr) { Ok(connec) => return Ok(EitherFuture::First(connec)), @@ -116,4 +112,24 @@ where self.1.address_translation(server, observed) } } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.project(); + match this.0.poll(cx) { + Poll::Ready(ev) => { + return Poll::Ready(ev.map_upgrade(EitherFuture::First).map_err(EitherError::A)) + } + Poll::Pending => {} + } + match this.1.poll(cx) { + Poll::Ready(ev) => { + return Poll::Ready(ev.map_upgrade(EitherFuture::Second).map_err(EitherError::B)) + } + Poll::Pending => {} + } + Poll::Pending + } } diff --git a/core/src/transport/dummy.rs b/core/src/transport/dummy.rs index 5862348b0d4..a7d1cab9089 100644 --- a/core/src/transport/dummy.rs +++ b/core/src/transport/dummy.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{ListenerEvent, Transport, TransportError}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use crate::Multiaddr; use futures::{prelude::*, task::Context, task::Poll}; use std::{fmt, io, marker::PhantomData, pin::Pin}; @@ -56,19 +56,17 @@ impl Clone for DummyTransport { impl Transport for DummyTransport { type Output = TOut; type Error = io::Error; - type Listener = futures::stream::Pending< - Result, Self::Error>, - >; type ListenerUpgrade = futures::future::Pending>; type Dial = futures::future::Pending>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { Err(TransportError::MultiaddrNotSupported(addr)) } + fn remove_listener(&mut self, _id: ListenerId) -> bool { + false + } + fn dial(&mut self, addr: Multiaddr) -> Result> { Err(TransportError::MultiaddrNotSupported(addr)) } @@ -83,6 +81,13 @@ impl Transport for DummyTransport { fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { None } + + fn poll( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + Poll::Pending + } } /// Implementation of `AsyncRead` and `AsyncWrite`. Not meant to be instanciated. diff --git a/core/src/transport/map.rs b/core/src/transport/map.rs index 703e1ea430b..50f7b826d36 100644 --- a/core/src/transport/map.rs +++ b/core/src/transport/map.rs @@ -20,15 +20,19 @@ use crate::{ connection::{ConnectedPoint, Endpoint}, - transport::{ListenerEvent, Transport, TransportError}, + transport::{Transport, TransportError, TransportEvent}, }; use futures::prelude::*; use multiaddr::Multiaddr; use std::{pin::Pin, task::Context, task::Poll}; +use super::ListenerId; + /// See `Transport::map`. #[derive(Debug, Copy, Clone)] +#[pin_project::pin_project] pub struct Map { + #[pin] transport: T, fun: F, } @@ -54,19 +58,15 @@ where { type Output = D; type Error = T::Error; - type Listener = MapStream; type ListenerUpgrade = MapFuture; type Dial = MapFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let stream = self.transport.listen_on(addr)?; - Ok(MapStream { - stream, - fun: self.fun.clone(), - }) + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.transport.listen_on(addr) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -99,58 +99,37 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.address_translation(server, observed) } -} - -/// Custom `Stream` implementation to avoid boxing. -/// -/// Maps a function over every stream item. -#[pin_project::pin_project] -#[derive(Clone, Debug)] -pub struct MapStream { - #[pin] - stream: T, - fun: F, -} -impl Stream for MapStream -where - T: TryStream, Error = E>, - X: TryFuture, - F: FnOnce(A, ConnectedPoint) -> B + Clone, -{ - type Item = Result, E>, E>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let this = self.project(); - match TryStream::try_poll_next(this.stream, cx) { - Poll::Ready(Some(Ok(event))) => { - let event = match event { - ListenerEvent::Upgrade { - upgrade, - local_addr, - remote_addr, - } => { - let point = ConnectedPoint::Listener { - local_addr: local_addr.clone(), - send_back_addr: remote_addr.clone(), - }; - ListenerEvent::Upgrade { - upgrade: MapFuture { - inner: upgrade, - args: Some((this.fun.clone(), point)), - }, - local_addr, - remote_addr, - } - } - ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), - ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), - ListenerEvent::Error(e) => ListenerEvent::Error(e), + match this.transport.poll(cx) { + Poll::Ready(TransportEvent::Incoming { + listener_id, + upgrade, + local_addr, + send_back_addr, + }) => { + let point = ConnectedPoint::Listener { + local_addr: local_addr.clone(), + send_back_addr: send_back_addr.clone(), }; - Poll::Ready(Some(Ok(event))) + Poll::Ready(TransportEvent::Incoming { + listener_id, + upgrade: MapFuture { + inner: upgrade, + args: Some((this.fun.clone(), point)), + }, + local_addr, + send_back_addr, + }) + } + Poll::Ready(other) => { + let mapped = other.map_upgrade(|_upgrade| unreachable!("case already matched")); + Poll::Ready(mapped) } - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), - Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } diff --git a/core/src/transport/map_err.rs b/core/src/transport/map_err.rs index 6cc2c5c3662..99f2912447f 100644 --- a/core/src/transport/map_err.rs +++ b/core/src/transport/map_err.rs @@ -18,14 +18,16 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{ListenerEvent, Transport, TransportError}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use futures::prelude::*; use multiaddr::Multiaddr; use std::{error, pin::Pin, task::Context, task::Poll}; /// See `Transport::map_err`. #[derive(Debug, Copy, Clone)] +#[pin_project::pin_project] pub struct MapErr { + #[pin] transport: T, map: F, } @@ -45,19 +47,16 @@ where { type Output = T::Output; type Error = TErr; - type Listener = MapErrListener; type ListenerUpgrade = MapErrListenerUpgrade; type Dial = MapErrDial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let map = self.map.clone(); - match self.transport.listen_on(addr) { - Ok(stream) => Ok(MapErrListener { inner: stream, map }), - Err(err) => Err(err.map(map)), - } + self.transport.listen_on(addr).map_err(|err| err.map(map)) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -88,41 +87,20 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.address_translation(server, observed) } -} -/// Listening stream for `MapErr`. -#[pin_project::pin_project] -pub struct MapErrListener { - #[pin] - inner: T::Listener, - map: F, -} - -impl Stream for MapErrListener -where - T: Transport, - F: FnOnce(T::Error) -> TErr + Clone, - TErr: error::Error, -{ - type Item = Result, TErr>, TErr>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let this = self.project(); - match TryStream::try_poll_next(this.inner, cx) { - Poll::Ready(Some(Ok(event))) => { - let map = &*this.map; - let event = event - .map(move |value| MapErrListenerUpgrade { - inner: value, - map: Some(map.clone()), - }) - .map_err(|err| (map.clone())(err)); - Poll::Ready(Some(Ok(event))) - } - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err((this.map.clone())(err)))), - } + let map = &*this.map; + this.transport.poll(cx).map(|ev| { + ev.map_upgrade(move |value| MapErrListenerUpgrade { + inner: value, + map: Some(map.clone()), + }) + .map_err(map.clone()) + }) } } diff --git a/core/src/transport/memory.rs b/core/src/transport/memory.rs index 40bc5d3da15..8121bfc895a 100644 --- a/core/src/transport/memory.rs +++ b/core/src/transport/memory.rs @@ -18,10 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{ - transport::{ListenerEvent, TransportError}, - Transport, -}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use fnv::FnvHashMap; use futures::{ channel::mpsc, @@ -34,7 +31,12 @@ use lazy_static::lazy_static; use multiaddr::{Multiaddr, Protocol}; use parking_lot::Mutex; use rw_stream_sink::RwStreamSink; -use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin}; +use std::{ + collections::{hash_map::Entry, VecDeque}, + error, fmt, io, + num::NonZeroU64, + pin::Pin, +}; lazy_static! { static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default())); @@ -91,8 +93,16 @@ impl Hub { } /// Transport that supports `/memory/N` multiaddresses. -#[derive(Debug, Copy, Clone, Default)] -pub struct MemoryTransport; +#[derive(Default)] +pub struct MemoryTransport { + listeners: VecDeque>>, +} + +impl MemoryTransport { + pub fn new() -> Self { + Self::default() + } +} /// Connection to a `MemoryTransport` currently being opened. pub struct DialFuture { @@ -168,14 +178,10 @@ impl Future for DialFuture { impl Transport for MemoryTransport { type Output = Channel>; type Error = MemoryTransportError; - type Listener = Listener; type ListenerUpgrade = Ready>; type Dial = DialFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let port = if let Ok(port) = parse_memory_addr(&addr) { port } else { @@ -187,14 +193,29 @@ impl Transport for MemoryTransport { None => return Err(TransportError::Other(MemoryTransportError::Unreachable)), }; + let id = ListenerId::new(); let listener = Listener { + id, port, addr: Protocol::Memory(port.get()).into(), receiver: rx, tell_listen_addr: true, }; + self.listeners.push_back(Box::pin(listener)); - Ok(listener) + Ok(id) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) { + let listener = self.listeners.get_mut(index).unwrap(); + let val_in = HUB.unregister_port(&listener.port); + debug_assert!(val_in.is_some()); + listener.receiver.close(); + true + } else { + false + } } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -221,6 +242,56 @@ impl Transport for MemoryTransport { fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { None } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> + where + Self: Sized, + { + let mut remaining = self.listeners.len(); + while let Some(mut listener) = self.listeners.pop_back() { + if listener.tell_listen_addr { + listener.tell_listen_addr = false; + let listen_addr = listener.addr.clone(); + let listener_id = listener.id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::NewAddress { + listen_addr, + listener_id, + }); + } + + let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) { + Poll::Pending => None, + Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming { + listener_id: listener.id, + upgrade: future::ready(Ok(channel)), + local_addr: listener.addr.clone(), + send_back_addr: Protocol::Memory(dial_port.get()).into(), + }), + Poll::Ready(None) => { + // Listener was closed. + return Poll::Ready(TransportEvent::ListenerClosed { + listener_id: listener.id, + reason: Ok(()), + }); + } + }; + + self.listeners.push_front(listener); + if let Some(event) = event { + return Poll::Ready(event); + } else { + remaining -= 1; + if remaining == 0 { + break; + } + } + } + Poll::Pending + } } /// Error that can be produced from the `MemoryTransport`. @@ -245,51 +316,17 @@ impl error::Error for MemoryTransportError {} /// Listener for memory connections. pub struct Listener { + id: ListenerId, /// Port we're listening on. port: NonZeroU64, /// The address we are listening on. addr: Multiaddr, /// Receives incoming connections. receiver: ChannelReceiver, - /// Generate `ListenerEvent::NewAddress` to inform about our listen address. + /// Generate [`TransportEvent::NewAddress`] to inform about our listen address. tell_listen_addr: bool, } -impl Stream for Listener { - type Item = Result< - ListenerEvent>, MemoryTransportError>>, MemoryTransportError>, - MemoryTransportError, - >; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.tell_listen_addr { - self.tell_listen_addr = false; - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))); - } - - let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => panic!("Alive listeners always have a sender."), - Poll::Ready(Some(v)) => v, - }; - - let event = ListenerEvent::Upgrade { - upgrade: future::ready(Ok(channel)), - local_addr: self.addr.clone(), - remote_addr: Protocol::Memory(dial_port.get()).into(), - }; - - Poll::Ready(Some(Ok(event))) - } -} - -impl Drop for Listener { - fn drop(&mut self) { - let val_in = HUB.unregister_port(&self.port); - debug_assert!(val_in.is_some()); - } -} - /// If the address is `/memory/n`, returns the value of `n`. fn parse_memory_addr(a: &Multiaddr) -> Result { let mut protocols = a.iter(); @@ -418,28 +455,34 @@ mod tests { #[test] fn listening_twice() { let mut transport = MemoryTransport::default(); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_ok()); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_ok()); - let _listener = transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .unwrap(); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_err()); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_err()); - drop(_listener); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_ok()); - assert!(transport - .listen_on("/memory/1639174018481".parse().unwrap()) - .is_ok()); + + let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap(); + let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap(); + + let listener_id_1 = transport.listen_on(addr_1.clone()).unwrap(); + assert!( + transport.remove_listener(listener_id_1), + "Listener doesn't exist." + ); + + let listener_id_2 = transport.listen_on(addr_1.clone()).unwrap(); + let listener_id_3 = transport.listen_on(addr_2.clone()).unwrap(); + + assert!(transport.listen_on(addr_1.clone()).is_err()); + assert!(transport.listen_on(addr_2.clone()).is_err()); + + assert!( + transport.remove_listener(listener_id_2), + "Listener doesn't exist." + ); + assert!(transport.listen_on(addr_1).is_ok()); + assert!(transport.listen_on(addr_2.clone()).is_err()); + + assert!( + transport.remove_listener(listener_id_3), + "Listener doesn't exist." + ); + assert!(transport.listen_on(addr_2).is_ok()); } #[test] @@ -456,6 +499,35 @@ mod tests { .is_ok()); } + #[test] + fn stop_listening() { + let rand_port = rand::random::().saturating_add(1); + let addr: Multiaddr = format!("/memory/{}", rand_port).parse().unwrap(); + + let mut transport = MemoryTransport::default().boxed(); + futures::executor::block_on(async { + let listener_id = transport.listen_on(addr.clone()).unwrap(); + let reported_addr = transport + .select_next_some() + .await + .into_new_address() + .expect("new address"); + assert_eq!(addr, reported_addr); + assert!(transport.remove_listener(listener_id)); + match transport.select_next_some().await { + TransportEvent::ListenerClosed { + listener_id: id, + reason, + } => { + assert_eq!(id, listener_id); + assert!(reason.is_ok()) + } + other => panic!("Unexpected transport event: {:?}", other), + } + assert!(!transport.remove_listener(listener_id)); + }) + } + #[test] fn communicating_between_dialer_and_listener() { let msg = [1, 2, 3]; @@ -466,16 +538,16 @@ mod tests { let t1_addr: Multiaddr = format!("/memory/{}", rand_port).parse().unwrap(); let cloned_t1_addr = t1_addr.clone(); - let mut t1 = MemoryTransport::default(); + let mut t1 = MemoryTransport::default().boxed(); let listener = async move { - let listener = t1.listen_on(t1_addr.clone()).unwrap(); - - let upgrade = listener - .filter_map(|ev| futures::future::ready(ListenerEvent::into_upgrade(ev.unwrap()))) - .next() - .await - .unwrap(); + t1.listen_on(t1_addr.clone()).unwrap(); + let upgrade = loop { + let event = t1.select_next_some().await; + if let Some(upgrade) = event.into_incoming() { + break upgrade; + } + }; let mut socket = upgrade.0.await.unwrap(); @@ -504,14 +576,16 @@ mod tests { Protocol::Memory(rand::random::().saturating_add(1)).into(); let listener_addr_cloned = listener_addr.clone(); - let mut listener_transport = MemoryTransport::default(); + let mut listener_transport = MemoryTransport::default().boxed(); let listener = async move { - let mut listener = listener_transport.listen_on(listener_addr.clone()).unwrap(); - while let Some(ev) = listener.next().await { - if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { + listener_transport.listen_on(listener_addr.clone()).unwrap(); + loop { + if let TransportEvent::Incoming { send_back_addr, .. } = + listener_transport.select_next_some().await + { assert!( - remote_addr != listener_addr, + send_back_addr != listener_addr, "Expect dialer address not to equal listener address." ); return; @@ -539,14 +613,16 @@ mod tests { Protocol::Memory(rand::random::().saturating_add(1)).into(); let listener_addr_cloned = listener_addr.clone(); - let mut listener_transport = MemoryTransport::default(); + let mut listener_transport = MemoryTransport::default().boxed(); let listener = async move { - let mut listener = listener_transport.listen_on(listener_addr.clone()).unwrap(); - while let Some(ev) = listener.next().await { - if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { + listener_transport.listen_on(listener_addr.clone()).unwrap(); + loop { + if let TransportEvent::Incoming { send_back_addr, .. } = + listener_transport.select_next_some().await + { let dialer_port = - NonZeroU64::new(parse_memory_addr(&remote_addr).unwrap()).unwrap(); + NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap(); assert!( HUB.get(&dialer_port).is_some(), diff --git a/core/src/transport/optional.rs b/core/src/transport/optional.rs index cb10c35e133..2d93077659c 100644 --- a/core/src/transport/optional.rs +++ b/core/src/transport/optional.rs @@ -18,8 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{Transport, TransportError}; +use crate::transport::{ListenerId, Transport, TransportError, TransportEvent}; use multiaddr::Multiaddr; +use std::{pin::Pin, task::Context, task::Poll}; /// Transport that is possibly disabled. /// @@ -28,7 +29,8 @@ use multiaddr::Multiaddr; /// enabled (read: contains `Some`), then dialing and listening will be handled by the inner /// transport. #[derive(Debug, Copy, Clone)] -pub struct OptionalTransport(Option); +#[pin_project::pin_project] +pub struct OptionalTransport(#[pin] Option); impl OptionalTransport { /// Builds an `OptionalTransport` with the given transport in an enabled @@ -55,14 +57,10 @@ where { type Output = T::Output; type Error = T::Error; - type Listener = T::Listener; type ListenerUpgrade = T::ListenerUpgrade; type Dial = T::Dial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { if let Some(inner) = self.0.as_mut() { inner.listen_on(addr) } else { @@ -70,6 +68,14 @@ where } } + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(inner) = self.0.as_mut() { + inner.remove_listener(id) + } else { + false + } + } + fn dial(&mut self, addr: Multiaddr) -> Result> { if let Some(inner) = self.0.as_mut() { inner.dial(addr) @@ -96,4 +102,15 @@ where None } } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(inner) = self.project().0.as_pin_mut() { + inner.poll(cx) + } else { + Poll::Pending + } + } } diff --git a/core/src/transport/timeout.rs b/core/src/transport/timeout.rs index bb413cf8909..5c3867b3c01 100644 --- a/core/src/transport/timeout.rs +++ b/core/src/transport/timeout.rs @@ -25,7 +25,7 @@ // TODO: add example use crate::{ - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, Multiaddr, Transport, }; use futures::prelude::*; @@ -38,7 +38,9 @@ use std::{error, fmt, io, pin::Pin, task::Context, task::Poll, time::Duration}; /// **Note**: `listen_on` is never subject to a timeout, only the setup of each /// individual accepted connection. #[derive(Debug, Copy, Clone)] +#[pin_project::pin_project] pub struct TransportTimeout { + #[pin] inner: InnerTrans, outgoing_timeout: Duration, incoming_timeout: Duration, @@ -80,25 +82,17 @@ where { type Output = InnerTrans::Output; type Error = TransportTimeoutError; - type Listener = TimeoutListener; type ListenerUpgrade = Timeout; type Dial = Timeout; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let listener = self - .inner + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.inner .listen_on(addr) - .map_err(|err| err.map(TransportTimeoutError::Other))?; - - let listener = TimeoutListener { - inner: listener, - timeout: self.incoming_timeout, - }; + .map_err(|err| err.map(TransportTimeoutError::Other)) + } - Ok(listener) + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -129,45 +123,21 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.inner.address_translation(server, observed) } -} -// TODO: can be removed and replaced with an `impl Stream` once impl Trait is fully stable -// in Rust (https://github.com/rust-lang/rust/issues/34511) -#[pin_project::pin_project] -pub struct TimeoutListener { - #[pin] - inner: InnerStream, - timeout: Duration, -} - -impl Stream for TimeoutListener -where - InnerStream: TryStream, Error = E>, -{ - type Item = - Result, TransportTimeoutError>, TransportTimeoutError>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let this = self.project(); - - let poll_out = match TryStream::try_poll_next(this.inner, cx) { - Poll::Ready(Some(Err(err))) => { - return Poll::Ready(Some(Err(TransportTimeoutError::Other(err)))) - } - Poll::Ready(Some(Ok(v))) => v, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - }; - - let timeout = *this.timeout; - let event = poll_out - .map(move |inner_fut| Timeout { - inner: inner_fut, - timer: Delay::new(timeout), - }) - .map_err(TransportTimeoutError::Other); - - Poll::Ready(Some(Ok(event))) + let timeout = *this.incoming_timeout; + this.inner.poll(cx).map(|event| { + event + .map_upgrade(move |inner_fut| Timeout { + inner: inner_fut, + timer: Delay::new(timeout), + }) + .map_err(TransportTimeoutError::Other) + }) } } diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index 964045ad33f..c872ec955e4 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -26,8 +26,8 @@ use crate::{ connection::ConnectedPoint, muxing::{StreamMuxer, StreamMuxerBox}, transport::{ - and_then::AndThen, boxed::boxed, timeout::TransportTimeout, ListenerEvent, Transport, - TransportError, + and_then::AndThen, boxed::boxed, timeout::TransportTimeout, ListenerId, Transport, + TransportError, TransportEvent, }, upgrade::{ self, apply_inbound, apply_outbound, InboundUpgrade, InboundUpgradeApply, OutboundUpgrade, @@ -287,16 +287,16 @@ where /// A authenticated and multiplexed transport, obtained from /// [`Authenticated::multiplex`]. #[derive(Clone)] -pub struct Multiplexed(T); +#[pin_project::pin_project] +pub struct Multiplexed(#[pin] T); impl Multiplexed { /// Boxes the authenticated, multiplexed transport, including /// the [`StreamMuxer`] and custom transport errors. pub fn boxed(self) -> super::Boxed<(PeerId, StreamMuxerBox)> where - T: Transport + Sized + Send + 'static, + T: Transport + Sized + Send + Unpin + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Error: Send + Sync, M: StreamMuxer + Send + Sync + 'static, @@ -332,7 +332,6 @@ where { type Output = T::Output; type Error = T::Error; - type Listener = T::Listener; type ListenerUpgrade = T::ListenerUpgrade; type Dial = T::Dial; @@ -340,6 +339,10 @@ where self.0.dial(addr) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.0.remove_listener(id) + } + fn dial_as_listener( &mut self, addr: Multiaddr, @@ -347,16 +350,20 @@ where self.0.dial_as_listener(addr) } - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { self.0.listen_on(addr) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.0.address_translation(server, observed) } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().0.poll(cx) + } } /// An inbound or outbound upgrade. @@ -366,7 +373,9 @@ type EitherUpgrade = future::Either, OutboundUpg /// /// See [`Transport::upgrade`] #[derive(Debug, Copy, Clone)] +#[pin_project::pin_project] pub struct Upgrade { + #[pin] inner: T, upgrade: U, } @@ -388,7 +397,6 @@ where { type Output = (PeerId, D); type Error = TransportUpgradeError; - type Listener = ListenerStream; type ListenerUpgrade = ListenerUpgradeFuture; type Dial = DialUpgradeFuture; @@ -403,6 +411,10 @@ where }) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.remove_listener(id) + } + fn dial_as_listener( &mut self, addr: Multiaddr, @@ -417,23 +429,31 @@ where }) } - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let stream = self - .inner + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.inner .listen_on(addr) - .map_err(|err| err.map(TransportUpgradeError::Transport))?; - Ok(ListenerStream { - stream: Box::pin(stream), - upgrade: self.upgrade.clone(), - }) + .map_err(|err| err.map(TransportUpgradeError::Transport)) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.inner.address_translation(server, observed) } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.project(); + let upgrade = this.upgrade.clone(); + this.inner.poll(cx).map(|event| { + event + .map_upgrade(move |future| ListenerUpgradeFuture { + future: Box::pin(future), + upgrade: future::Either::Left(Some(upgrade)), + }) + .map_err(TransportUpgradeError::Transport) + }) + } } /// Errors produced by a transport upgrade. @@ -478,7 +498,7 @@ where C: AsyncRead + AsyncWrite + Unpin, { future: Pin>, - upgrade: future::Either, (Option, OutboundUpgradeApply)>, + upgrade: future::Either, (PeerId, OutboundUpgradeApply)>, } impl Future for DialUpgradeFuture @@ -507,18 +527,15 @@ where let u = up .take() .expect("DialUpgradeFuture is constructed with Either::Left(Some)."); - future::Either::Right((Some(i), apply_outbound(c, u, upgrade::Version::V1))) + future::Either::Right((i, apply_outbound(c, u, upgrade::Version::V1))) } - future::Either::Right((ref mut i, ref mut up)) => { + future::Either::Right((i, ref mut up)) => { let d = match ready!( Future::poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade) ) { Ok(d) => d, Err(err) => return Poll::Ready(Err(err)), }; - let i = i - .take() - .expect("DialUpgradeFuture polled after completion."); return Poll::Ready(Ok((i, d))); } } @@ -533,43 +550,6 @@ where { } -/// The [`Transport::Listener`] stream of an [`Upgrade`]d transport. -pub struct ListenerStream { - stream: Pin>, - upgrade: U, -} - -impl Stream for ListenerStream -where - S: TryStream, Error = E>, - F: TryFuture, - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade, Output = D> + Clone, -{ - type Item = Result< - ListenerEvent, TransportUpgradeError>, - TransportUpgradeError, - >; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(TryStream::try_poll_next(self.stream.as_mut(), cx)) { - Some(Ok(event)) => { - let event = event - .map(move |future| ListenerUpgradeFuture { - future: Box::pin(future), - upgrade: future::Either::Left(Some(self.upgrade.clone())), - }) - .map_err(TransportUpgradeError::Transport); - Poll::Ready(Some(Ok(event))) - } - Some(Err(err)) => Poll::Ready(Some(Err(TransportUpgradeError::Transport(err)))), - None => Poll::Ready(None), - } - } -} - -impl Unpin for ListenerStream {} - /// The [`Transport::ListenerUpgrade`] future of an [`Upgrade`]d transport. pub struct ListenerUpgradeFuture where @@ -577,7 +557,7 @@ where U: InboundUpgrade>, { future: Pin>, - upgrade: future::Either, (Option, InboundUpgradeApply)>, + upgrade: future::Either, (PeerId, InboundUpgradeApply)>, } impl Future for ListenerUpgradeFuture @@ -606,18 +586,15 @@ where let u = up .take() .expect("ListenerUpgradeFuture is constructed with Either::Left(Some)."); - future::Either::Right((Some(i), apply_inbound(c, u))) + future::Either::Right((i, apply_inbound(c, u))) } - future::Either::Right((ref mut i, ref mut up)) => { + future::Either::Right((i, ref mut up)) => { let d = match ready!(TryFuture::try_poll(Pin::new(up), cx) .map_err(TransportUpgradeError::Upgrade)) { Ok(v) => v, Err(err) => return Poll::Ready(Err(err)), }; - let i = i - .take() - .expect("ListenerUpgradeFuture polled after completion."); return Poll::Ready(Ok((i, d))); } } diff --git a/core/tests/transport_upgrade.rs b/core/tests/transport_upgrade.rs index 9fd1e8eaabb..ecba64dfb2f 100644 --- a/core/tests/transport_upgrade.rs +++ b/core/tests/transport_upgrade.rs @@ -95,7 +95,8 @@ fn upgrade_pipeline() { // Gracefully close the connection to allow protocol // negotiation to complete. util::CloseMuxer::new(mplex).map_ok(move |mplex| (peer, mplex)) - }); + }) + .boxed(); let dialer_keys = identity::Keypair::generate_ed25519(); let dialer_id = dialer_keys.public().to_peer_id(); @@ -113,17 +114,18 @@ fn upgrade_pipeline() { // Gracefully close the connection to allow protocol // negotiation to complete. util::CloseMuxer::new(mplex).map_ok(move |mplex| (peer, mplex)) - }); + }) + .boxed(); let listen_addr1 = Multiaddr::from(Protocol::Memory(random::())); let listen_addr2 = listen_addr1.clone(); - let mut listener = listener_transport.listen_on(listen_addr1).unwrap(); + listener_transport.listen_on(listen_addr1).unwrap(); let server = async move { loop { - let (upgrade, _remote_addr) = - match listener.next().await.unwrap().unwrap().into_upgrade() { + let (upgrade, _send_back_addr) = + match listener_transport.select_next_some().await.into_incoming() { Some(u) => u, None => continue, }; diff --git a/examples/chat-tokio.rs b/examples/chat-tokio.rs index 2400c8a98a5..66c25205246 100644 --- a/examples/chat-tokio.rs +++ b/examples/chat-tokio.rs @@ -45,13 +45,14 @@ use libp2p::{ mplex, noise, swarm::{dial_opts::DialOpts, NetworkBehaviourEventProcess, SwarmBuilder, SwarmEvent}, - // `TokioTcpConfig` is available through the `tcp-tokio` feature. - tcp::TokioTcpConfig, + // `TokioTcpTransport` is available through the `tcp-tokio` feature. + tcp::TokioTcpTransport, Multiaddr, NetworkBehaviour, PeerId, Transport, }; +use libp2p_tcp::GenTcpConfig; use std::error::Error; use tokio::io::{self, AsyncBufReadExt}; @@ -72,8 +73,7 @@ async fn main() -> Result<(), Box> { // Create a tokio-based TCP transport use noise for authenticated // encryption and Mplex for multiplexing of substreams on a TCP stream. - let transport = TokioTcpConfig::new() - .nodelay(true) + let transport = TokioTcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(mplex::MplexConfig::new()) diff --git a/examples/ipfs-private.rs b/examples/ipfs-private.rs index fdeed494141..113bdf988f2 100644 --- a/examples/ipfs-private.rs +++ b/examples/ipfs-private.rs @@ -44,10 +44,11 @@ use libp2p::{ noise, ping, pnet::{PnetConfig, PreSharedKey}, swarm::{NetworkBehaviourEventProcess, SwarmEvent}, - tcp::TcpConfig, + tcp::TcpTransport, yamux::YamuxConfig, Multiaddr, NetworkBehaviour, PeerId, Swarm, Transport, }; +use libp2p_tcp::GenTcpConfig; use std::{env, error::Error, fs, path::Path, str::FromStr, time::Duration}; /// Builds the transport that serves as a common ground for all connections. @@ -61,7 +62,7 @@ pub fn build_transport( let noise_config = noise::NoiseConfig::xx(noise_keys).into_authenticated(); let yamux_config = YamuxConfig::default(); - let base_transport = TcpConfig::new().nodelay(true); + let base_transport = TcpTransport::new(GenTcpConfig::default().nodelay(true)); let maybe_encrypted = match psk { Some(psk) => EitherTransport::Left( base_transport.and_then(move |socket, _| PnetConfig::new(psk).handshake(socket)), diff --git a/muxers/mplex/benches/split_send_size.rs b/muxers/mplex/benches/split_send_size.rs index a15c15c3c92..f5bf771d1ab 100644 --- a/muxers/mplex/benches/split_send_size.rs +++ b/muxers/mplex/benches/split_send_size.rs @@ -23,15 +23,16 @@ use async_std::task; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; -use futures::channel::oneshot; use futures::future::poll_fn; use futures::prelude::*; +use futures::{channel::oneshot, future::join}; use libp2p_core::{ identity, multiaddr::multiaddr, muxing, transport, upgrade, Multiaddr, PeerId, StreamMuxer, Transport, }; use libp2p_mplex as mplex; use libp2p_plaintext::PlainText2Config; +use libp2p_tcp::GenTcpConfig; use std::pin::Pin; use std::time::Duration; @@ -58,11 +59,13 @@ fn prepare(c: &mut Criterion) { let tcp_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1)), Tcp(0u16)]; for &size in BENCH_SIZES.iter() { tcp.throughput(Throughput::Bytes(payload.len() as u64)); - let mut trans = tcp_transport(size); + let mut receiver_transport = tcp_transport(size); + let mut sender_transport = tcp_transport(size); tcp.bench_function(format!("{}", size), |b| { b.iter(|| { run( - black_box(&mut trans), + black_box(&mut receiver_transport), + black_box(&mut sender_transport), black_box(&payload), black_box(&tcp_addr), ) @@ -75,11 +78,13 @@ fn prepare(c: &mut Criterion) { let mem_addr = multiaddr![Memory(0u64)]; for &size in BENCH_SIZES.iter() { mem.throughput(Throughput::Bytes(payload.len() as u64)); - let mut trans = mem_transport(size); + let mut receiver_transport = mem_transport(size); + let mut sender_transport = mem_transport(size); mem.bench_function(format!("{}", size), |b| { b.iter(|| { run( - black_box(&mut trans), + black_box(&mut receiver_transport), + black_box(&mut sender_transport), black_box(&payload), black_box(&mem_addr), ) @@ -90,20 +95,24 @@ fn prepare(c: &mut Criterion) { } /// Transfers the given payload between two nodes using the given transport. -fn run(transport: &mut BenchTransport, payload: &Vec, listen_addr: &Multiaddr) { - let mut listener = transport.listen_on(listen_addr.clone()).unwrap(); +fn run( + receiver_trans: &mut BenchTransport, + sender_trans: &mut BenchTransport, + payload: &Vec, + listen_addr: &Multiaddr, +) { + receiver_trans.listen_on(listen_addr.clone()).unwrap(); let (addr_sender, addr_receiver) = oneshot::channel(); let mut addr_sender = Some(addr_sender); let payload_len = payload.len(); - // Spawn the receiver. - let receiver = task::spawn(async move { + let receiver = async move { loop { - match listener.next().await.unwrap().unwrap() { - transport::ListenerEvent::NewAddress(a) => { - addr_sender.take().unwrap().send(a).unwrap(); + match receiver_trans.next().await.unwrap() { + transport::TransportEvent::NewAddress { listen_addr, .. } => { + addr_sender.take().unwrap().send(listen_addr).unwrap(); } - transport::ListenerEvent::Upgrade { upgrade, .. } => { + transport::TransportEvent::Incoming { upgrade, .. } => { let (_peer, conn) = upgrade.await.unwrap(); let mut s = poll_fn(|cx| conn.poll_event(cx)) .await @@ -125,15 +134,15 @@ fn run(transport: &mut BenchTransport, payload: &Vec, listen_addr: &Multiadd } } } - _ => panic!("Unexpected listener event"), + _ => panic!("Unexpected transport event"), } } - }); + }; // Spawn and block on the sender, i.e. until all data is sent. - task::block_on(async move { + let sender = async move { let addr = addr_receiver.await.unwrap(); - let (_peer, conn) = transport.dial(addr).unwrap().await.unwrap(); + let (_peer, conn) = sender_trans.dial(addr).unwrap().await.unwrap(); let mut handle = conn.open_outbound(); let mut stream = poll_fn(|cx| conn.poll_outbound(cx, &mut handle)) .await @@ -151,10 +160,10 @@ fn run(transport: &mut BenchTransport, payload: &Vec, listen_addr: &Multiadd return; } } - }); + }; // Wait for all data to be received. - task::block_on(receiver); + task::block_on(join(sender, receiver)); } fn tcp_transport(split_send_size: usize) -> BenchTransport { @@ -164,8 +173,7 @@ fn tcp_transport(split_send_size: usize) -> BenchTransport { let mut mplex = mplex::MplexConfig::default(); mplex.set_split_send_size(split_send_size); - libp2p_tcp::TcpConfig::new() - .nodelay(true) + libp2p_tcp::TcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(PlainText2Config { local_public_key }) .multiplex(mplex) diff --git a/muxers/mplex/tests/async_write.rs b/muxers/mplex/tests/async_write.rs index d59f8b279f9..9dbda1a198d 100644 --- a/muxers/mplex/tests/async_write.rs +++ b/muxers/mplex/tests/async_write.rs @@ -21,7 +21,7 @@ use futures::future::poll_fn; use futures::{channel::oneshot, prelude::*}; use libp2p_core::{upgrade, StreamMuxer, Transport}; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::TcpTransport; use std::sync::Arc; #[test] @@ -33,29 +33,28 @@ fn async_write() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); - let client = listener + let client = transport .next() .await - .unwrap() - .unwrap() - .into_upgrade() + .expect("some event") + .into_incoming() .unwrap() .0 .await @@ -73,7 +72,7 @@ fn async_write() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() + let mut transport = TcpTransport::default() .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let client = Arc::new(transport.dial(rx.await.unwrap()).unwrap().await.unwrap()); diff --git a/muxers/mplex/tests/two_peers.rs b/muxers/mplex/tests/two_peers.rs index a6438feaff9..4283452fe07 100644 --- a/muxers/mplex/tests/two_peers.rs +++ b/muxers/mplex/tests/two_peers.rs @@ -21,7 +21,7 @@ use futures::future::poll_fn; use futures::{channel::oneshot, prelude::*}; use libp2p_core::{upgrade, StreamMuxer, Transport}; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::TcpTransport; use std::sync::Arc; #[test] @@ -33,29 +33,28 @@ fn client_to_server_outbound() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); - let client = listener + let client = transport .next() .await - .unwrap() - .unwrap() - .into_upgrade() + .expect("some event") + .into_incoming() .unwrap() .0 .await @@ -73,8 +72,9 @@ fn client_to_server_outbound() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); let client = Arc::new(transport.dial(rx.await.unwrap()).unwrap().await.unwrap()); let mut inbound = loop { @@ -102,30 +102,29 @@ fn client_to_server_inbound() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); let client = Arc::new( - listener + transport .next() .await - .unwrap() - .unwrap() - .into_upgrade() + .expect("some event") + .into_incoming() .unwrap() .0 .await @@ -149,8 +148,9 @@ fn client_to_server_inbound() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); @@ -172,29 +172,28 @@ fn protocol_not_match() { let _bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); - let client = listener + let client = transport .next() .await - .unwrap() - .unwrap() - .into_upgrade() + .expect("some event") + .into_incoming() .unwrap() .0 .await @@ -214,8 +213,9 @@ fn protocol_not_match() { // Make sure they do not connect when protocols do not match let mut mplex = libp2p_mplex::MplexConfig::new(); mplex.set_protocol_name(b"/mplextest/1.0.0"); - let mut transport = TcpConfig::new() - .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let mut transport = TcpTransport::default() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)) + .boxed(); match transport.dial(rx.await.unwrap()).unwrap().await { Ok(_) => { assert!(false, "Dialing should fail here as protocols do not match") diff --git a/protocols/autonat/src/behaviour.rs b/protocols/autonat/src/behaviour.rs index 86c65bb3416..f98bf5af532 100644 --- a/protocols/autonat/src/behaviour.rs +++ b/protocols/autonat/src/behaviour.rs @@ -29,9 +29,8 @@ pub use as_server::{InboundProbeError, InboundProbeEvent}; use futures_timer::Delay; use instant::Instant; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - multiaddr::Protocol, - ConnectedPoint, Endpoint, Multiaddr, PeerId, + connection::ConnectionId, multiaddr::Protocol, transport::ListenerId, ConnectedPoint, Endpoint, + Multiaddr, PeerId, }; use libp2p_request_response::{ handler::RequestResponseHandlerEvent, ProtocolSupport, RequestId, RequestResponse, diff --git a/protocols/dcutr/examples/client.rs b/protocols/dcutr/examples/client.rs index 94bed7c2d7f..dd73b7d3ac3 100644 --- a/protocols/dcutr/examples/client.rs +++ b/protocols/dcutr/examples/client.rs @@ -32,7 +32,7 @@ use libp2p::noise; use libp2p::ping::{Ping, PingConfig, PingEvent}; use libp2p::relay::v2::client::{self, Client}; use libp2p::swarm::{SwarmBuilder, SwarmEvent}; -use libp2p::tcp::TcpConfig; +use libp2p::tcp::{GenTcpConfig, TcpTransport}; use libp2p::Transport; use libp2p::{identity, NetworkBehaviour, PeerId}; use log::info; @@ -95,7 +95,10 @@ fn main() -> Result<(), Box> { let transport = OrTransport::new( relay_transport, - block_on(DnsConfig::system(TcpConfig::new().port_reuse(true))).unwrap(), + block_on(DnsConfig::system(TcpTransport::new( + GenTcpConfig::default().port_reuse(true), + ))) + .unwrap(), ) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 1bfab2d84e5..9ed56b3265e 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -22,9 +22,8 @@ use crate::handler::{IdentifyHandlerEvent, IdentifyHandlerProto, IdentifyPush}; use crate::protocol::{IdentifyInfo, ReplySubstream, UpgradeError}; use futures::prelude::*; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - multiaddr::Protocol, - ConnectedPoint, Multiaddr, PeerId, PublicKey, + connection::ConnectionId, multiaddr::Protocol, transport::ListenerId, ConnectedPoint, + Multiaddr, PeerId, PublicKey, }; use libp2p_swarm::{ dial_opts::{self, DialOpts}, @@ -515,7 +514,7 @@ mod tests { use futures::pin_mut; use libp2p::mplex::MplexConfig; use libp2p::noise; - use libp2p::tcp::TcpConfig; + use libp2p::tcp::{GenTcpConfig, TcpTransport}; use libp2p_core::{identity, muxing::StreamMuxerBox, transport, upgrade, PeerId, Transport}; use libp2p_swarm::{Swarm, SwarmEvent}; use std::time::Duration; @@ -529,8 +528,7 @@ mod tests { .into_authentic(&id_keys) .unwrap(); let pubkey = id_keys.public(); - let transport = TcpConfig::new() - .nodelay(true) + let transport = TcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(MplexConfig::new()) diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index 11a33f652a3..735fbcb342b 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -287,7 +287,7 @@ pub enum UpgradeError { mod tests { use super::*; use futures::channel::oneshot; - use libp2p::tcp::TcpConfig; + use libp2p::tcp::TcpTransport; use libp2p_core::{ identity, upgrade::{self, apply_inbound, apply_outbound}, @@ -304,27 +304,25 @@ mod tests { let (tx, rx) = oneshot::channel(); let bg_task = async_std::task::spawn(async move { - let mut transport = TcpConfig::new(); + let mut transport = TcpTransport::default().boxed(); - let mut listener = transport + transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener + let addr = transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); tx.send(addr).unwrap(); - let socket = listener + let socket = transport .next() .await - .unwrap() - .unwrap() - .into_upgrade() + .expect("some event") + .into_incoming() .unwrap() .0 .await @@ -349,7 +347,7 @@ mod tests { }); async_std::task::block_on(async move { - let mut transport = TcpConfig::new(); + let mut transport = TcpTransport::default(); let socket = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); let info = apply_outbound(socket, IdentifyProtocol, upgrade::Version::V1) diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index d5b322e096a..59d63b36e9d 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -40,8 +40,7 @@ use crate::K_VALUE; use fnv::{FnvHashMap, FnvHashSet}; use instant::Instant; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - ConnectedPoint, Multiaddr, PeerId, + connection::ConnectionId, transport::ListenerId, ConnectedPoint, Multiaddr, PeerId, }; use libp2p_swarm::{ dial_opts::{self, DialOpts}, diff --git a/protocols/kad/src/protocol.rs b/protocols/kad/src/protocol.rs index 648f7fc9e07..656917b54f6 100644 --- a/protocols/kad/src/protocol.rs +++ b/protocols/kad/src/protocol.rs @@ -603,7 +603,7 @@ where mod tests { /*// TODO: restore - use self::libp2p_tcp::TcpConfig; + use self::libp2p_tcp::TcpTransport; use self::tokio::runtime::current_thread::Runtime; use futures::{Future, Sink, Stream}; use libp2p_core::{PeerId, PublicKey, Transport}; @@ -658,10 +658,10 @@ mod tests { let (tx, rx) = mpsc::channel(); let bg_thread = thread::spawn(move || { - let transport = TcpConfig::new().with_upgrade(KademliaProtocolConfig); + let transport = TcpTransport::default().with_upgrade(KademliaProtocolConfig); let (listener, addr) = transport - .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .listen_on( "/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); tx.send(addr).unwrap(); @@ -678,7 +678,7 @@ mod tests { let _ = rt.block_on(future).unwrap(); }); - let transport = TcpConfig::new().with_upgrade(KademliaProtocolConfig); + let transport = TcpTransport::default().with_upgrade(KademliaProtocolConfig); let future = transport .dial(rx.recv().unwrap()) diff --git a/protocols/mdns/src/behaviour.rs b/protocols/mdns/src/behaviour.rs index c844f742af9..244b2b784dd 100644 --- a/protocols/mdns/src/behaviour.rs +++ b/protocols/mdns/src/behaviour.rs @@ -25,7 +25,7 @@ use crate::MdnsConfig; use async_io::Timer; use futures::prelude::*; use if_watch::{IfEvent, IfWatcher}; -use libp2p_core::connection::ListenerId; +use libp2p_core::transport::ListenerId; use libp2p_core::{Multiaddr, PeerId}; use libp2p_swarm::{ handler::DummyConnectionHandler, ConnectionHandler, NetworkBehaviour, NetworkBehaviourAction, diff --git a/protocols/ping/src/protocol.rs b/protocols/ping/src/protocol.rs index ae60f67a858..499c5ad4a0f 100644 --- a/protocols/ping/src/protocol.rs +++ b/protocols/ping/src/protocol.rs @@ -115,9 +115,10 @@ where #[cfg(test)] mod tests { use super::*; + use futures::StreamExt; use libp2p_core::{ multiaddr::multiaddr, - transport::{memory::MemoryTransport, ListenerEvent, Transport}, + transport::{memory::MemoryTransport, Transport}, }; use rand::{thread_rng, Rng}; use std::time::Duration; @@ -125,24 +126,28 @@ mod tests { #[test] fn ping_pong() { let mem_addr = multiaddr![Memory(thread_rng().gen::())]; - let mut listener = MemoryTransport.listen_on(mem_addr).unwrap(); + let mut transport = MemoryTransport::new().boxed(); + transport.listen_on(mem_addr).unwrap(); - let listener_addr = - if let Some(Some(Ok(ListenerEvent::NewAddress(a)))) = listener.next().now_or_never() { - a - } else { - panic!("MemoryTransport not listening on an address!"); - }; + let listener_addr = transport + .select_next_some() + .now_or_never() + .and_then(|ev| ev.into_new_address()) + .expect("MemoryTransport not listening on an address!"); async_std::task::spawn(async move { - let listener_event = listener.next().await.unwrap(); - let (listener_upgrade, _) = listener_event.unwrap().into_upgrade().unwrap(); + let transport_event = transport.next().await.unwrap(); + let (listener_upgrade, _) = transport_event.into_incoming().unwrap(); let conn = listener_upgrade.await.unwrap(); recv_ping(conn).await.unwrap(); }); async_std::task::block_on(async move { - let c = MemoryTransport.dial(listener_addr).unwrap().await.unwrap(); + let c = MemoryTransport::new() + .dial(listener_addr) + .unwrap() + .await + .unwrap(); let (_, rtt) = send_ping(c).await.unwrap(); assert!(rtt > Duration::from_secs(0)); }); diff --git a/protocols/ping/tests/ping.rs b/protocols/ping/tests/ping.rs index dbde7db608d..ac45949ced7 100644 --- a/protocols/ping/tests/ping.rs +++ b/protocols/ping/tests/ping.rs @@ -31,7 +31,7 @@ use libp2p_mplex as mplex; use libp2p_noise as noise; use libp2p_ping as ping; use libp2p_swarm::{DummyBehaviour, KeepAlive, Swarm, SwarmEvent}; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::{GenTcpConfig, TcpTransport}; use libp2p_yamux as yamux; use quickcheck::*; use rand::prelude::*; @@ -248,8 +248,7 @@ fn mk_transport(muxer: MuxerChoice) -> (PeerId, transport::Boxed<(PeerId, Stream .unwrap(); ( peer_id, - TcpConfig::new() - .nodelay(true) + TcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(match muxer { diff --git a/protocols/relay/CHANGELOG.md b/protocols/relay/CHANGELOG.md index 2f5930201aa..50f6a771f40 100644 --- a/protocols/relay/CHANGELOG.md +++ b/protocols/relay/CHANGELOG.md @@ -6,7 +6,11 @@ - Do not duplicate the p2p/xxx component with the relay PeerId when a client requests a reservation. See [PR 2701]. +- Drive the `RelayListener`s within the `ClientTransport`. Add `Transport::poll` and `Transport::remove_listener` + for `ClientTransport`. See [PR 2652]. + [PR 2701]: https://github.com/libp2p/rust-libp2p/pull/2701/ +[PR 2652]: https://github.com/libp2p/rust-libp2p/pull/2652 # 0.9.1 diff --git a/protocols/relay/examples/relay_v2.rs b/protocols/relay/examples/relay_v2.rs index 8a4ee914fce..25d0bb7fc94 100644 --- a/protocols/relay/examples/relay_v2.rs +++ b/protocols/relay/examples/relay_v2.rs @@ -28,7 +28,7 @@ use libp2p::multiaddr::Protocol; use libp2p::ping::{Ping, PingConfig, PingEvent}; use libp2p::relay::v2::relay::{self, Relay}; use libp2p::swarm::{Swarm, SwarmEvent}; -use libp2p::tcp::TcpConfig; +use libp2p::tcp::TcpTransport; use libp2p::Transport; use libp2p::{identity, NetworkBehaviour, PeerId}; use libp2p::{noise, Multiaddr}; @@ -46,7 +46,7 @@ fn main() -> Result<(), Box> { let local_peer_id = PeerId::from(local_key.public()); println!("Local peer id: {:?}", local_peer_id); - let tcp_transport = TcpConfig::new(); + let tcp_transport = TcpTransport::default(); let noise_keys = noise::Keypair::::new() .into_authentic(&local_key) diff --git a/protocols/relay/src/v2/client/transport.rs b/protocols/relay/src/v2/client/transport.rs index 5414353786c..18e4a989597 100644 --- a/protocols/relay/src/v2/client/transport.rs +++ b/protocols/relay/src/v2/client/transport.rs @@ -23,12 +23,13 @@ use crate::v2::client::RelayedConnection; use crate::v2::RequestId; use futures::channel::mpsc; use futures::channel::oneshot; -use futures::future::{ready, BoxFuture, Future, FutureExt, Ready}; +use futures::future::{ready, BoxFuture, FutureExt, Ready}; use futures::ready; use futures::sink::SinkExt; +use futures::stream::SelectAll; use futures::stream::{Stream, StreamExt}; use libp2p_core::multiaddr::{Multiaddr, Protocol}; -use libp2p_core::transport::{ListenerEvent, TransportError}; +use libp2p_core::transport::{ListenerId, TransportError, TransportEvent}; use libp2p_core::{PeerId, Transport}; use std::collections::VecDeque; use std::pin::Pin; @@ -85,9 +86,10 @@ use thiserror::Error; /// .with(Protocol::P2pCircuit); // Signal to listen via remote relay node. /// transport.listen_on(relay_addr).unwrap(); /// ``` -#[derive(Clone)] pub struct ClientTransport { to_behaviour: mpsc::Sender, + pending_to_behaviour: VecDeque, + listeners: SelectAll, } impl ClientTransport { @@ -112,22 +114,22 @@ impl ClientTransport { /// ``` pub(crate) fn new() -> (Self, mpsc::Receiver) { let (to_behaviour, from_transport) = mpsc::channel(0); - - (ClientTransport { to_behaviour }, from_transport) + let transport = ClientTransport { + to_behaviour, + pending_to_behaviour: VecDeque::new(), + listeners: SelectAll::new(), + }; + (transport, from_transport) } } impl Transport for ClientTransport { type Output = RelayedConnection; type Error = RelayError; - type Listener = RelayListener; type ListenerUpgrade = Ready>; type Dial = RelayedDial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let (relay_peer_id, relay_addr) = match parse_relayed_multiaddr(addr)? { RelayedMultiaddr { relay_peer_id: None, @@ -147,25 +149,31 @@ impl Transport for ClientTransport { }; let (to_listener, from_behaviour) = mpsc::channel(0); - let mut to_behaviour = self.to_behaviour.clone(); - let msg_to_behaviour = Some( - async move { - to_behaviour - .send(TransportToBehaviourMsg::ListenReq { - relay_peer_id, - relay_addr, - to_listener, - }) - .await - } - .boxed(), - ); - - Ok(RelayListener { - queued_new_addresses: Default::default(), + self.pending_to_behaviour + .push_back(TransportToBehaviourMsg::ListenReq { + relay_peer_id, + relay_addr, + to_listener, + }); + + let listener_id = ListenerId::new(); + let listener = RelayListener { + listener_id, + queued_events: Default::default(), from_behaviour, - msg_to_behaviour, - }) + is_closed: false, + }; + self.listeners.push(listener); + Ok(listener_id) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) { + listener.close(Ok(())); + true + } else { + false + } } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -217,6 +225,35 @@ impl Transport for ClientTransport { fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { None } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> + where + Self: Sized, + { + loop { + if !self.pending_to_behaviour.is_empty() { + match self.to_behaviour.poll_ready(cx) { + Poll::Ready(Ok(())) => { + let msg = self + .pending_to_behaviour + .pop_front() + .expect("Called !is_empty()."); + let _ = self.to_behaviour.start_send(msg); + continue; + } + Poll::Ready(Err(_)) => unreachable!("Receiver is never dropped."), + Poll::Pending => {} + } + } + match self.listeners.poll_next_unpin(cx) { + Poll::Ready(Some(event)) => return Poll::Ready(event), + _ => return Poll::Pending, + } + } + } } #[derive(Default)] @@ -282,64 +319,87 @@ fn parse_relayed_multiaddr( } pub struct RelayListener { - queued_new_addresses: VecDeque, + listener_id: ListenerId, + /// Queue of events to report when polled. + queued_events: VecDeque<::Item>, + /// Channel for messages from the behaviour [`Handler`][super::handler::Handler]. from_behaviour: mpsc::Receiver, - msg_to_behaviour: Option>>, + /// The listener can be closed either manually with [`Transport::remove_listener`] or if + /// the sender side of the `from_behaviour` channel is dropped. + is_closed: bool, } -impl Unpin for RelayListener {} +impl RelayListener { + /// Close the listener. + /// + /// This will create a [`TransportEvent::ListenerClosed`] event + /// and terminate the stream once all remaining events in queue have + /// been reported. + fn close(&mut self, reason: Result<(), RelayError>) { + self.queued_events + .push_back(TransportEvent::ListenerClosed { + listener_id: self.listener_id, + reason, + }); + self.is_closed = true; + } +} impl Stream for RelayListener { - type Item = - Result>, RelayError>, RelayError>; + type Item = TransportEvent<::ListenerUpgrade, RelayError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - if let Some(msg) = &mut self.msg_to_behaviour { - match Future::poll(msg.as_mut(), cx) { - Poll::Ready(Ok(())) => self.msg_to_behaviour = None, - Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))), - Poll::Pending => {} - } + if let Some(event) = self.queued_events.pop_front() { + return Poll::Ready(Some(event)); } - if let Some(addr) = self.queued_new_addresses.pop_front() { - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(addr)))); + if self.is_closed { + // Terminate the stream if the listener closed and all remaining events have been reported. + return Poll::Ready(None); } let msg = match ready!(self.from_behaviour.poll_next_unpin(cx)) { Some(msg) => msg, None => { // Sender of `from_behaviour` has been dropped, signaling listener to close. - return Poll::Ready(None); + self.close(Ok(())); + continue; } }; - let result = match msg { + match msg { ToListenerMsg::Reservation(Ok(Reservation { addrs })) => { debug_assert!( - self.queued_new_addresses.is_empty(), + self.queued_events.is_empty(), "Assert empty due to previous `pop_front` attempt." ); // Returned as [`ListenerEvent::NewAddress`] in next iteration of loop. - self.queued_new_addresses = addrs.into(); - - continue; + self.queued_events = addrs + .into_iter() + .map(|listen_addr| TransportEvent::NewAddress { + listener_id: self.listener_id, + listen_addr, + }) + .collect(); } ToListenerMsg::IncomingRelayedConnection { stream, src_peer_id, relay_addr, relay_peer_id: _, - } => Ok(ListenerEvent::Upgrade { - upgrade: ready(Ok(stream)), - local_addr: relay_addr.with(Protocol::P2pCircuit), - remote_addr: Protocol::P2p(src_peer_id.into()).into(), - }), - ToListenerMsg::Reservation(Err(())) => Err(RelayError::Reservation), + } => { + let listener_id = self.listener_id; + + self.queued_events.push_back(TransportEvent::Incoming { + upgrade: ready(Ok(stream)), + listener_id, + local_addr: relay_addr.with(Protocol::P2pCircuit), + send_back_addr: Protocol::P2p(src_peer_id.into()).into(), + }) + } + ToListenerMsg::Reservation(Err(())) => self.close(Err(RelayError::Reservation)), }; - - return Poll::Ready(Some(result)); } } } diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index 6cd6a732d4e..8cbc06e7444 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -32,7 +32,7 @@ use libp2p_core::{ use libp2p_noise::{Keypair, NoiseConfig, X25519Spec}; use libp2p_request_response::*; use libp2p_swarm::{Swarm, SwarmEvent}; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::{GenTcpConfig, TcpTransport}; use rand::{self, Rng}; use std::{io, iter}; @@ -300,8 +300,7 @@ fn mk_transport() -> (PeerId, transport::Boxed<(PeerId, StreamMuxerBox)>) { .unwrap(); ( peer_id, - TcpConfig::new() - .nodelay(true) + TcpTransport::new(GenTcpConfig::default().nodelay(true)) .upgrade(upgrade::Version::V1) .authenticate(NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(libp2p_yamux::YamuxConfig::default()) diff --git a/src/bandwidth.rs b/src/bandwidth.rs index 2e22d73b163..a58eec95ddb 100644 --- a/src/bandwidth.rs +++ b/src/bandwidth.rs @@ -20,7 +20,7 @@ use crate::{ core::{ - transport::{ListenerEvent, TransportError}, + transport::{TransportError, TransportEvent}, Transport, }, Multiaddr, @@ -31,6 +31,7 @@ use futures::{ prelude::*, ready, }; +use libp2p_core::transport::ListenerId; use std::{ convert::TryFrom as _, io, @@ -45,7 +46,9 @@ use std::{ /// Wraps around a `Transport` and counts the number of bytes that go through all the opened /// connections. #[derive(Clone)] +#[pin_project::pin_project] pub struct BandwidthLogging { + #[pin] inner: TInner, sinks: Arc, } @@ -73,18 +76,32 @@ where { type Output = BandwidthConnecLogging; type Error = TInner::Error; - type Listener = BandwidthListener; type ListenerUpgrade = BandwidthFuture; type Dial = BandwidthFuture; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let sinks = self.sinks.clone(); - self.inner - .listen_on(addr) - .map(move |inner| BandwidthListener { inner, sinks }) + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.project(); + match this.inner.poll(cx) { + Poll::Ready(event) => { + let event = event.map_upgrade({ + let sinks = this.sinks.clone(); + |inner| BandwidthFuture { inner, sinks } + }); + Poll::Ready(event) + } + Poll::Pending => Poll::Pending, + } + } + + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.inner.listen_on(addr) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -109,39 +126,6 @@ where } } -/// Wraps around a `Stream` that produces connections. Wraps each connection around a bandwidth -/// counter. -#[pin_project::pin_project] -pub struct BandwidthListener { - #[pin] - inner: TInner, - sinks: Arc, -} - -impl Stream for BandwidthListener -where - TInner: TryStream, Error = TErr>, -{ - type Item = Result, TErr>, TErr>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - let event = if let Some(event) = ready!(this.inner.try_poll_next(cx)?) { - event - } else { - return Poll::Ready(None); - }; - - let event = event.map({ - let sinks = this.sinks.clone(); - |inner| BandwidthFuture { inner, sinks } - }); - - Poll::Ready(Some(Ok(event))) - } -} - /// Wraps around a `Future` that produces a connection. Wraps the connection around a bandwidth /// counter. #[pin_project::pin_project] diff --git a/src/lib.rs b/src/lib.rs index 45174b66af7..6bb577b1f52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -201,9 +201,15 @@ pub async fn development_transport( keypair: identity::Keypair, ) -> std::io::Result> { let transport = { - let dns_tcp = dns::DnsConfig::system(tcp::TcpConfig::new().nodelay(true)).await?; + let dns_tcp = dns::DnsConfig::system(tcp::TcpTransport::new( + tcp::GenTcpConfig::new().nodelay(true), + )) + .await?; let ws_dns_tcp = websocket::WsConfig::new( - dns::DnsConfig::system(tcp::TcpConfig::new().nodelay(true)).await?, + dns::DnsConfig::system(tcp::TcpTransport::new( + tcp::GenTcpConfig::new().nodelay(true), + )) + .await?, ); dns_tcp.or_transport(ws_dns_tcp) }; @@ -259,9 +265,11 @@ pub fn tokio_development_transport( keypair: identity::Keypair, ) -> std::io::Result> { let transport = { - let dns_tcp = dns::TokioDnsConfig::system(tcp::TokioTcpConfig::new().nodelay(true))?; + let dns_tcp = dns::TokioDnsConfig::system(tcp::TokioTcpTransport::new( + tcp::GenTcpConfig::new().nodelay(true), + ))?; let ws_dns_tcp = websocket::WsConfig::new(dns::TokioDnsConfig::system( - tcp::TokioTcpConfig::new().nodelay(true), + tcp::TokioTcpTransport::new(tcp::GenTcpConfig::new().nodelay(true)), )?); dns_tcp.or_transport(ws_dns_tcp) }; diff --git a/swarm-derive/src/lib.rs b/swarm-derive/src/lib.rs index e81bd8ae4d1..1216add96c0 100644 --- a/swarm-derive/src/lib.rs +++ b/swarm-derive/src/lib.rs @@ -57,7 +57,7 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { let connection_id = quote! {::libp2p::core::connection::ConnectionId}; let dial_errors = quote! {Option<&Vec<::libp2p::core::Multiaddr>>}; let connected_point = quote! {::libp2p::core::ConnectedPoint}; - let listener_id = quote! {::libp2p::core::connection::ListenerId}; + let listener_id = quote! {::libp2p::core::transport::ListenerId}; let dial_error = quote! {::libp2p::swarm::DialError}; let poll_parameters = quote! {::libp2p::swarm::PollParameters}; diff --git a/swarm/CHANGELOG.md b/swarm/CHANGELOG.md index 463d406add7..300e4e96a09 100644 --- a/swarm/CHANGELOG.md +++ b/swarm/CHANGELOG.md @@ -4,7 +4,10 @@ - Extend log message when exceeding inbound negotiating streams with peer ID and limit. See [PR 2716]. +- Remove `connection::ListenersStream` and poll the `Transport` directly. See [PR 2652]. + [PR 2716]: https://github.com/libp2p/rust-libp2p/pull/2716/ +[PR 2652]: https://github.com/libp2p/rust-libp2p/pull/2652 # 0.36.1 diff --git a/swarm/src/behaviour.rs b/swarm/src/behaviour.rs index d09427fbc6b..3dd6ddf9588 100644 --- a/swarm/src/behaviour.rs +++ b/swarm/src/behaviour.rs @@ -25,8 +25,7 @@ use crate::dial_opts::DialOpts; use crate::handler::{ConnectionHandler, IntoConnectionHandler}; use crate::{AddressRecord, AddressScore, DialError}; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - ConnectedPoint, Multiaddr, PeerId, + connection::ConnectionId, transport::ListenerId, ConnectedPoint, Multiaddr, PeerId, }; use std::{task::Context, task::Poll}; diff --git a/swarm/src/behaviour/either.rs b/swarm/src/behaviour/either.rs index 479534f6a8f..54e60e77b3a 100644 --- a/swarm/src/behaviour/either.rs +++ b/swarm/src/behaviour/either.rs @@ -25,8 +25,7 @@ use crate::{ }; use either::Either; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - ConnectedPoint, Multiaddr, PeerId, + connection::ConnectionId, transport::ListenerId, ConnectedPoint, Multiaddr, PeerId, }; use std::{task::Context, task::Poll}; diff --git a/swarm/src/behaviour/toggle.rs b/swarm/src/behaviour/toggle.rs index 25183b932c7..50ea6487770 100644 --- a/swarm/src/behaviour/toggle.rs +++ b/swarm/src/behaviour/toggle.rs @@ -29,8 +29,9 @@ use crate::{ }; use either::Either; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, + connection::ConnectionId, either::{EitherError, EitherOutput}, + transport::ListenerId, upgrade::{DeniedUpgrade, EitherUpgrade}, ConnectedPoint, Multiaddr, PeerId, }; diff --git a/swarm/src/connection.rs b/swarm/src/connection.rs index b09cb0480bd..733016dceb0 100644 --- a/swarm/src/connection.rs +++ b/swarm/src/connection.rs @@ -20,7 +20,6 @@ mod error; mod handler_wrapper; -mod listeners; mod substream; pub(crate) mod pool; @@ -29,7 +28,6 @@ pub use error::{ ConnectionError, PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError, }; -pub use listeners::{ListenersEvent, ListenersStream}; pub use pool::{ConnectionCounters, ConnectionLimits}; pub use pool::{EstablishedConnection, PendingConnection}; pub use substream::{Close, SubstreamEndpoint}; diff --git a/swarm/src/connection/listeners.rs b/swarm/src/connection/listeners.rs deleted file mode 100644 index 484a36dc15d..00000000000 --- a/swarm/src/connection/listeners.rs +++ /dev/null @@ -1,554 +0,0 @@ -// Copyright 2018 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Manage listening on multiple multiaddresses at once. - -use crate::{ - transport::{ListenerEvent, TransportError}, - Multiaddr, Transport, -}; -use futures::{prelude::*, task::Context, task::Poll}; -use libp2p_core::connection::ListenerId; -use log::debug; -use smallvec::SmallVec; -use std::{collections::VecDeque, fmt, mem, pin::Pin}; - -/// Implementation of `futures::Stream` that allows listening on multiaddresses. -/// -/// To start using a [`ListenersStream`], create one with [`ListenersStream::new`] by passing an -/// implementation of [`Transport`]. This [`Transport`] will be used to start listening, therefore -/// you want to pass a [`Transport`] that supports the protocols you wish you listen on. -/// -/// Then, call [`ListenersStream::listen_on`] for all addresses you want to start listening on. -/// -/// The [`ListenersStream`] never ends and never produces errors. If a listener errors or closes, an -/// event is generated on the stream and the listener is then dropped, but the [`ListenersStream`] -/// itself continues. -pub struct ListenersStream -where - TTrans: Transport, -{ - /// Transport used to spawn listeners. - transport: TTrans, - /// All the active listeners. - /// The `Listener` struct contains a stream that we want to be pinned. Since the `VecDeque` - /// can be resized, the only way is to use a `Pin>`. - listeners: VecDeque>>>, - /// The next listener ID to assign. - next_id: ListenerId, - /// Pending listeners events to return from [`ListenersStream::poll`]. - pending_events: VecDeque>, -} - -/// A single active listener. -#[pin_project::pin_project] -#[derive(Debug)] -struct Listener -where - TTrans: Transport, -{ - /// The ID of this listener. - id: ListenerId, - /// The object that actually listens. - #[pin] - listener: TTrans::Listener, - /// Addresses it is listening on. - addresses: SmallVec<[Multiaddr; 4]>, -} - -/// Event that can happen on the `ListenersStream`. -pub enum ListenersEvent -where - TTrans: Transport, -{ - /// A new address is being listened on. - NewAddress { - /// The listener that is listening on the new address. - listener_id: ListenerId, - /// The new address that is being listened on. - listen_addr: Multiaddr, - }, - /// An address is no longer being listened on. - AddressExpired { - /// The listener that is no longer listening on the address. - listener_id: ListenerId, - /// The new address that is being listened on. - listen_addr: Multiaddr, - }, - /// A connection is incoming on one of the listeners. - Incoming { - /// The listener that produced the upgrade. - listener_id: ListenerId, - /// The produced upgrade. - upgrade: TTrans::ListenerUpgrade, - /// Local connection address. - local_addr: Multiaddr, - /// Address used to send back data to the incoming client. - send_back_addr: Multiaddr, - }, - /// A listener closed. - Closed { - /// The ID of the listener that closed. - listener_id: ListenerId, - /// The addresses that the listener was listening on. - addresses: Vec, - /// Reason for the closure. Contains `Ok(())` if the stream produced `None`, or `Err` - /// if the stream produced an error. - reason: Result<(), TTrans::Error>, - }, - /// A listener errored. - /// - /// The listener will continue to be polled for new events and the event - /// is for informational purposes only. - Error { - /// The ID of the listener that errored. - listener_id: ListenerId, - /// The error value. - error: TTrans::Error, - }, -} - -impl ListenersStream -where - TTrans: Transport, -{ - /// Starts a new stream of listeners. - pub fn new(transport: TTrans) -> Self { - ListenersStream { - transport, - listeners: VecDeque::new(), - next_id: ListenerId::new(1), - pending_events: VecDeque::new(), - } - } - - /// Start listening on a multiaddress. - /// - /// Returns an error if the transport doesn't support the given multiaddress. - pub fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let listener = self.transport.listen_on(addr)?; - self.listeners.push_back(Box::pin(Listener { - id: self.next_id, - listener, - addresses: SmallVec::new(), - })); - let id = self.next_id; - self.next_id = self.next_id + 1; - Ok(id) - } - - /// Remove the listener matching the given `ListenerId`. - /// - /// Returns `true` if there was a listener with this ID, `false` - /// otherwise. - pub fn remove_listener(&mut self, id: ListenerId) -> bool { - if let Some(i) = self.listeners.iter().position(|l| l.id == id) { - let mut listener = self - .listeners - .remove(i) - .expect("Index can not be out of bounds."); - let listener_project = listener.as_mut().project(); - let addresses = mem::take(listener_project.addresses).into_vec(); - self.pending_events.push_back(ListenersEvent::Closed { - listener_id: *listener_project.id, - addresses, - reason: Ok(()), - }); - true - } else { - false - } - } - - /// Returns a reference to the transport passed when building this object. - pub fn transport(&self) -> &TTrans { - &self.transport - } - - /// Returns a mutable reference to the transport passed when building this object. - pub fn transport_mut(&mut self) -> &mut TTrans { - &mut self.transport - } - - /// Returns an iterator that produces the list of addresses we're listening on. - pub fn listen_addrs(&self) -> impl Iterator { - self.listeners.iter().flat_map(|l| l.addresses.iter()) - } - - /// Provides an API similar to `Stream`, except that it cannot end. - pub fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Return pending events from closed listeners. - if let Some(event) = self.pending_events.pop_front() { - return Poll::Ready(event); - } - // We remove each element from `listeners` one by one and add them back. - let mut remaining = self.listeners.len(); - while let Some(mut listener) = self.listeners.pop_back() { - let mut listener_project = listener.as_mut().project(); - match TryStream::try_poll_next(listener_project.listener.as_mut(), cx) { - Poll::Pending => { - self.listeners.push_front(listener); - remaining -= 1; - if remaining == 0 { - break; - } - } - Poll::Ready(Some(Ok(ListenerEvent::Upgrade { - upgrade, - local_addr, - remote_addr, - }))) => { - let id = *listener_project.id; - self.listeners.push_front(listener); - return Poll::Ready(ListenersEvent::Incoming { - listener_id: id, - upgrade, - local_addr, - send_back_addr: remote_addr, - }); - } - Poll::Ready(Some(Ok(ListenerEvent::NewAddress(a)))) => { - if listener_project.addresses.contains(&a) { - debug!("Transport has reported address {} multiple times", a) - } else { - listener_project.addresses.push(a.clone()); - } - let id = *listener_project.id; - self.listeners.push_front(listener); - return Poll::Ready(ListenersEvent::NewAddress { - listener_id: id, - listen_addr: a, - }); - } - Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(a)))) => { - listener_project.addresses.retain(|x| x != &a); - let id = *listener_project.id; - self.listeners.push_front(listener); - return Poll::Ready(ListenersEvent::AddressExpired { - listener_id: id, - listen_addr: a, - }); - } - Poll::Ready(Some(Ok(ListenerEvent::Error(error)))) => { - let id = *listener_project.id; - self.listeners.push_front(listener); - return Poll::Ready(ListenersEvent::Error { - listener_id: id, - error, - }); - } - Poll::Ready(None) => { - let addresses = mem::take(listener_project.addresses).into_vec(); - return Poll::Ready(ListenersEvent::Closed { - listener_id: *listener_project.id, - addresses, - reason: Ok(()), - }); - } - Poll::Ready(Some(Err(err))) => { - let addresses = mem::take(listener_project.addresses).into_vec(); - return Poll::Ready(ListenersEvent::Closed { - listener_id: *listener_project.id, - addresses, - reason: Err(err), - }); - } - } - } - - // We register the current task to be woken up if a new listener is added. - Poll::Pending - } -} - -impl Stream for ListenersStream -where - TTrans: Transport, -{ - type Item = ListenersEvent; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ListenersStream::poll(self, cx).map(Option::Some) - } -} - -impl Unpin for ListenersStream where TTrans: Transport {} - -impl fmt::Debug for ListenersStream -where - TTrans: Transport + fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - f.debug_struct("ListenersStream") - .field("transport", &self.transport) - .field("listen_addrs", &self.listen_addrs().collect::>()) - .finish() - } -} - -impl fmt::Debug for ListenersEvent -where - TTrans: Transport, - TTrans::Error: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self { - ListenersEvent::NewAddress { - listener_id, - listen_addr, - } => f - .debug_struct("ListenersEvent::NewAddress") - .field("listener_id", listener_id) - .field("listen_addr", listen_addr) - .finish(), - ListenersEvent::AddressExpired { - listener_id, - listen_addr, - } => f - .debug_struct("ListenersEvent::AddressExpired") - .field("listener_id", listener_id) - .field("listen_addr", listen_addr) - .finish(), - ListenersEvent::Incoming { - listener_id, - local_addr, - .. - } => f - .debug_struct("ListenersEvent::Incoming") - .field("listener_id", listener_id) - .field("local_addr", local_addr) - .finish(), - ListenersEvent::Closed { - listener_id, - addresses, - reason, - } => f - .debug_struct("ListenersEvent::Closed") - .field("listener_id", listener_id) - .field("addresses", addresses) - .field("reason", reason) - .finish(), - ListenersEvent::Error { listener_id, error } => f - .debug_struct("ListenersEvent::Error") - .field("listener_id", listener_id) - .field("error", error) - .finish(), - } - } -} - -#[cfg(test)] -mod tests { - use futures::{future::BoxFuture, stream::BoxStream}; - - use super::*; - use crate::transport; - - #[test] - fn incoming_event() { - async_std::task::block_on(async move { - let mut mem_transport = transport::MemoryTransport::default(); - - let mut listeners = ListenersStream::new(mem_transport); - listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - - let address = { - let event = listeners.next().await.unwrap(); - if let ListenersEvent::NewAddress { listen_addr, .. } = event { - listen_addr - } else { - panic!("Was expecting the listen address to be reported") - } - }; - - let address2 = address.clone(); - async_std::task::spawn(async move { - mem_transport.dial(address2).unwrap().await.unwrap(); - }); - - match listeners.next().await.unwrap() { - ListenersEvent::Incoming { - local_addr, - send_back_addr, - .. - } => { - assert_eq!(local_addr, address); - assert!(send_back_addr != address); - } - _ => panic!(), - } - }); - } - - #[test] - fn listener_event_error_isnt_fatal() { - // Tests that a listener continues to be polled even after producing - // a `ListenerEvent::Error`. - - #[derive(Clone)] - struct DummyTrans; - impl transport::Transport for DummyTrans { - type Output = (); - type Error = std::io::Error; - type Listener = BoxStream< - 'static, - Result, std::io::Error>, - >; - type ListenerUpgrade = BoxFuture<'static, Result>; - type Dial = BoxFuture<'static, Result>; - - fn listen_on( - &mut self, - _: Multiaddr, - ) -> Result> { - Ok(Box::pin(stream::unfold((), |()| async move { - Some(( - Ok(ListenerEvent::Error(std::io::Error::from( - std::io::ErrorKind::Other, - ))), - (), - )) - }))) - } - - fn dial( - &mut self, - _: Multiaddr, - ) -> Result> { - panic!() - } - - fn dial_as_listener( - &mut self, - _: Multiaddr, - ) -> Result> { - panic!() - } - - fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { - None - } - } - - async_std::task::block_on(async move { - let transport = DummyTrans; - let mut listeners = ListenersStream::new(transport); - listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - - for _ in 0..10 { - match listeners.next().await.unwrap() { - ListenersEvent::Error { .. } => {} - _ => panic!(), - } - } - }); - } - - #[test] - fn listener_error_is_fatal() { - // Tests that a listener stops after producing an error on the stream itself. - - #[derive(Clone)] - struct DummyTrans; - impl transport::Transport for DummyTrans { - type Output = (); - type Error = std::io::Error; - type Listener = BoxStream< - 'static, - Result, std::io::Error>, - >; - type ListenerUpgrade = BoxFuture<'static, Result>; - type Dial = BoxFuture<'static, Result>; - - fn listen_on( - &mut self, - _: Multiaddr, - ) -> Result> { - Ok(Box::pin(stream::unfold((), |()| async move { - Some((Err(std::io::Error::from(std::io::ErrorKind::Other)), ())) - }))) - } - - fn dial( - &mut self, - _: Multiaddr, - ) -> Result> { - panic!() - } - - fn dial_as_listener( - &mut self, - _: Multiaddr, - ) -> Result> { - panic!() - } - - fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { - None - } - } - - async_std::task::block_on(async move { - let transport = DummyTrans; - let mut listeners = ListenersStream::new(transport); - listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - - match listeners.next().await.unwrap() { - ListenersEvent::Closed { .. } => {} - _ => panic!(), - } - }); - } - - #[test] - fn listener_closed() { - async_std::task::block_on(async move { - let mem_transport = transport::MemoryTransport::default(); - - let mut listeners = ListenersStream::new(mem_transport); - let id = listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - - let event = listeners.next().await.unwrap(); - let addr; - if let ListenersEvent::NewAddress { listen_addr, .. } = event { - addr = listen_addr - } else { - panic!("Was expecting the listen address to be reported") - } - - assert!(listeners.remove_listener(id)); - - match listeners.next().await.unwrap() { - ListenersEvent::Closed { - listener_id, - addresses, - reason: Ok(()), - } => { - assert_eq!(listener_id, id); - assert!(addresses.contains(&addr)); - } - other => panic!("Unexpected listeners event: {:?}", other), - } - }); - } -} diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index c5ef1b0aa88..f32c2df56bf 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -79,24 +79,24 @@ pub use handler::{ pub use registry::{AddAddressResult, AddressRecord, AddressScore}; use connection::pool::{Pool, PoolConfig, PoolEvent}; -use connection::{EstablishedConnection, IncomingInfo, ListenersEvent, ListenersStream}; +use connection::{EstablishedConnection, IncomingInfo}; use dial_opts::{DialOpts, PeerCondition}; use either::Either; use futures::{executor::ThreadPoolBuilder, prelude::*, stream::FusedStream}; use libp2p_core::connection::{ConnectionId, PendingPoint}; use libp2p_core::muxing::SubstreamBox; use libp2p_core::{ - connection::{ConnectedPoint, ListenerId}, + connection::ConnectedPoint, multiaddr::Protocol, multihash::Multihash, muxing::StreamMuxerBox, - transport::{self, TransportError}, + transport::{self, ListenerId, TransportError, TransportEvent}, upgrade::ProtocolName, Endpoint, Executor, Multiaddr, Negotiated, PeerId, Transport, }; use registry::{AddressIntoIter, Addresses}; use smallvec::SmallVec; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::iter; use std::num::{NonZeroU32, NonZeroU8, NonZeroUsize}; use std::{ @@ -258,8 +258,8 @@ pub struct Swarm where TBehaviour: NetworkBehaviour, { - /// Listeners for incoming connections. - listeners: ListenersStream>, + /// [`Transport`] for dialing remote peers and listening for incoming connection. + transport: transport::Boxed<(PeerId, StreamMuxerBox)>, /// The nodes currently active. pool: Pool, transport::Boxed<(PeerId, StreamMuxerBox)>>, @@ -274,8 +274,8 @@ where /// List of protocols that the behaviour says it supports. supported_protocols: SmallVec<[Vec; 16]>, - /// List of multiaddresses we're listening on. - listened_addrs: SmallVec<[Multiaddr; 8]>, + /// Multiaddresses that our listeners are listening on, + listened_addrs: HashMap>, /// List of multiaddresses we're listening on, after account for external IP addresses and /// similar mechanisms. @@ -327,7 +327,7 @@ where /// Listeners report their new listening addresses as [`SwarmEvent::NewListenAddr`]. /// Depending on the underlying transport, one listener may have multiple listening addresses. pub fn listen_on(&mut self, addr: Multiaddr) -> Result> { - let id = self.listeners.listen_on(addr)?; + let id = self.transport.listen_on(addr)?; self.behaviour.inject_new_listener(id); Ok(id) } @@ -336,8 +336,8 @@ where /// /// Returns `true` if there was a listener with this ID, `false` /// otherwise. - pub fn remove_listener(&mut self, id: ListenerId) -> bool { - self.listeners.remove_listener(id) + pub fn remove_listener(&mut self, listener_id: ListenerId) -> bool { + self.transport.remove_listener(listener_id) } /// Dial a known or unknown peer. @@ -446,8 +446,9 @@ where }; let mut unique_addresses = HashSet::new(); - addresses.retain(|a| { - !self.listened_addrs.contains(a) && unique_addresses.insert(a.clone()) + addresses.retain(|addr| { + !self.listened_addrs.values().flatten().any(|a| a == addr) + && unique_addresses.insert(addr.clone()) }); if addresses.is_empty() { @@ -507,11 +508,8 @@ where .map(|a| match p2p_addr(peer_id, a) { Ok(address) => { let dial = match role_override { - Endpoint::Dialer => self.listeners.transport_mut().dial(address.clone()), - Endpoint::Listener => self - .listeners - .transport_mut() - .dial_as_listener(address.clone()), + Endpoint::Dialer => self.transport.dial(address.clone()), + Endpoint::Listener => self.transport.dial_as_listener(address.clone()), }; match dial { Ok(fut) => fut @@ -546,7 +544,7 @@ where /// Returns an iterator that produces the list of addresses we're listening on. pub fn listeners(&self) -> impl Iterator { - self.listeners.listen_addrs() + self.listened_addrs.values().flatten() } /// Returns the peer ID of the swarm passed as parameter. @@ -830,12 +828,15 @@ where None } - fn handle_listeners_event( + fn handle_transport_event( &mut self, - event: ListenersEvent>, + event: TransportEvent< + as Transport>::ListenerUpgrade, + io::Error, + >, ) -> Option>> { match event { - ListenersEvent::Incoming { + TransportEvent::Incoming { listener_id: _, upgrade, local_addr, @@ -863,13 +864,14 @@ where } }; } - ListenersEvent::NewAddress { + TransportEvent::NewAddress { listener_id, listen_addr, } => { log::debug!("Listener {:?}; New address: {:?}", listener_id, listen_addr); - if !self.listened_addrs.contains(&listen_addr) { - self.listened_addrs.push(listen_addr.clone()) + let addrs = self.listened_addrs.entry(listener_id).or_default(); + if !addrs.contains(&listen_addr) { + addrs.push(listen_addr.clone()) } self.behaviour .inject_new_listen_addr(listener_id, &listen_addr); @@ -878,7 +880,7 @@ where address: listen_addr, }); } - ListenersEvent::AddressExpired { + TransportEvent::AddressExpired { listener_id, listen_addr, } => { @@ -887,7 +889,9 @@ where listener_id, listen_addr ); - self.listened_addrs.retain(|a| a != &listen_addr); + if let Some(addrs) = self.listened_addrs.get_mut(&listener_id) { + addrs.retain(|a| a != &listen_addr); + } self.behaviour .inject_expired_listen_addr(listener_id, &listen_addr); return Some(SwarmEvent::ExpiredListenAddr { @@ -895,13 +899,13 @@ where address: listen_addr, }); } - ListenersEvent::Closed { + TransportEvent::ListenerClosed { listener_id, - addresses, reason, } => { log::debug!("Listener {:?}; Closed by {:?}.", listener_id, reason); - for addr in addresses.iter() { + let addrs = self.listened_addrs.remove(&listener_id).unwrap_or_default(); + for addr in addrs.iter() { self.behaviour.inject_expired_listen_addr(listener_id, addr); } self.behaviour.inject_listener_closed( @@ -913,11 +917,11 @@ where ); return Some(SwarmEvent::ListenerClosed { listener_id, - addresses, + addresses: addrs.to_vec(), reason, }); } - ListenersEvent::Error { listener_id, error } => { + TransportEvent::ListenerError { listener_id, error } => { self.behaviour.inject_listener_error(listener_id, &error); return Some(SwarmEvent::ListenerError { listener_id, error }); } @@ -974,11 +978,11 @@ where // // The translation is transport-specific. See [`Transport::address_translation`]. let translated_addresses = { - let transport = self.listeners.transport(); let mut addrs: Vec<_> = self - .listeners - .listen_addrs() - .filter_map(move |server| transport.address_translation(server, &address)) + .listened_addrs + .values() + .flatten() + .filter_map(|server| self.transport.address_translation(server, &address)) .collect(); // remove duplicates @@ -1060,7 +1064,7 @@ where let mut parameters = SwarmPollParameters { local_peer_id: &this.local_peer_id, supported_protocols: &this.supported_protocols, - listened_addrs: &this.listened_addrs, + listened_addrs: this.listened_addrs.values().flatten().collect(), external_addrs: &this.external_addrs, }; this.behaviour.poll(cx, &mut parameters) @@ -1093,10 +1097,10 @@ where }; // Poll the listener(s) for new connections. - match ListenersStream::poll(Pin::new(&mut this.listeners), cx) { + match Pin::new(&mut this.transport).poll(cx) { Poll::Pending => {} - Poll::Ready(listeners_event) => { - if let Some(swarm_event) = this.handle_listeners_event(listeners_event) { + Poll::Ready(transport_event) => { + if let Some(swarm_event) = this.handle_transport_event(transport_event) { return Poll::Ready(swarm_event); } @@ -1231,13 +1235,13 @@ where pub struct SwarmPollParameters<'a> { local_peer_id: &'a PeerId, supported_protocols: &'a [Vec], - listened_addrs: &'a [Multiaddr], + listened_addrs: Vec<&'a Multiaddr>, external_addrs: &'a Addresses, } impl<'a> PollParameters for SwarmPollParameters<'a> { type SupportedProtocolsIter = std::iter::Cloned>>; - type ListenedAddressesIter = std::iter::Cloned>; + type ListenedAddressesIter = std::iter::Cloned>; type ExternalAddressesIter = AddressIntoIter; fn supported_protocols(&self) -> Self::SupportedProtocolsIter { @@ -1245,7 +1249,7 @@ impl<'a> PollParameters for SwarmPollParameters<'a> { } fn listened_addresses(&self) -> Self::ListenedAddressesIter { - self.listened_addrs.iter().cloned() + self.listened_addrs.clone().into_iter().cloned() } fn external_addresses(&self) -> Self::ExternalAddressesIter { @@ -1401,11 +1405,11 @@ where Swarm { local_peer_id: self.local_peer_id, - listeners: ListenersStream::new(self.transport), + transport: self.transport, pool: Pool::new(self.local_peer_id, pool_config, self.connection_limits), behaviour: self.behaviour, supported_protocols, - listened_addrs: SmallVec::new(), + listened_addrs: HashMap::new(), external_addrs: Addresses::default(), banned_peers: HashSet::new(), banned_peer_connections: HashSet::new(), @@ -1618,7 +1622,7 @@ mod tests { use libp2p::plaintext; use libp2p::yamux; use libp2p_core::multiaddr::multiaddr; - use libp2p_core::transport::ListenerEvent; + use libp2p_core::transport::TransportEvent; use libp2p_core::Endpoint; use quickcheck::{quickcheck, Arbitrary, Gen, QuickCheck}; use rand::prelude::SliceRandom; @@ -2067,20 +2071,19 @@ mod tests { // `+ 2` to ensure a subset of addresses is dialed by network_2. let num_listen_addrs = concurrency_factor.0.get() + 2; let mut listen_addresses = Vec::new(); - let mut listeners = Vec::new(); + let mut transports = Vec::new(); for _ in 0..num_listen_addrs { - let mut listener = transport::MemoryTransport {} - .listen_on("/memory/0".parse().unwrap()) - .unwrap(); + let mut transport = transport::MemoryTransport::default().boxed(); + transport.listen_on("/memory/0".parse().unwrap()).unwrap(); - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(address) => { - listen_addresses.push(address); + match transport.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { + listen_addresses.push(listen_addr); } _ => panic!("Expected `NewListenAddr` event."), } - listeners.push(listener); + transports.push(transport); } // Have swarm dial each listener and wait for each listener to receive the incoming @@ -2092,14 +2095,16 @@ mod tests { .build(), ) .unwrap(); - for mut listener in listeners.into_iter() { + for mut transport in transports.into_iter() { loop { - match futures::future::select(listener.next(), swarm.next()).await { - Either::Left((Some(Ok(ListenerEvent::Upgrade { .. })), _)) => { + match futures::future::select(transport.select_next_some(), swarm.next()) + .await + { + Either::Left((TransportEvent::Incoming { .. }, _)) => { break; } Either::Left(_) => { - panic!("Unexpected listener event.") + panic!("Unexpected transport event.") } Either::Right((e, _)) => { panic!("Expect swarm to not emit any event {:?}", e) diff --git a/swarm/src/test.rs b/swarm/src/test.rs index e201432ec39..166e9185a47 100644 --- a/swarm/src/test.rs +++ b/swarm/src/test.rs @@ -23,9 +23,7 @@ use crate::{ PollParameters, }; use libp2p_core::{ - connection::{ConnectionId, ListenerId}, - multiaddr::Multiaddr, - ConnectedPoint, PeerId, + connection::ConnectionId, multiaddr::Multiaddr, transport::ListenerId, ConnectedPoint, PeerId, }; use std::collections::HashMap; use std::task::{Context, Poll}; diff --git a/transports/deflate/tests/test.rs b/transports/deflate/tests/test.rs index 743d464c1cf..06d91ddd808 100644 --- a/transports/deflate/tests/test.rs +++ b/transports/deflate/tests/test.rs @@ -21,7 +21,7 @@ use futures::{future, prelude::*}; use libp2p_core::{transport::Transport, upgrade}; use libp2p_deflate::DeflateConfig; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::TcpTransport; use quickcheck::{QuickCheck, RngCore, TestResult}; #[test] @@ -44,38 +44,39 @@ fn lot_of_data() { } async fn run(message1: Vec) { - let mut transport = TcpConfig::new().and_then(|conn, endpoint| { - upgrade::apply( - conn, - DeflateConfig::default(), - endpoint, - upgrade::Version::V1, - ) - }); - - let mut listener = transport + let new_transport = || { + TcpTransport::default() + .and_then(|conn, endpoint| { + upgrade::apply( + conn, + DeflateConfig::default(), + endpoint, + upgrade::Version::V1, + ) + }) + .boxed() + }; + let mut listener_transport = new_transport(); + listener_transport .listen_on("/ip4/0.0.0.0/tcp/0".parse().expect("multiaddr")) .expect("listener"); - let listen_addr = listener - .by_ref() + let listen_addr = listener_transport .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("new address"); let message2 = message1.clone(); let listener_task = async_std::task::spawn(async move { - let mut conn = listener - .filter(|e| future::ready(e.as_ref().map(|e| e.is_upgrade()).unwrap_or(false))) + let mut conn = listener_transport + .filter(|e| future::ready(e.is_upgrade())) .next() .await .expect("some event") - .expect("no error") - .into_upgrade() + .into_incoming() .expect("upgrade") .0 .await @@ -89,7 +90,8 @@ async fn run(message1: Vec) { conn.close().await.expect("close") }); - let mut conn = transport + let mut dialer_transport = new_transport(); + let mut conn = dialer_transport .dial(listen_addr) .expect("dialer") .await diff --git a/transports/dns/src/lib.rs b/transports/dns/src/lib.rs index 45806a7b772..0ee89f78373 100644 --- a/transports/dns/src/lib.rs +++ b/transports/dns/src/lib.rs @@ -60,15 +60,23 @@ use futures::{future::BoxFuture, prelude::*}; use libp2p_core::{ connection::Endpoint, multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, Transport, }; use parking_lot::Mutex; use smallvec::SmallVec; #[cfg(any(feature = "async-std", feature = "tokio"))] use std::io; -use std::sync::Arc; -use std::{convert::TryFrom, error, fmt, iter, net::IpAddr, str}; +use std::{ + convert::TryFrom, + error, fmt, iter, + net::IpAddr, + ops::DerefMut, + pin::Pin, + str, + sync::Arc, + task::{Context, Poll}, +}; #[cfg(any(feature = "async-std", feature = "tokio"))] use trust_dns_resolver::system_conf; use trust_dns_resolver::{proto::xfer::dns_handle::DnsHandle, AsyncResolver, ConnectionProvider}; @@ -174,7 +182,7 @@ where impl Transport for GenDnsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send, T::Dial: Send, C: DnsHandle, @@ -182,38 +190,21 @@ where { type Output = T::Output; type Error = DnsErr; - type Listener = stream::MapErr< - stream::MapOk< - T::Listener, - fn( - ListenerEvent, - ) -> ListenerEvent, - >, - fn(T::Error) -> Self::Error, - >; type ListenerUpgrade = future::MapErr Self::Error>; type Dial = future::Either< future::MapErr Self::Error>, BoxFuture<'static, Result>, >; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { - let listener = self - .inner + fn listen_on(&mut self, addr: Multiaddr) -> Result> { + self.inner .lock() .listen_on(addr) - .map_err(|err| err.map(DnsErr::Transport))?; - let listener = listener - .map_ok::<_, fn(_) -> _>(|event| { - event - .map(|upgr| upgr.map_err::<_, fn(_) -> _>(DnsErr::Transport)) - .map_err(DnsErr::Transport) - }) - .map_err::<_, fn(_) -> _>(DnsErr::Transport); - Ok(listener) + .map_err(|e| e.map(DnsErr::Transport)) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.inner.lock().remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -230,11 +221,23 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.inner.lock().address_translation(server, observed) } + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut inner = self.inner.lock(); + Transport::poll(Pin::new(inner.deref_mut()), cx).map(|event| { + event + .map_upgrade(|upgr| upgr.map_err::<_, fn(_) -> _>(DnsErr::Transport)) + .map_err(DnsErr::Transport) + }) + } } impl GenDnsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send, T::Dial: Send, C: DnsHandle, @@ -571,11 +574,10 @@ fn invalid_data(e: impl Into>) -> io::E #[cfg(test)] mod tests { use super::*; - use futures::{future::BoxFuture, stream::BoxStream}; + use futures::future::BoxFuture; use libp2p_core::{ multiaddr::{Multiaddr, Protocol}, - transport::ListenerEvent, - transport::TransportError, + transport::{TransportError, TransportEvent}, PeerId, Transport, }; @@ -589,20 +591,20 @@ mod tests { impl Transport for CustomTransport { type Output = (); type Error = std::io::Error; - type Listener = BoxStream< - 'static, - Result, Self::Error>, - >; type ListenerUpgrade = BoxFuture<'static, Result>; type Dial = BoxFuture<'static, Result>; fn listen_on( &mut self, _: Multiaddr, - ) -> Result> { + ) -> Result> { unreachable!() } + fn remove_listener(&mut self, _: ListenerId) -> bool { + false + } + fn dial(&mut self, addr: Multiaddr) -> Result> { // Check that all DNS components have been resolved, i.e. replaced. assert!(!addr.iter().any(|p| match p { @@ -625,13 +627,20 @@ mod tests { fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { None } + + fn poll( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + unreachable!() + } } async fn run(mut transport: GenDnsConfig) where C: DnsHandle, P: ConnectionProvider, - T: Transport + Clone + Send + 'static, + T: Transport + Clone + Send + Unpin + 'static, T::Error: Send, T::Dial: Send, { diff --git a/transports/noise/src/lib.rs b/transports/noise/src/lib.rs index f4cc85dea4a..ee609fd028d 100644 --- a/transports/noise/src/lib.rs +++ b/transports/noise/src/lib.rs @@ -40,14 +40,14 @@ //! //! ``` //! use libp2p_core::{identity, Transport, upgrade}; -//! use libp2p_tcp::TcpConfig; +//! use libp2p_tcp::TcpTransport; //! use libp2p_noise::{Keypair, X25519Spec, NoiseConfig}; //! //! # fn main() { //! let id_keys = identity::Keypair::generate_ed25519(); //! let dh_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); //! let noise = NoiseConfig::xx(dh_keys).into_authenticated(); -//! let builder = TcpConfig::new().upgrade(upgrade::Version::V1).authenticate(noise); +//! let builder = TcpTransport::default().upgrade(upgrade::Version::V1).authenticate(noise); //! // let transport = builder.multiplex(...); //! # } //! ``` diff --git a/transports/noise/tests/smoke.rs b/transports/noise/tests/smoke.rs index 5c745c9463c..0148d03b4d6 100644 --- a/transports/noise/tests/smoke.rs +++ b/transports/noise/tests/smoke.rs @@ -24,12 +24,12 @@ use futures::{ prelude::*, }; use libp2p_core::identity; -use libp2p_core::transport::{ListenerEvent, Transport}; +use libp2p_core::transport::{self, Transport}; use libp2p_core::upgrade::{self, apply_inbound, apply_outbound, Negotiated}; use libp2p_noise::{ Keypair, NoiseConfig, NoiseError, NoiseOutput, RemoteIdentity, X25519Spec, X25519, }; -use libp2p_tcp::TcpConfig; +use libp2p_tcp::TcpTransport; use log::info; use quickcheck::QuickCheck; use std::{convert::TryInto, io, net::TcpStream}; @@ -41,7 +41,7 @@ fn core_upgrade_compat() { let id_keys = identity::Keypair::generate_ed25519(); let dh_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); let noise = NoiseConfig::xx(dh_keys).into_authenticated(); - let _ = TcpConfig::new() + let _ = TcpTransport::default() .upgrade(upgrade::Version::V1) .authenticate(noise); } @@ -60,7 +60,7 @@ fn xx_spec() { let server_dh = Keypair::::new() .into_authentic(&server_id) .unwrap(); - let server_transport = TcpConfig::new() + let server_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -69,12 +69,13 @@ fn xx_spec() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &client_id_public)); + .and_then(move |out, _| expect_identity(out, &client_id_public)) + .boxed(); let client_dh = Keypair::::new() .into_authentic(&client_id) .unwrap(); - let client_transport = TcpConfig::new() + let client_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -83,7 +84,8 @@ fn xx_spec() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &server_id_public)); + .and_then(move |out, _| expect_identity(out, &server_id_public)) + .boxed(); run(server_transport, client_transport, messages); true @@ -105,7 +107,7 @@ fn xx() { let client_id_public = client_id.public(); let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); - let server_transport = TcpConfig::new() + let server_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -114,10 +116,11 @@ fn xx() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &client_id_public)); + .and_then(move |out, _| expect_identity(out, &client_id_public)) + .boxed(); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); - let client_transport = TcpConfig::new() + let client_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -126,7 +129,8 @@ fn xx() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &server_id_public)); + .and_then(move |out, _| expect_identity(out, &server_id_public)) + .boxed(); run(server_transport, client_transport, messages); true @@ -148,7 +152,7 @@ fn ix() { let client_id_public = client_id.public(); let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); - let server_transport = TcpConfig::new() + let server_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -157,10 +161,11 @@ fn ix() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &client_id_public)); + .and_then(move |out, _| expect_identity(out, &client_id_public)) + .boxed(); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); - let client_transport = TcpConfig::new() + let client_transport = TcpTransport::default() .and_then(move |output, endpoint| { upgrade::apply( output, @@ -169,7 +174,8 @@ fn ix() { upgrade::Version::V1, ) }) - .and_then(move |out, _| expect_identity(out, &server_id_public)); + .and_then(move |out, _| expect_identity(out, &server_id_public)) + .boxed(); run(server_transport, client_transport, messages); true @@ -192,7 +198,7 @@ fn ik_xx() { let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); let server_dh_public = server_dh.public().clone(); - let server_transport = TcpConfig::new() + let server_transport = TcpTransport::default() .and_then(move |output, endpoint| { if endpoint.is_listener() { Either::Left(apply_inbound(output, NoiseConfig::ik_listener(server_dh))) @@ -204,11 +210,12 @@ fn ik_xx() { )) } }) - .and_then(move |out, _| expect_identity(out, &client_id_public)); + .and_then(move |out, _| expect_identity(out, &client_id_public)) + .boxed(); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); let server_id_public2 = server_id_public.clone(); - let client_transport = TcpConfig::new() + let client_transport = TcpTransport::default() .and_then(move |output, endpoint| { if endpoint.is_dialer() { Either::Left(apply_outbound( @@ -220,7 +227,8 @@ fn ik_xx() { Either::Right(apply_inbound(output, NoiseConfig::xx(client_dh))) } }) - .and_then(move |out, _| expect_identity(out, &server_id_public2)); + .and_then(move |out, _| expect_identity(out, &server_id_public2)) + .boxed(); run(server_transport, client_transport, messages); true @@ -232,34 +240,28 @@ fn ik_xx() { type Output = (RemoteIdentity, NoiseOutput>>); -fn run(mut server_transport: T, mut client_transport: U, messages: I) -where - T: Transport>, - T::Dial: Send + 'static, - T::Listener: Send + Unpin + 'static, - T::ListenerUpgrade: Send + 'static, - U: Transport>, - U::Dial: Send + 'static, - U::Listener: Send + 'static, - U::ListenerUpgrade: Send + 'static, +fn run( + mut server: transport::Boxed>, + mut client: transport::Boxed>, + messages: I, +) where I: IntoIterator + Clone, { futures::executor::block_on(async { - let mut server: T::Listener = server_transport + server .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); let server_address = server - .try_next() + .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); let outbound_msgs = messages.clone(); let client_fut = async { - let mut client_session = client_transport + let mut client_session = client .dial(server_address.clone()) .unwrap() .await @@ -276,13 +278,12 @@ where let server_fut = async { let mut server_session = server - .try_next() + .next() .await .expect("some event") - .map(ListenerEvent::into_upgrade) - .expect("no error") - .map(|client| client.0) + .into_incoming() .expect("listener upgrade") + .0 .await .map(|(_, session)| session) .expect("no error"); diff --git a/transports/plaintext/tests/smoke.rs b/transports/plaintext/tests/smoke.rs index ec20e8ff20e..ea62f0a9dfa 100644 --- a/transports/plaintext/tests/smoke.rs +++ b/transports/plaintext/tests/smoke.rs @@ -18,14 +18,11 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::io::{AsyncReadExt, AsyncWriteExt}; -use futures::stream::TryStreamExt; -use libp2p_core::{ - identity, - multiaddr::Multiaddr, - transport::{ListenerEvent, Transport}, - upgrade, +use futures::{ + io::{AsyncReadExt, AsyncWriteExt}, + StreamExt, }; +use libp2p_core::{identity, multiaddr::Multiaddr, transport::Transport, upgrade}; use libp2p_plaintext::PlainText2Config; use log::debug; use quickcheck::QuickCheck; @@ -45,8 +42,8 @@ fn variable_msg_length() { let client_id_public = client_id.public(); futures::executor::block_on(async { - let mut server_transport = - libp2p_core::transport::MemoryTransport {}.and_then(move |output, endpoint| { + let mut server = libp2p_core::transport::MemoryTransport::new() + .and_then(move |output, endpoint| { upgrade::apply( output, PlainText2Config { @@ -55,10 +52,11 @@ fn variable_msg_length() { endpoint, libp2p_core::upgrade::Version::V1, ) - }); + }) + .boxed(); - let mut client_transport = - libp2p_core::transport::MemoryTransport {}.and_then(move |output, endpoint| { + let mut client = libp2p_core::transport::MemoryTransport::new() + .and_then(move |output, endpoint| { upgrade::apply( output, PlainText2Config { @@ -67,31 +65,28 @@ fn variable_msg_length() { endpoint, libp2p_core::upgrade::Version::V1, ) - }); + }) + .boxed(); let server_address: Multiaddr = format!("/memory/{}", std::cmp::Ord::max(1, rand::random::())) .parse() .unwrap(); - let mut server = server_transport.listen_on(server_address.clone()).unwrap(); + server.listen_on(server_address.clone()).unwrap(); // Ignore server listen address event. let _ = server - .try_next() + .next() .await .expect("some event") - .expect("no error") .into_new_address() .expect("listen address"); let client_fut = async { debug!("dialing {:?}", server_address); - let (received_server_id, mut client_channel) = client_transport - .dial(server_address) - .unwrap() - .await - .unwrap(); + let (received_server_id, mut client_channel) = + client.dial(server_address).unwrap().await.unwrap(); assert_eq!(received_server_id, server_id.public().to_peer_id()); debug!("Client: writing message."); @@ -105,13 +100,12 @@ fn variable_msg_length() { let server_fut = async { let mut server_channel = server - .try_next() + .next() .await .expect("some event") - .map(ListenerEvent::into_upgrade) + .into_incoming() .expect("no error") - .map(|client| client.0) - .expect("listener upgrade xyz") + .0 .await .map(|(_, session)| session) .expect("no error"); diff --git a/transports/tcp/CHANGELOG.md b/transports/tcp/CHANGELOG.md index 1d119005696..f479cdef9e4 100644 --- a/transports/tcp/CHANGELOG.md +++ b/transports/tcp/CHANGELOG.md @@ -6,6 +6,12 @@ establishment errors early. See also [PR 2458] for the related async-io change. +- Split `GenTcpConfig` into `GenTcpConfig` and `GenTcpTransport`. Drive the `TcpListenStream`s + within the `GenTcpTransport`. Add `Transport::poll` and `Transport::remove_listener` + for `GenTcpTransport`. See [PR 2652]. + +[PR 2652]: https://github.com/libp2p/rust-libp2p/pull/2652 + # 0.33.0 - Update to `libp2p-core` `v0.33.0`. diff --git a/transports/tcp/src/lib.rs b/transports/tcp/src/lib.rs index a2650f7d216..981c896bcb5 100644 --- a/transports/tcp/src/lib.rs +++ b/transports/tcp/src/lib.rs @@ -22,7 +22,7 @@ //! //! # Usage //! -//! This crate provides a `TcpConfig` and `TokioTcpConfig`, depending on +//! This crate provides a `TcpTransport` and `TokioTcpTransport`, depending on //! the enabled features, which implement the `Transport` trait for use as a //! transport with `libp2p-core` or `libp2p-swarm`. @@ -31,16 +31,16 @@ mod provider; #[cfg(feature = "async-io")] pub use provider::async_io; -/// The type of a [`GenTcpConfig`] using the `async-io` implementation. +/// The type of a [`GenTcpTransport`] using the `async-io` implementation. #[cfg(feature = "async-io")] -pub type TcpConfig = GenTcpConfig; +pub type TcpTransport = GenTcpTransport; #[cfg(feature = "tokio")] pub use provider::tokio; -/// The type of a [`GenTcpConfig`] using the `tokio` implementation. +/// The type of a [`GenTcpTransport`] using the `tokio` implementation. #[cfg(feature = "tokio")] -pub type TokioTcpConfig = GenTcpConfig; +pub type TokioTcpTransport = GenTcpTransport; use futures::{ future::{self, BoxFuture, Ready}, @@ -51,11 +51,11 @@ use futures_timer::Delay; use libp2p_core::{ address_translation, multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, Transport, TransportError}, + transport::{ListenerId, Transport, TransportError, TransportEvent}, }; use socket2::{Domain, Socket, Type}; use std::{ - collections::HashSet, + collections::{HashSet, VecDeque}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener}, pin::Pin, @@ -67,18 +67,16 @@ use std::{ use provider::{IfEvent, Provider}; /// The configuration for a TCP/IP transport capability for libp2p. -#[derive(Debug)] -pub struct GenTcpConfig { - /// The type of the I/O provider. - _impl: std::marker::PhantomData, +#[derive(Clone, Debug)] +pub struct GenTcpConfig { /// TTL to set for opened sockets, or `None` to keep default. ttl: Option, /// `TCP_NODELAY` to set for opened sockets, or `None` to keep default. nodelay: Option, /// Size of the listen backlog for listen sockets. backlog: u32, - /// The configuration of port reuse when dialing. - port_reuse: PortReuse, + /// Whether port reuse should be enabled. + enable_port_reuse: bool, } type Port = u16; @@ -159,10 +157,7 @@ impl PortReuse { } } -impl GenTcpConfig -where - T: Provider + Send, -{ +impl GenTcpConfig { /// Creates a new configuration for a TCP/IP transport: /// /// * Nagle's algorithm, i.e. `TCP_NODELAY`, is _enabled_. @@ -178,8 +173,7 @@ where ttl: None, nodelay: None, backlog: 1024, - port_reuse: PortReuse::Disabled, - _impl: std::marker::PhantomData, + enable_port_reuse: false, } } @@ -238,29 +232,29 @@ where /// > a single outgoing connection to a particular address and port /// > of a peer per local listening socket address. /// - /// `GenTcpConfig` keeps track of the listen socket addresses as they - /// are reported by polling [`TcpListenStream`]s obtained from - /// [`GenTcpConfig::listen_on()`]. It is possible to listen on multiple + /// [`GenTcpTransport`] keeps track of the listen socket addresses as they + /// are reported by polling it. It is possible to listen on multiple /// addresses, enabling port reuse for each, knowing exactly which listen - /// address is reused when dialing with a specific `GenTcpConfig`, as in the + /// address is reused when dialing with a specific [`GenTcpTransport`], as in the /// following example: /// /// ```no_run - /// # use libp2p_core::transport::ListenerEvent; + /// # use futures::StreamExt; + /// # use libp2p_core::transport::{ListenerId, TransportEvent}; /// # use libp2p_core::{Multiaddr, Transport}; - /// # use futures::stream::StreamExt; + /// # use std::pin::Pin; /// #[cfg(feature = "async-io")] /// #[async_std::main] /// async fn main() -> std::io::Result<()> { - /// use libp2p_tcp::TcpConfig; + /// use libp2p_tcp::{GenTcpConfig, TcpTransport}; /// /// let listen_addr1: Multiaddr = "/ip4/127.0.0.1/tcp/9001".parse().unwrap(); /// let listen_addr2: Multiaddr = "/ip4/127.0.0.1/tcp/9002".parse().unwrap(); /// - /// let mut tcp1 = TcpConfig::new().port_reuse(true); - /// let mut listener1 = tcp1.listen_on(listen_addr1.clone()).expect("listener"); - /// match listener1.next().await.expect("event")? { - /// ListenerEvent::NewAddress(listen_addr) => { + /// let mut tcp1 = TcpTransport::new(GenTcpConfig::new().port_reuse(true)).boxed(); + /// tcp1.listen_on( listen_addr1.clone()).expect("listener"); + /// match tcp1.select_next_some().await { + /// TransportEvent::NewAddress { listen_addr, .. } => { /// println!("Listening on {:?}", listen_addr); /// let mut stream = tcp1.dial(listen_addr2.clone()).unwrap().await?; /// // `stream` has `listen_addr1` as its local socket address. @@ -268,10 +262,10 @@ where /// _ => {} /// } /// - /// let mut tcp2 = TcpConfig::new().port_reuse(true); - /// let mut listener2 = tcp2.listen_on(listen_addr2).expect("listener"); - /// match listener2.next().await.expect("event")? { - /// ListenerEvent::NewAddress(listen_addr) => { + /// let mut tcp2 = TcpTransport::new(GenTcpConfig::new().port_reuse(true)).boxed(); + /// tcp2.listen_on( listen_addr2).expect("listener"); + /// match tcp2.select_next_some().await { + /// TransportEvent::NewAddress { listen_addr, .. } => { /// println!("Listening on {:?}", listen_addr); /// let mut socket = tcp2.dial(listen_addr1).unwrap().await?; /// // `stream` has `listen_addr2` as its local socket address. @@ -287,7 +281,7 @@ where /// case, one is chosen whose IP protocol version and loopback status is the /// same as that of the remote address. Consequently, for maximum control of /// the local listening addresses and ports that are used for outgoing - /// connections, a new `GenTcpConfig` should be created for each listening + /// connections, a new [`GenTcpTransport`] should be created for each listening /// socket, avoiding the use of wildcard addresses which bind a socket to /// all network interfaces. /// @@ -295,15 +289,50 @@ where /// option `SO_REUSEPORT` is set, if available, to permit /// reuse of listening ports for multiple sockets. pub fn port_reuse(mut self, port_reuse: bool) -> Self { - self.port_reuse = if port_reuse { + self.enable_port_reuse = port_reuse; + self + } +} + +impl Default for GenTcpConfig { + fn default() -> Self { + Self::new() + } +} + +pub struct GenTcpTransport +where + T: Provider + Send, +{ + config: GenTcpConfig, + + /// The configuration of port reuse when dialing. + port_reuse: PortReuse, + /// All the active listeners. + /// The `TcpListenStream` struct contains a stream that we want to be pinned. Since the `VecDeque` + /// can be resized, the only way is to use a `Pin>`. + listeners: VecDeque>>>, + /// Pending transport events to return from [`GenTcpTransport::poll`]. + pending_events: VecDeque::ListenerUpgrade, io::Error>>, +} + +impl GenTcpTransport +where + T: Provider + Send, +{ + pub fn new(config: GenTcpConfig) -> Self { + let port_reuse = if config.enable_port_reuse { PortReuse::Enabled { listen_addrs: Arc::new(RwLock::new(HashSet::new())), } } else { PortReuse::Disabled }; - - self + GenTcpTransport { + config, + port_reuse, + ..Default::default() + } } fn create_socket(&self, socket_addr: &SocketAddr) -> io::Result { @@ -316,10 +345,10 @@ where if socket_addr.is_ipv6() { socket.set_only_v6(true)?; } - if let Some(ttl) = self.ttl { + if let Some(ttl) = self.config.ttl { socket.set_ttl(ttl)?; } - if let Some(nodelay) = self.nodelay { + if let Some(nodelay) = self.config.nodelay { socket.set_nodelay(nodelay)?; } socket.set_reuse_address(true)?; @@ -330,22 +359,42 @@ where Ok(socket) } - fn do_listen(&mut self, socket_addr: SocketAddr) -> io::Result> { + fn do_listen( + &mut self, + id: ListenerId, + socket_addr: SocketAddr, + ) -> io::Result> { let socket = self.create_socket(&socket_addr)?; socket.bind(&socket_addr.into())?; - socket.listen(self.backlog as _)?; + socket.listen(self.config.backlog as _)?; socket.set_nonblocking(true)?; - TcpListenStream::::new(socket.into(), self.port_reuse.clone()) + TcpListenStream::::new(id, socket.into(), self.port_reuse.clone()) } } -impl Default for GenTcpConfig { +impl Default for GenTcpTransport +where + T: Provider + Send, +{ fn default() -> Self { - Self::new() + let config = GenTcpConfig::default(); + let port_reuse = if config.enable_port_reuse { + PortReuse::Enabled { + listen_addrs: Arc::new(RwLock::new(HashSet::new())), + } + } else { + PortReuse::Disabled + }; + GenTcpTransport { + port_reuse, + config, + listeners: VecDeque::new(), + pending_events: VecDeque::new(), + } } } -impl Transport for GenTcpConfig +impl Transport for GenTcpTransport where T: Provider + Send + 'static, T::Listener: Unpin, @@ -355,20 +404,35 @@ where type Output = T::Stream; type Error = io::Error; type Dial = Pin> + Send>>; - type Listener = TcpListenStream; type ListenerUpgrade = Ready>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let socket_addr = if let Ok(sa) = multiaddr_to_socketaddr(addr.clone()) { sa } else { return Err(TransportError::MultiaddrNotSupported(addr)); }; + let id = ListenerId::new(); log::debug!("listening on {}", socket_addr); - self.do_listen(socket_addr).map_err(TransportError::Other) + let listener = self + .do_listen(id, socket_addr) + .map_err(TransportError::Other)?; + self.listeners.push_back(Box::pin(listener)); + Ok(id) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(index) = self.listeners.iter().position(|l| l.listener_id != id) { + self.listeners.remove(index); + self.pending_events + .push_back(TransportEvent::ListenerClosed { + listener_id: id, + reason: Ok(()), + }); + true + } else { + false + } } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -441,9 +505,105 @@ where PortReuse::Enabled { .. } => Some(observed.clone()), } } + + /// Poll all listeners. + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // Return pending events from closed listeners. + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(event); + } + // We remove each element from `listeners` one by one and add them back. + let mut remaining = self.listeners.len(); + while let Some(mut listener) = self.listeners.pop_back() { + match TryStream::try_poll_next(listener.as_mut(), cx) { + Poll::Pending => { + self.listeners.push_front(listener); + remaining -= 1; + if remaining == 0 { + break; + } + } + Poll::Ready(Some(Ok(TcpListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + }))) => { + let id = listener.listener_id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::Incoming { + listener_id: id, + upgrade, + local_addr, + send_back_addr: remote_addr, + }); + } + Poll::Ready(Some(Ok(TcpListenerEvent::NewAddress(a)))) => { + let id = listener.listener_id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::NewAddress { + listener_id: id, + listen_addr: a, + }); + } + Poll::Ready(Some(Ok(TcpListenerEvent::AddressExpired(a)))) => { + let id = listener.listener_id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::AddressExpired { + listener_id: id, + listen_addr: a, + }); + } + Poll::Ready(Some(Ok(TcpListenerEvent::Error(error)))) => { + let id = listener.listener_id; + self.listeners.push_front(listener); + return Poll::Ready(TransportEvent::ListenerError { + listener_id: id, + error, + }); + } + Poll::Ready(None) => { + return Poll::Ready(TransportEvent::ListenerClosed { + listener_id: listener.listener_id, + reason: Ok(()), + }); + } + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(TransportEvent::ListenerClosed { + listener_id: listener.listener_id, + reason: Err(err), + }); + } + } + } + Poll::Pending + } } -type TcpListenerEvent = ListenerEvent>, io::Error>; +/// Event produced by a [`TcpListenStream`]. +#[derive(Debug)] +pub enum TcpListenerEvent { + /// The listener is listening on a new additional [`Multiaddr`]. + NewAddress(Multiaddr), + /// An upgrade, consisting of the upgrade future, the listener address and the remote address. + Upgrade { + /// The upgrade. + upgrade: Ready>, + /// The local address which produced this upgrade. + local_addr: Multiaddr, + /// The remote address which produced this upgrade. + remote_addr: Multiaddr, + }, + /// A [`Multiaddr`] is no longer used for listening. + AddressExpired(Multiaddr), + /// A non-fatal error has happened on the listener. + /// + /// This event should be generated in order to notify the user that something wrong has + /// happened. The listener, however, continues to run. + Error(io::Error), +} enum IfWatch { Pending(BoxFuture<'static, io::Result>), @@ -469,6 +629,8 @@ pub struct TcpListenStream where T: Provider, { + /// The ID of this listener. + listener_id: ListenerId, /// The socket address that the listening socket is bound to, /// which may be a "wildcard address" like `INADDR_ANY` or `IN6ADDR_ANY` /// when listening on all interfaces for IPv4 respectively IPv6 connections. @@ -499,9 +661,13 @@ impl TcpListenStream where T: Provider, { - /// Constructs a `TcpListenStream` for incoming connections around - /// the given `TcpListener`. - fn new(listener: TcpListener, port_reuse: PortReuse) -> io::Result { + /// Constructs a [`TcpListenStream`] for incoming connections around + /// the given [`TcpListener`]. + fn new( + listener_id: ListenerId, + listener: TcpListener, + port_reuse: PortReuse, + ) -> io::Result { let listen_addr = listener.local_addr()?; let in_addr = if match &listen_addr { @@ -526,6 +692,7 @@ where Ok(TcpListenStream { port_reuse, listener, + listener_id, listen_addr, in_addr, pause: None, @@ -590,7 +757,7 @@ where }; *if_watch = IfWatch::Pending(T::if_watcher()); me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(ListenerEvent::Error(err)))); + return Poll::Ready(Some(Ok(TcpListenerEvent::Error(err)))); } }, // Consume all events for up/down interface changes. @@ -604,9 +771,9 @@ where let ma = ip_to_multiaddr(ip, me.listen_addr.port()); log::debug!("New listen address: {}", ma); me.port_reuse.register(ip, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress( - ma, - )))); + return Poll::Ready(Some(Ok( + TcpListenerEvent::NewAddress(ma), + ))); } } Ok(IfEvent::Down(inet)) => { @@ -617,7 +784,7 @@ where log::debug!("Expired listen address: {}", ma); me.port_reuse.unregister(ip, me.listen_addr.port()); return Poll::Ready(Some(Ok( - ListenerEvent::AddressExpired(ma), + TcpListenerEvent::AddressExpired(ma), ))); } } @@ -627,7 +794,7 @@ where err }; me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(ListenerEvent::Error(err)))); + return Poll::Ready(Some(Ok(TcpListenerEvent::Error(err)))); } } } @@ -638,7 +805,7 @@ where InAddr::One { addr, out } => { if let Some(multiaddr) = out.take() { me.port_reuse.register(*addr, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(multiaddr)))); + return Poll::Ready(Some(Ok(TcpListenerEvent::NewAddress(multiaddr)))); } } } @@ -661,7 +828,7 @@ where // These errors are non-fatal for the listener stream. log::error!("error accepting incoming connection: {}", e); me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(ListenerEvent::Error(e)))); + return Poll::Ready(Some(Ok(TcpListenerEvent::Error(e)))); } }; @@ -671,7 +838,7 @@ where log::debug!("Incoming connection from {} at {}", remote_addr, local_addr); - return Poll::Ready(Some(Ok(ListenerEvent::Upgrade { + return Poll::Ready(Some(Ok(TcpListenerEvent::Upgrade { upgrade: future::ok(incoming.stream), local_addr, remote_addr, @@ -718,7 +885,10 @@ fn ip_to_multiaddr(ip: IpAddr, port: u16) -> Multiaddr { #[cfg(test)] mod tests { use super::*; - use futures::channel::{mpsc, oneshot}; + use futures::{ + channel::{mpsc, oneshot}, + future::poll_fn, + }; #[test] fn multiaddr_to_tcp_conversion() { @@ -774,14 +944,14 @@ mod tests { env_logger::try_init().ok(); async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { - let mut tcp = GenTcpConfig::::new(); - let mut listener = tcp.listen_on(addr).unwrap(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()).boxed(); + tcp.listen_on(addr).unwrap(); loop { - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(listen_addr) => { + match tcp.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { ready_tx.send(listen_addr).await.unwrap(); } - ListenerEvent::Upgrade { upgrade, .. } => { + TransportEvent::Incoming { upgrade, .. } => { let mut upgrade = upgrade.await.unwrap(); let mut buf = [0u8; 3]; upgrade.read_exact(&mut buf).await.unwrap(); @@ -789,14 +959,14 @@ mod tests { upgrade.write_all(&[4, 5, 6]).await.unwrap(); return; } - e => panic!("Unexpected listener event: {:?}", e), + e => panic!("Unexpected transport event: {:?}", e), } } } async fn dialer(mut ready_rx: mpsc::Receiver) { let addr = ready_rx.next().await.unwrap(); - let mut tcp = GenTcpConfig::::new(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()); // Obtain a future socket through dialing let mut socket = tcp.dial(addr.clone()).unwrap().await.unwrap(); @@ -843,13 +1013,13 @@ mod tests { env_logger::try_init().ok(); async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { - let mut tcp = GenTcpConfig::::new(); - let mut listener = tcp.listen_on(addr).unwrap(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()).boxed(); + tcp.listen_on(addr).unwrap(); loop { - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(a) => { - let mut iter = a.iter(); + match tcp.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { + let mut iter = listen_addr.iter(); match iter.next().expect("ip address") { Protocol::Ip4(ip) => assert!(!ip.is_unspecified()), Protocol::Ip6(ip) => assert!(!ip.is_unspecified()), @@ -858,11 +1028,11 @@ mod tests { if let Protocol::Tcp(port) = iter.next().expect("port") { assert_ne!(0, port) } else { - panic!("No TCP port in address: {}", a) + panic!("No TCP port in address: {}", listen_addr) } - ready_tx.send(a).await.ok(); + ready_tx.send(listen_addr).await.ok(); } - ListenerEvent::Upgrade { .. } => { + TransportEvent::Incoming { .. } => { return; } _ => {} @@ -872,7 +1042,7 @@ mod tests { async fn dialer(mut ready_rx: mpsc::Receiver) { let dest_addr = ready_rx.next().await.unwrap(); - let mut tcp = GenTcpConfig::::new(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()); tcp.dial(dest_addr).unwrap().await.unwrap(); } @@ -916,22 +1086,22 @@ mod tests { mut ready_tx: mpsc::Sender, port_reuse_rx: oneshot::Receiver>, ) { - let mut tcp = GenTcpConfig::::new(); - let mut listener = tcp.listen_on(addr).unwrap(); + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()).boxed(); + tcp.listen_on(addr).unwrap(); loop { - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(listen_addr) => { + match tcp.select_next_some().await { + TransportEvent::NewAddress { listen_addr, .. } => { ready_tx.send(listen_addr).await.ok(); } - ListenerEvent::Upgrade { + TransportEvent::Incoming { upgrade, - local_addr: _, - mut remote_addr, + mut send_back_addr, + .. } => { // Receive the dialer tcp port reuse let remote_port_reuse = port_reuse_rx.await.unwrap(); // And check it is the same as the remote port used for upgrade - assert_eq!(remote_addr.pop().unwrap(), remote_port_reuse); + assert_eq!(send_back_addr.pop().unwrap(), remote_port_reuse); let mut upgrade = upgrade.await.unwrap(); let mut buf = [0u8; 3]; @@ -951,11 +1121,12 @@ mod tests { port_reuse_tx: oneshot::Sender>, ) { let dest_addr = ready_rx.next().await.unwrap(); - let mut tcp = GenTcpConfig::::new().port_reuse(true); - let mut listener = tcp.listen_on(addr).unwrap(); - match listener.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(_) => { + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new().port_reuse(true)); + tcp.listen_on(addr).unwrap(); + match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await { + TransportEvent::NewAddress { .. } => { // Check that tcp and listener share the same port reuse SocketAddr + let listener = tcp.listeners.front().unwrap(); let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip()); let port_reuse_listener = listener .port_reuse @@ -976,7 +1147,7 @@ mod tests { socket.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, [4, 5, 6]); } - e => panic!("Unexpected listener event: {:?}", e), + e => panic!("Unexpected transport event: {:?}", e), } } @@ -1018,11 +1189,13 @@ mod tests { env_logger::try_init().ok(); async fn listen_twice(addr: Multiaddr) { - let mut tcp = GenTcpConfig::::new().port_reuse(true); - let mut listener1 = tcp.listen_on(addr).unwrap(); - match listener1.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(addr1) => { - // Check that tcp and listener share the same port reuse SocketAddr + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new().port_reuse(true)); + tcp.listen_on(addr).unwrap(); + match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await { + TransportEvent::NewAddress { + listen_addr: addr1, .. + } => { + let listener1 = tcp.listeners.front().unwrap(); let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip()); let port_reuse_listener1 = listener1 @@ -1032,16 +1205,18 @@ mod tests { assert_eq!(port_reuse_tcp, port_reuse_listener1); // Listen on the same address a second time. - let mut listener2 = tcp.listen_on(addr1.clone()).unwrap(); - match listener2.next().await.unwrap().unwrap() { - ListenerEvent::NewAddress(addr2) => { + tcp.listen_on(addr1.clone()).unwrap(); + match poll_fn(|cx| Pin::new(&mut tcp).poll(cx)).await { + TransportEvent::NewAddress { + listen_addr: addr2, .. + } => { assert_eq!(addr1, addr2); return; } - e => panic!("Unexpected listener event: {:?}", e), + e => panic!("Unexpected transport event: {:?}", e), } } - e => panic!("Unexpected listener event: {:?}", e), + e => panic!("Unexpected transport event: {:?}", e), } } @@ -1071,13 +1246,10 @@ mod tests { env_logger::try_init().ok(); async fn listen(addr: Multiaddr) -> Multiaddr { - GenTcpConfig::::new() - .listen_on(addr) - .unwrap() - .next() + let mut tcp = GenTcpTransport::::new(GenTcpConfig::new()).boxed(); + tcp.listen_on(addr).unwrap(); + tcp.select_next_some() .await - .expect("some event") - .expect("no error") .into_new_address() .expect("listen address") } @@ -1111,13 +1283,13 @@ mod tests { fn test(addr: Multiaddr) { #[cfg(feature = "async-io")] { - let mut tcp = TcpConfig::new(); + let mut tcp = TcpTransport::new(GenTcpConfig::new()); assert!(tcp.listen_on(addr.clone()).is_err()); } #[cfg(feature = "tokio")] { - let mut tcp = TokioTcpConfig::new(); + let mut tcp = TokioTcpTransport::new(GenTcpConfig::new()); assert!(tcp.listen_on(addr.clone()).is_err()); } } diff --git a/transports/uds/CHANGELOG.md b/transports/uds/CHANGELOG.md index affbd60a156..e122885226f 100644 --- a/transports/uds/CHANGELOG.md +++ b/transports/uds/CHANGELOG.md @@ -2,6 +2,10 @@ - Update dependencies. - Update to `libp2p-core` `v0.34.0`. +- Add `Transport::poll` and `Transport::remove_listener` and remove `Transport::Listener` for + `UdsConfig` Drive listener streams in `UdsConfig` directly. See [PR 2652]. + +[PR 2652]: https://github.com/libp2p/rust-libp2p/pull/2652 # 0.32.0 [2022-01-27] diff --git a/transports/uds/src/lib.rs b/transports/uds/src/lib.rs index 54d7a6f7ffa..492f1afb029 100644 --- a/transports/uds/src/lib.rs +++ b/transports/uds/src/lib.rs @@ -43,113 +43,194 @@ use futures::{ future::{BoxFuture, Ready}, prelude::*, }; +use libp2p_core::transport::ListenerId; use libp2p_core::{ multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, TransportError}, + transport::{TransportError, TransportEvent}, Transport, }; use log::debug; +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{io, path::PathBuf}; +pub type Listener = BoxStream< + 'static, + Result< + TransportEvent<::ListenerUpgrade, ::Error>, + Result<(), ::Error>, + >, +>; + macro_rules! codegen { ($feature_name:expr, $uds_config:ident, $build_listener:expr, $unix_stream:ty, $($mut_or_not:tt)*) => { + /// Represents the configuration for a Unix domain sockets transport capability for libp2p. + #[cfg_attr(docsrs, doc(cfg(feature = $feature_name)))] + pub struct $uds_config { + listeners: VecDeque<(ListenerId, Listener)>, + } -/// Represents the configuration for a Unix domain sockets transport capability for libp2p. -#[cfg_attr(docsrs, doc(cfg(feature = $feature_name)))] -#[derive(Debug, Clone)] -pub struct $uds_config { -} + impl $uds_config { + /// Creates a new configuration object for Unix domain sockets. + pub fn new() -> $uds_config { + $uds_config { + listeners: VecDeque::new(), + } + } + } -impl $uds_config { - /// Creates a new configuration object for Unix domain sockets. - pub fn new() -> $uds_config { - $uds_config {} - } -} + impl Default for $uds_config { + fn default() -> Self { + Self::new() + } + } -impl Default for $uds_config { - fn default() -> Self { - Self::new() - } -} + impl Transport for $uds_config { + type Output = $unix_stream; + type Error = io::Error; + type ListenerUpgrade = Ready>; + type Dial = BoxFuture<'static, Result>; -impl Transport for $uds_config { - type Output = $unix_stream; - type Error = io::Error; - type Listener = BoxStream<'static, Result, Self::Error>>; - type ListenerUpgrade = Ready>; - type Dial = BoxFuture<'static, Result>; + fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> { + if let Ok(path) = multiaddr_to_path(&addr) { + let id = ListenerId::new(); + let listener = $build_listener(path) + .map_err(Err) + .map_ok(move |listener| { + stream::once({ + let addr = addr.clone(); + async move { + debug!("Now listening on {}", addr); + Ok(TransportEvent::NewAddress { + listener_id: id, + listen_addr: addr, + }) + } + }) + .chain(stream::unfold( + listener, + move |listener| { + let addr = addr.clone(); + async move { + let event = match listener.accept().await { + Ok((stream, _)) => { + debug!("incoming connection on {}", addr); + TransportEvent::Incoming { + upgrade: future::ok(stream), + local_addr: addr.clone(), + send_back_addr: addr.clone(), + listener_id: id, + } + } + Err(error) => TransportEvent::ListenerError { + listener_id: id, + error, + }, + }; + Some((Ok(event), listener)) + } + }, + )) + }) + .try_flatten_stream() + .boxed(); + self.listeners.push_back((id, listener)); + Ok(id) + } else { + Err(TransportError::MultiaddrNotSupported(addr)) + } + } - fn listen_on(&mut self, addr: Multiaddr) -> Result> { - if let Ok(path) = multiaddr_to_path(&addr) { - Ok(async move { $build_listener(&path).await } - .map_ok(move |listener| { - stream::once({ - let addr = addr.clone(); - async move { - debug!("Now listening on {}", addr); - Ok(ListenerEvent::NewAddress(addr)) - } - }).chain(stream::unfold(listener, move |$($mut_or_not)* listener| { - let addr = addr.clone(); - async move { - let (stream, _) = match listener.accept().await { - Ok(v) => v, - Err(err) => return Some((Err(err), listener)) - }; - debug!("incoming connection on {}", addr); - let event = ListenerEvent::Upgrade { - upgrade: future::ok(stream), - local_addr: addr.clone(), - remote_addr: addr.clone() - }; - Some((Ok(event), listener)) - } - })) - }) - .try_flatten_stream() - .boxed()) - } else { - Err(TransportError::MultiaddrNotSupported(addr)) - } - } + fn remove_listener(&mut self, id: ListenerId) -> bool { + if let Some(index) = self + .listeners + .iter() + .position(|(listener_id, _)| listener_id == &id) + { + let listener_stream = self.listeners.get_mut(index).unwrap(); + let report_closed_stream = stream::once(async { Err(Ok(())) }).boxed(); + *listener_stream = (id, report_closed_stream); + true + } else { + false + } + } - fn dial(&mut self, addr: Multiaddr) -> Result> { - // TODO: Should we dial at all? - if let Ok(path) = multiaddr_to_path(&addr) { - debug!("Dialing {}", addr); - Ok(async move { <$unix_stream>::connect(&path).await }.boxed()) - } else { - Err(TransportError::MultiaddrNotSupported(addr)) - } - } + fn dial(&mut self, addr: Multiaddr) -> Result> { + // TODO: Should we dial at all? + if let Ok(path) = multiaddr_to_path(&addr) { + debug!("Dialing {}", addr); + Ok(async move { <$unix_stream>::connect(&path).await }.boxed()) + } else { + Err(TransportError::MultiaddrNotSupported(addr)) + } + } - fn dial_as_listener(&mut self, addr: Multiaddr) -> Result> { - self.dial(addr) - } + fn dial_as_listener( + &mut self, + addr: Multiaddr, + ) -> Result> { + self.dial(addr) + } - fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { - None - } -} + fn address_translation( + &self, + _server: &Multiaddr, + _observed: &Multiaddr, + ) -> Option { + None + } -}; + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut remaining = self.listeners.len(); + while let Some((id, mut listener)) = self.listeners.pop_back() { + let event = match Stream::poll_next(Pin::new(&mut listener), cx) { + Poll::Pending => None, + Poll::Ready(None) => panic!("Alive listeners always have a sender."), + Poll::Ready(Some(Ok(event))) => Some(event), + Poll::Ready(Some(Err(reason))) => { + return Poll::Ready(TransportEvent::ListenerClosed { + listener_id: id, + reason, + }) + } + }; + self.listeners.push_front((id, listener)); + if let Some(event) = event { + return Poll::Ready(event); + } else { + remaining -= 1; + if remaining == 0 { + break; + } + } + } + Poll::Pending + } + } + }; } #[cfg(feature = "async-std")] codegen!( "async-std", UdsConfig, - |addr| async move { async_std::os::unix::net::UnixListener::bind(addr).await }, + |addr| async move { async_std::os::unix::net::UnixListener::bind(&addr).await }, async_std::os::unix::net::UnixStream, ); #[cfg(feature = "tokio")] codegen!( "tokio", TokioUdsConfig, - |addr| async move { tokio::net::UnixListener::bind(addr) }, + |addr| async move { tokio::net::UnixListener::bind(&addr) }, tokio::net::UnixStream, - mut ); /// Turns a `Multiaddr` containing a single `Unix` component into a path. @@ -212,24 +293,22 @@ mod tests { let (tx, rx) = oneshot::channel(); async_std::task::spawn(async move { - let mut listener = UdsConfig::new().listen_on(addr).unwrap(); + let mut transport = UdsConfig::new().boxed(); + transport.listen_on(addr).unwrap(); - let listen_addr = listener - .try_next() + let listen_addr = transport + .select_next_some() .await - .unwrap() - .expect("some event") .into_new_address() .expect("listen address"); tx.send(listen_addr).unwrap(); - let (sock, _addr) = listener - .try_filter_map(|e| future::ok(e.into_upgrade())) - .try_next() + let (sock, _addr) = transport + .select_next_some() .await - .unwrap() - .expect("some event"); + .into_incoming() + .expect("incoming stream"); let mut sock = sock.await.unwrap(); let mut buf = [0u8; 3]; diff --git a/transports/wasm-ext/CHANGELOG.md b/transports/wasm-ext/CHANGELOG.md index af2c271a42d..05a188697d6 100644 --- a/transports/wasm-ext/CHANGELOG.md +++ b/transports/wasm-ext/CHANGELOG.md @@ -1,6 +1,10 @@ # 0.34.0 [unreleased] - Update to `libp2p-core` `v0.34.0`. +- Add `Transport::poll` and `Transport::remove_listener` and remove `Transport::Listener` + for `ExtTransport`. Drive the `Listen` streams within `ExtTransport`. See [PR 2652]. + +[PR 2652]: https://github.com/libp2p/rust-libp2p/pull/2652 # 0.33.0 diff --git a/transports/wasm-ext/src/lib.rs b/transports/wasm-ext/src/lib.rs index 64deb877858..bb1e3ea0653 100644 --- a/transports/wasm-ext/src/lib.rs +++ b/transports/wasm-ext/src/lib.rs @@ -32,10 +32,10 @@ //! module. //! -use futures::{future::Ready, prelude::*}; +use futures::{future::Ready, prelude::*, ready, stream::SelectAll}; use libp2p_core::{ connection::Endpoint, - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, Multiaddr, Transport, }; use parity_send_wrapper::SendWrapper; @@ -147,6 +147,7 @@ pub mod ffi { /// Implementation of `Transport` whose implementation is handled by some FFI. pub struct ExtTransport { inner: SendWrapper, + listeners: SelectAll, } impl ExtTransport { @@ -154,8 +155,10 @@ impl ExtTransport { pub fn new(transport: ffi::Transport) -> Self { ExtTransport { inner: SendWrapper::new(transport), + listeners: SelectAll::new(), } } + fn do_dial( &mut self, addr: Multiaddr, @@ -187,25 +190,13 @@ impl fmt::Debug for ExtTransport { } } -impl Clone for ExtTransport { - fn clone(&self) -> Self { - ExtTransport { - inner: SendWrapper::new(self.inner.clone().into()), - } - } -} - impl Transport for ExtTransport { type Output = Connection; type Error = JsErr; - type Listener = Listen; type ListenerUpgrade = Ready>; type Dial = Dial; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let iter = self.inner.listen_on(&addr.to_string()).map_err(|err| { if is_not_supported_error(&err) { TransportError::MultiaddrNotSupported(addr) @@ -213,34 +204,52 @@ impl Transport for ExtTransport { TransportError::Other(JsErr::from(err)) } })?; - - Ok(Listen { + let listener_id = ListenerId::new(); + let listen = Listen { + listener_id, iterator: SendWrapper::new(iter), next_event: None, pending_events: VecDeque::new(), - }) + is_closed: false, + }; + self.listeners.push(listen); + Ok(listener_id) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + match self.listeners.iter_mut().find(|l| l.listener_id == id) { + Some(listener) => { + listener.close(Ok(())); + true + } + None => false, + } } - fn dial(&mut self, addr: Multiaddr) -> Result> - where - Self: Sized, - { + fn dial(&mut self, addr: Multiaddr) -> Result> { self.do_dial(addr, Endpoint::Dialer) } fn dial_as_listener( &mut self, addr: Multiaddr, - ) -> Result> - where - Self: Sized, - { + ) -> Result> { self.do_dial(addr, Endpoint::Listener) } fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option { None } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match ready!(self.listeners.poll_next_unpin(cx)) { + Some(event) => Poll::Ready(event), + None => Poll::Pending, + } + } } /// Future that dial a remote through an external transport. @@ -271,27 +280,47 @@ impl Future for Dial { /// Stream that listens for incoming connections through an external transport. #[must_use = "futures do nothing unless polled"] pub struct Listen { + listener_id: ListenerId, /// Iterator of `ListenEvent`s. iterator: SendWrapper, /// Promise that will yield the next `ListenEvent`. next_event: Option>, /// List of events that we are waiting to propagate. - pending_events: VecDeque>, JsErr>>, + pending_events: VecDeque<::Item>, + /// If the iterator is done close the listener. + is_closed: bool, +} + +impl Listen { + /// Report the listener as closed and terminate its stream. + fn close(&mut self, reason: Result<(), JsErr>) { + self.pending_events + .push_back(TransportEvent::ListenerClosed { + listener_id: self.listener_id, + reason, + }); + self.is_closed = true; + } } impl fmt::Debug for Listen { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Listen").finish() + f.debug_tuple("Listen").field(&self.listener_id).finish() } } impl Stream for Listen { - type Item = Result>, JsErr>, JsErr>; + type Item = TransportEvent<::ListenerUpgrade, JsErr>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { if let Some(ev) = self.pending_events.pop_front() { - return Poll::Ready(Some(Ok(ev))); + return Poll::Ready(Some(ev)); + } + + if self.is_closed { + // Terminate the stream if the listener closed and all remaining events have been reported. + return Poll::Ready(None); } // Try to fill `self.next_event` if necessary and possible. If we fail, then @@ -309,30 +338,59 @@ impl Stream for Listen { let e = match Future::poll(Pin::new(&mut **next_event), cx) { Poll::Ready(Ok(ev)) => ffi::ListenEvent::from(ev), Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err.into()))), + Poll::Ready(Err(err)) => { + self.close(Err(err.into())); + continue; + } }; self.next_event = None; e } else { - return Poll::Ready(None); + self.close(Ok(())); + continue; }; + let listener_id = self.listener_id; + if let Some(addrs) = event.new_addrs() { for addr in addrs.iter() { - let addr = js_value_to_addr(addr)?; - self.pending_events - .push_back(ListenerEvent::NewAddress(addr)); + match js_value_to_addr(addr) { + Ok(addr) => self.pending_events.push_back(TransportEvent::NewAddress { + listener_id, + listen_addr: addr, + }), + Err(err) => self + .pending_events + .push_back(TransportEvent::ListenerError { + listener_id, + error: err, + }), + }; } } if let Some(upgrades) = event.new_connections() { for upgrade in upgrades.iter().cloned() { let upgrade: ffi::ConnectionEvent = upgrade.into(); - self.pending_events.push_back(ListenerEvent::Upgrade { - local_addr: upgrade.local_addr().parse()?, - remote_addr: upgrade.observed_addr().parse()?, - upgrade: futures::future::ok(Connection::new(upgrade.connection())), - }); + match upgrade.local_addr().parse().and_then(|local| { + let observed = upgrade.observed_addr().parse()?; + Ok((local, observed)) + }) { + Ok((local_addr, send_back_addr)) => { + self.pending_events.push_back(TransportEvent::Incoming { + listener_id, + local_addr, + send_back_addr, + upgrade: futures::future::ok(Connection::new(upgrade.connection())), + }) + } + Err(err) => self + .pending_events + .push_back(TransportEvent::ListenerError { + listener_id, + error: err.into(), + }), + } } } @@ -341,8 +399,16 @@ impl Stream for Listen { match js_value_to_addr(addr) { Ok(addr) => self .pending_events - .push_back(ListenerEvent::NewAddress(addr)), - Err(err) => self.pending_events.push_back(ListenerEvent::Error(err)), + .push_back(TransportEvent::AddressExpired { + listener_id, + listen_addr: addr, + }), + Err(err) => self + .pending_events + .push_back(TransportEvent::ListenerError { + listener_id, + error: err, + }), } } } diff --git a/transports/websocket/CHANGELOG.md b/transports/websocket/CHANGELOG.md index 1594580f119..9eca77aa0ce 100644 --- a/transports/websocket/CHANGELOG.md +++ b/transports/websocket/CHANGELOG.md @@ -1,6 +1,10 @@ # 0.36.0 [unreleased] - Update to `libp2p-core` `v0.34.0`. +- Add `Transport::poll` and `Transport::remove_listener` and remove `Transport::Listener` + for `WsConfig`. See [PR 2652]. + +[PR 2652]: https://github.com/libp2p/rust-libp2p/pull/2652 # 0.35.0 diff --git a/transports/websocket/src/framed.rs b/transports/websocket/src/framed.rs index c04e7354587..65b4db93fd6 100644 --- a/transports/websocket/src/framed.rs +++ b/transports/websocket/src/framed.rs @@ -26,7 +26,7 @@ use libp2p_core::{ connection::Endpoint, either::EitherOutput, multiaddr::{Multiaddr, Protocol}, - transport::{ListenerEvent, TransportError}, + transport::{ListenerId, TransportError, TransportEvent}, Transport, }; use log::{debug, trace}; @@ -36,7 +36,7 @@ use soketto::{ extension::deflate::Deflate, handshake, }; -use std::sync::Arc; +use std::{collections::HashMap, ops::DerefMut, sync::Arc}; use std::{convert::TryInto, fmt, io, mem, pin::Pin, task::Context, task::Poll}; use url::Url; @@ -53,18 +53,11 @@ pub struct WsConfig { tls_config: tls::Config, max_redirects: u8, use_deflate: bool, -} - -impl Clone for WsConfig { - fn clone(&self) -> Self { - Self { - transport: self.transport.clone(), - max_data_size: self.max_data_size, - tls_config: self.tls_config.clone(), - max_redirects: self.max_redirects, - use_deflate: self.use_deflate, - } - } + /// Websocket protocol of the inner listener. + /// + /// This is the suffix of the address provided in `listen_on`. + /// Can only be [`Protocol::Ws`] or [`Protocol::Wss`]. + listener_protos: HashMap>, } impl WsConfig { @@ -76,6 +69,7 @@ impl WsConfig { tls_config: tls::Config::client(), max_redirects: 0, use_deflate: false, + listener_protos: HashMap::new(), } } @@ -118,149 +112,45 @@ type TlsOrPlain = EitherOutput, server::Tls impl Transport for WsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = Connection; type Error = Error; - type Listener = - BoxStream<'static, Result, Self::Error>>; type ListenerUpgrade = BoxFuture<'static, Result>; type Dial = BoxFuture<'static, Result>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { let mut inner_addr = addr.clone(); - - let (use_tls, proto) = match inner_addr.pop() { + let proto = match inner_addr.pop() { Some(p @ Protocol::Wss(_)) => { if self.tls_config.server.is_some() { - (true, p) + p } else { debug!("/wss address but TLS server support is not configured"); return Err(TransportError::MultiaddrNotSupported(addr)); } } - Some(p @ Protocol::Ws(_)) => (false, p), + Some(p @ Protocol::Ws(_)) => p, _ => { debug!("{} is not a websocket multiaddr", addr); return Err(TransportError::MultiaddrNotSupported(addr)); } }; + match self.transport.lock().listen_on(inner_addr) { + Ok(id) => { + self.listener_protos.insert(id, proto); + Ok(id) + } + Err(e) => Err(e.map(Error::Transport)), + } + } - let tls_config = self.tls_config.clone(); - let max_size = self.max_data_size; - let use_deflate = self.use_deflate; - let transport = self - .transport - .lock() - .listen_on(inner_addr) - .map_err(|e| e.map(Error::Transport))?; - let listen = transport - .map_err(Error::Transport) - .map_ok(move |event| match event { - ListenerEvent::NewAddress(mut a) => { - a = a.with(proto.clone()); - debug!("Listening on {}", a); - ListenerEvent::NewAddress(a) - } - ListenerEvent::AddressExpired(mut a) => { - a = a.with(proto.clone()); - ListenerEvent::AddressExpired(a) - } - ListenerEvent::Error(err) => ListenerEvent::Error(Error::Transport(err)), - ListenerEvent::Upgrade { - upgrade, - mut local_addr, - mut remote_addr, - } => { - local_addr = local_addr.with(proto.clone()); - remote_addr = remote_addr.with(proto.clone()); - let remote1 = remote_addr.clone(); // used for logging - let remote2 = remote_addr.clone(); // used for logging - let tls_config = tls_config.clone(); - - let upgrade = async move { - let stream = upgrade.map_err(Error::Transport).await?; - trace!("incoming connection from {}", remote1); - - let stream = if use_tls { - // begin TLS session - let server = tls_config - .server - .expect("for use_tls we checked server is not none"); - - trace!("awaiting TLS handshake with {}", remote1); - - let stream = server - .accept(stream) - .map_err(move |e| { - debug!("TLS handshake with {} failed: {}", remote1, e); - Error::Tls(tls::Error::from(e)) - }) - .await?; - - let stream: TlsOrPlain<_> = - EitherOutput::First(EitherOutput::Second(stream)); - - stream - } else { - // continue with plain stream - EitherOutput::Second(stream) - }; - - trace!("receiving websocket handshake request from {}", remote2); - - let mut server = handshake::Server::new(stream); - - if use_deflate { - server.add_extension(Box::new(Deflate::new(connection::Mode::Server))); - } - - let ws_key = { - let request = server - .receive_request() - .map_err(|e| Error::Handshake(Box::new(e))) - .await?; - request.key() - }; - - trace!("accepting websocket handshake request from {}", remote2); - - let response = handshake::server::Response::Accept { - key: ws_key, - protocol: None, - }; - - server - .send_response(&response) - .map_err(|e| Error::Handshake(Box::new(e))) - .await?; - - let conn = { - let mut builder = server.into_builder(); - builder.set_max_message_size(max_size); - builder.set_max_frame_size(max_size); - Connection::new(builder) - }; - - Ok(conn) - }; - - ListenerEvent::Upgrade { - upgrade: Box::pin(upgrade) as BoxFuture<'static, _>, - local_addr, - remote_addr, - } - } - }); - Ok(Box::pin(listen)) + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.lock().remove_listener(id) } fn dial(&mut self, addr: Multiaddr) -> Result> { @@ -277,14 +167,100 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.lock().address_translation(server, observed) } + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let inner_event = { + let mut transport = self.transport.lock(); + match Transport::poll(Pin::new(transport.deref_mut()), cx) { + Poll::Ready(ev) => ev, + Poll::Pending => return Poll::Pending, + } + }; + let event = match inner_event { + TransportEvent::NewAddress { + listener_id, + mut listen_addr, + } => { + // Append the ws / wss protocol back to the inner address. + let proto = self + .listener_protos + .get(&listener_id) + .expect("Protocol was inserted in Transport::listen_on."); + listen_addr.push(proto.clone()); + debug!("Listening on {}", listen_addr); + TransportEvent::NewAddress { + listener_id, + listen_addr, + } + } + TransportEvent::AddressExpired { + listener_id, + mut listen_addr, + } => { + let proto = self + .listener_protos + .get(&listener_id) + .expect("Protocol was inserted in Transport::listen_on."); + listen_addr.push(proto.clone()); + TransportEvent::AddressExpired { + listener_id, + listen_addr, + } + } + TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError { + listener_id, + error: Error::Transport(error), + }, + TransportEvent::ListenerClosed { + listener_id, + reason, + } => { + self.listener_protos + .remove(&listener_id) + .expect("Protocol was inserted in Transport::listen_on."); + TransportEvent::ListenerClosed { + listener_id, + reason: reason.map_err(Error::Transport), + } + } + TransportEvent::Incoming { + listener_id, + upgrade, + mut local_addr, + mut send_back_addr, + } => { + let proto = self + .listener_protos + .get(&listener_id) + .expect("Protocol was inserted in Transport::listen_on."); + let use_tls = match proto { + Protocol::Wss(_) => true, + Protocol::Ws(_) => false, + _ => unreachable!("Map contains only ws and wss protocols."), + }; + local_addr.push(proto.clone()); + send_back_addr.push(proto.clone()); + let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls); + TransportEvent::Incoming { + listener_id, + upgrade, + local_addr, + send_back_addr, + } + } + }; + Poll::Ready(event) + } } impl WsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -304,13 +280,25 @@ where // We are looping here in order to follow redirects (if any): let mut remaining_redirects = self.max_redirects; - let mut this = self.clone(); + let transport = self.transport.clone(); + let tls_config = self.tls_config.clone(); + let use_deflate = self.use_deflate; + let max_redirects = self.max_redirects; + let future = async move { loop { - match this.dial_once(addr, role_override).await { + match Self::dial_once( + transport.clone(), + addr, + tls_config.clone(), + use_deflate, + role_override, + ) + .await + { Ok(Either::Left(redirect)) => { if remaining_redirects == 0 { - debug!("Too many redirects (> {})", this.max_redirects); + debug!("Too many redirects (> {})", max_redirects); return Err(Error::TooManyRedirects); } remaining_redirects -= 1; @@ -324,17 +312,20 @@ where Ok(Box::pin(future)) } + /// Attempts to dial the given address and perform a websocket handshake. async fn dial_once( - &mut self, + transport: Arc>, addr: WsAddress, + tls_config: tls::Config, + use_deflate: bool, role_override: Endpoint, ) -> Result>, Error> { trace!("Dialing websocket address: {:?}", addr); let dial = match role_override { - Endpoint::Dialer => self.transport.lock().dial(addr.tcp_addr), - Endpoint::Listener => self.transport.lock().dial_as_listener(addr.tcp_addr), + Endpoint::Dialer => transport.lock().dial(addr.tcp_addr), + Endpoint::Listener => transport.lock().dial_as_listener(addr.tcp_addr), } .map_err(|e| match e { TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a), @@ -350,8 +341,7 @@ where .dns_name .expect("for use_tls we have checked that dns_name is some"); trace!("Starting TLS handshake with {:?}", dns_name); - let stream = self - .tls_config + let stream = tls_config .client .connect(dns_name.clone(), stream) .map_err(|e| { @@ -371,7 +361,7 @@ where let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref()); - if self.use_deflate { + if use_deflate { client.add_extension(Box::new(Deflate::new(connection::Mode::Client))); } @@ -400,6 +390,91 @@ where } } } + + fn map_upgrade( + &self, + upgrade: T::ListenerUpgrade, + remote_addr: Multiaddr, + use_tls: bool, + ) -> ::ListenerUpgrade { + let remote_addr2 = remote_addr.clone(); // used for logging + let tls_config = self.tls_config.clone(); + let max_size = self.max_data_size; + let use_deflate = self.use_deflate; + + async move { + let stream = upgrade.map_err(Error::Transport).await?; + trace!("incoming connection from {}", remote_addr); + + let stream = if use_tls { + // begin TLS session + let server = tls_config + .server + .expect("for use_tls we checked server is not none"); + + trace!("awaiting TLS handshake with {}", remote_addr); + + let stream = server + .accept(stream) + .map_err(move |e| { + debug!("TLS handshake with {} failed: {}", remote_addr, e); + Error::Tls(tls::Error::from(e)) + }) + .await?; + + let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::Second(stream)); + + stream + } else { + // continue with plain stream + EitherOutput::Second(stream) + }; + + trace!( + "receiving websocket handshake request from {}", + remote_addr2 + ); + + let mut server = handshake::Server::new(stream); + + if use_deflate { + server.add_extension(Box::new(Deflate::new(connection::Mode::Server))); + } + + let ws_key = { + let request = server + .receive_request() + .map_err(|e| Error::Handshake(Box::new(e))) + .await?; + request.key() + }; + + trace!( + "accepting websocket handshake request from {}", + remote_addr2 + ); + + let response = handshake::server::Response::Accept { + key: ws_key, + protocol: None, + }; + + server + .send_response(&response) + .map_err(|e| Error::Handshake(Box::new(e))) + .await?; + + let conn = { + let mut builder = server.into_builder(); + builder.set_max_message_size(max_size); + builder.set_max_frame_size(max_size); + Connection::new(builder) + }; + + Ok(conn) + } + .boxed() + } } #[derive(Debug)] diff --git a/transports/websocket/src/lib.rs b/transports/websocket/src/lib.rs index 770e559a633..bf75e389648 100644 --- a/transports/websocket/src/lib.rs +++ b/transports/websocket/src/lib.rs @@ -26,14 +26,11 @@ pub mod tls; use error::Error; use framed::{Connection, Incoming}; -use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream}; +use futures::{future::BoxFuture, prelude::*, ready}; use libp2p_core::{ connection::ConnectedPoint, multiaddr::Multiaddr, - transport::{ - map::{MapFuture, MapStream}, - ListenerEvent, TransportError, - }, + transport::{map::MapFuture, ListenerId, TransportError, TransportEvent}, Transport, }; use rw_stream_sink::RwStreamSink; @@ -55,10 +52,9 @@ where impl WsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static, { @@ -114,26 +110,25 @@ where impl Transport for WsConfig where - T: Transport + Send + 'static, + T: Transport + Send + Unpin + 'static, T::Error: Send + 'static, T::Dial: Send + 'static, - T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = RwStreamSink>; type Error = Error; - type Listener = MapStream, WrapperFn>; type ListenerUpgrade = MapFuture, WrapperFn>; type Dial = MapFuture, WrapperFn>; - fn listen_on( - &mut self, - addr: Multiaddr, - ) -> Result> { + fn listen_on(&mut self, addr: Multiaddr) -> Result> { self.transport.listen_on(addr) } + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.transport.remove_listener(id) + } + fn dial(&mut self, addr: Multiaddr) -> Result> { self.transport.dial(addr) } @@ -148,11 +143,14 @@ where fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { self.transport.address_translation(server, observed) } -} -/// Type alias corresponding to `framed::WsConfig::Listener`. -pub type InnerStream = - BoxStream<'static, Result, Error>, Error>>; + fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.transport).poll(cx) + } +} /// Type alias corresponding to `framed::WsConfig::Dial` and `framed::WsConfig::ListenerUpgrade`. pub type InnerFuture = BoxFuture<'static, Result, Error>>; @@ -236,15 +234,17 @@ mod tests { futures::executor::block_on(connect(a)) } - async fn connect(listen_addr: Multiaddr) { - let ws_config = || WsConfig::new(tcp::TcpConfig::new()); + fn new_ws_config() -> WsConfig { + WsConfig::new(tcp::TcpTransport::new(tcp::GenTcpConfig::default())) + } - let mut listener = ws_config().listen_on(listen_addr).expect("listener"); + async fn connect(listen_addr: Multiaddr) { + let mut ws_config = new_ws_config().boxed(); + ws_config.listen_on(listen_addr).expect("listener"); - let addr = listener - .try_next() + let addr = ws_config + .next() .await - .expect("some event") .expect("no error") .into_new_address() .expect("listen address"); @@ -253,16 +253,16 @@ mod tests { assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1)); let inbound = async move { - let (conn, _addr) = listener - .try_filter_map(|e| future::ready(Ok(e.into_upgrade()))) - .try_next() + let (conn, _addr) = ws_config + .select_next_some() + .map(|ev| ev.into_incoming()) .await - .unwrap() .unwrap(); conn.await }; - let outbound = ws_config() + let outbound = new_ws_config() + .boxed() .dial(addr.with(Protocol::P2p(PeerId::random().into()))) .unwrap();