diff --git a/Cargo.lock b/Cargo.lock index 32ae0c9..c41585f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,11 +248,13 @@ dependencies = [ "bytes", "hmac", "rand 0.8.6", + "ring", "rustls-pki-types", "serde", "sha2", "thiserror 1.0.69", "tokio", + "x509-parser 0.16.0", "zeroize", ] diff --git a/crates/aura-proto/Cargo.toml b/crates/aura-proto/Cargo.toml index edba2b5..9906110 100644 --- a/crates/aura-proto/Cargo.toml +++ b/crates/aura-proto/Cargo.toml @@ -17,6 +17,15 @@ sha2.workspace = true rand.workspace = true rustls-pki-types.workspace = true thiserror.workspace = true +# Handshake signatures (ECDSA P-256 / SHA-256, ASN.1 DER). Already in the workspace lockfile. +ring = "0.17" +# Parse leaf cert DER (extract the EC SubjectPublicKeyInfo point) and decode PEM blocks +# (certificates + PKCS#8 keys) to DER. Already a workspace dependency and used by aura-pki, so +# this adds no new resolution and lets us avoid pulling in rustls-pemfile. +x509-parser.workspace = true +# The handshake and session are async over tokio::io::{AsyncRead, AsyncWrite}, so tokio must be a +# normal dependency (available via the workspace `full` feature set), not only a dev-dependency. +tokio.workspace = true [dev-dependencies] tokio.workspace = true diff --git a/crates/aura-proto/src/frame.rs b/crates/aura-proto/src/frame.rs new file mode 100644 index 0000000..12f577a --- /dev/null +++ b/crates/aura-proto/src/frame.rs @@ -0,0 +1,371 @@ +//! Wire format: the 5-byte protocol header (§6.1) and the application [`Frame`] enum (§6.3). +//! +//! Every Aura protocol message on the wire is a 5-byte header followed by a payload: +//! +//! ```text +//! byte 0 : msg_type (u8) +//! bytes 1..4 : length (u24, big-endian) = payload length in bytes +//! byte 4 : version = 0x01 +//! bytes 5.. : payload (length bytes) +//! ``` +//! +//! [`Frame`] is the post-handshake application payload. Each `Frame` is serialized with +//! [`Frame::encode`], AEAD-sealed, and shipped inside a [`MsgType::Data`] record (see +//! [`crate::session`]). + +use bytes::Bytes; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::ProtoError; + +/// Length in bytes of the protocol frame header. +pub const HEADER_LEN: usize = 5; + +/// Protocol version carried in byte 4 of every header. +pub const PROTOCOL_VERSION: u8 = 0x01; + +/// Largest payload expressible by the u24 length field. +pub const MAX_PAYLOAD_LEN: usize = 0x00FF_FFFF; + +/// Message types carried in byte 0 of the header (§6.1). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(u8)] +pub enum MsgType { + /// Handshake message 1 (C->S): hybrid public key + client nonce. + ClientHello = 0x01, + /// Handshake message 2 (S->C): hybrid ciphertext + server nonce. + ServerHello = 0x02, + /// Handshake message 4 (C->S, encrypted): client cert + signature. + ClientAuth = 0x03, + /// Handshake message 3 (S->C, encrypted): server cert + signature. + ServerAuth = 0x04, + /// Handshake Finished (encrypted): HMAC over the handshake hash. + Finished = 0x05, + /// Application data (encrypted): an AEAD-sealed [`Frame`]. + Data = 0x06, + /// Fatal alert / error notification. + Alert = 0xFF, +} + +impl MsgType { + /// Map the on-wire byte to a [`MsgType`]. + /// + /// # Errors + /// Returns [`ProtoError::UnknownMsgType`] for an unrecognized byte. + pub fn from_u8(b: u8) -> Result { + Ok(match b { + 0x01 => Self::ClientHello, + 0x02 => Self::ServerHello, + 0x03 => Self::ClientAuth, + 0x04 => Self::ServerAuth, + 0x05 => Self::Finished, + 0x06 => Self::Data, + 0xFF => Self::Alert, + other => return Err(ProtoError::UnknownMsgType(other)), + }) + } +} + +/// Build a 5-byte header for `msg_type` carrying a payload of `payload_len` bytes. +/// +/// # Errors +/// Returns [`ProtoError::FrameTooLarge`] if `payload_len` does not fit in the u24 length field. +pub fn encode_header( + msg_type: MsgType, + payload_len: usize, +) -> Result<[u8; HEADER_LEN], ProtoError> { + if payload_len > MAX_PAYLOAD_LEN { + return Err(ProtoError::FrameTooLarge(payload_len)); + } + let len = payload_len as u32; + Ok([ + msg_type as u8, + ((len >> 16) & 0xFF) as u8, + ((len >> 8) & 0xFF) as u8, + (len & 0xFF) as u8, + PROTOCOL_VERSION, + ]) +} + +/// Parse a 5-byte header into `(msg_type, payload_len)`. +/// +/// # Errors +/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized type byte or +/// [`ProtoError::BadVersion`] if byte 4 is not [`PROTOCOL_VERSION`]. +pub fn decode_header(header: &[u8; HEADER_LEN]) -> Result<(MsgType, usize), ProtoError> { + let msg_type = MsgType::from_u8(header[0])?; + let version = header[4]; + if version != PROTOCOL_VERSION { + return Err(ProtoError::BadVersion(version)); + } + let len = ((header[1] as usize) << 16) | ((header[2] as usize) << 8) | (header[3] as usize); + Ok((msg_type, len)) +} + +/// Write one full frame (`header || payload`) to `writer`. +/// +/// # Errors +/// Returns [`ProtoError::FrameTooLarge`] if the payload is too long, or [`ProtoError::Io`] on a +/// write failure. +pub async fn write_frame( + writer: &mut W, + msg_type: MsgType, + payload: &[u8], +) -> Result<(), ProtoError> +where + W: AsyncWrite + Unpin, +{ + let header = encode_header(msg_type, payload.len())?; + writer.write_all(&header).await?; + writer.write_all(payload).await?; + writer.flush().await?; + Ok(()) +} + +/// A frame read off the wire: its type, the raw header bytes (useful as AEAD AAD and for the +/// handshake transcript hash), and the payload. +#[derive(Clone, Debug)] +pub struct RawFrame { + /// The decoded message type. + pub msg_type: MsgType, + /// The 5 header bytes exactly as transmitted. + pub header: [u8; HEADER_LEN], + /// The payload bytes. + pub payload: Vec, +} + +impl RawFrame { + /// The full serialized frame (`header || payload`) exactly as it appeared on the wire. + /// + /// Used to feed the handshake transcript hash, which must hash the bytes as transmitted. + #[must_use] + pub fn wire_bytes(&self) -> Vec { + let mut out = Vec::with_capacity(HEADER_LEN + self.payload.len()); + out.extend_from_slice(&self.header); + out.extend_from_slice(&self.payload); + out + } +} + +/// Read one full frame (`header || payload`) from `reader`. +/// +/// # Errors +/// Returns [`ProtoError::Io`] on a read failure (including a truncated frame / EOF), or a header +/// decode error from [`decode_header`]. +pub async fn read_frame(reader: &mut R) -> Result +where + R: AsyncRead + Unpin, +{ + let mut header = [0u8; HEADER_LEN]; + reader.read_exact(&mut header).await?; + let (msg_type, len) = decode_header(&header)?; + let mut payload = vec![0u8; len]; + reader.read_exact(&mut payload).await?; + Ok(RawFrame { + msg_type, + header, + payload, + }) +} + +/// Frame type tags used in the application [`Frame`] encoding (§6.3). +mod frame_tag { + pub const DATA: u8 = 0x01; + pub const PING: u8 = 0x02; + pub const PONG: u8 = 0x03; + pub const CLOSE: u8 = 0x04; +} + +/// Application-level frames carried inside encrypted [`MsgType::Data`] records (§6.3). +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Frame { + /// A stream data payload. + Data { + /// Logical stream identifier. + stream_id: u32, + /// Opaque application bytes. + payload: Bytes, + }, + /// Liveness probe. + Ping { + /// Monotonic sequence number echoed back in the matching [`Frame::Pong`]. + seq: u32, + }, + /// Reply to a [`Frame::Ping`]. + Pong { + /// Sequence number copied from the [`Frame::Ping`]. + seq: u32, + }, + /// Orderly shutdown of the logical connection. + Close { + /// Application-defined close code. + code: u8, + /// Human-readable reason (UTF-8). + reason: String, + }, +} + +impl Frame { + /// Serialize this frame to its compact byte encoding. + /// + /// Layout (all multi-byte integers big-endian): + /// * `Data` : `0x01 || stream_id(u32) || payload` + /// * `Ping` : `0x02 || seq(u32)` + /// * `Pong` : `0x03 || seq(u32)` + /// * `Close` : `0x04 || code(u8) || reason_len(u32) || reason_utf8` + #[must_use] + pub fn encode(&self) -> Vec { + let mut out = Vec::new(); + match self { + Frame::Data { stream_id, payload } => { + out.push(frame_tag::DATA); + out.extend_from_slice(&stream_id.to_be_bytes()); + out.extend_from_slice(payload); + } + Frame::Ping { seq } => { + out.push(frame_tag::PING); + out.extend_from_slice(&seq.to_be_bytes()); + } + Frame::Pong { seq } => { + out.push(frame_tag::PONG); + out.extend_from_slice(&seq.to_be_bytes()); + } + Frame::Close { code, reason } => { + out.push(frame_tag::CLOSE); + out.push(*code); + let bytes = reason.as_bytes(); + out.extend_from_slice(&(bytes.len() as u32).to_be_bytes()); + out.extend_from_slice(bytes); + } + } + out + } + + /// Parse a frame from its byte encoding (the inverse of [`Frame::encode`]). + /// + /// # Errors + /// Returns [`ProtoError::MalformedFrame`] if the buffer is truncated, has an unknown tag, or + /// (for `Close`) does not contain valid UTF-8. + pub fn decode(buf: &[u8]) -> Result { + let (&tag, rest) = buf + .split_first() + .ok_or(ProtoError::MalformedFrame("empty frame"))?; + match tag { + frame_tag::DATA => { + let stream_id = read_u32(rest, "Data.stream_id")?; + let payload = Bytes::copy_from_slice(&rest[4..]); + Ok(Frame::Data { stream_id, payload }) + } + frame_tag::PING => Ok(Frame::Ping { + seq: read_u32(rest, "Ping.seq")?, + }), + frame_tag::PONG => Ok(Frame::Pong { + seq: read_u32(rest, "Pong.seq")?, + }), + frame_tag::CLOSE => { + let code = *rest + .first() + .ok_or(ProtoError::MalformedFrame("Close: missing code"))?; + let reason_len = read_u32(&rest[1..], "Close.reason_len")? as usize; + let reason_bytes = rest + .get(5..5 + reason_len) + .ok_or(ProtoError::MalformedFrame("Close: truncated reason"))?; + let reason = String::from_utf8(reason_bytes.to_vec()) + .map_err(|_| ProtoError::MalformedFrame("Close: reason not UTF-8"))?; + Ok(Frame::Close { code, reason }) + } + _ => Err(ProtoError::MalformedFrame("unknown frame tag")), + } + } +} + +/// Read a big-endian u32 from the start of `buf`, erroring if it is too short. +fn read_u32(buf: &[u8], what: &'static str) -> Result { + let bytes: [u8; 4] = buf + .get(..4) + .ok_or(ProtoError::MalformedFrame(what))? + .try_into() + .expect("slice of length 4 converts to [u8; 4]"); + Ok(u32::from_be_bytes(bytes)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn header_roundtrip_all_types() { + for (ty, byte) in [ + (MsgType::ClientHello, 0x01u8), + (MsgType::ServerHello, 0x02), + (MsgType::ClientAuth, 0x03), + (MsgType::ServerAuth, 0x04), + (MsgType::Finished, 0x05), + (MsgType::Data, 0x06), + (MsgType::Alert, 0xFF), + ] { + let h = encode_header(ty, 0x0012_3456).unwrap(); + assert_eq!(h[0], byte); + assert_eq!(h[4], PROTOCOL_VERSION); + let (got_ty, got_len) = decode_header(&h).unwrap(); + assert_eq!(got_ty, ty); + assert_eq!(got_len, 0x0012_3456); + } + } + + #[test] + fn header_rejects_oversize_and_bad_version() { + assert!(matches!( + encode_header(MsgType::Data, MAX_PAYLOAD_LEN + 1), + Err(ProtoError::FrameTooLarge(_)) + )); + let mut h = encode_header(MsgType::Data, 1).unwrap(); + h[4] = 0x02; + assert!(matches!( + decode_header(&h), + Err(ProtoError::BadVersion(0x02)) + )); + h[0] = 0x77; + assert!(matches!( + decode_header(&h), + Err(ProtoError::UnknownMsgType(0x77)) + )); + } + + #[test] + fn frame_roundtrip() { + let frames = vec![ + Frame::Data { + stream_id: 0xDEAD_BEEF, + payload: Bytes::from_static(b"hello world"), + }, + Frame::Data { + stream_id: 0, + payload: Bytes::new(), + }, + Frame::Ping { seq: 42 }, + Frame::Pong { seq: 0xFFFF_FFFF }, + Frame::Close { + code: 7, + reason: "going away \u{1f44b}".to_string(), + }, + Frame::Close { + code: 0, + reason: String::new(), + }, + ]; + for f in frames { + let encoded = f.encode(); + let decoded = Frame::decode(&encoded).unwrap(); + assert_eq!(f, decoded); + } + } + + #[test] + fn frame_decode_rejects_garbage() { + assert!(Frame::decode(&[]).is_err()); + assert!(Frame::decode(&[0x99]).is_err()); + assert!(Frame::decode(&[frame_tag::PING, 0x00]).is_err()); // truncated u32 + assert!(Frame::decode(&[frame_tag::CLOSE]).is_err()); // missing code + } +} diff --git a/crates/aura-proto/src/handshake.rs b/crates/aura-proto/src/handshake.rs new file mode 100644 index 0000000..ea0c49f --- /dev/null +++ b/crates/aura-proto/src/handshake.rs @@ -0,0 +1,453 @@ +//! The Aura handshake state machine (§6.2): a hybrid X25519 + ML-KEM-768 key exchange with mutual +//! X.509 authentication and a Finished-MAC transcript binding. +//! +//! Both entry points — [`client_handshake`] and [`server_handshake`] — are generic over a separate +//! [`tokio::io::AsyncRead`] reader and [`tokio::io::AsyncWrite`] writer so they drive an in-memory +//! duplex pipe (tests) or quinn's split `RecvStream` / `SendStream` (the QUIC transport) +//! identically. On success each returns an established [`Session`]. +//! +//! ## Message order (resolves the spec diagram ambiguity) +//! +//! ```text +//! 1. C->S ClientHello (plaintext): x25519_pub[32] || kyber_ek[1184] || client_nonce[32] +//! 2. S->C ServerHello (plaintext): x25519_ephemeral[32] || kyber_ct[1088] || server_nonce[32] +//! -- both sides derive the hybrid shared secret and the directional SessionKeys -- +//! 3. S->C ServerAuth (encrypted): u16(cert_der_len) || server_leaf_cert_der || sig(transcript) +//! 4. C->S ClientAuth (encrypted): u16(cert_der_len) || client_leaf_cert_der || sig(transcript) +//! 5. C->S Finished (encrypted): HMAC-SHA256(key = key_c2s, transcript) +//! 6. S->C Finished (encrypted): HMAC-SHA256(key = key_s2c, transcript) +//! ``` +//! +//! `transcript = SHA-256(ClientHello_frame_bytes || ServerHello_frame_bytes)` over the FULL +//! serialized frames (header + payload) exactly as transmitted. The same two [`AeadSession`]s +//! protect messages 3–6 and all subsequent application Data — their counters stay continuous. + +use aura_crypto::{ + derive_session_keys, AeadSession, HybridCiphertext, HybridPrivateKey, HybridPublicKey, + SessionKeys, +}; +use aura_pki::AuraCertVerifier; +use hmac::{Hmac, Mac}; +use rand::RngCore; +use rustls_pki_types::CertificateDer; +use sha2::{Digest, Sha256}; + +use crate::frame::{self, MsgType, RawFrame, HEADER_LEN}; +use crate::session::Session; +use crate::{ClientConfig, ProtoError, ServerConfig}; + +/// X25519 public key / ephemeral / shared-secret length. +const X25519_LEN: usize = 32; +/// ML-KEM-768 encapsulation (public) key length. +const KYBER_EK_LEN: usize = 1184; +/// ML-KEM-768 ciphertext length. +const KYBER_CT_LEN: usize = 1088; +/// Handshake nonce length. +const NONCE_LEN: usize = 32; + +/// Exact ClientHello payload length: x25519_pub || kyber_ek || client_nonce. +const CLIENT_HELLO_LEN: usize = X25519_LEN + KYBER_EK_LEN + NONCE_LEN; +/// Exact ServerHello payload length: x25519_ephemeral || kyber_ct || server_nonce. +const SERVER_HELLO_LEN: usize = X25519_LEN + KYBER_CT_LEN + NONCE_LEN; + +/// HMAC-SHA256 alias for the Finished MAC. +type HmacSha256 = Hmac; + +/// Each direction's AEAD seals exactly two encrypted handshake messages (an Auth and a Finished) +/// before application Data begins, so both AEAD counters start Data at this value. +const POST_HANDSHAKE_COUNTER: u64 = 2; + +// --------------------------------------------------------------------------------------------- +// Client +// --------------------------------------------------------------------------------------------- + +/// Drive the client side of the Aura handshake to completion. +/// +/// Generates a hybrid keypair, exchanges hello messages, derives session keys, authenticates the +/// server, proves possession of the client key, and exchanges Finished MACs. On success returns an +/// established [`Session`] whose AEAD counters continue from the handshake. +/// +/// # Errors +/// Returns a [`ProtoError`] for any transport, cryptographic, certificate, signature, or +/// transcript-MAC failure (see the variants of [`ProtoError`]). +pub async fn client_handshake( + mut reader: R, + mut writer: W, + cfg: &ClientConfig, +) -> Result, ProtoError> +where + R: tokio::io::AsyncRead + Unpin, + W: tokio::io::AsyncWrite + Unpin, +{ + // (1) C->S ClientHello: generate our hybrid keypair; send public half + a fresh nonce. + let (priv_key, pub_key): (HybridPrivateKey, HybridPublicKey) = HybridPrivateKey::generate(); + let client_nonce = random_nonce(); + + let mut ch_payload = Vec::with_capacity(CLIENT_HELLO_LEN); + ch_payload.extend_from_slice(&pub_key.x25519); + ch_payload.extend_from_slice(&pub_key.kyber); + ch_payload.extend_from_slice(&client_nonce); + debug_assert_eq!(ch_payload.len(), CLIENT_HELLO_LEN); + + let ch_header = frame::encode_header(MsgType::ClientHello, ch_payload.len())?; + frame::write_frame(&mut writer, MsgType::ClientHello, &ch_payload).await?; + let client_hello_wire = concat_frame(&ch_header, &ch_payload); + + // (2) S->C ServerHello: ciphertext + server nonce. Capture raw bytes for the transcript. + let sh = read_expect(&mut reader, MsgType::ServerHello).await?; + if sh.payload.len() != SERVER_HELLO_LEN { + return Err(ProtoError::MalformedHandshake("ServerHello: wrong length")); + } + let (sh_x_eph, rest) = sh.payload.split_at(X25519_LEN); + let (sh_kyber_ct, sh_nonce) = rest.split_at(KYBER_CT_LEN); + let mut server_nonce = [0u8; NONCE_LEN]; + server_nonce.copy_from_slice(sh_nonce); + + let ciphertext = HybridCiphertext { + x25519_ephemeral: sh_x_eph.try_into().expect("32-byte x25519 ephemeral"), + kyber_ciphertext: sh_kyber_ct.to_vec(), + }; + let shared = priv_key.decapsulate(&ciphertext)?; + let keys: SessionKeys = derive_session_keys(&shared, &client_nonce, &server_nonce); + + // Transcript hash over the two hello frames exactly as transmitted. + let transcript = transcript_hash(&client_hello_wire, &sh.wire_bytes()); + + // Two AEADs: client seals c2s, opens s2c. + let mut aead_c2s = AeadSession::new(keys.client_to_server); + let mut aead_s2c = AeadSession::new(keys.server_to_client); + + // (3) S->C ServerAuth (encrypted under s2c): verify cert chain + signature over transcript. + let server_auth = open_handshake_msg(&mut reader, MsgType::ServerAuth, &mut aead_s2c).await?; + let (server_cert_der, server_sig) = split_cert_and_sig(&server_auth)?; + let verifier = AuraCertVerifier::new(&cfg.ca_cert_pem)?; + let chain = [CertificateDer::from(server_cert_der.to_vec())]; + verifier.verify_server_cert(&chain, &cfg.server_name)?; + verify_signature(server_cert_der, &transcript, server_sig)?; + + // (4) C->S ClientAuth (encrypted under c2s): our leaf + our signature over transcript. + let client_cert_der = pem_cert_to_der(&cfg.client_cert_pem)?; + let client_sig = sign_transcript(&cfg.client_key_pem, &transcript)?; + let client_auth = build_cert_and_sig(&client_cert_der, &client_sig); + seal_handshake_msg( + &mut writer, + MsgType::ClientAuth, + &mut aead_c2s, + &client_auth, + ) + .await?; + + // (5) C->S Finished (encrypted under c2s): HMAC(key_c2s, transcript). + let client_finished = finished_mac(&keys.client_to_server, &transcript); + seal_handshake_msg( + &mut writer, + MsgType::Finished, + &mut aead_c2s, + &client_finished, + ) + .await?; + + // (6) S->C Finished (encrypted under s2c): verify HMAC(key_s2c, transcript). + let server_finished = open_handshake_msg(&mut reader, MsgType::Finished, &mut aead_s2c).await?; + verify_finished(&keys.server_to_client, &transcript, &server_finished)?; + + Ok(Session::new( + reader, + writer, + aead_c2s, + aead_s2c, + POST_HANDSHAKE_COUNTER, + Some(cfg.server_name.clone()), + )) +} + +// --------------------------------------------------------------------------------------------- +// Server +// --------------------------------------------------------------------------------------------- + +/// Drive the server side of the Aura handshake to completion. +/// +/// Receives the client hello, encapsulates to the client's hybrid public key, derives session +/// keys, authenticates itself, verifies the client certificate (capturing the client id) and its +/// signature, and exchanges Finished MACs. On success returns an established [`Session`] whose +/// [`Session::peer_id`] is the verified client id. +/// +/// # Errors +/// Returns a [`ProtoError`] for any transport, cryptographic, certificate, signature, or +/// transcript-MAC failure. +pub async fn server_handshake( + mut reader: R, + mut writer: W, + cfg: &ServerConfig, +) -> Result, ProtoError> +where + R: tokio::io::AsyncRead + Unpin, + W: tokio::io::AsyncWrite + Unpin, +{ + // (1) S receives ClientHello: reconstruct the client's hybrid public key + nonce. + let ch = read_expect(&mut reader, MsgType::ClientHello).await?; + if ch.payload.len() != CLIENT_HELLO_LEN { + return Err(ProtoError::MalformedHandshake("ClientHello: wrong length")); + } + let (ch_x, rest) = ch.payload.split_at(X25519_LEN); + let (ch_kyber_ek, ch_nonce) = rest.split_at(KYBER_EK_LEN); + let mut client_nonce = [0u8; NONCE_LEN]; + client_nonce.copy_from_slice(ch_nonce); + + let client_pub = HybridPublicKey { + x25519: ch_x.try_into().expect("32-byte x25519 public key"), + kyber: ch_kyber_ek.to_vec(), + }; + + // (2) S->C ServerHello: encapsulate to the client's public key; send ciphertext + nonce. + let (ciphertext, shared) = client_pub.encapsulate(); + let server_nonce = random_nonce(); + let mut sh_payload = Vec::with_capacity(SERVER_HELLO_LEN); + sh_payload.extend_from_slice(&ciphertext.x25519_ephemeral); + sh_payload.extend_from_slice(&ciphertext.kyber_ciphertext); + sh_payload.extend_from_slice(&server_nonce); + debug_assert_eq!(sh_payload.len(), SERVER_HELLO_LEN); + + let sh_header = frame::encode_header(MsgType::ServerHello, sh_payload.len())?; + frame::write_frame(&mut writer, MsgType::ServerHello, &sh_payload).await?; + let server_hello_wire = concat_frame(&sh_header, &sh_payload); + + let keys: SessionKeys = derive_session_keys(&shared, &client_nonce, &server_nonce); + let transcript = transcript_hash(&ch.wire_bytes(), &server_hello_wire); + + // Two AEADs: server seals s2c, opens c2s. + let mut aead_c2s = AeadSession::new(keys.client_to_server); + let mut aead_s2c = AeadSession::new(keys.server_to_client); + + // (3) S->C ServerAuth (encrypted under s2c): our leaf + our signature over transcript. + let server_cert_der = pem_cert_to_der(&cfg.server_cert_pem)?; + let server_sig = sign_transcript(&cfg.server_key_pem, &transcript)?; + let server_auth = build_cert_and_sig(&server_cert_der, &server_sig); + seal_handshake_msg( + &mut writer, + MsgType::ServerAuth, + &mut aead_s2c, + &server_auth, + ) + .await?; + + // (4) C->S ClientAuth (encrypted under c2s): verify cert chain (=> client_id) + signature. + let client_auth = open_handshake_msg(&mut reader, MsgType::ClientAuth, &mut aead_c2s).await?; + let (client_cert_der, client_sig) = split_cert_and_sig(&client_auth)?; + let verifier = AuraCertVerifier::new(&cfg.ca_cert_pem)?; + let chain = [CertificateDer::from(client_cert_der.to_vec())]; + let client_id = verifier.verify_client_cert(&chain)?; + verify_signature(client_cert_der, &transcript, client_sig)?; + + // (5) C->S Finished (encrypted under c2s): verify HMAC(key_c2s, transcript). + let client_finished = open_handshake_msg(&mut reader, MsgType::Finished, &mut aead_c2s).await?; + verify_finished(&keys.client_to_server, &transcript, &client_finished)?; + + // (6) S->C Finished (encrypted under s2c): HMAC(key_s2c, transcript). + let server_finished = finished_mac(&keys.server_to_client, &transcript); + seal_handshake_msg( + &mut writer, + MsgType::Finished, + &mut aead_s2c, + &server_finished, + ) + .await?; + + Ok(Session::new( + reader, + writer, + aead_s2c, + aead_c2s, + POST_HANDSHAKE_COUNTER, + Some(client_id), + )) +} + +// --------------------------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------------------------- + +/// Generate a fresh 32-byte handshake nonce from the OS RNG. +fn random_nonce() -> [u8; NONCE_LEN] { + let mut n = [0u8; NONCE_LEN]; + rand::thread_rng().fill_bytes(&mut n); + n +} + +/// Concatenate a header and payload into the full serialized frame bytes. +fn concat_frame(header: &[u8; HEADER_LEN], payload: &[u8]) -> Vec { + let mut out = Vec::with_capacity(HEADER_LEN + payload.len()); + out.extend_from_slice(header); + out.extend_from_slice(payload); + out +} + +/// `transcript = SHA-256(client_hello_frame || server_hello_frame)`. +fn transcript_hash(client_hello_frame: &[u8], server_hello_frame: &[u8]) -> [u8; 32] { + let mut h = Sha256::new(); + h.update(client_hello_frame); + h.update(server_hello_frame); + h.finalize().into() +} + +/// Read a frame and require it to be `expected`, mapping an Alert to [`ProtoError::Alert`]. +async fn read_expect(reader: &mut R, expected: MsgType) -> Result +where + R: tokio::io::AsyncRead + Unpin, +{ + let raw = frame::read_frame(reader).await?; + if raw.msg_type == MsgType::Alert { + let code = raw.payload.first().copied().unwrap_or(0); + return Err(ProtoError::Alert(code)); + } + if raw.msg_type != expected { + return Err(ProtoError::UnexpectedMsg { + expected, + got: raw.msg_type, + }); + } + Ok(raw) +} + +/// Seal `plaintext` and write it as an encrypted handshake frame of type `msg_type`. +/// +/// The AAD is the 5-byte frame header (binding type + length), matching the Data-record convention. +async fn seal_handshake_msg( + writer: &mut W, + msg_type: MsgType, + aead: &mut AeadSession, + plaintext: &[u8], +) -> Result<(), ProtoError> +where + W: tokio::io::AsyncWrite + Unpin, +{ + let sealed_len = plaintext.len() + 16; // Poly1305 tag + let header = frame::encode_header(msg_type, sealed_len)?; + let ciphertext = aead.seal(plaintext, &header); + debug_assert_eq!(ciphertext.len(), sealed_len); + frame::write_frame(writer, msg_type, &ciphertext).await +} + +/// Read an encrypted handshake frame of type `msg_type` and AEAD-open it, returning the plaintext. +async fn open_handshake_msg( + reader: &mut R, + msg_type: MsgType, + aead: &mut AeadSession, +) -> Result, ProtoError> +where + R: tokio::io::AsyncRead + Unpin, +{ + let raw = read_expect(reader, msg_type).await?; + let plaintext = aead.open(&raw.payload, &raw.header)?; + Ok(plaintext) +} + +/// Build an Auth payload: `u16_be(cert_der_len) || cert_der || signature`. +fn build_cert_and_sig(cert_der: &[u8], sig: &[u8]) -> Vec { + let mut out = Vec::with_capacity(2 + cert_der.len() + sig.len()); + out.extend_from_slice(&(cert_der.len() as u16).to_be_bytes()); + out.extend_from_slice(cert_der); + out.extend_from_slice(sig); + out +} + +/// Parse an Auth payload into `(cert_der, signature)`. +fn split_cert_and_sig(buf: &[u8]) -> Result<(&[u8], &[u8]), ProtoError> { + let len_bytes: [u8; 2] = buf + .get(..2) + .ok_or(ProtoError::MalformedHandshake("Auth: missing cert length"))? + .try_into() + .expect("2-byte length prefix"); + let cert_len = u16::from_be_bytes(len_bytes) as usize; + let cert_der = buf + .get(2..2 + cert_len) + .ok_or(ProtoError::MalformedHandshake("Auth: truncated cert"))?; + let sig = buf + .get(2 + cert_len..) + .ok_or(ProtoError::MalformedHandshake("Auth: missing signature"))?; + if sig.is_empty() { + return Err(ProtoError::MalformedHandshake("Auth: empty signature")); + } + Ok((cert_der, sig)) +} + +/// Compute the Finished MAC: `HMAC-SHA256(key, transcript)`. +fn finished_mac(key: &[u8; 32], transcript: &[u8; 32]) -> Vec { + let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts a 32-byte key"); + mac.update(transcript); + mac.finalize().into_bytes().to_vec() +} + +/// Verify a received Finished MAC in constant time. +fn verify_finished( + key: &[u8; 32], + transcript: &[u8; 32], + received: &[u8], +) -> Result<(), ProtoError> { + let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts a 32-byte key"); + mac.update(transcript); + // `verify_slice` is constant-time and also checks the length. + mac.verify_slice(received) + .map_err(|_| ProtoError::FinishedMismatch) +} + +/// Sign the transcript with a PKCS#8 PEM key (ECDSA P-256 / SHA-256, ASN.1 DER signature). +fn sign_transcript(key_pem: &str, transcript: &[u8; 32]) -> Result, ProtoError> { + let pkcs8_der = pem_key_to_der(key_pem)?; + let rng = ring::rand::SystemRandom::new(); + let key_pair = ring::signature::EcdsaKeyPair::from_pkcs8( + &ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING, + &pkcs8_der, + &rng, + ) + .map_err(|_| ProtoError::Signature("invalid PKCS#8 signing key"))?; + let sig = key_pair + .sign(&rng, transcript) + .map_err(|_| ProtoError::Signature("ECDSA signing failed"))?; + Ok(sig.as_ref().to_vec()) +} + +/// Convenience wrapper used by the client/server symmetric calls. +fn verify_signature( + cert_der: &[u8], + transcript: &[u8; 32], + signature: &[u8], +) -> Result<(), ProtoError> { + let ec_point = ec_public_key_from_cert(cert_der)?; + ring::signature::UnparsedPublicKey::new(&ring::signature::ECDSA_P256_SHA256_ASN1, ec_point) + .verify(transcript, signature) + .map_err(|_| ProtoError::Signature("handshake signature did not verify")) +} + +/// Extract the uncompressed EC public-key point (`04 || X || Y`) from a leaf certificate DER. +fn ec_public_key_from_cert(cert_der: &[u8]) -> Result, ProtoError> { + use x509_parser::prelude::FromDer; + let (_, cert) = x509_parser::certificate::X509Certificate::from_der(cert_der) + .map_err(|_| ProtoError::Signature("could not parse peer certificate DER"))?; + Ok(cert.public_key().subject_public_key.data.to_vec()) +} + +/// Decode the first `CERTIFICATE` PEM block to DER. +fn pem_cert_to_der(pem: &str) -> Result, ProtoError> { + pem_block_to_der(pem, &["CERTIFICATE"]) + .ok_or(ProtoError::Signature("no CERTIFICATE block in PEM")) +} + +/// Decode the first private-key PEM block (PKCS#8 `PRIVATE KEY`, or the EC/RSA variants) to DER. +fn pem_key_to_der(pem: &str) -> Result, ProtoError> { + pem_block_to_der(pem, &["PRIVATE KEY", "EC PRIVATE KEY", "RSA PRIVATE KEY"]) + .ok_or(ProtoError::Signature("no private-key block in PEM")) +} + +/// Decode the first PEM block whose label is one of `labels` into its DER contents. +/// +/// Uses `x509-parser`'s generic PEM container reader, so we do not need `rustls-pemfile`. +fn pem_block_to_der(pem: &str, labels: &[&str]) -> Option> { + for item in x509_parser::pem::Pem::iter_from_buffer(pem.as_bytes()) { + let item = item.ok()?; + if labels.contains(&item.label.as_str()) { + return Some(item.contents); + } + } + None +} diff --git a/crates/aura-proto/src/lib.rs b/crates/aura-proto/src/lib.rs index 56b4c0d..b97fd7a 100644 --- a/crates/aura-proto/src/lib.rs +++ b/crates/aura-proto/src/lib.rs @@ -1 +1,138 @@ -//! aura-proto — protocol wire format and handshake (skeleton; implemented in Wave 2). +//! aura-proto — the Aura VPN protocol: wire format, hybrid-KEM + mutual-X.509 handshake, and the +//! post-handshake encrypted session. +//! +//! This crate sits on top of [`aura_crypto`] (hybrid KEM, HKDF, AEAD) and [`aura_pki`] (mutual +//! X.509 verification) and implements project §6: +//! +//! * [`frame`] — the 5-byte protocol header (§6.1) and the application [`Frame`] enum (§6.3). +//! * [`handshake`] — [`client_handshake`] / [`server_handshake`], the hybrid-KEM + mutual-auth +//! state machine (§6.2). +//! * [`session`] — [`Session`], post-handshake encrypted [`Frame`] send/recv with replay +//! protection. +//! +//! ## Transport-agnostic by design +//! +//! The handshake is generic over any [`tokio::io::AsyncRead`] reader and [`tokio::io::AsyncWrite`] +//! writer, supplied as **separate** halves. That matches both an in-memory duplex pipe (tests) and +//! quinn's split `RecvStream` / `SendStream` (the Wave 3 QUIC transport), so the same code drives +//! both. See [`client_handshake`] / [`server_handshake`] and [`Session`]. +//! +//! The handshake takes the reader and writer **by value** (`reader: R, writer: W`) rather than the +//! `&mut R, &mut W` sketched in the brief: the returned [`Session`] *owns* both halves (so it can +//! keep `send_frame` / `recv_frame` going), and an owning return cannot be built from borrows +//! without a `Default` placeholder that generic `R`/`W` lack. Callers pass owned halves anyway — +//! `tokio::io::split(...)` and quinn's `(SendStream, RecvStream)` both yield owned values — so this +//! is ergonomically identical while being sound. +//! +//! ## Handshake message order (resolving the spec diagram) +//! +//! The spec diagram is ambiguous about the ordering of the encrypted auth/Finished messages. This +//! implementation fixes the exact order below and both peers follow it lock-step: +//! +//! ```text +//! 1. C->S ClientHello (plaintext) +//! 2. S->C ServerHello (plaintext) -- both derive SessionKeys here -- +//! 3. S->C ServerAuth (encrypted) +//! 4. C->S ClientAuth (encrypted) +//! 5. C->S Finished (encrypted) +//! 6. S->C Finished (encrypted) -- encrypted Data channel open both ways -- +//! ``` + +#![forbid(unsafe_code)] +#![warn(missing_docs)] + +pub mod frame; +pub mod handshake; +pub mod session; + +pub use frame::{Frame, MsgType}; +pub use handshake::{client_handshake, server_handshake}; +pub use session::Session; + +use thiserror::Error; + +/// Errors produced by the Aura protocol layer. +#[derive(Debug, Error)] +pub enum ProtoError { + /// An I/O error on the underlying transport (read/write/EOF). + #[error("transport I/O error: {0}")] + Io(#[from] std::io::Error), + + /// An error from the cryptographic core (KEM decapsulation, AEAD open). + #[error("crypto error: {0}")] + Crypto(#[from] aura_crypto::CryptoError), + + /// An error from the PKI layer (certificate chain or name verification). + #[error("pki error: {0}")] + Pki(#[from] aura_pki::PkiError), + + /// The header carried an unknown message-type byte. + #[error("unknown message type byte: 0x{0:02x}")] + UnknownMsgType(u8), + + /// The header carried an unsupported protocol version. + #[error("unsupported protocol version: 0x{0:02x}")] + BadVersion(u8), + + /// A payload exceeded the u24 length field of the frame header. + #[error("frame payload too large: {0} bytes")] + FrameTooLarge(usize), + + /// A received frame had the wrong type for the current handshake/session step. + #[error("unexpected message type: expected {expected:?}, got {got:?}")] + UnexpectedMsg { + /// The message type the state machine required. + expected: MsgType, + /// The message type actually received. + got: MsgType, + }, + + /// A handshake message had an invalid length or internal structure. + #[error("malformed handshake message: {0}")] + MalformedHandshake(&'static str), + + /// An application [`Frame`] could not be decoded. + #[error("malformed frame: {0}")] + MalformedFrame(&'static str), + + /// A digital signature (over the handshake hash) failed to verify, or a key/cert could not be + /// parsed for signing or verification. + #[error("signature verification failed: {0}")] + Signature(&'static str), + + /// The peer's Finished HMAC did not match the locally computed value. + #[error("handshake Finished verification failed (HMAC mismatch)")] + FinishedMismatch, + + /// A Data record was rejected by the replay-protection window (duplicate or too old). + #[error("replay detected: record counter {0} is a duplicate or outside the window")] + Replay(u64), + + /// The peer sent a fatal Alert. + #[error("peer sent fatal alert (code {0})")] + Alert(u8), +} + +/// Client-side handshake configuration. +#[derive(Clone, Debug)] +pub struct ClientConfig { + /// PEM of the Aura CA certificate, used as the trust anchor for the server cert. + pub ca_cert_pem: String, + /// PEM of this client's leaf certificate (sent to the server during ClientAuth). + pub client_cert_pem: String, + /// PEM of this client's PKCS#8 private key (used to sign the handshake hash). + pub client_key_pem: String, + /// The DNS name the client expects in the server certificate's SAN. + pub server_name: String, +} + +/// Server-side handshake configuration. +#[derive(Clone, Debug)] +pub struct ServerConfig { + /// PEM of the Aura CA certificate, used as the trust anchor for client certs. + pub ca_cert_pem: String, + /// PEM of this server's leaf certificate (sent to the client during ServerAuth). + pub server_cert_pem: String, + /// PEM of this server's PKCS#8 private key (used to sign the handshake hash). + pub server_key_pem: String, +} diff --git a/crates/aura-proto/src/session.rs b/crates/aura-proto/src/session.rs new file mode 100644 index 0000000..5d44286 --- /dev/null +++ b/crates/aura-proto/src/session.rs @@ -0,0 +1,294 @@ +//! Post-handshake encrypted session: AEAD-protected [`Frame`] exchange with replay protection. +//! +//! A [`Session`] owns the transport reader + writer and the two directional [`AeadSession`]s +//! produced by the handshake. It exposes [`Session::send_frame`] / [`Session::recv_frame`], which +//! serialize a [`Frame`], AEAD-seal/open it, and ship it inside a [`MsgType::Data`] record framed +//! by the 5-byte protocol header. +//! +//! ## Record format and replay protection +//! +//! Each Data record's payload is: +//! +//! ```text +//! seq (u64, big-endian) || AEAD_seal(frame_bytes, aad = header || seq) +//! ``` +//! +//! The 8-byte `seq` is the AEAD nonce counter for that record. Because `aura_crypto::AeadSession` +//! advances its internal counter in lock-step on the sealing and opening sides, the transmitted +//! `seq` always equals the receiver's expected AEAD counter on the happy path. Carrying it +//! explicitly lets the receiver run a **sliding-window** replay check (window size +//! [`REPLAY_WINDOW`]) *before* touching the AEAD: a duplicate or too-old `seq` is rejected with +//! [`ProtoError::Replay`] without disturbing the AEAD's counter, so the session stays usable. +//! +//! The `seq` is also folded into the AEAD AAD (alongside the frame header), cryptographically +//! binding the record to its claimed position. + +use aura_crypto::AeadSession; + +use crate::frame::{self, Frame, MsgType, RawFrame, HEADER_LEN}; +use crate::ProtoError; + +/// Width of the sliding replay window, in records. +pub const REPLAY_WINDOW: u64 = 64; + +/// Width in bytes of the per-record sequence-number prefix. +const SEQ_LEN: usize = 8; + +/// Sliding-window replay detector over a monotonically increasing 64-bit counter. +/// +/// Tracks the highest accepted sequence number and a bitmap of the [`REPLAY_WINDOW`] positions +/// below it. A sequence number is accepted iff it is strictly newer than everything seen, or falls +/// within the window and has not been seen before. +#[derive(Debug)] +struct ReplayWindow { + /// Highest sequence number accepted so far. + highest: u64, + /// Bitmap of accepted positions in `(highest - REPLAY_WINDOW, highest]`; bit `i` set means + /// `highest - 1 - i` was accepted. + bitmap: u64, + /// Whether any record has been accepted yet (so seq 0 can itself be accepted exactly once). + seeded: bool, +} + +impl ReplayWindow { + /// Create a window primed so the first expected record is `start` (the AEAD counter at the end + /// of the handshake). Anything strictly below `start` is treated as already-consumed. + fn new(start: u64) -> Self { + // Treat (start - 1) as the highest already-accepted slot so the first Data record at `seq + // == start` is accepted as "newer". `seeded = true` keeps seq 0 handling uniform even when + // start == 0 (the handshake always leaves start >= 1, but be defensive). + Self { + highest: start.saturating_sub(1), + bitmap: 0, + seeded: start > 0, + } + } + + /// Check and record `seq`. Returns `Ok(())` if it is fresh; [`ProtoError::Replay`] otherwise. + fn check_and_set(&mut self, seq: u64) -> Result<(), ProtoError> { + if !self.seeded { + // First ever record (only reachable if started at 0): accept and seed. + self.seeded = true; + self.highest = seq; + self.bitmap = 0; + return Ok(()); + } + + if seq > self.highest { + // New high-water mark: shift the bitmap up by the gap, marking the old highest as seen. + let shift = seq - self.highest; + if shift >= 64 { + self.bitmap = 0; + } else { + self.bitmap = (self.bitmap << shift) | (1u64 << (shift - 1)); + } + self.highest = seq; + Ok(()) + } else { + // seq <= highest: must be inside the window and previously unseen. + let offset = self.highest - seq; // 0 == highest itself (already accepted) + if offset >= REPLAY_WINDOW { + return Err(ProtoError::Replay(seq)); + } + if offset == 0 { + // Exactly the current highest: already accepted. + return Err(ProtoError::Replay(seq)); + } + let mask = 1u64 << (offset - 1); + if self.bitmap & mask != 0 { + return Err(ProtoError::Replay(seq)); + } + self.bitmap |= mask; + Ok(()) + } + } +} + +/// An established, encrypted Aura session over a transport reader `R` and writer `W`. +/// +/// Created by [`crate::client_handshake`] / [`crate::server_handshake`]. Use +/// [`Session::send_frame`] and [`Session::recv_frame`] for application traffic. +pub struct Session { + reader: R, + writer: W, + /// AEAD this endpoint seals outgoing Data with. + send_aead: AeadSession, + /// AEAD this endpoint opens incoming Data with. + recv_aead: AeadSession, + /// Next sequence number to stamp on an outgoing Data record (mirrors `send_aead`'s counter). + send_seq: u64, + /// Replay window over incoming Data sequence numbers. + replay: ReplayWindow, + /// The verified identity (Common Name) of the peer, learned during the handshake. + peer_id: Option, +} + +impl Session +where + R: tokio::io::AsyncRead + Unpin, + W: tokio::io::AsyncWrite + Unpin, +{ + /// Assemble a session from the handshake outputs. + /// + /// `start_counter` is the AEAD nonce counter both directions have reached after the encrypted + /// handshake messages (so the first Data record stamps `seq == start_counter`). `peer_id` is + /// the verified peer Common Name (the server learns the client id; the client may store the + /// server name). + pub(crate) fn new( + reader: R, + writer: W, + send_aead: AeadSession, + recv_aead: AeadSession, + start_counter: u64, + peer_id: Option, + ) -> Self { + Self { + reader, + writer, + send_aead, + recv_aead, + send_seq: start_counter, + replay: ReplayWindow::new(start_counter), + peer_id, + } + } + + /// The verified identity (Common Name) of the peer, if one was captured during the handshake. + #[must_use] + pub fn peer_id(&self) -> Option<&str> { + self.peer_id.as_deref() + } + + /// Serialize, seal, and send a single application [`Frame`]. + /// + /// # Errors + /// Returns [`ProtoError::Io`] on a write failure or [`ProtoError::FrameTooLarge`] if the sealed + /// record exceeds the frame header's length field. + pub async fn send_frame(&mut self, frame: Frame) -> Result<(), ProtoError> { + let frame_bytes = frame.encode(); + let seq = self.send_seq; + + // Build the record payload: seq(8) || ciphertext. The frame header binds (type, length), + // and the seq binds the record position; both go into the AAD. + let sealed_len = SEQ_LEN + frame_bytes.len() + 16 /* Poly1305 tag */; + let header = frame::encode_header(MsgType::Data, sealed_len)?; + + let mut aad = [0u8; HEADER_LEN + SEQ_LEN]; + aad[..HEADER_LEN].copy_from_slice(&header); + aad[HEADER_LEN..].copy_from_slice(&seq.to_be_bytes()); + + let ciphertext = self.send_aead.seal(&frame_bytes, &aad); + + let mut payload = Vec::with_capacity(SEQ_LEN + ciphertext.len()); + payload.extend_from_slice(&seq.to_be_bytes()); + payload.extend_from_slice(&ciphertext); + debug_assert_eq!(payload.len(), sealed_len); + + // Write header + payload. (We already know the exact header for AAD; reuse it.) + use tokio::io::AsyncWriteExt; + self.writer.write_all(&header).await?; + self.writer.write_all(&payload).await?; + self.writer.flush().await?; + + self.send_seq = self.send_seq.wrapping_add(1); + Ok(()) + } + + /// Receive, open, and decode a single application [`Frame`]. + /// + /// # Errors + /// * [`ProtoError::Replay`] — the record's sequence number was a duplicate or too old. + /// * [`ProtoError::Crypto`] — AEAD authentication failed. + /// * [`ProtoError::UnexpectedMsg`] — a non-Data record arrived. + /// * [`ProtoError::Alert`] — the peer sent a fatal alert. + /// * [`ProtoError::Io`] / [`ProtoError::MalformedFrame`] on transport / decode failure. + pub async fn recv_frame(&mut self) -> Result { + let raw: RawFrame = frame::read_frame(&mut self.reader).await?; + match raw.msg_type { + MsgType::Data => {} + MsgType::Alert => { + let code = raw.payload.first().copied().unwrap_or(0); + return Err(ProtoError::Alert(code)); + } + other => { + return Err(ProtoError::UnexpectedMsg { + expected: MsgType::Data, + got: other, + }) + } + } + + if raw.payload.len() < SEQ_LEN { + return Err(ProtoError::MalformedFrame( + "Data record shorter than seq prefix", + )); + } + let mut seq_bytes = [0u8; SEQ_LEN]; + seq_bytes.copy_from_slice(&raw.payload[..SEQ_LEN]); + let seq = u64::from_be_bytes(seq_bytes); + let ciphertext = &raw.payload[SEQ_LEN..]; + + // Replay check FIRST — a duplicate/old record must not advance the AEAD counter. + self.replay.check_and_set(seq)?; + + let mut aad = [0u8; HEADER_LEN + SEQ_LEN]; + aad[..HEADER_LEN].copy_from_slice(&raw.header); + aad[HEADER_LEN..].copy_from_slice(&seq_bytes); + + let plaintext = self.recv_aead.open(ciphertext, &aad)?; + Frame::decode(&plaintext) + } + + /// Consume the session, returning its transport halves (for clean shutdown / reuse). + #[must_use] + pub fn into_inner(self) -> (R, W) { + (self.reader, self.writer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn replay_window_basic_monotonic() { + let mut w = ReplayWindow::new(2); + // First legitimate Data records. + assert!(w.check_and_set(2).is_ok()); + assert!(w.check_and_set(3).is_ok()); + assert!(w.check_and_set(4).is_ok()); + // Replays of each are rejected. + assert!(matches!(w.check_and_set(2), Err(ProtoError::Replay(2)))); + assert!(matches!(w.check_and_set(3), Err(ProtoError::Replay(3)))); + assert!(matches!(w.check_and_set(4), Err(ProtoError::Replay(4)))); + } + + #[test] + fn replay_window_out_of_order_within_window() { + let mut w = ReplayWindow::new(0); + assert!(w.check_and_set(0).is_ok()); + assert!(w.check_and_set(10).is_ok()); + // Older-but-unseen inside the window is accepted exactly once. + assert!(w.check_and_set(5).is_ok()); + assert!(matches!(w.check_and_set(5), Err(ProtoError::Replay(5)))); + // The new-high record (10) is a replay. + assert!(matches!(w.check_and_set(10), Err(ProtoError::Replay(10)))); + // Brand-new high still works. + assert!(w.check_and_set(11).is_ok()); + } + + #[test] + fn replay_window_rejects_too_old() { + let mut w = ReplayWindow::new(0); + assert!(w.check_and_set(0).is_ok()); + assert!(w.check_and_set(200).is_ok()); + // Far below the window edge => rejected as too old. + assert!(matches!(w.check_and_set(1), Err(ProtoError::Replay(1)))); + assert!(matches!( + w.check_and_set(200 - REPLAY_WINDOW), + Err(ProtoError::Replay(_)) + )); + // Just inside the window edge => still acceptable. + assert!(w.check_and_set(200 - REPLAY_WINDOW + 1).is_ok()); + } +} diff --git a/crates/aura-proto/tests/common/mod.rs b/crates/aura-proto/tests/common/mod.rs new file mode 100644 index 0000000..c6cf12f --- /dev/null +++ b/crates/aura-proto/tests/common/mod.rs @@ -0,0 +1,56 @@ +//! Shared test helpers: minting an Aura CA + leaf certs, and wiring an in-memory duplex transport. + +#![allow(dead_code)] // each integration test binary uses a different subset of these helpers + +use aura_pki::AuraCa; +use aura_proto::{ClientConfig, ServerConfig}; + +/// A minted PKI fixture: a CA, a server cert/key, and a client cert/key. +pub struct Pki { + pub ca_cert_pem: String, + pub server_cert_pem: String, + pub server_key_pem: String, + pub client_cert_pem: String, + pub client_key_pem: String, + pub server_name: String, + pub client_id: String, +} + +/// Mint a CA plus a server cert (for `server_name`) and a client cert (CN = `client_id`). +pub fn mint_pki(server_name: &str, client_id: &str) -> Pki { + let ca = AuraCa::generate("Aura Test Root CA").expect("generate CA"); + let server = ca + .issue_server_cert(server_name) + .expect("issue server cert"); + let client = ca.issue_client_cert(client_id).expect("issue client cert"); + Pki { + ca_cert_pem: ca.ca_cert_pem(), + server_cert_pem: server.cert_pem, + server_key_pem: server.key_pem, + client_cert_pem: client.cert_pem, + client_key_pem: client.key_pem, + server_name: server_name.to_string(), + client_id: client_id.to_string(), + } +} + +impl Pki { + /// Build a matching [`ClientConfig`] from this fixture. + pub fn client_config(&self) -> ClientConfig { + ClientConfig { + ca_cert_pem: self.ca_cert_pem.clone(), + client_cert_pem: self.client_cert_pem.clone(), + client_key_pem: self.client_key_pem.clone(), + server_name: self.server_name.clone(), + } + } + + /// Build a matching [`ServerConfig`] from this fixture. + pub fn server_config(&self) -> ServerConfig { + ServerConfig { + ca_cert_pem: self.ca_cert_pem.clone(), + server_cert_pem: self.server_cert_pem.clone(), + server_key_pem: self.server_key_pem.clone(), + } + } +} diff --git a/crates/aura-proto/tests/data_exchange.rs b/crates/aura-proto/tests/data_exchange.rs new file mode 100644 index 0000000..f7fd557 --- /dev/null +++ b/crates/aura-proto/tests/data_exchange.rs @@ -0,0 +1,136 @@ +//! `test_data_exchange_1000pkts` — after the handshake, exchange 1000 Data frames in each +//! direction and assert payload integrity and ordering. + +mod common; + +use aura_proto::{client_handshake, server_handshake, Frame}; +use bytes::Bytes; +use tokio::io::split; + +const N: u32 = 1000; + +/// Build the deterministic payload for frame `i` from `who`. +fn payload_for(who: &str, i: u32) -> Bytes { + Bytes::from(format!( + "{who}-packet-{i}-{}", + "x".repeat((i % 37) as usize) + )) +} + +#[tokio::test] +async fn test_data_exchange_1000pkts() { + let pki = common::mint_pki("vpn.aura.example", "client-alpha"); + let client_cfg = pki.client_config(); + let server_cfg = pki.server_config(); + + let (client_end, server_end) = tokio::io::duplex(64 * 1024); + let (c_read, c_write) = split(client_end); + let (s_read, s_write) = split(server_end); + + let client = tokio::spawn(async move { + let mut sess = client_handshake(c_read, c_write, &client_cfg) + .await + .expect("client handshake"); + // Interleave send + recv in lockstep to avoid filling the duplex buffer. + for i in 0..N { + sess.send_frame(Frame::Data { + stream_id: 1, + payload: payload_for("client", i), + }) + .await + .expect("client send"); + + match sess.recv_frame().await.expect("client recv") { + Frame::Data { stream_id, payload } => { + assert_eq!(stream_id, 2, "wrong stream id at i={i}"); + assert_eq!( + payload, + payload_for("server", i), + "payload mismatch at i={i}" + ); + } + other => panic!("client expected Data, got {other:?}"), + } + } + }); + + let server = tokio::spawn(async move { + let mut sess = server_handshake(s_read, s_write, &server_cfg) + .await + .expect("server handshake"); + for i in 0..N { + // Receive the client's i-th packet first, then reply, mirroring the client's lockstep. + match sess.recv_frame().await.expect("server recv") { + Frame::Data { stream_id, payload } => { + assert_eq!(stream_id, 1, "wrong stream id at i={i}"); + assert_eq!( + payload, + payload_for("client", i), + "payload mismatch at i={i}" + ); + } + other => panic!("server expected Data, got {other:?}"), + } + + sess.send_frame(Frame::Data { + stream_id: 2, + payload: payload_for("server", i), + }) + .await + .expect("server send"); + } + }); + + let (c, s) = tokio::join!(client, server); + c.expect("client task"); + s.expect("server task"); +} + +#[tokio::test] +async fn ping_pong_and_close_frames_roundtrip() { + let pki = common::mint_pki("vpn.aura.example", "c1"); + let client_cfg = pki.client_config(); + let server_cfg = pki.server_config(); + + let (client_end, server_end) = tokio::io::duplex(64 * 1024); + let (c_read, c_write) = split(client_end); + let (s_read, s_write) = split(server_end); + + let client = tokio::spawn(async move { + let mut sess = client_handshake(c_read, c_write, &client_cfg) + .await + .unwrap(); + sess.send_frame(Frame::Ping { seq: 7 }).await.unwrap(); + match sess.recv_frame().await.unwrap() { + Frame::Pong { seq } => assert_eq!(seq, 7), + other => panic!("expected Pong, got {other:?}"), + } + sess.send_frame(Frame::Close { + code: 0, + reason: "bye".into(), + }) + .await + .unwrap(); + }); + + let server = tokio::spawn(async move { + let mut sess = server_handshake(s_read, s_write, &server_cfg) + .await + .unwrap(); + match sess.recv_frame().await.unwrap() { + Frame::Ping { seq } => sess.send_frame(Frame::Pong { seq }).await.unwrap(), + other => panic!("expected Ping, got {other:?}"), + } + match sess.recv_frame().await.unwrap() { + Frame::Close { code, reason } => { + assert_eq!(code, 0); + assert_eq!(reason, "bye"); + } + other => panic!("expected Close, got {other:?}"), + } + }); + + let (c, s) = tokio::join!(client, server); + c.unwrap(); + s.unwrap(); +} diff --git a/crates/aura-proto/tests/handshake_loopback.rs b/crates/aura-proto/tests/handshake_loopback.rs new file mode 100644 index 0000000..e0ccfd5 --- /dev/null +++ b/crates/aura-proto/tests/handshake_loopback.rs @@ -0,0 +1,43 @@ +//! `test_full_handshake_loopback` — a full client+server handshake over an in-memory duplex. + +mod common; + +use aura_proto::{client_handshake, server_handshake}; +use tokio::io::split; + +#[tokio::test] +async fn test_full_handshake_loopback() { + let pki = common::mint_pki("vpn.aura.example", "client-alpha"); + let client_cfg = pki.client_config(); + let server_cfg = pki.server_config(); + + // Connected in-memory transport; split each end into independent read/write halves so the + // handshake can use separate reader + writer (matching quinn's split streams). + let (client_end, server_end) = tokio::io::duplex(64 * 1024); + let (c_read, c_write) = split(client_end); + let (s_read, s_write) = split(server_end); + + let client = tokio::spawn(async move { + client_handshake(c_read, c_write, &client_cfg) + .await + .map(|s| s.peer_id().map(str::to_string)) + }); + let server = tokio::spawn(async move { + server_handshake(s_read, s_write, &server_cfg) + .await + .map(|s| s.peer_id().map(str::to_string)) + }); + + let (client_res, server_res) = tokio::join!(client, server); + let client_peer = client_res + .expect("client task") + .expect("client handshake ok"); + let server_peer = server_res + .expect("server task") + .expect("server handshake ok"); + + // Server learned the client id from the verified client certificate. + assert_eq!(server_peer.as_deref(), Some("client-alpha")); + // Client recorded the server name it authenticated. + assert_eq!(client_peer.as_deref(), Some("vpn.aura.example")); +} diff --git a/crates/aura-proto/tests/pki_mutual_auth.rs b/crates/aura-proto/tests/pki_mutual_auth.rs new file mode 100644 index 0000000..d3d59da --- /dev/null +++ b/crates/aura-proto/tests/pki_mutual_auth.rs @@ -0,0 +1,86 @@ +//! `test_pki_mutual_auth` — the server must reject a client whose certificate was issued by a +//! different CA, and must reject a client that presents a valid certificate but a forged signature +//! (one made with a key that does not match the certificate). + +mod common; + +use aura_pki::AuraCa; +use aura_proto::{client_handshake, server_handshake, ClientConfig, ProtoError}; +use tokio::io::split; + +/// Run a handshake and return both sides' results. +async fn run( + client_cfg: ClientConfig, + server_cfg: aura_proto::ServerConfig, +) -> (Result<(), ProtoError>, Result, ProtoError>) { + let (client_end, server_end) = tokio::io::duplex(64 * 1024); + let (c_read, c_write) = split(client_end); + let (s_read, s_write) = split(server_end); + + let client = tokio::spawn(async move { + client_handshake(c_read, c_write, &client_cfg) + .await + .map(|_| ()) + }); + let server = tokio::spawn(async move { + server_handshake(s_read, s_write, &server_cfg) + .await + .map(|s| s.peer_id().map(str::to_string)) + }); + let (c, s) = tokio::join!(client, server); + (c.expect("client task"), s.expect("server task")) +} + +#[tokio::test] +async fn wrong_ca_client_cert_is_rejected() { + // The legitimate server-side PKI. + let pki = common::mint_pki("vpn.aura.example", "client-alpha"); + + // An attacker CA issues a client cert with a plausible CN, but it does NOT chain to the + // server's trusted CA. + let rogue_ca = AuraCa::generate("Rogue CA").expect("rogue CA"); + let rogue_client = rogue_ca + .issue_client_cert("client-alpha") + .expect("rogue client cert"); + + let client_cfg = ClientConfig { + ca_cert_pem: pki.ca_cert_pem.clone(), + client_cert_pem: rogue_client.cert_pem, + client_key_pem: rogue_client.key_pem, + server_name: pki.server_name.clone(), + }; + + let (_client_res, server_res) = run(client_cfg, pki.server_config()).await; + // The server must fail verifying the client chain against its trusted CA. + assert!( + matches!(server_res, Err(ProtoError::Pki(_))), + "expected a PKI verification failure, got {server_res:?}" + ); +} + +#[tokio::test] +async fn forged_client_signature_is_rejected() { + let pki = common::mint_pki("vpn.aura.example", "client-alpha"); + + // Mint an unrelated P-256 keypair (via a throwaway issued cert) to use as the WRONG signing + // key. We pair the legitimate client's certificate with this mismatched private key: the chain + // verifies fine, but the signature over the transcript is made with a key that does not match + // the certificate's public key, so signature verification must fail. + let throwaway_ca = AuraCa::generate("throwaway").expect("throwaway CA"); + let mismatched = throwaway_ca + .issue_client_cert("mismatched") + .expect("throwaway cert"); + + let client_cfg = ClientConfig { + ca_cert_pem: pki.ca_cert_pem.clone(), + client_cert_pem: pki.client_cert_pem.clone(), // valid cert (chains to trusted CA) + client_key_pem: mismatched.key_pem, // WRONG key -> forged signature + server_name: pki.server_name.clone(), + }; + + let (_client_res, server_res) = run(client_cfg, pki.server_config()).await; + assert!( + matches!(server_res, Err(ProtoError::Signature(_))), + "expected a signature verification failure, got {server_res:?}" + ); +} diff --git a/crates/aura-proto/tests/replay_protection.rs b/crates/aura-proto/tests/replay_protection.rs new file mode 100644 index 0000000..d4f2d38 --- /dev/null +++ b/crates/aura-proto/tests/replay_protection.rs @@ -0,0 +1,122 @@ +//! `test_replay_protection` — a Data record that was already delivered, replayed verbatim, must be +//! rejected by the receiver's sliding replay window. +//! +//! Topology: the client and server each talk to one end of their own duplex. A relay task in the +//! middle forwards bytes between the two. For the client->server direction the relay parses whole +//! frames (using the crate's public framing helpers) so that, once the client has sent its data +//! packet, the relay can forward that exact record to the server a SECOND time — a verbatim replay. + +mod common; + +use aura_proto::frame::{read_frame, write_frame, MsgType, RawFrame}; +use aura_proto::{client_handshake, server_handshake, Frame, ProtoError}; +use bytes::Bytes; +use tokio::io::{split, AsyncWriteExt}; +use tokio::sync::oneshot; + +#[tokio::test] +async fn test_replay_protection() { + let pki = common::mint_pki("vpn.aura.example", "client-alpha"); + let client_cfg = pki.client_config(); + let server_cfg = pki.server_config(); + + // Two duplexes with a relay in the middle. + let (client_io, relay_a) = tokio::io::duplex(64 * 1024); + let (relay_b, server_io) = tokio::io::duplex(64 * 1024); + + let (c_read, c_write) = split(client_io); + let (s_read, s_write) = split(server_io); + let (ra_read, ra_write) = split(relay_a); // faces the client + let (rb_read, rb_write) = split(relay_b); // faces the server + + // Signal so the relay forwards the replay only after the server has consumed the genuine copy. + let (genuine_done_tx, genuine_done_rx) = oneshot::channel::<()>(); + + // ---- Relay: client -> server, with a one-shot verbatim replay of the first Data record ---- + let relay_c2s = tokio::spawn(async move { + let mut ra_read = ra_read; + let mut rb_write = rb_write; + let mut genuine_done = Some(genuine_done_rx); + let mut replayed = false; + loop { + let frame: RawFrame = match read_frame(&mut ra_read).await { + Ok(f) => f, + Err(_) => break, // EOF when the client side closes + }; + // Forward the frame unchanged. + write_frame(&mut rb_write, frame.msg_type, &frame.payload) + .await + .expect("relay forward c->s"); + + // On the first Data record, wait until the server has accepted it, then replay it once. + if frame.msg_type == MsgType::Data && !replayed { + replayed = true; + if let Some(rx) = genuine_done.take() { + let _ = rx.await; // server signals after it accepted the genuine record + } + write_frame(&mut rb_write, frame.msg_type, &frame.payload) + .await + .expect("relay replay c->s"); + rb_write.flush().await.expect("flush replay"); + } + } + }); + + // ---- Relay: server -> client (straight byte copy) ---- + let relay_s2c = tokio::spawn(async move { + let mut rb_read = rb_read; + let mut ra_write = ra_write; + let _ = tokio::io::copy(&mut rb_read, &mut ra_write).await; + let _ = ra_write.shutdown().await; + }); + + // ---- Client: handshake, then send exactly one Data frame ---- + let client = tokio::spawn(async move { + let mut sess = client_handshake(c_read, c_write, &client_cfg) + .await + .expect("client handshake"); + sess.send_frame(Frame::Data { + stream_id: 9, + payload: Bytes::from_static(b"the one and only payload"), + }) + .await + .expect("client send"); + // Keep the session (and thus the transport) alive until the test signals completion. + sess + }); + + // ---- Server: handshake, accept the genuine record, then expect the replay to be rejected ---- + let server = tokio::spawn(async move { + let mut sess = server_handshake(s_read, s_write, &server_cfg) + .await + .expect("server handshake"); + + // 1) Genuine record is accepted. + let first = sess.recv_frame().await.expect("genuine recv"); + match first { + Frame::Data { stream_id, payload } => { + assert_eq!(stream_id, 9); + assert_eq!(&payload[..], b"the one and only payload"); + } + other => panic!("expected Data, got {other:?}"), + } + + // Tell the relay it may now inject the verbatim replay. + genuine_done_tx.send(()).expect("signal genuine done"); + + // 2) The replayed record must be rejected by the replay window. + let replay_result = sess.recv_frame().await; + assert!( + matches!(replay_result, Err(ProtoError::Replay(_))), + "expected ProtoError::Replay, got {replay_result:?}" + ); + sess + }); + + let (_client_sess, server_outcome) = tokio::join!(client, server); + server_outcome.expect("server task"); + drop(_client_sess); // closes the client side -> relays drain and exit + + let _ = relay_c2s.await; + let _ = relay_s2c.await; +}