|
| 1 | +// Copyright 2018 Parity Technologies (UK) Ltd. |
| 2 | +// This file is part of Polkadot. |
| 3 | + |
| 4 | +// Polkadot is free software: you can redistribute it and/or modify |
| 5 | +// it under the terms of the GNU General Public License as published by |
| 6 | +// the Free Software Foundation, either version 3 of the License, or |
| 7 | +// (at your option) any later version. |
| 8 | + |
| 9 | +// Polkadot is distributed in the hope that it will be useful, |
| 10 | +// but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 11 | +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 12 | +// GNU General Public License for more details. |
| 13 | + |
| 14 | +// You should have received a copy of the GNU General Public License |
| 15 | +// along with Polkadot. If not, see <http://www.gnu.org/licenses/>.? |
| 16 | + |
| 17 | +use bytes::{Bytes, BytesMut}; |
| 18 | +use network::ProtocolId; |
| 19 | +use libp2p::core::{Multiaddr, ConnectionUpgrade, Endpoint}; |
| 20 | +use network::PacketId; |
| 21 | +use std::io::Error as IoError; |
| 22 | +use std::vec::IntoIter as VecIntoIter; |
| 23 | +use futures::{future, Future, stream, Stream, Sink}; |
| 24 | +use futures::sync::mpsc; |
| 25 | +use tokio_io::{AsyncRead, AsyncWrite}; |
| 26 | +use varint::VarintCodec; |
| 27 | + |
| 28 | +/// Connection upgrade for a single protocol. |
| 29 | +/// |
| 30 | +/// Note that "a single protocol" here refers to `par` for example. However |
| 31 | +/// each protocol can have multiple different versions for networking purposes. |
| 32 | +#[derive(Clone)] |
| 33 | +pub struct RegisteredProtocol<T> { |
| 34 | + /// Id of the protocol for API purposes. |
| 35 | + id: ProtocolId, |
| 36 | + /// Base name of the protocol as advertised on the network. |
| 37 | + /// Ends with `/` so that we can append a version number behind. |
| 38 | + base_name: Bytes, |
| 39 | + /// List of protocol versions that we support, plus their packet count. |
| 40 | + /// Ordered in descending order so that the best comes first. |
| 41 | + /// The packet count is used to filter out invalid messages. |
| 42 | + supported_versions: Vec<(u8, u8)>, |
| 43 | + /// Custom data. |
| 44 | + custom_data: T, |
| 45 | +} |
| 46 | + |
| 47 | +/// Output of a `RegisteredProtocol` upgrade. |
| 48 | +pub struct RegisteredProtocolOutput<T> { |
| 49 | + /// Data passed to `RegisteredProtocol::new`. |
| 50 | + pub custom_data: T, |
| 51 | + |
| 52 | + /// Id of the protocol. |
| 53 | + pub protocol_id: ProtocolId, |
| 54 | + |
| 55 | + /// Version of the protocol that was negotiated. |
| 56 | + pub protocol_version: u8, |
| 57 | + |
| 58 | + /// Channel to sender outgoing messages to. Closing this channel closes the |
| 59 | + /// connection. |
| 60 | + // TODO: consider assembling packet_id here |
| 61 | + pub outgoing: mpsc::UnboundedSender<Bytes>, |
| 62 | + |
| 63 | + /// Stream where incoming messages are received. The stream ends whenever |
| 64 | + /// either side is closed. |
| 65 | + pub incoming: Box<Stream<Item = (PacketId, Bytes), Error = IoError>>, |
| 66 | +} |
| 67 | + |
| 68 | +impl<T> RegisteredProtocol<T> { |
| 69 | + /// Creates a new `RegisteredProtocol`. The `custom_data` parameter will be |
| 70 | + /// passed inside the `RegisteredProtocolOutput`. |
| 71 | + pub fn new(custom_data: T, protocol: ProtocolId, versions: &[(u8, u8)]) |
| 72 | + -> Self { |
| 73 | + let mut proto_name = Bytes::from_static(b"/substrate/"); |
| 74 | + proto_name.extend_from_slice(&protocol); |
| 75 | + proto_name.extend_from_slice(b"/"); |
| 76 | + |
| 77 | + RegisteredProtocol { |
| 78 | + base_name: proto_name, |
| 79 | + id: protocol, |
| 80 | + supported_versions: { |
| 81 | + let mut tmp: Vec<_> = versions.iter().rev().cloned().collect(); |
| 82 | + tmp.sort_unstable_by(|a, b| b.1.cmp(&a.1)); |
| 83 | + tmp |
| 84 | + }, |
| 85 | + custom_data: custom_data, |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + /// Returns the ID of the protocol. |
| 90 | + pub fn id(&self) -> ProtocolId { |
| 91 | + self.id |
| 92 | + } |
| 93 | + |
| 94 | + /// Returns the custom data that was passed to `new`. |
| 95 | + pub fn custom_data(&self) -> &T { |
| 96 | + &self.custom_data |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +// `Maf` is short for `MultiaddressFuture` |
| 101 | +impl<T, C, Maf> ConnectionUpgrade<C, Maf> for RegisteredProtocol<T> |
| 102 | +where C: AsyncRead + AsyncWrite + 'static, // TODO: 'static :-/ |
| 103 | + Maf: Future<Item = Multiaddr, Error = IoError> + 'static, // TODO: 'static :( |
| 104 | +{ |
| 105 | + type NamesIter = VecIntoIter<(Bytes, Self::UpgradeIdentifier)>; |
| 106 | + type UpgradeIdentifier = u8; // Protocol version |
| 107 | + |
| 108 | + #[inline] |
| 109 | + fn protocol_names(&self) -> Self::NamesIter { |
| 110 | + // Report each version as an individual protocol. |
| 111 | + self.supported_versions.iter().map(|&(ver, _)| { |
| 112 | + let num = ver.to_string(); |
| 113 | + let mut name = self.base_name.clone(); |
| 114 | + name.extend_from_slice(num.as_bytes()); |
| 115 | + (name, ver) |
| 116 | + }).collect::<Vec<_>>().into_iter() |
| 117 | + } |
| 118 | + |
| 119 | + type Output = RegisteredProtocolOutput<T>; |
| 120 | + type MultiaddrFuture = Maf; |
| 121 | + type Future = future::FutureResult<(Self::Output, Self::MultiaddrFuture), IoError>; |
| 122 | + |
| 123 | + fn upgrade( |
| 124 | + self, |
| 125 | + socket: C, |
| 126 | + protocol_version: Self::UpgradeIdentifier, |
| 127 | + endpoint: Endpoint, |
| 128 | + remote_addr: Maf |
| 129 | + ) -> Self::Future { |
| 130 | + let packet_count = self.supported_versions |
| 131 | + .iter() |
| 132 | + .find(|&(v, _)| *v == protocol_version) |
| 133 | + .expect("negotiated protocol version that wasn't advertised ; \ |
| 134 | + programmer error") |
| 135 | + .1; |
| 136 | + |
| 137 | + // This function is called whenever we successfully negotiated a |
| 138 | + // protocol with a remote (both if initiated by us or by the remote) |
| 139 | + |
| 140 | + // This channel is used to send outgoing packets to the custom_data |
| 141 | + // for this open substream. |
| 142 | + let (msg_tx, msg_rx) = mpsc::unbounded(); |
| 143 | + |
| 144 | + // Build the sink for outgoing network bytes, and the stream for |
| 145 | + // incoming instructions. `stream` implements `Stream<Item = Message>`. |
| 146 | + enum Message { |
| 147 | + /// Received data from the network. |
| 148 | + RecvSocket(BytesMut), |
| 149 | + /// Data to send to the network. |
| 150 | + /// The packet_id must already be inside the `Bytes`. |
| 151 | + SendReq(Bytes), |
| 152 | + /// The socket has been closed. |
| 153 | + Finished, |
| 154 | + } |
| 155 | + |
| 156 | + let (sink, stream) = { |
| 157 | + let framed = AsyncRead::framed(socket, VarintCodec::default()); |
| 158 | + let msg_rx = msg_rx.map(Message::SendReq) |
| 159 | + .chain(stream::once(Ok(Message::Finished))) |
| 160 | + .map_err(|()| unreachable!("mpsc::UnboundedReceiver never errors")); |
| 161 | + let (sink, stream) = framed.split(); |
| 162 | + let stream = stream.map(Message::RecvSocket) |
| 163 | + .chain(stream::once(Ok(Message::Finished))); |
| 164 | + (sink, msg_rx.select(stream)) |
| 165 | + }; |
| 166 | + |
| 167 | + let incoming = stream::unfold((sink, stream, false), move |(sink, stream, finished)| { |
| 168 | + if finished { |
| 169 | + return None |
| 170 | + } |
| 171 | + |
| 172 | + Some(stream |
| 173 | + .into_future() |
| 174 | + .map_err(|(err, _)| err) |
| 175 | + .and_then(move |(message, stream)| |
| 176 | + match message { |
| 177 | + Some(Message::RecvSocket(mut data)) => { |
| 178 | + // The `data` should be prefixed by the packet ID, |
| 179 | + // therefore an empty packet is invalid. |
| 180 | + if data.is_empty() { |
| 181 | + debug!(target: "sub-libp2p", "ignoring incoming \ |
| 182 | + packet because it was empty"); |
| 183 | + let f = future::ok((None, (sink, stream, false))); |
| 184 | + return future::Either::A(f) |
| 185 | + } |
| 186 | + |
| 187 | + let packet_id = data[0]; |
| 188 | + let data = data.split_off(1); |
| 189 | + |
| 190 | + if packet_id >= packet_count { |
| 191 | + debug!(target: "sub-libp2p", "ignoring incoming packet \ |
| 192 | + because packet_id {} is too large", packet_id); |
| 193 | + let f = future::ok((None, (sink, stream, false))); |
| 194 | + future::Either::A(f) |
| 195 | + } else { |
| 196 | + let out = Some((packet_id, data.freeze())); |
| 197 | + let f = future::ok((out, (sink, stream, false))); |
| 198 | + future::Either::A(f) |
| 199 | + } |
| 200 | + }, |
| 201 | + |
| 202 | + Some(Message::SendReq(data)) => { |
| 203 | + let fut = sink.send(data) |
| 204 | + .map(move |sink| (None, (sink, stream, false))); |
| 205 | + future::Either::B(fut) |
| 206 | + }, |
| 207 | + |
| 208 | + Some(Message::Finished) | None => { |
| 209 | + let f = future::ok((None, (sink, stream, true))); |
| 210 | + future::Either::A(f) |
| 211 | + }, |
| 212 | + } |
| 213 | + )) |
| 214 | + }).filter_map(|v| v); |
| 215 | + |
| 216 | + let out = RegisteredProtocolOutput { |
| 217 | + custom_data: self.custom_data, |
| 218 | + protocol_id: self.id, |
| 219 | + protocol_version: protocol_version, |
| 220 | + outgoing: msg_tx, |
| 221 | + incoming: Box::new(incoming), |
| 222 | + }; |
| 223 | + |
| 224 | + future::ok((out, remote_addr)) |
| 225 | + } |
| 226 | +} |
| 227 | + |
| 228 | +// Connection upgrade for all the protocols contained in it. |
| 229 | +#[derive(Clone)] |
| 230 | +pub struct RegisteredProtocols<T>(pub Vec<RegisteredProtocol<T>>); |
| 231 | + |
| 232 | +impl<T> RegisteredProtocols<T> { |
| 233 | + /// Finds a protocol in the list by its id. |
| 234 | + pub fn find_protocol(&self, protocol: ProtocolId) |
| 235 | + -> Option<&RegisteredProtocol<T>> { |
| 236 | + self.0.iter().find(|p| p.id == protocol) |
| 237 | + } |
| 238 | + |
| 239 | + /// Returns true if the given protocol is in the list. |
| 240 | + pub fn has_protocol(&self, protocol: ProtocolId) -> bool { |
| 241 | + self.0.iter().any(|p| p.id == protocol) |
| 242 | + } |
| 243 | +} |
| 244 | + |
| 245 | +impl<T> Default for RegisteredProtocols<T> { |
| 246 | + fn default() -> Self { |
| 247 | + RegisteredProtocols(Vec::new()) |
| 248 | + } |
| 249 | +} |
| 250 | + |
| 251 | +impl<T, C, Maf> ConnectionUpgrade<C, Maf> for RegisteredProtocols<T> |
| 252 | +where C: AsyncRead + AsyncWrite + 'static, // TODO: 'static :-/ |
| 253 | + Maf: Future<Item = Multiaddr, Error = IoError> + 'static, // TODO: 'static :( |
| 254 | +{ |
| 255 | + type NamesIter = VecIntoIter<(Bytes, Self::UpgradeIdentifier)>; |
| 256 | + type UpgradeIdentifier = (usize, |
| 257 | + <RegisteredProtocol<T> as ConnectionUpgrade<C, Maf>>::UpgradeIdentifier); |
| 258 | + |
| 259 | + fn protocol_names(&self) -> Self::NamesIter { |
| 260 | + // We concat the lists of `RegisteredProtocol::protocol_names` for |
| 261 | + // each protocol. |
| 262 | + self.0.iter().enumerate().flat_map(|(n, proto)| |
| 263 | + ConnectionUpgrade::<C, Maf>::protocol_names(proto) |
| 264 | + .map(move |(name, id)| (name, (n, id))) |
| 265 | + ).collect::<Vec<_>>().into_iter() |
| 266 | + } |
| 267 | + |
| 268 | + type Output = <RegisteredProtocol<T> as ConnectionUpgrade<C, Maf>>::Output; |
| 269 | + type MultiaddrFuture = <RegisteredProtocol<T> as |
| 270 | + ConnectionUpgrade<C, Maf>>::MultiaddrFuture; |
| 271 | + type Future = <RegisteredProtocol<T> as ConnectionUpgrade<C, Maf>>::Future; |
| 272 | + |
| 273 | + #[inline] |
| 274 | + fn upgrade( |
| 275 | + self, |
| 276 | + socket: C, |
| 277 | + upgrade_identifier: Self::UpgradeIdentifier, |
| 278 | + endpoint: Endpoint, |
| 279 | + remote_addr: Maf |
| 280 | + ) -> Self::Future { |
| 281 | + let (protocol_index, inner_proto_id) = upgrade_identifier; |
| 282 | + self.0.into_iter() |
| 283 | + .nth(protocol_index) |
| 284 | + .expect("invalid protocol index ; programmer logic error") |
| 285 | + .upgrade(socket, inner_proto_id, endpoint, remote_addr) |
| 286 | + } |
| 287 | +} |
0 commit comments