diff --git a/crates/aura-transport/src/udp.rs b/crates/aura-transport/src/udp.rs index 06920de..38e7150 100644 --- a/crates/aura-transport/src/udp.rs +++ b/crates/aura-transport/src/udp.rs @@ -34,16 +34,19 @@ //! [`DatagramSender::seal`](aura_proto::DatagramSender::seal) output. Any trailing bytes are //! obfuscation padding and are ignored by the receiver (it reads exactly `rec_len`). //! -//! ## Single peer per accepted connection (v1) +//! ## Many peers per server, one bound socket (v2) //! -//! [`UdpServer::accept`] handles **one** client per call: it waits for a client's first HS datagram, -//! latches that source address, runs the handshake bound to it, and returns a [`UdpConnection`] -//! dedicated to that peer. A server that wants to serve many clients concurrently on one well-known -//! port would need a demuxing layer (route datagrams to per-peer connections by source address); -//! that is out of scope for v1. The client side always `.connect()`s its ephemeral socket to the -//! server, so it only ever talks to one peer. +//! A single [`UdpServer`] multiplexes **many** clients over one bound UDP port. A background +//! *master loop* owns the listening socket: every received datagram is routed by source address into +//! a per-peer mailbox (`tokio::sync::mpsc` channel). The first HS datagram from an unknown source +//! spawns a per-peer handshake task that runs [`server_handshake`] over the reliable adapter and, +//! on success, hands the established [`UdpConnection`] to whoever is calling +//! [`UdpServer::accept`]. The client side keeps a `connect()`ed ephemeral socket and talks to one +//! peer (the server). Per-peer state is cleaned up when either the handshake task ends or the +//! [`UdpConnection`] is dropped: dropping the peer's inbox receiver causes the master loop's next +//! `send` to fail with `Closed`, which evicts the entry. -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::io; use std::net::SocketAddr; use std::pin::Pin; @@ -55,7 +58,7 @@ use async_trait::async_trait; use bytes::Bytes; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex, RwLock}; use aura_proto::frame::{decode_header, HEADER_LEN}; use aura_proto::{ @@ -140,27 +143,60 @@ impl Default for UdpOpts { /// A UDP socket bound to a single peer address. /// -/// The client connects its ephemeral socket to the server, so it can use plain `send`/`recv`. The -/// server shares one listening socket and remembers the accepted client's address, so it uses -/// `send_to(peer)` and filters `recv_from` to that address. This type hides that asymmetry behind a -/// uniform datagram send/recv pair used by both the reliable handshake adapter and the data path. +/// Two flavours, hidden behind one [`send_dgram`](PeerSocket::send_dgram) / +/// [`recv_dgram`](PeerSocket::recv_dgram) pair so [`ReliableHsAdapter`] and the data path do not +/// care which side they are on: +/// +/// * **Client** ([`PeerSocketState::ConnectedClient`]): an ephemeral `connect()`ed socket; plain +/// `send`/`recv` reach the server directly. +/// * **Server** ([`PeerSocketState::SharedServer`]): the shared master listening socket plus the +/// peer's source address. Sends go out as `send_to(peer)`; receives are pulled from an +/// `mpsc::Receiver>` *inbox* that the server's master loop fills by routing every +/// incoming datagram on its source address. Filtering by source address therefore happens once, +/// in the master loop — not on every `recv_dgram`. #[derive(Debug)] struct PeerSocket { - socket: UdpSocket, - /// `Some(addr)` for the server (it must address the specific client and ignore strangers); - /// `None` for the client (the socket is already `connect()`ed to the server). - peer: Option, + state: PeerSocketState, +} + +/// The two variants of [`PeerSocket`]; see the type's docs for the contract. +enum PeerSocketState { + /// Client side: an ephemeral socket already `connect()`ed to the server. + ConnectedClient { socket: UdpSocket }, + /// Server side: the shared master socket addresses the peer; the master loop routes inbound + /// datagrams from `peer_addr` into `inbox`. + SharedServer { + master: Arc, + peer_addr: SocketAddr, + inbox: Mutex>>, + }, +} + +impl std::fmt::Debug for PeerSocketState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConnectedClient { .. } => { + f.debug_struct("ConnectedClient").finish_non_exhaustive() + } + Self::SharedServer { peer_addr, .. } => f + .debug_struct("SharedServer") + .field("peer_addr", peer_addr) + .finish_non_exhaustive(), + } + } } impl PeerSocket { /// Send one datagram to the bound peer. async fn send_dgram(&self, buf: &[u8]) -> io::Result<()> { - match self.peer { - Some(addr) => { - self.socket.send_to(buf, addr).await?; + match &self.state { + PeerSocketState::ConnectedClient { socket } => { + socket.send(buf).await?; } - None => { - self.socket.send(buf).await?; + PeerSocketState::SharedServer { + master, peer_addr, .. + } => { + master.send_to(buf, *peer_addr).await?; } } Ok(()) @@ -168,24 +204,23 @@ impl PeerSocket { /// Receive one datagram from the bound peer. /// - /// For the server, datagrams from a *different* source address are dropped (v1 serves a single - /// peer per connection), so this loops until a datagram from the latched peer arrives. + /// For the client, this is a plain `recv` on the `connect()`ed socket. For the server, this + /// pulls the next datagram from the per-peer inbox the master loop fills; if the inbox is + /// closed (master loop stopped or evicted us) it returns an `UnexpectedEof`. async fn recv_dgram(&self) -> io::Result> { - let mut buf = vec![0u8; RECV_BUF]; - match self.peer { - Some(expected) => loop { - let (n, from) = self.socket.recv_from(&mut buf).await?; - if from == expected { - buf.truncate(n); - return Ok(buf); - } - // Datagram from an unrelated source: ignore (single-peer connection). - }, - None => { - let n = self.socket.recv(&mut buf).await?; + match &self.state { + PeerSocketState::ConnectedClient { socket } => { + let mut buf = vec![0u8; RECV_BUF]; + let n = socket.recv(&mut buf).await?; buf.truncate(n); Ok(buf) } + PeerSocketState::SharedServer { inbox, .. } => { + let mut rx = inbox.lock().await; + rx.recv().await.ok_or_else(|| { + io::Error::new(io::ErrorKind::UnexpectedEof, "peer inbox closed") + }) + } } } } @@ -523,8 +558,10 @@ struct Established { /// /// `run_hs` is either [`client_handshake`] or [`server_handshake`] partially applied with config; it /// receives the adapter's reader and writer (two handles sharing `state` + `write_notify`) and -/// returns the established [`aura_proto::Session`] reduced to its datagram parts. `state` may be -/// pre-seeded (the server seeds the client's first datagram before calling this). +/// returns the established [`aura_proto::Session`] reduced to its datagram parts. `state` always +/// starts fresh: in the multi-peer server, the master loop has already pushed the client's first +/// HS datagram into the per-peer inbox, so the very first `pump_one_incoming` call will deliver it +/// into the reorder buffer just like any subsequent datagram. /// /// We spawn nothing: the handshake future and the I/O driver are raced with `tokio::select!` in a /// loop so that (a) outgoing whole messages are framed and flushed to datagrams as soon as @@ -567,8 +604,10 @@ where rto.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); rto.tick().await; // skip the immediate first tick - // If `state` was pre-seeded (server case), respond to it immediately rather than waiting for the - // first timer/recv: flush any reply the handshake future already queued and ack the seed. + // Kick the I/O once before entering the select loop: flush anything the handshake future + // already buffered synchronously (the client's ClientHello, mainly) and emit a bare ack if the + // state already has something to acknowledge. Both are no-ops on a fresh adapter, so this is + // safe regardless of which side we are on. driver.flush_outgoing().await; driver.maybe_bare_ack().await; @@ -677,16 +716,25 @@ pub struct UdpConnection { receiver: Mutex, peer_id: Option, opts: UdpOpts, + /// `Some` for server-side connections (keeps the [`UdpServer`]'s master loop alive past the + /// server handle being dropped); `None` for client-side connections (the ephemeral + /// `connect()`ed socket lives inside the [`PeerSocket`] and needs no external task). + _master_task: Option>, } impl UdpConnection { - fn from_established(est: Established, opts: UdpOpts) -> Self { + fn from_established( + est: Established, + opts: UdpOpts, + master_task: Option>, + ) -> Self { Self { socket: est.socket, sender: Mutex::new(est.sender), receiver: Mutex::new(est.receiver), peer_id: est.peer_id, opts, + _master_task: master_task, } } @@ -796,24 +844,57 @@ impl PacketConnection for UdpConnection { } } -/// An Aura UDP server: a bound UDP socket that accepts one authenticated [`UdpConnection`] per -/// [`accept`](UdpServer::accept). +/// Per-peer inbox capacity in the server's master loop demuxer. /// -/// v1 serves a **single peer per accepted connection** (see the module docs). Each `accept` waits -/// for a client's first HS datagram, latches that source address, runs [`server_handshake`] over the -/// reliable adapter, and returns the connection. To serve multiple clients, bind multiple sockets or -/// add a per-source demuxer (out of scope for v1). +/// 128 datagrams is comfortably more than a single handshake flight (a handful of messages) +/// and absorbs short bursts on the data path before the per-peer consumer drains them. When the +/// inbox is full the master loop drops the datagram and logs — UDP is best-effort by design and +/// the upper layers (handshake retransmit; the tunnel's own loss tolerance) recover. +const PEER_INBOX_CAPACITY: usize = 128; + +/// Capacity of the [`UdpServer::accept`] queue (handed-off ready connections). +/// +/// Small on purpose: the bound is just back-pressure for the unusual case where many handshakes +/// finish faster than the application calls `accept`. Established connections are tiny. +const ACCEPT_QUEUE_CAPACITY: usize = 32; + +/// Shared lifetime owner of the [`UdpServer`]'s master loop task. +/// +/// Both the [`UdpServer`] handle and every server-side [`UdpConnection`] hold an `Arc`, +/// so the master loop keeps running as long as *either* the server can still accept new peers or +/// any already-accepted connection is still in use. When the last `Arc` is dropped, `Drop` aborts +/// the task — at which point all per-peer inboxes close, and any pending `recv_dgram` returns the +/// canonical `peer inbox closed` `UnexpectedEof`. +struct MasterTask(tokio::task::JoinHandle<()>); + +impl Drop for MasterTask { + fn drop(&mut self) { + self.0.abort(); + } +} + +/// An Aura UDP server: a bound UDP socket multiplexing **many** authenticated peers. +/// +/// One background master loop owns the listening socket and routes every incoming datagram into the +/// per-peer inbox keyed by source address. The first HS datagram from an unknown source spawns a +/// dedicated handshake task; on success the resulting [`UdpConnection`] is pushed onto the +/// `accept` queue. Per-peer state is reclaimed when the handshake task fails (its inbox receiver +/// is dropped → the master loop sees `Closed` on next send and evicts the entry) or when the +/// [`UdpConnection`] is dropped (same path via the [`PeerSocket`] holding the inbox). pub struct UdpServer { - socket: Arc, - /// A std clone of the same bound socket, kept solely so [`accept`](UdpServer::accept) can safely - /// `try_clone` an independent handle for the per-connection [`PeerSocket`] (no `unsafe`). - std_socket: std::net::UdpSocket, - proto_cfg: Arc, - /// Live options: kept behind an `Arc` so the daily mask rotator can update the - /// padding profile (and any future per-rotation field) and the next [`Self::accept`] picks up - /// the change. Already-accepted [`UdpConnection`]s hold their own snapshot, so an in-flight - /// connection's wire behaviour does not change mid-stream. - opts: Arc>, + /// Cached local address (so we can reply to `local_addr()` after the master loop has taken + /// ownership of the socket). + local_addr: SocketAddr, + /// Queue of established connections ready to be handed to callers of [`Self::accept`]. + accept_rx: Mutex>, + /// Shared lifetime owner of the master loop: kept here AND in each accepted + /// [`UdpConnection`] so the master loop survives until both the server is dropped and the + /// last established connection is dropped. Without this, dropping the server (e.g. tests that + /// move ownership into the accept task) would tear down per-peer inboxes mid-connection. + _master_task: Arc, + /// Snapshotted by each spawned handshake task to keep wire behaviour stable for the lifetime + /// of that connection while still letting the rotator update what new peers will use. + opts: Arc>, } impl UdpServer { @@ -826,22 +907,38 @@ impl UdpServer { /// # Errors /// Returns an [`io::Error`] if the UDP socket cannot bind. pub fn bind(local: SocketAddr, proto_cfg: ServerConfig, opts: UdpOpts) -> io::Result { - let socket = std::net::UdpSocket::bind(local)?; - socket.set_nonblocking(true)?; - // Keep a safe std clone for per-connection handles; both refer to the same bound port. - let std_socket = socket.try_clone()?; - let socket = UdpSocket::from_std(socket)?; + let std_socket = std::net::UdpSocket::bind(local)?; + std_socket.set_nonblocking(true)?; + let socket = UdpSocket::from_std(std_socket)?; + let local_addr = socket.local_addr()?; + let master_socket = Arc::new(socket); + let opts = Arc::new(RwLock::new(opts)); + let proto_cfg = Arc::new(proto_cfg); + let (accept_tx, accept_rx) = mpsc::channel::(ACCEPT_QUEUE_CAPACITY); + // `Arc::new_cyclic` lets the spawned master loop hold a `Weak`. The master + // loop upgrades it when handing established connections to `from_established` so each + // `UdpConnection` keeps the task alive past the [`UdpServer`] being dropped. + let master_task: Arc = Arc::new_cyclic(|weak: &std::sync::Weak| { + let weak_for_loop = weak.clone(); + MasterTask(tokio::spawn(server_master_loop( + master_socket, + proto_cfg, + opts.clone(), + accept_tx, + weak_for_loop, + ))) + }); Ok(Self { - socket: Arc::new(socket), - std_socket, - proto_cfg: Arc::new(proto_cfg), - opts: Arc::new(tokio::sync::RwLock::new(opts)), + local_addr, + accept_rx: Mutex::new(accept_rx), + _master_task: master_task, + opts, }) } - /// Replace the server's accept-time options. The change applies to the **next** [`Self::accept`]; - /// already-accepted connections keep their snapshot. Used by the daily mask rotator to update - /// the padding profile new connections will use. + /// Replace the server's accept-time options. The change applies to the **next** handshake the + /// master loop kicks off; already-accepted connections keep their snapshot. Used by the daily + /// mask rotator to update the padding profile new connections will use. pub async fn set_opts(&self, new_opts: UdpOpts) { *self.opts.write().await = new_opts; } @@ -854,59 +951,139 @@ impl UdpServer { /// The local address (including the OS-assigned port) this server is bound to. /// /// # Errors - /// Returns an [`io::Error`] if the socket address cannot be read. + /// Returns an [`io::Error`] only for API symmetry with the old single-peer impl; the cached + /// value is read back here and never actually fails. pub fn local_addr(&self) -> io::Result { - self.socket.local_addr() + Ok(self.local_addr) } - /// Accept the next client: wait for its first HS datagram, then run the Aura mutual-auth - /// handshake bound to that peer over the reliable UDP adapter. + /// Wait for the next established connection from the master loop. /// - /// Returns a ready [`UdpConnection`] whose [`peer_id`](UdpConnection::peer_id) is the verified - /// client Common Name. + /// Returns the next [`UdpConnection`] whose [`peer_id`](UdpConnection::peer_id) is the verified + /// client Common Name. May be called from any number of tasks; calls observe a fair queue. /// /// # Errors - /// Returns an error if receiving fails or the Aura handshake fails (e.g. the client's - /// certificate does not verify against the CA, or the handshake times out). + /// Returns an error only if the server has been dropped (the master loop's task ended and the + /// channel closed). Individual handshake failures are logged and swallowed inside the master + /// loop — they do not propagate to `accept`, and the server keeps accepting other peers. pub async fn accept(&self) -> anyhow::Result { - // Wait for the first HS datagram and latch the client's address. We must NOT consume the - // datagram's content blindly: re-deliver it to the handshake by seeding the reorder buffer. - let (peer_addr, first) = loop { - let mut buf = vec![0u8; RECV_BUF]; - let (n, from) = self.socket.recv_from(&mut buf).await?; - buf.truncate(n); - if !buf.is_empty() && buf[0] == TYPE_HS && buf.len() >= HS_PREFIX_LEN { - break (from, buf); + let mut rx = self.accept_rx.lock().await; + rx.recv() + .await + .ok_or_else(|| anyhow::anyhow!("UdpServer closed")) + } +} + +/// The UDP server's demuxer + per-peer dispatcher. +/// +/// Loops forever (until the last `Arc` is dropped and the task is aborted) on +/// `master.recv_from`. Routing rules: +/// +/// * Datagram from a **known peer** → push into that peer's inbox via `try_send`. `Full` is +/// logged-and-dropped (UDP is best-effort); `Closed` evicts the entry so a future first-HS +/// from the same address can start fresh. +/// * Datagram from an **unknown peer** with a leading [`TYPE_HS`] byte → allocate an inbox, +/// push the first datagram into it, register the peer, and spawn a handshake task. On +/// success the established [`UdpConnection`] is sent to `accept_tx`. On failure the spawn +/// ends silently; its inbox receiver is dropped, the next master-loop send to that peer fails +/// `Closed`, and the entry is evicted on the next datagram from that address. +/// * Anything else (unknown source, non-HS first byte, or empty datagram) is dropped. +async fn server_master_loop( + master: Arc, + proto_cfg: Arc, + opts: Arc>, + accept_tx: mpsc::Sender, + master_task_weak: std::sync::Weak, +) { + let mut peers: HashMap>> = HashMap::new(); + let mut buf = vec![0u8; RECV_BUF]; + loop { + let (n, from) = match master.recv_from(&mut buf).await { + Ok(v) => v, + Err(e) => { + tracing::warn!("udp master recv failed: {e}"); + continue; } - // Ignore stray non-HS datagrams while waiting for a fresh client. }; + let dg = buf[..n].to_vec(); - // A peer-bound view over the same bound port: safely `try_clone` the std socket and rebuild - // an independent tokio handle for it. Both the handshake adapter and the data path use this - // handle, addressing the latched client and ignoring any stray sources. - let peer_std = self.std_socket.try_clone()?; - peer_std.set_nonblocking(true)?; - let peer_socket = Arc::new(PeerSocket { - socket: UdpSocket::from_std(peer_std)?, - peer: Some(peer_addr), + // Existing peer (handshake-in-progress OR established): hand it to that peer's inbox. + if let Some(tx) = peers.get(&from) { + match tx.try_send(dg) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + tracing::warn!("udp inbox full for {from}, dropping datagram"); + } + Err(mpsc::error::TrySendError::Closed(_)) => { + // Peer is gone (handshake failed or connection dropped). Evict so a *new* + // first-HS from this address can establish a fresh peer. + peers.remove(&from); + } + } + continue; + } + + // Unknown source: only a leading HS byte is allowed to spawn a fresh peer. Late stray + // data datagrams from sources we forgot are silently dropped. + if dg.is_empty() || dg[0] != TYPE_HS { + continue; + } + + // Register the peer and pre-load the inbox with its first datagram so the spawned + // handshake task picks it up on its first `recv_dgram`. + let (inbox_tx, inbox_rx) = mpsc::channel::>(PEER_INBOX_CAPACITY); + // Capacity > 0, so this `try_send` cannot fail; ignore the result defensively. + let _ = inbox_tx.try_send(dg); + peers.insert(from, inbox_tx); + + // Snapshot opts for this peer's lifetime so a concurrent rotation does not change wire + // behaviour mid-handshake (matches the single-peer impl's contract). + let opts_snap = *opts.read().await; + let cfg = proto_cfg.clone(); + let master_for_peer = master.clone(); + let acc = accept_tx.clone(); + let weak = master_task_weak.clone(); + tokio::spawn(async move { + let peer_socket = Arc::new(PeerSocket { + state: PeerSocketState::SharedServer { + master: master_for_peer, + peer_addr: from, + inbox: Mutex::new(inbox_rx), + }, + }); + let state = Arc::new(Mutex::new(HsState::new())); + let result = + run_reliable_handshake(peer_socket, state, opts_snap, move |r, w| async move { + let session = server_handshake(r, w, &cfg).await?; + Ok(session.into_datagram_parts()) + }) + .await; + match result { + Ok(est) => { + // Pin the master task alive while this connection lives: upgrading `Weak` + // succeeds as long as either the [`UdpServer`] or some other connection still + // holds the `Arc`. The upgrade can only return `None` if every + // owner has dropped between the master loop reading and us running here, in + // which case the task itself is about to be aborted — drop silently. + let Some(task_anchor) = weak.upgrade() else { + tracing::debug!( + "udp master task gone before handshake from {from} finished; dropping" + ); + return; + }; + let conn = UdpConnection::from_established(est, opts_snap, Some(task_anchor)); + if acc.send(conn).await.is_err() { + tracing::warn!("udp accept queue closed; dropping connection from {from}"); + } + } + Err(e) => { + tracing::warn!("udp handshake from {from} failed: {e:#}"); + } + } + // If the handshake failed, dropping the `PeerSocket` also drops the inbox receiver — + // so the next master-loop send to `from` returns `Closed` and the peer is evicted from + // the map (lazy cleanup, no extra signalling needed). }); - - // Seed the reorder buffer with the first datagram so its ClientHello is not lost. - let state = Arc::new(Mutex::new(HsState::new())); - seed_first_hs(&state, &first).await; - - let cfg = self.proto_cfg.clone(); - // Snapshot the current accept-time options once: the resulting connection keeps this exact - // copy for its lifetime, so a concurrent mask rotation does not change in-flight wire - // behaviour (only the *next* accept will see the new mask). - let opts = *self.opts.read().await; - let est = run_reliable_handshake(peer_socket, state, opts, move |r, w| async move { - let session = server_handshake(r, w, &cfg).await?; - Ok(session.into_datagram_parts()) - }) - .await?; - - Ok(UdpConnection::from_established(est, opts)) } } @@ -939,9 +1116,9 @@ impl UdpClient { let socket = UdpSocket::from_std(std_sock)?; socket.connect(server).await?; + // Connected client: plain send/recv on the ephemeral socket. let peer_socket = Arc::new(PeerSocket { - socket, - peer: None, // connected socket: plain send/recv to the server + state: PeerSocketState::ConnectedClient { socket }, }); // Fresh (unseeded) state: the client speaks first (ClientHello). @@ -952,27 +1129,9 @@ impl UdpClient { }) .await?; - Ok(UdpConnection::from_established(est, opts)) - } -} - -// --------------------------------------------------------------------------------------------- -// Internal helpers for socket sharing and seeding -// --------------------------------------------------------------------------------------------- - -/// Seed an [`HsState`] with the server's first received HS datagram so its message is delivered to -/// the handshake reader in order (its `hs_seq` is 0 for a fresh client). -async fn seed_first_hs(state: &Arc>, dg: &[u8]) { - if dg.len() < HS_PREFIX_LEN || dg[0] != TYPE_HS { - return; - } - let seq = u16::from_be_bytes([dg[1], dg[2]]); - let ack_upto = u16::from_be_bytes([dg[3], dg[4]]); - let msg = dg[HS_PREFIX_LEN..].to_vec(); - let mut st = state.lock().await; - st.prune_acked(ack_upto); - if !msg.is_empty() { - st.accept_incoming(seq, msg); + // Client side has no master loop to keep alive — the ephemeral connected socket lives in + // the [`PeerSocket`] itself, so no external anchor is needed. + Ok(UdpConnection::from_established(est, opts, None)) } } diff --git a/crates/aura-transport/tests/udp_multi_client.rs b/crates/aura-transport/tests/udp_multi_client.rs new file mode 100644 index 0000000..d406b4a --- /dev/null +++ b/crates/aura-transport/tests/udp_multi_client.rs @@ -0,0 +1,333 @@ +//! Multi-client integration tests for the Aura UDP transport (the v2 master-loop demuxer). +//! +//! These prove that a single bound [`UdpServer`] can simultaneously serve **many** peers, that bad +//! peers do not poison the server, and that established connections survive other peers coming and +//! going. The single-client and lossy-channel tests live in `udp_loopback.rs`; here we focus on +//! demuxer correctness. +//! +//! * [`udp_multi_client_two_concurrent`] — bind one server, drive two clients (different client CNs) +//! to it concurrently, accept twice, and verify both connections are independent (no cross-talk; +//! each side learns the correct peer id). +//! * [`udp_bad_ca_does_not_block_other_clients`] — a third client with a foreign CA fails the +//! handshake; the server must keep accepting subsequent legitimate clients on the same port. +//! * [`udp_dropped_connection_does_not_block_other_clients`] — drop one client's connection mid-flight +//! and prove the server keeps serving the other plus accepts a fresh one. + +use std::sync::Arc; +use std::time::Duration; + +use aura_pki::AuraCa; +use aura_proto::{ClientConfig, PacketConnection, ServerConfig}; +use aura_transport::{UdpClient, UdpConnection, UdpOpts, UdpServer}; + +const SERVER_NAME: &str = "localhost"; + +/// Mint a CA, a server cert, and a set of client certs whose CNs are taken from `client_ids`. +fn make_configs(client_ids: &[&str]) -> (ServerConfig, Vec) { + let ca = AuraCa::generate("Aura UDP Multi-Client Test CA").expect("generate CA"); + let server_cert = ca + .issue_server_cert(SERVER_NAME) + .expect("issue server cert"); + let ca_pem = ca.ca_cert_pem(); + let server_cfg = ServerConfig { + ca_cert_pem: ca_pem.clone(), + server_cert_pem: server_cert.cert_pem, + server_key_pem: server_cert.key_pem, + }; + let client_cfgs: Vec = client_ids + .iter() + .map(|id| { + let c = ca.issue_client_cert(id).expect("issue client cert"); + ClientConfig { + ca_cert_pem: ca_pem.clone(), + client_cert_pem: c.cert_pem, + client_key_pem: c.key_pem, + server_name: SERVER_NAME.to_string(), + } + }) + .collect(); + (server_cfg, client_cfgs) +} + +/// Mint a **separate** CA + matching client cert; the resulting `ClientConfig` will trust this CA +/// for the server (so it will reject the real server) and present a cert the real server will not +/// verify either. Used to drive a handshake failure that must NOT take down the server. +fn make_foreign_ca_client(server_name: &str, client_cn: &str) -> ClientConfig { + let foreign = AuraCa::generate("Foreign CA").expect("generate foreign CA"); + let client_cert = foreign + .issue_client_cert(client_cn) + .expect("issue client cert under foreign CA"); + ClientConfig { + ca_cert_pem: foreign.ca_cert_pem(), + client_cert_pem: client_cert.cert_pem, + client_key_pem: client_cert.key_pem, + server_name: server_name.to_string(), + } +} + +/// Round-trip a payload `pkt` from `tx` to `rx` and assert byte equality. +async fn round_trip(tx: &Arc, rx: &Arc, pkt: &[u8]) { + tx.send_packet(pkt).await.expect("send"); + let got = tokio::time::timeout(Duration::from_secs(5), rx.recv_packet()) + .await + .expect("recv did not arrive within 5s") + .expect("recv"); + assert_eq!(got, pkt, "payload mismatch over round trip"); +} + +#[tokio::test] +async fn udp_multi_client_two_concurrent() { + let (server_cfg, client_cfgs) = make_configs(&["client-a", "client-b"]); + let opts = UdpOpts::default(); + + let server = + UdpServer::bind("127.0.0.1:0".parse().unwrap(), server_cfg, opts).expect("bind server"); + let server_addr = server.local_addr().expect("server addr"); + let server = Arc::new(server); + + // Spawn two server-side accepts in parallel; they must each pull their own connection from the + // master-loop's accept queue. + let s_a = server.clone(); + let accept_a = tokio::spawn(async move { s_a.accept().await }); + let s_b = server.clone(); + let accept_b = tokio::spawn(async move { s_b.accept().await }); + + // Spawn the two clients concurrently. They share the server's bound port. + let cfg_a = client_cfgs[0].clone(); + let cfg_b = client_cfgs[1].clone(); + let connect_a = tokio::spawn(async move { UdpClient::connect(server_addr, cfg_a, opts).await }); + let connect_b = tokio::spawn(async move { UdpClient::connect(server_addr, cfg_b, opts).await }); + + // Wait for everything to settle (generous timeout — handshake should be sub-second on loopback). + let timeout = Duration::from_secs(15); + let server_a: UdpConnection = tokio::time::timeout(timeout, accept_a) + .await + .expect("accept_a within timeout") + .expect("accept_a join") + .expect("accept_a result"); + let server_b: UdpConnection = tokio::time::timeout(timeout, accept_b) + .await + .expect("accept_b within timeout") + .expect("accept_b join") + .expect("accept_b result"); + let client_a: UdpConnection = tokio::time::timeout(timeout, connect_a) + .await + .expect("connect_a within timeout") + .expect("connect_a join") + .expect("connect_a result"); + let client_b: UdpConnection = tokio::time::timeout(timeout, connect_b) + .await + .expect("connect_b within timeout") + .expect("connect_b join") + .expect("connect_b result"); + + // Each server-side connection has a `peer_id` of either `client-a` or `client-b`; the accept + // order is *not* guaranteed (whichever handshake finishes first), so detect which is which + // and pair them with the matching client connection. + let id_a = server_a.peer_id().map(str::to_owned); + let id_b = server_b.peer_id().map(str::to_owned); + let mut ids = vec![id_a.clone(), id_b.clone()]; + ids.sort(); + assert_eq!( + ids, + vec![Some("client-a".to_string()), Some("client-b".to_string())], + "the two server-side connections must carry client-a and client-b CNs (no duplicates)" + ); + + let (srv_for_a, srv_for_b) = if id_a.as_deref() == Some("client-a") { + (server_a, server_b) + } else { + (server_b, server_a) + }; + + // Each side sees its own peer id (client side sees the server name). + assert_eq!(client_a.peer_id(), Some(SERVER_NAME)); + assert_eq!(client_b.peer_id(), Some(SERVER_NAME)); + + let client_a: Arc = Arc::new(client_a); + let client_b: Arc = Arc::new(client_b); + let server_for_a: Arc = Arc::new(srv_for_a); + let server_for_b: Arc = Arc::new(srv_for_b); + + // No cross-talk: A's payload reaches A's server-side conn (not B's), and vice versa. + round_trip(&client_a, &server_for_a, b"hi from a").await; + round_trip(&client_b, &server_for_b, b"hi from b").await; + round_trip(&server_for_a, &client_a, b"reply to a").await; + round_trip(&server_for_b, &client_b, b"reply to b").await; + + // And both directions still work concurrently (no head-of-line blocking via the master loop). + let a_send = { + let c = client_a.clone(); + let s = server_for_a.clone(); + tokio::spawn(async move { + c.send_packet(b"a-concurrent").await.unwrap(); + s.recv_packet().await.unwrap() + }) + }; + let b_send = { + let c = client_b.clone(); + let s = server_for_b.clone(); + tokio::spawn(async move { + c.send_packet(b"b-concurrent").await.unwrap(); + s.recv_packet().await.unwrap() + }) + }; + assert_eq!(a_send.await.unwrap(), b"a-concurrent"); + assert_eq!(b_send.await.unwrap(), b"b-concurrent"); +} + +#[tokio::test] +async fn udp_bad_ca_does_not_block_other_clients() { + let (server_cfg, client_cfgs) = make_configs(&["client-good"]); + // Use a tighter handshake timeout so the failing peer fails quickly and the test finishes + // even if the rogue client retransmits its ClientHello for a while. + let opts = UdpOpts { + hs_timeout: Duration::from_secs(3), + ..UdpOpts::default() + }; + + let server = + UdpServer::bind("127.0.0.1:0".parse().unwrap(), server_cfg, opts).expect("bind server"); + let server_addr = server.local_addr().expect("server addr"); + let server = Arc::new(server); + + // A rogue client with a foreign CA: its server-side handshake task will fail. The server must + // log + drop and keep accepting OTHER peers. + let foreign_cfg = make_foreign_ca_client(SERVER_NAME, "rogue"); + let rogue = + tokio::spawn(async move { UdpClient::connect(server_addr, foreign_cfg, opts).await }); + + // Give the rogue task a head start so the server's master loop registers it first. + tokio::time::sleep(Duration::from_millis(50)).await; + + // Now the legitimate client connects. The server must still accept it. + let cfg = client_cfgs[0].clone(); + let s = server.clone(); + let accept_good = tokio::spawn(async move { s.accept().await }); + let connect_good = + tokio::spawn(async move { UdpClient::connect(server_addr, cfg, opts).await }); + + let timeout = Duration::from_secs(15); + let server_good: UdpConnection = tokio::time::timeout(timeout, accept_good) + .await + .expect("accept_good within timeout") + .expect("accept_good join") + .expect("accept_good result"); + let client_good: UdpConnection = tokio::time::timeout(timeout, connect_good) + .await + .expect("connect_good within timeout") + .expect("connect_good join") + .expect("connect_good result"); + + assert_eq!( + server_good.peer_id(), + Some("client-good"), + "server must learn the good client's CN despite the rogue peer" + ); + + let server_good: Arc = Arc::new(server_good); + let client_good: Arc = Arc::new(client_good); + round_trip(&client_good, &server_good, b"still serving").await; + round_trip(&server_good, &client_good, b"yes we are").await; + + // The rogue connect should eventually fail (foreign CA → server's handshake rejects, the + // client's handshake adapter then errors out on the deadline / chain mismatch). We do not + // care about the exact error; we only require that it *errors*, not that it succeeds. + let rogue_result = tokio::time::timeout(Duration::from_secs(10), rogue) + .await + .expect("rogue task should terminate") + .expect("rogue task join"); + assert!( + rogue_result.is_err(), + "rogue client (foreign CA) must NOT succeed in establishing a connection" + ); +} + +/// Establish ONE client against a running multi-client server, then verify the server-side conn +/// has the expected CN. The accept happens in its own spawned task to avoid blocking the connect. +async fn establish_one( + server: &Arc, + server_addr: std::net::SocketAddr, + cfg: ClientConfig, + opts: UdpOpts, + expect_cn: &str, +) -> (UdpConnection, UdpConnection) { + let s = server.clone(); + let acc = tokio::spawn(async move { s.accept().await }); + let con = tokio::spawn(async move { UdpClient::connect(server_addr, cfg, opts).await }); + let timeout = Duration::from_secs(15); + let srv = tokio::time::timeout(timeout, acc) + .await + .expect("accept timely") + .expect("accept join") + .expect("accept result"); + let cli = tokio::time::timeout(timeout, con) + .await + .expect("connect timely") + .expect("connect join") + .expect("connect result"); + assert_eq!( + srv.peer_id(), + Some(expect_cn), + "server learned wrong CN for this client" + ); + (srv, cli) +} + +#[tokio::test] +async fn udp_dropped_connection_does_not_block_other_clients() { + let (server_cfg, client_cfgs) = make_configs(&["client-1", "client-2", "client-3"]); + let opts = UdpOpts::default(); + + let server = + UdpServer::bind("127.0.0.1:0".parse().unwrap(), server_cfg, opts).expect("bind server"); + let server_addr = server.local_addr().expect("server addr"); + let server = Arc::new(server); + + // Connect clients sequentially so the (server-side, client-side) pairing is unambiguous. + let (srv1, cli1) = establish_one( + &server, + server_addr, + client_cfgs[0].clone(), + opts, + "client-1", + ) + .await; + let (srv2, cli2) = establish_one( + &server, + server_addr, + client_cfgs[1].clone(), + opts, + "client-2", + ) + .await; + + let srv2: Arc = Arc::new(srv2); + let cli2: Arc = Arc::new(cli2); + // Sanity: client-2 works. + round_trip(&cli2, &srv2, b"keep-1").await; + + // Drop both ends of client-1's pair: the server-side `UdpConnection` is dropped, its + // `PeerSocket` (with the master's per-peer inbox receiver) is dropped, and the master loop's + // next datagram from client-1's address — if any — will `Closed` and evict the entry. + drop(srv1); + drop(cli1); + tokio::time::sleep(Duration::from_millis(50)).await; + + // The other client must keep working. + round_trip(&cli2, &srv2, b"keep-2").await; + round_trip(&srv2, &cli2, b"keep-3").await; + + // A fresh client-3 must also still be accepted. + let (srv3, cli3) = establish_one( + &server, + server_addr, + client_cfgs[2].clone(), + opts, + "client-3", + ) + .await; + let srv3: Arc = Arc::new(srv3); + let cli3: Arc = Arc::new(cli3); + round_trip(&cli3, &srv3, b"hi-3").await; +}