feat(proto): implement Wave 2 — hybrid PKI handshake + session

aura-proto: 5-byte wire header + Frame codec (§6.1/§6.3); transport-agnostic
handshake state machine (§6.2) over split tokio AsyncRead/AsyncWrite —
hybrid X25519+ML-KEM-768 KEM, SHA-256 transcript, mutual X.509 auth with
ECDSA-P256 transcript signatures (ring), constant-time HMAC Finished;
Session with sliding-window replay protection. 13 tests green, clippy clean.

Handshake message order pinned (resolves spec diagram ambiguity); reader/writer
taken by value since Session owns both halves.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xah30
2026-05-25 18:05:11 +03:00
parent b8ce58ddf0
commit bb835e4ca7
11 changed files with 1710 additions and 1 deletions
Generated
+2
View File
@@ -248,11 +248,13 @@ dependencies = [
"bytes", "bytes",
"hmac", "hmac",
"rand 0.8.6", "rand 0.8.6",
"ring",
"rustls-pki-types", "rustls-pki-types",
"serde", "serde",
"sha2", "sha2",
"thiserror 1.0.69", "thiserror 1.0.69",
"tokio", "tokio",
"x509-parser 0.16.0",
"zeroize", "zeroize",
] ]
+9
View File
@@ -17,6 +17,15 @@ sha2.workspace = true
rand.workspace = true rand.workspace = true
rustls-pki-types.workspace = true rustls-pki-types.workspace = true
thiserror.workspace = true thiserror.workspace = true
# Handshake signatures (ECDSA P-256 / SHA-256, ASN.1 DER). Already in the workspace lockfile.
ring = "0.17"
# Parse leaf cert DER (extract the EC SubjectPublicKeyInfo point) and decode PEM blocks
# (certificates + PKCS#8 keys) to DER. Already a workspace dependency and used by aura-pki, so
# this adds no new resolution and lets us avoid pulling in rustls-pemfile.
x509-parser.workspace = true
# The handshake and session are async over tokio::io::{AsyncRead, AsyncWrite}, so tokio must be a
# normal dependency (available via the workspace `full` feature set), not only a dev-dependency.
tokio.workspace = true
[dev-dependencies] [dev-dependencies]
tokio.workspace = true tokio.workspace = true
+371
View File
@@ -0,0 +1,371 @@
//! Wire format: the 5-byte protocol header (§6.1) and the application [`Frame`] enum (§6.3).
//!
//! Every Aura protocol message on the wire is a 5-byte header followed by a payload:
//!
//! ```text
//! byte 0 : msg_type (u8)
//! bytes 1..4 : length (u24, big-endian) = payload length in bytes
//! byte 4 : version = 0x01
//! bytes 5.. : payload (length bytes)
//! ```
//!
//! [`Frame`] is the post-handshake application payload. Each `Frame` is serialized with
//! [`Frame::encode`], AEAD-sealed, and shipped inside a [`MsgType::Data`] record (see
//! [`crate::session`]).
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::ProtoError;
/// Length in bytes of the protocol frame header.
pub const HEADER_LEN: usize = 5;
/// Protocol version carried in byte 4 of every header.
pub const PROTOCOL_VERSION: u8 = 0x01;
/// Largest payload expressible by the u24 length field.
pub const MAX_PAYLOAD_LEN: usize = 0x00FF_FFFF;
/// Message types carried in byte 0 of the header (§6.1).
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum MsgType {
/// Handshake message 1 (C->S): hybrid public key + client nonce.
ClientHello = 0x01,
/// Handshake message 2 (S->C): hybrid ciphertext + server nonce.
ServerHello = 0x02,
/// Handshake message 4 (C->S, encrypted): client cert + signature.
ClientAuth = 0x03,
/// Handshake message 3 (S->C, encrypted): server cert + signature.
ServerAuth = 0x04,
/// Handshake Finished (encrypted): HMAC over the handshake hash.
Finished = 0x05,
/// Application data (encrypted): an AEAD-sealed [`Frame`].
Data = 0x06,
/// Fatal alert / error notification.
Alert = 0xFF,
}
impl MsgType {
/// Map the on-wire byte to a [`MsgType`].
///
/// # Errors
/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized byte.
pub fn from_u8(b: u8) -> Result<Self, ProtoError> {
Ok(match b {
0x01 => Self::ClientHello,
0x02 => Self::ServerHello,
0x03 => Self::ClientAuth,
0x04 => Self::ServerAuth,
0x05 => Self::Finished,
0x06 => Self::Data,
0xFF => Self::Alert,
other => return Err(ProtoError::UnknownMsgType(other)),
})
}
}
/// Build a 5-byte header for `msg_type` carrying a payload of `payload_len` bytes.
///
/// # Errors
/// Returns [`ProtoError::FrameTooLarge`] if `payload_len` does not fit in the u24 length field.
pub fn encode_header(
msg_type: MsgType,
payload_len: usize,
) -> Result<[u8; HEADER_LEN], ProtoError> {
if payload_len > MAX_PAYLOAD_LEN {
return Err(ProtoError::FrameTooLarge(payload_len));
}
let len = payload_len as u32;
Ok([
msg_type as u8,
((len >> 16) & 0xFF) as u8,
((len >> 8) & 0xFF) as u8,
(len & 0xFF) as u8,
PROTOCOL_VERSION,
])
}
/// Parse a 5-byte header into `(msg_type, payload_len)`.
///
/// # Errors
/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized type byte or
/// [`ProtoError::BadVersion`] if byte 4 is not [`PROTOCOL_VERSION`].
pub fn decode_header(header: &[u8; HEADER_LEN]) -> Result<(MsgType, usize), ProtoError> {
let msg_type = MsgType::from_u8(header[0])?;
let version = header[4];
if version != PROTOCOL_VERSION {
return Err(ProtoError::BadVersion(version));
}
let len = ((header[1] as usize) << 16) | ((header[2] as usize) << 8) | (header[3] as usize);
Ok((msg_type, len))
}
/// Write one full frame (`header || payload`) to `writer`.
///
/// # Errors
/// Returns [`ProtoError::FrameTooLarge`] if the payload is too long, or [`ProtoError::Io`] on a
/// write failure.
pub async fn write_frame<W>(
writer: &mut W,
msg_type: MsgType,
payload: &[u8],
) -> Result<(), ProtoError>
where
W: AsyncWrite + Unpin,
{
let header = encode_header(msg_type, payload.len())?;
writer.write_all(&header).await?;
writer.write_all(payload).await?;
writer.flush().await?;
Ok(())
}
/// A frame read off the wire: its type, the raw header bytes (useful as AEAD AAD and for the
/// handshake transcript hash), and the payload.
#[derive(Clone, Debug)]
pub struct RawFrame {
/// The decoded message type.
pub msg_type: MsgType,
/// The 5 header bytes exactly as transmitted.
pub header: [u8; HEADER_LEN],
/// The payload bytes.
pub payload: Vec<u8>,
}
impl RawFrame {
/// The full serialized frame (`header || payload`) exactly as it appeared on the wire.
///
/// Used to feed the handshake transcript hash, which must hash the bytes as transmitted.
#[must_use]
pub fn wire_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(HEADER_LEN + self.payload.len());
out.extend_from_slice(&self.header);
out.extend_from_slice(&self.payload);
out
}
}
/// Read one full frame (`header || payload`) from `reader`.
///
/// # Errors
/// Returns [`ProtoError::Io`] on a read failure (including a truncated frame / EOF), or a header
/// decode error from [`decode_header`].
pub async fn read_frame<R>(reader: &mut R) -> Result<RawFrame, ProtoError>
where
R: AsyncRead + Unpin,
{
let mut header = [0u8; HEADER_LEN];
reader.read_exact(&mut header).await?;
let (msg_type, len) = decode_header(&header)?;
let mut payload = vec![0u8; len];
reader.read_exact(&mut payload).await?;
Ok(RawFrame {
msg_type,
header,
payload,
})
}
/// Frame type tags used in the application [`Frame`] encoding (§6.3).
mod frame_tag {
pub const DATA: u8 = 0x01;
pub const PING: u8 = 0x02;
pub const PONG: u8 = 0x03;
pub const CLOSE: u8 = 0x04;
}
/// Application-level frames carried inside encrypted [`MsgType::Data`] records (§6.3).
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Frame {
/// A stream data payload.
Data {
/// Logical stream identifier.
stream_id: u32,
/// Opaque application bytes.
payload: Bytes,
},
/// Liveness probe.
Ping {
/// Monotonic sequence number echoed back in the matching [`Frame::Pong`].
seq: u32,
},
/// Reply to a [`Frame::Ping`].
Pong {
/// Sequence number copied from the [`Frame::Ping`].
seq: u32,
},
/// Orderly shutdown of the logical connection.
Close {
/// Application-defined close code.
code: u8,
/// Human-readable reason (UTF-8).
reason: String,
},
}
impl Frame {
/// Serialize this frame to its compact byte encoding.
///
/// Layout (all multi-byte integers big-endian):
/// * `Data` : `0x01 || stream_id(u32) || payload`
/// * `Ping` : `0x02 || seq(u32)`
/// * `Pong` : `0x03 || seq(u32)`
/// * `Close` : `0x04 || code(u8) || reason_len(u32) || reason_utf8`
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let mut out = Vec::new();
match self {
Frame::Data { stream_id, payload } => {
out.push(frame_tag::DATA);
out.extend_from_slice(&stream_id.to_be_bytes());
out.extend_from_slice(payload);
}
Frame::Ping { seq } => {
out.push(frame_tag::PING);
out.extend_from_slice(&seq.to_be_bytes());
}
Frame::Pong { seq } => {
out.push(frame_tag::PONG);
out.extend_from_slice(&seq.to_be_bytes());
}
Frame::Close { code, reason } => {
out.push(frame_tag::CLOSE);
out.push(*code);
let bytes = reason.as_bytes();
out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
out.extend_from_slice(bytes);
}
}
out
}
/// Parse a frame from its byte encoding (the inverse of [`Frame::encode`]).
///
/// # Errors
/// Returns [`ProtoError::MalformedFrame`] if the buffer is truncated, has an unknown tag, or
/// (for `Close`) does not contain valid UTF-8.
pub fn decode(buf: &[u8]) -> Result<Self, ProtoError> {
let (&tag, rest) = buf
.split_first()
.ok_or(ProtoError::MalformedFrame("empty frame"))?;
match tag {
frame_tag::DATA => {
let stream_id = read_u32(rest, "Data.stream_id")?;
let payload = Bytes::copy_from_slice(&rest[4..]);
Ok(Frame::Data { stream_id, payload })
}
frame_tag::PING => Ok(Frame::Ping {
seq: read_u32(rest, "Ping.seq")?,
}),
frame_tag::PONG => Ok(Frame::Pong {
seq: read_u32(rest, "Pong.seq")?,
}),
frame_tag::CLOSE => {
let code = *rest
.first()
.ok_or(ProtoError::MalformedFrame("Close: missing code"))?;
let reason_len = read_u32(&rest[1..], "Close.reason_len")? as usize;
let reason_bytes = rest
.get(5..5 + reason_len)
.ok_or(ProtoError::MalformedFrame("Close: truncated reason"))?;
let reason = String::from_utf8(reason_bytes.to_vec())
.map_err(|_| ProtoError::MalformedFrame("Close: reason not UTF-8"))?;
Ok(Frame::Close { code, reason })
}
_ => Err(ProtoError::MalformedFrame("unknown frame tag")),
}
}
}
/// Read a big-endian u32 from the start of `buf`, erroring if it is too short.
fn read_u32(buf: &[u8], what: &'static str) -> Result<u32, ProtoError> {
let bytes: [u8; 4] = buf
.get(..4)
.ok_or(ProtoError::MalformedFrame(what))?
.try_into()
.expect("slice of length 4 converts to [u8; 4]");
Ok(u32::from_be_bytes(bytes))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_roundtrip_all_types() {
for (ty, byte) in [
(MsgType::ClientHello, 0x01u8),
(MsgType::ServerHello, 0x02),
(MsgType::ClientAuth, 0x03),
(MsgType::ServerAuth, 0x04),
(MsgType::Finished, 0x05),
(MsgType::Data, 0x06),
(MsgType::Alert, 0xFF),
] {
let h = encode_header(ty, 0x0012_3456).unwrap();
assert_eq!(h[0], byte);
assert_eq!(h[4], PROTOCOL_VERSION);
let (got_ty, got_len) = decode_header(&h).unwrap();
assert_eq!(got_ty, ty);
assert_eq!(got_len, 0x0012_3456);
}
}
#[test]
fn header_rejects_oversize_and_bad_version() {
assert!(matches!(
encode_header(MsgType::Data, MAX_PAYLOAD_LEN + 1),
Err(ProtoError::FrameTooLarge(_))
));
let mut h = encode_header(MsgType::Data, 1).unwrap();
h[4] = 0x02;
assert!(matches!(
decode_header(&h),
Err(ProtoError::BadVersion(0x02))
));
h[0] = 0x77;
assert!(matches!(
decode_header(&h),
Err(ProtoError::UnknownMsgType(0x77))
));
}
#[test]
fn frame_roundtrip() {
let frames = vec![
Frame::Data {
stream_id: 0xDEAD_BEEF,
payload: Bytes::from_static(b"hello world"),
},
Frame::Data {
stream_id: 0,
payload: Bytes::new(),
},
Frame::Ping { seq: 42 },
Frame::Pong { seq: 0xFFFF_FFFF },
Frame::Close {
code: 7,
reason: "going away \u{1f44b}".to_string(),
},
Frame::Close {
code: 0,
reason: String::new(),
},
];
for f in frames {
let encoded = f.encode();
let decoded = Frame::decode(&encoded).unwrap();
assert_eq!(f, decoded);
}
}
#[test]
fn frame_decode_rejects_garbage() {
assert!(Frame::decode(&[]).is_err());
assert!(Frame::decode(&[0x99]).is_err());
assert!(Frame::decode(&[frame_tag::PING, 0x00]).is_err()); // truncated u32
assert!(Frame::decode(&[frame_tag::CLOSE]).is_err()); // missing code
}
}
+453
View File
@@ -0,0 +1,453 @@
//! The Aura handshake state machine (§6.2): a hybrid X25519 + ML-KEM-768 key exchange with mutual
//! X.509 authentication and a Finished-MAC transcript binding.
//!
//! Both entry points — [`client_handshake`] and [`server_handshake`] — are generic over a separate
//! [`tokio::io::AsyncRead`] reader and [`tokio::io::AsyncWrite`] writer so they drive an in-memory
//! duplex pipe (tests) or quinn's split `RecvStream` / `SendStream` (the QUIC transport)
//! identically. On success each returns an established [`Session`].
//!
//! ## Message order (resolves the spec diagram ambiguity)
//!
//! ```text
//! 1. C->S ClientHello (plaintext): x25519_pub[32] || kyber_ek[1184] || client_nonce[32]
//! 2. S->C ServerHello (plaintext): x25519_ephemeral[32] || kyber_ct[1088] || server_nonce[32]
//! -- both sides derive the hybrid shared secret and the directional SessionKeys --
//! 3. S->C ServerAuth (encrypted): u16(cert_der_len) || server_leaf_cert_der || sig(transcript)
//! 4. C->S ClientAuth (encrypted): u16(cert_der_len) || client_leaf_cert_der || sig(transcript)
//! 5. C->S Finished (encrypted): HMAC-SHA256(key = key_c2s, transcript)
//! 6. S->C Finished (encrypted): HMAC-SHA256(key = key_s2c, transcript)
//! ```
//!
//! `transcript = SHA-256(ClientHello_frame_bytes || ServerHello_frame_bytes)` over the FULL
//! serialized frames (header + payload) exactly as transmitted. The same two [`AeadSession`]s
//! protect messages 36 and all subsequent application Data — their counters stay continuous.
use aura_crypto::{
derive_session_keys, AeadSession, HybridCiphertext, HybridPrivateKey, HybridPublicKey,
SessionKeys,
};
use aura_pki::AuraCertVerifier;
use hmac::{Hmac, Mac};
use rand::RngCore;
use rustls_pki_types::CertificateDer;
use sha2::{Digest, Sha256};
use crate::frame::{self, MsgType, RawFrame, HEADER_LEN};
use crate::session::Session;
use crate::{ClientConfig, ProtoError, ServerConfig};
/// X25519 public key / ephemeral / shared-secret length.
const X25519_LEN: usize = 32;
/// ML-KEM-768 encapsulation (public) key length.
const KYBER_EK_LEN: usize = 1184;
/// ML-KEM-768 ciphertext length.
const KYBER_CT_LEN: usize = 1088;
/// Handshake nonce length.
const NONCE_LEN: usize = 32;
/// Exact ClientHello payload length: x25519_pub || kyber_ek || client_nonce.
const CLIENT_HELLO_LEN: usize = X25519_LEN + KYBER_EK_LEN + NONCE_LEN;
/// Exact ServerHello payload length: x25519_ephemeral || kyber_ct || server_nonce.
const SERVER_HELLO_LEN: usize = X25519_LEN + KYBER_CT_LEN + NONCE_LEN;
/// HMAC-SHA256 alias for the Finished MAC.
type HmacSha256 = Hmac<Sha256>;
/// Each direction's AEAD seals exactly two encrypted handshake messages (an Auth and a Finished)
/// before application Data begins, so both AEAD counters start Data at this value.
const POST_HANDSHAKE_COUNTER: u64 = 2;
// ---------------------------------------------------------------------------------------------
// Client
// ---------------------------------------------------------------------------------------------
/// Drive the client side of the Aura handshake to completion.
///
/// Generates a hybrid keypair, exchanges hello messages, derives session keys, authenticates the
/// server, proves possession of the client key, and exchanges Finished MACs. On success returns an
/// established [`Session`] whose AEAD counters continue from the handshake.
///
/// # Errors
/// Returns a [`ProtoError`] for any transport, cryptographic, certificate, signature, or
/// transcript-MAC failure (see the variants of [`ProtoError`]).
pub async fn client_handshake<R, W>(
mut reader: R,
mut writer: W,
cfg: &ClientConfig,
) -> Result<Session<R, W>, ProtoError>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
// (1) C->S ClientHello: generate our hybrid keypair; send public half + a fresh nonce.
let (priv_key, pub_key): (HybridPrivateKey, HybridPublicKey) = HybridPrivateKey::generate();
let client_nonce = random_nonce();
let mut ch_payload = Vec::with_capacity(CLIENT_HELLO_LEN);
ch_payload.extend_from_slice(&pub_key.x25519);
ch_payload.extend_from_slice(&pub_key.kyber);
ch_payload.extend_from_slice(&client_nonce);
debug_assert_eq!(ch_payload.len(), CLIENT_HELLO_LEN);
let ch_header = frame::encode_header(MsgType::ClientHello, ch_payload.len())?;
frame::write_frame(&mut writer, MsgType::ClientHello, &ch_payload).await?;
let client_hello_wire = concat_frame(&ch_header, &ch_payload);
// (2) S->C ServerHello: ciphertext + server nonce. Capture raw bytes for the transcript.
let sh = read_expect(&mut reader, MsgType::ServerHello).await?;
if sh.payload.len() != SERVER_HELLO_LEN {
return Err(ProtoError::MalformedHandshake("ServerHello: wrong length"));
}
let (sh_x_eph, rest) = sh.payload.split_at(X25519_LEN);
let (sh_kyber_ct, sh_nonce) = rest.split_at(KYBER_CT_LEN);
let mut server_nonce = [0u8; NONCE_LEN];
server_nonce.copy_from_slice(sh_nonce);
let ciphertext = HybridCiphertext {
x25519_ephemeral: sh_x_eph.try_into().expect("32-byte x25519 ephemeral"),
kyber_ciphertext: sh_kyber_ct.to_vec(),
};
let shared = priv_key.decapsulate(&ciphertext)?;
let keys: SessionKeys = derive_session_keys(&shared, &client_nonce, &server_nonce);
// Transcript hash over the two hello frames exactly as transmitted.
let transcript = transcript_hash(&client_hello_wire, &sh.wire_bytes());
// Two AEADs: client seals c2s, opens s2c.
let mut aead_c2s = AeadSession::new(keys.client_to_server);
let mut aead_s2c = AeadSession::new(keys.server_to_client);
// (3) S->C ServerAuth (encrypted under s2c): verify cert chain + signature over transcript.
let server_auth = open_handshake_msg(&mut reader, MsgType::ServerAuth, &mut aead_s2c).await?;
let (server_cert_der, server_sig) = split_cert_and_sig(&server_auth)?;
let verifier = AuraCertVerifier::new(&cfg.ca_cert_pem)?;
let chain = [CertificateDer::from(server_cert_der.to_vec())];
verifier.verify_server_cert(&chain, &cfg.server_name)?;
verify_signature(server_cert_der, &transcript, server_sig)?;
// (4) C->S ClientAuth (encrypted under c2s): our leaf + our signature over transcript.
let client_cert_der = pem_cert_to_der(&cfg.client_cert_pem)?;
let client_sig = sign_transcript(&cfg.client_key_pem, &transcript)?;
let client_auth = build_cert_and_sig(&client_cert_der, &client_sig);
seal_handshake_msg(
&mut writer,
MsgType::ClientAuth,
&mut aead_c2s,
&client_auth,
)
.await?;
// (5) C->S Finished (encrypted under c2s): HMAC(key_c2s, transcript).
let client_finished = finished_mac(&keys.client_to_server, &transcript);
seal_handshake_msg(
&mut writer,
MsgType::Finished,
&mut aead_c2s,
&client_finished,
)
.await?;
// (6) S->C Finished (encrypted under s2c): verify HMAC(key_s2c, transcript).
let server_finished = open_handshake_msg(&mut reader, MsgType::Finished, &mut aead_s2c).await?;
verify_finished(&keys.server_to_client, &transcript, &server_finished)?;
Ok(Session::new(
reader,
writer,
aead_c2s,
aead_s2c,
POST_HANDSHAKE_COUNTER,
Some(cfg.server_name.clone()),
))
}
// ---------------------------------------------------------------------------------------------
// Server
// ---------------------------------------------------------------------------------------------
/// Drive the server side of the Aura handshake to completion.
///
/// Receives the client hello, encapsulates to the client's hybrid public key, derives session
/// keys, authenticates itself, verifies the client certificate (capturing the client id) and its
/// signature, and exchanges Finished MACs. On success returns an established [`Session`] whose
/// [`Session::peer_id`] is the verified client id.
///
/// # Errors
/// Returns a [`ProtoError`] for any transport, cryptographic, certificate, signature, or
/// transcript-MAC failure.
pub async fn server_handshake<R, W>(
mut reader: R,
mut writer: W,
cfg: &ServerConfig,
) -> Result<Session<R, W>, ProtoError>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
// (1) S receives ClientHello: reconstruct the client's hybrid public key + nonce.
let ch = read_expect(&mut reader, MsgType::ClientHello).await?;
if ch.payload.len() != CLIENT_HELLO_LEN {
return Err(ProtoError::MalformedHandshake("ClientHello: wrong length"));
}
let (ch_x, rest) = ch.payload.split_at(X25519_LEN);
let (ch_kyber_ek, ch_nonce) = rest.split_at(KYBER_EK_LEN);
let mut client_nonce = [0u8; NONCE_LEN];
client_nonce.copy_from_slice(ch_nonce);
let client_pub = HybridPublicKey {
x25519: ch_x.try_into().expect("32-byte x25519 public key"),
kyber: ch_kyber_ek.to_vec(),
};
// (2) S->C ServerHello: encapsulate to the client's public key; send ciphertext + nonce.
let (ciphertext, shared) = client_pub.encapsulate();
let server_nonce = random_nonce();
let mut sh_payload = Vec::with_capacity(SERVER_HELLO_LEN);
sh_payload.extend_from_slice(&ciphertext.x25519_ephemeral);
sh_payload.extend_from_slice(&ciphertext.kyber_ciphertext);
sh_payload.extend_from_slice(&server_nonce);
debug_assert_eq!(sh_payload.len(), SERVER_HELLO_LEN);
let sh_header = frame::encode_header(MsgType::ServerHello, sh_payload.len())?;
frame::write_frame(&mut writer, MsgType::ServerHello, &sh_payload).await?;
let server_hello_wire = concat_frame(&sh_header, &sh_payload);
let keys: SessionKeys = derive_session_keys(&shared, &client_nonce, &server_nonce);
let transcript = transcript_hash(&ch.wire_bytes(), &server_hello_wire);
// Two AEADs: server seals s2c, opens c2s.
let mut aead_c2s = AeadSession::new(keys.client_to_server);
let mut aead_s2c = AeadSession::new(keys.server_to_client);
// (3) S->C ServerAuth (encrypted under s2c): our leaf + our signature over transcript.
let server_cert_der = pem_cert_to_der(&cfg.server_cert_pem)?;
let server_sig = sign_transcript(&cfg.server_key_pem, &transcript)?;
let server_auth = build_cert_and_sig(&server_cert_der, &server_sig);
seal_handshake_msg(
&mut writer,
MsgType::ServerAuth,
&mut aead_s2c,
&server_auth,
)
.await?;
// (4) C->S ClientAuth (encrypted under c2s): verify cert chain (=> client_id) + signature.
let client_auth = open_handshake_msg(&mut reader, MsgType::ClientAuth, &mut aead_c2s).await?;
let (client_cert_der, client_sig) = split_cert_and_sig(&client_auth)?;
let verifier = AuraCertVerifier::new(&cfg.ca_cert_pem)?;
let chain = [CertificateDer::from(client_cert_der.to_vec())];
let client_id = verifier.verify_client_cert(&chain)?;
verify_signature(client_cert_der, &transcript, client_sig)?;
// (5) C->S Finished (encrypted under c2s): verify HMAC(key_c2s, transcript).
let client_finished = open_handshake_msg(&mut reader, MsgType::Finished, &mut aead_c2s).await?;
verify_finished(&keys.client_to_server, &transcript, &client_finished)?;
// (6) S->C Finished (encrypted under s2c): HMAC(key_s2c, transcript).
let server_finished = finished_mac(&keys.server_to_client, &transcript);
seal_handshake_msg(
&mut writer,
MsgType::Finished,
&mut aead_s2c,
&server_finished,
)
.await?;
Ok(Session::new(
reader,
writer,
aead_s2c,
aead_c2s,
POST_HANDSHAKE_COUNTER,
Some(client_id),
))
}
// ---------------------------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------------------------
/// Generate a fresh 32-byte handshake nonce from the OS RNG.
fn random_nonce() -> [u8; NONCE_LEN] {
let mut n = [0u8; NONCE_LEN];
rand::thread_rng().fill_bytes(&mut n);
n
}
/// Concatenate a header and payload into the full serialized frame bytes.
fn concat_frame(header: &[u8; HEADER_LEN], payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(HEADER_LEN + payload.len());
out.extend_from_slice(header);
out.extend_from_slice(payload);
out
}
/// `transcript = SHA-256(client_hello_frame || server_hello_frame)`.
fn transcript_hash(client_hello_frame: &[u8], server_hello_frame: &[u8]) -> [u8; 32] {
let mut h = Sha256::new();
h.update(client_hello_frame);
h.update(server_hello_frame);
h.finalize().into()
}
/// Read a frame and require it to be `expected`, mapping an Alert to [`ProtoError::Alert`].
async fn read_expect<R>(reader: &mut R, expected: MsgType) -> Result<RawFrame, ProtoError>
where
R: tokio::io::AsyncRead + Unpin,
{
let raw = frame::read_frame(reader).await?;
if raw.msg_type == MsgType::Alert {
let code = raw.payload.first().copied().unwrap_or(0);
return Err(ProtoError::Alert(code));
}
if raw.msg_type != expected {
return Err(ProtoError::UnexpectedMsg {
expected,
got: raw.msg_type,
});
}
Ok(raw)
}
/// Seal `plaintext` and write it as an encrypted handshake frame of type `msg_type`.
///
/// The AAD is the 5-byte frame header (binding type + length), matching the Data-record convention.
async fn seal_handshake_msg<W>(
writer: &mut W,
msg_type: MsgType,
aead: &mut AeadSession,
plaintext: &[u8],
) -> Result<(), ProtoError>
where
W: tokio::io::AsyncWrite + Unpin,
{
let sealed_len = plaintext.len() + 16; // Poly1305 tag
let header = frame::encode_header(msg_type, sealed_len)?;
let ciphertext = aead.seal(plaintext, &header);
debug_assert_eq!(ciphertext.len(), sealed_len);
frame::write_frame(writer, msg_type, &ciphertext).await
}
/// Read an encrypted handshake frame of type `msg_type` and AEAD-open it, returning the plaintext.
async fn open_handshake_msg<R>(
reader: &mut R,
msg_type: MsgType,
aead: &mut AeadSession,
) -> Result<Vec<u8>, ProtoError>
where
R: tokio::io::AsyncRead + Unpin,
{
let raw = read_expect(reader, msg_type).await?;
let plaintext = aead.open(&raw.payload, &raw.header)?;
Ok(plaintext)
}
/// Build an Auth payload: `u16_be(cert_der_len) || cert_der || signature`.
fn build_cert_and_sig(cert_der: &[u8], sig: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(2 + cert_der.len() + sig.len());
out.extend_from_slice(&(cert_der.len() as u16).to_be_bytes());
out.extend_from_slice(cert_der);
out.extend_from_slice(sig);
out
}
/// Parse an Auth payload into `(cert_der, signature)`.
fn split_cert_and_sig(buf: &[u8]) -> Result<(&[u8], &[u8]), ProtoError> {
let len_bytes: [u8; 2] = buf
.get(..2)
.ok_or(ProtoError::MalformedHandshake("Auth: missing cert length"))?
.try_into()
.expect("2-byte length prefix");
let cert_len = u16::from_be_bytes(len_bytes) as usize;
let cert_der = buf
.get(2..2 + cert_len)
.ok_or(ProtoError::MalformedHandshake("Auth: truncated cert"))?;
let sig = buf
.get(2 + cert_len..)
.ok_or(ProtoError::MalformedHandshake("Auth: missing signature"))?;
if sig.is_empty() {
return Err(ProtoError::MalformedHandshake("Auth: empty signature"));
}
Ok((cert_der, sig))
}
/// Compute the Finished MAC: `HMAC-SHA256(key, transcript)`.
fn finished_mac(key: &[u8; 32], transcript: &[u8; 32]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts a 32-byte key");
mac.update(transcript);
mac.finalize().into_bytes().to_vec()
}
/// Verify a received Finished MAC in constant time.
fn verify_finished(
key: &[u8; 32],
transcript: &[u8; 32],
received: &[u8],
) -> Result<(), ProtoError> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts a 32-byte key");
mac.update(transcript);
// `verify_slice` is constant-time and also checks the length.
mac.verify_slice(received)
.map_err(|_| ProtoError::FinishedMismatch)
}
/// Sign the transcript with a PKCS#8 PEM key (ECDSA P-256 / SHA-256, ASN.1 DER signature).
fn sign_transcript(key_pem: &str, transcript: &[u8; 32]) -> Result<Vec<u8>, ProtoError> {
let pkcs8_der = pem_key_to_der(key_pem)?;
let rng = ring::rand::SystemRandom::new();
let key_pair = ring::signature::EcdsaKeyPair::from_pkcs8(
&ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING,
&pkcs8_der,
&rng,
)
.map_err(|_| ProtoError::Signature("invalid PKCS#8 signing key"))?;
let sig = key_pair
.sign(&rng, transcript)
.map_err(|_| ProtoError::Signature("ECDSA signing failed"))?;
Ok(sig.as_ref().to_vec())
}
/// Convenience wrapper used by the client/server symmetric calls.
fn verify_signature(
cert_der: &[u8],
transcript: &[u8; 32],
signature: &[u8],
) -> Result<(), ProtoError> {
let ec_point = ec_public_key_from_cert(cert_der)?;
ring::signature::UnparsedPublicKey::new(&ring::signature::ECDSA_P256_SHA256_ASN1, ec_point)
.verify(transcript, signature)
.map_err(|_| ProtoError::Signature("handshake signature did not verify"))
}
/// Extract the uncompressed EC public-key point (`04 || X || Y`) from a leaf certificate DER.
fn ec_public_key_from_cert(cert_der: &[u8]) -> Result<Vec<u8>, ProtoError> {
use x509_parser::prelude::FromDer;
let (_, cert) = x509_parser::certificate::X509Certificate::from_der(cert_der)
.map_err(|_| ProtoError::Signature("could not parse peer certificate DER"))?;
Ok(cert.public_key().subject_public_key.data.to_vec())
}
/// Decode the first `CERTIFICATE` PEM block to DER.
fn pem_cert_to_der(pem: &str) -> Result<Vec<u8>, ProtoError> {
pem_block_to_der(pem, &["CERTIFICATE"])
.ok_or(ProtoError::Signature("no CERTIFICATE block in PEM"))
}
/// Decode the first private-key PEM block (PKCS#8 `PRIVATE KEY`, or the EC/RSA variants) to DER.
fn pem_key_to_der(pem: &str) -> Result<Vec<u8>, ProtoError> {
pem_block_to_der(pem, &["PRIVATE KEY", "EC PRIVATE KEY", "RSA PRIVATE KEY"])
.ok_or(ProtoError::Signature("no private-key block in PEM"))
}
/// Decode the first PEM block whose label is one of `labels` into its DER contents.
///
/// Uses `x509-parser`'s generic PEM container reader, so we do not need `rustls-pemfile`.
fn pem_block_to_der(pem: &str, labels: &[&str]) -> Option<Vec<u8>> {
for item in x509_parser::pem::Pem::iter_from_buffer(pem.as_bytes()) {
let item = item.ok()?;
if labels.contains(&item.label.as_str()) {
return Some(item.contents);
}
}
None
}
+138 -1
View File
@@ -1 +1,138 @@
//! aura-proto — protocol wire format and handshake (skeleton; implemented in Wave 2). //! aura-proto — the Aura VPN protocol: wire format, hybrid-KEM + mutual-X.509 handshake, and the
//! post-handshake encrypted session.
//!
//! This crate sits on top of [`aura_crypto`] (hybrid KEM, HKDF, AEAD) and [`aura_pki`] (mutual
//! X.509 verification) and implements project §6:
//!
//! * [`frame`] — the 5-byte protocol header (§6.1) and the application [`Frame`] enum (§6.3).
//! * [`handshake`] — [`client_handshake`] / [`server_handshake`], the hybrid-KEM + mutual-auth
//! state machine (§6.2).
//! * [`session`] — [`Session`], post-handshake encrypted [`Frame`] send/recv with replay
//! protection.
//!
//! ## Transport-agnostic by design
//!
//! The handshake is generic over any [`tokio::io::AsyncRead`] reader and [`tokio::io::AsyncWrite`]
//! writer, supplied as **separate** halves. That matches both an in-memory duplex pipe (tests) and
//! quinn's split `RecvStream` / `SendStream` (the Wave 3 QUIC transport), so the same code drives
//! both. See [`client_handshake`] / [`server_handshake`] and [`Session`].
//!
//! The handshake takes the reader and writer **by value** (`reader: R, writer: W`) rather than the
//! `&mut R, &mut W` sketched in the brief: the returned [`Session`] *owns* both halves (so it can
//! keep `send_frame` / `recv_frame` going), and an owning return cannot be built from borrows
//! without a `Default` placeholder that generic `R`/`W` lack. Callers pass owned halves anyway —
//! `tokio::io::split(...)` and quinn's `(SendStream, RecvStream)` both yield owned values — so this
//! is ergonomically identical while being sound.
//!
//! ## Handshake message order (resolving the spec diagram)
//!
//! The spec diagram is ambiguous about the ordering of the encrypted auth/Finished messages. This
//! implementation fixes the exact order below and both peers follow it lock-step:
//!
//! ```text
//! 1. C->S ClientHello (plaintext)
//! 2. S->C ServerHello (plaintext) -- both derive SessionKeys here --
//! 3. S->C ServerAuth (encrypted)
//! 4. C->S ClientAuth (encrypted)
//! 5. C->S Finished (encrypted)
//! 6. S->C Finished (encrypted) -- encrypted Data channel open both ways --
//! ```
#![forbid(unsafe_code)]
#![warn(missing_docs)]
pub mod frame;
pub mod handshake;
pub mod session;
pub use frame::{Frame, MsgType};
pub use handshake::{client_handshake, server_handshake};
pub use session::Session;
use thiserror::Error;
/// Errors produced by the Aura protocol layer.
#[derive(Debug, Error)]
pub enum ProtoError {
/// An I/O error on the underlying transport (read/write/EOF).
#[error("transport I/O error: {0}")]
Io(#[from] std::io::Error),
/// An error from the cryptographic core (KEM decapsulation, AEAD open).
#[error("crypto error: {0}")]
Crypto(#[from] aura_crypto::CryptoError),
/// An error from the PKI layer (certificate chain or name verification).
#[error("pki error: {0}")]
Pki(#[from] aura_pki::PkiError),
/// The header carried an unknown message-type byte.
#[error("unknown message type byte: 0x{0:02x}")]
UnknownMsgType(u8),
/// The header carried an unsupported protocol version.
#[error("unsupported protocol version: 0x{0:02x}")]
BadVersion(u8),
/// A payload exceeded the u24 length field of the frame header.
#[error("frame payload too large: {0} bytes")]
FrameTooLarge(usize),
/// A received frame had the wrong type for the current handshake/session step.
#[error("unexpected message type: expected {expected:?}, got {got:?}")]
UnexpectedMsg {
/// The message type the state machine required.
expected: MsgType,
/// The message type actually received.
got: MsgType,
},
/// A handshake message had an invalid length or internal structure.
#[error("malformed handshake message: {0}")]
MalformedHandshake(&'static str),
/// An application [`Frame`] could not be decoded.
#[error("malformed frame: {0}")]
MalformedFrame(&'static str),
/// A digital signature (over the handshake hash) failed to verify, or a key/cert could not be
/// parsed for signing or verification.
#[error("signature verification failed: {0}")]
Signature(&'static str),
/// The peer's Finished HMAC did not match the locally computed value.
#[error("handshake Finished verification failed (HMAC mismatch)")]
FinishedMismatch,
/// A Data record was rejected by the replay-protection window (duplicate or too old).
#[error("replay detected: record counter {0} is a duplicate or outside the window")]
Replay(u64),
/// The peer sent a fatal Alert.
#[error("peer sent fatal alert (code {0})")]
Alert(u8),
}
/// Client-side handshake configuration.
#[derive(Clone, Debug)]
pub struct ClientConfig {
/// PEM of the Aura CA certificate, used as the trust anchor for the server cert.
pub ca_cert_pem: String,
/// PEM of this client's leaf certificate (sent to the server during ClientAuth).
pub client_cert_pem: String,
/// PEM of this client's PKCS#8 private key (used to sign the handshake hash).
pub client_key_pem: String,
/// The DNS name the client expects in the server certificate's SAN.
pub server_name: String,
}
/// Server-side handshake configuration.
#[derive(Clone, Debug)]
pub struct ServerConfig {
/// PEM of the Aura CA certificate, used as the trust anchor for client certs.
pub ca_cert_pem: String,
/// PEM of this server's leaf certificate (sent to the client during ServerAuth).
pub server_cert_pem: String,
/// PEM of this server's PKCS#8 private key (used to sign the handshake hash).
pub server_key_pem: String,
}
+294
View File
@@ -0,0 +1,294 @@
//! Post-handshake encrypted session: AEAD-protected [`Frame`] exchange with replay protection.
//!
//! A [`Session`] owns the transport reader + writer and the two directional [`AeadSession`]s
//! produced by the handshake. It exposes [`Session::send_frame`] / [`Session::recv_frame`], which
//! serialize a [`Frame`], AEAD-seal/open it, and ship it inside a [`MsgType::Data`] record framed
//! by the 5-byte protocol header.
//!
//! ## Record format and replay protection
//!
//! Each Data record's payload is:
//!
//! ```text
//! seq (u64, big-endian) || AEAD_seal(frame_bytes, aad = header || seq)
//! ```
//!
//! The 8-byte `seq` is the AEAD nonce counter for that record. Because `aura_crypto::AeadSession`
//! advances its internal counter in lock-step on the sealing and opening sides, the transmitted
//! `seq` always equals the receiver's expected AEAD counter on the happy path. Carrying it
//! explicitly lets the receiver run a **sliding-window** replay check (window size
//! [`REPLAY_WINDOW`]) *before* touching the AEAD: a duplicate or too-old `seq` is rejected with
//! [`ProtoError::Replay`] without disturbing the AEAD's counter, so the session stays usable.
//!
//! The `seq` is also folded into the AEAD AAD (alongside the frame header), cryptographically
//! binding the record to its claimed position.
use aura_crypto::AeadSession;
use crate::frame::{self, Frame, MsgType, RawFrame, HEADER_LEN};
use crate::ProtoError;
/// Width of the sliding replay window, in records.
pub const REPLAY_WINDOW: u64 = 64;
/// Width in bytes of the per-record sequence-number prefix.
const SEQ_LEN: usize = 8;
/// Sliding-window replay detector over a monotonically increasing 64-bit counter.
///
/// Tracks the highest accepted sequence number and a bitmap of the [`REPLAY_WINDOW`] positions
/// below it. A sequence number is accepted iff it is strictly newer than everything seen, or falls
/// within the window and has not been seen before.
#[derive(Debug)]
struct ReplayWindow {
/// Highest sequence number accepted so far.
highest: u64,
/// Bitmap of accepted positions in `(highest - REPLAY_WINDOW, highest]`; bit `i` set means
/// `highest - 1 - i` was accepted.
bitmap: u64,
/// Whether any record has been accepted yet (so seq 0 can itself be accepted exactly once).
seeded: bool,
}
impl ReplayWindow {
/// Create a window primed so the first expected record is `start` (the AEAD counter at the end
/// of the handshake). Anything strictly below `start` is treated as already-consumed.
fn new(start: u64) -> Self {
// Treat (start - 1) as the highest already-accepted slot so the first Data record at `seq
// == start` is accepted as "newer". `seeded = true` keeps seq 0 handling uniform even when
// start == 0 (the handshake always leaves start >= 1, but be defensive).
Self {
highest: start.saturating_sub(1),
bitmap: 0,
seeded: start > 0,
}
}
/// Check and record `seq`. Returns `Ok(())` if it is fresh; [`ProtoError::Replay`] otherwise.
fn check_and_set(&mut self, seq: u64) -> Result<(), ProtoError> {
if !self.seeded {
// First ever record (only reachable if started at 0): accept and seed.
self.seeded = true;
self.highest = seq;
self.bitmap = 0;
return Ok(());
}
if seq > self.highest {
// New high-water mark: shift the bitmap up by the gap, marking the old highest as seen.
let shift = seq - self.highest;
if shift >= 64 {
self.bitmap = 0;
} else {
self.bitmap = (self.bitmap << shift) | (1u64 << (shift - 1));
}
self.highest = seq;
Ok(())
} else {
// seq <= highest: must be inside the window and previously unseen.
let offset = self.highest - seq; // 0 == highest itself (already accepted)
if offset >= REPLAY_WINDOW {
return Err(ProtoError::Replay(seq));
}
if offset == 0 {
// Exactly the current highest: already accepted.
return Err(ProtoError::Replay(seq));
}
let mask = 1u64 << (offset - 1);
if self.bitmap & mask != 0 {
return Err(ProtoError::Replay(seq));
}
self.bitmap |= mask;
Ok(())
}
}
}
/// An established, encrypted Aura session over a transport reader `R` and writer `W`.
///
/// Created by [`crate::client_handshake`] / [`crate::server_handshake`]. Use
/// [`Session::send_frame`] and [`Session::recv_frame`] for application traffic.
pub struct Session<R, W> {
reader: R,
writer: W,
/// AEAD this endpoint seals outgoing Data with.
send_aead: AeadSession,
/// AEAD this endpoint opens incoming Data with.
recv_aead: AeadSession,
/// Next sequence number to stamp on an outgoing Data record (mirrors `send_aead`'s counter).
send_seq: u64,
/// Replay window over incoming Data sequence numbers.
replay: ReplayWindow,
/// The verified identity (Common Name) of the peer, learned during the handshake.
peer_id: Option<String>,
}
impl<R, W> Session<R, W>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
/// Assemble a session from the handshake outputs.
///
/// `start_counter` is the AEAD nonce counter both directions have reached after the encrypted
/// handshake messages (so the first Data record stamps `seq == start_counter`). `peer_id` is
/// the verified peer Common Name (the server learns the client id; the client may store the
/// server name).
pub(crate) fn new(
reader: R,
writer: W,
send_aead: AeadSession,
recv_aead: AeadSession,
start_counter: u64,
peer_id: Option<String>,
) -> Self {
Self {
reader,
writer,
send_aead,
recv_aead,
send_seq: start_counter,
replay: ReplayWindow::new(start_counter),
peer_id,
}
}
/// The verified identity (Common Name) of the peer, if one was captured during the handshake.
#[must_use]
pub fn peer_id(&self) -> Option<&str> {
self.peer_id.as_deref()
}
/// Serialize, seal, and send a single application [`Frame`].
///
/// # Errors
/// Returns [`ProtoError::Io`] on a write failure or [`ProtoError::FrameTooLarge`] if the sealed
/// record exceeds the frame header's length field.
pub async fn send_frame(&mut self, frame: Frame) -> Result<(), ProtoError> {
let frame_bytes = frame.encode();
let seq = self.send_seq;
// Build the record payload: seq(8) || ciphertext. The frame header binds (type, length),
// and the seq binds the record position; both go into the AAD.
let sealed_len = SEQ_LEN + frame_bytes.len() + 16 /* Poly1305 tag */;
let header = frame::encode_header(MsgType::Data, sealed_len)?;
let mut aad = [0u8; HEADER_LEN + SEQ_LEN];
aad[..HEADER_LEN].copy_from_slice(&header);
aad[HEADER_LEN..].copy_from_slice(&seq.to_be_bytes());
let ciphertext = self.send_aead.seal(&frame_bytes, &aad);
let mut payload = Vec::with_capacity(SEQ_LEN + ciphertext.len());
payload.extend_from_slice(&seq.to_be_bytes());
payload.extend_from_slice(&ciphertext);
debug_assert_eq!(payload.len(), sealed_len);
// Write header + payload. (We already know the exact header for AAD; reuse it.)
use tokio::io::AsyncWriteExt;
self.writer.write_all(&header).await?;
self.writer.write_all(&payload).await?;
self.writer.flush().await?;
self.send_seq = self.send_seq.wrapping_add(1);
Ok(())
}
/// Receive, open, and decode a single application [`Frame`].
///
/// # Errors
/// * [`ProtoError::Replay`] — the record's sequence number was a duplicate or too old.
/// * [`ProtoError::Crypto`] — AEAD authentication failed.
/// * [`ProtoError::UnexpectedMsg`] — a non-Data record arrived.
/// * [`ProtoError::Alert`] — the peer sent a fatal alert.
/// * [`ProtoError::Io`] / [`ProtoError::MalformedFrame`] on transport / decode failure.
pub async fn recv_frame(&mut self) -> Result<Frame, ProtoError> {
let raw: RawFrame = frame::read_frame(&mut self.reader).await?;
match raw.msg_type {
MsgType::Data => {}
MsgType::Alert => {
let code = raw.payload.first().copied().unwrap_or(0);
return Err(ProtoError::Alert(code));
}
other => {
return Err(ProtoError::UnexpectedMsg {
expected: MsgType::Data,
got: other,
})
}
}
if raw.payload.len() < SEQ_LEN {
return Err(ProtoError::MalformedFrame(
"Data record shorter than seq prefix",
));
}
let mut seq_bytes = [0u8; SEQ_LEN];
seq_bytes.copy_from_slice(&raw.payload[..SEQ_LEN]);
let seq = u64::from_be_bytes(seq_bytes);
let ciphertext = &raw.payload[SEQ_LEN..];
// Replay check FIRST — a duplicate/old record must not advance the AEAD counter.
self.replay.check_and_set(seq)?;
let mut aad = [0u8; HEADER_LEN + SEQ_LEN];
aad[..HEADER_LEN].copy_from_slice(&raw.header);
aad[HEADER_LEN..].copy_from_slice(&seq_bytes);
let plaintext = self.recv_aead.open(ciphertext, &aad)?;
Frame::decode(&plaintext)
}
/// Consume the session, returning its transport halves (for clean shutdown / reuse).
#[must_use]
pub fn into_inner(self) -> (R, W) {
(self.reader, self.writer)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn replay_window_basic_monotonic() {
let mut w = ReplayWindow::new(2);
// First legitimate Data records.
assert!(w.check_and_set(2).is_ok());
assert!(w.check_and_set(3).is_ok());
assert!(w.check_and_set(4).is_ok());
// Replays of each are rejected.
assert!(matches!(w.check_and_set(2), Err(ProtoError::Replay(2))));
assert!(matches!(w.check_and_set(3), Err(ProtoError::Replay(3))));
assert!(matches!(w.check_and_set(4), Err(ProtoError::Replay(4))));
}
#[test]
fn replay_window_out_of_order_within_window() {
let mut w = ReplayWindow::new(0);
assert!(w.check_and_set(0).is_ok());
assert!(w.check_and_set(10).is_ok());
// Older-but-unseen inside the window is accepted exactly once.
assert!(w.check_and_set(5).is_ok());
assert!(matches!(w.check_and_set(5), Err(ProtoError::Replay(5))));
// The new-high record (10) is a replay.
assert!(matches!(w.check_and_set(10), Err(ProtoError::Replay(10))));
// Brand-new high still works.
assert!(w.check_and_set(11).is_ok());
}
#[test]
fn replay_window_rejects_too_old() {
let mut w = ReplayWindow::new(0);
assert!(w.check_and_set(0).is_ok());
assert!(w.check_and_set(200).is_ok());
// Far below the window edge => rejected as too old.
assert!(matches!(w.check_and_set(1), Err(ProtoError::Replay(1))));
assert!(matches!(
w.check_and_set(200 - REPLAY_WINDOW),
Err(ProtoError::Replay(_))
));
// Just inside the window edge => still acceptable.
assert!(w.check_and_set(200 - REPLAY_WINDOW + 1).is_ok());
}
}
+56
View File
@@ -0,0 +1,56 @@
//! Shared test helpers: minting an Aura CA + leaf certs, and wiring an in-memory duplex transport.
#![allow(dead_code)] // each integration test binary uses a different subset of these helpers
use aura_pki::AuraCa;
use aura_proto::{ClientConfig, ServerConfig};
/// A minted PKI fixture: a CA, a server cert/key, and a client cert/key.
pub struct Pki {
pub ca_cert_pem: String,
pub server_cert_pem: String,
pub server_key_pem: String,
pub client_cert_pem: String,
pub client_key_pem: String,
pub server_name: String,
pub client_id: String,
}
/// Mint a CA plus a server cert (for `server_name`) and a client cert (CN = `client_id`).
pub fn mint_pki(server_name: &str, client_id: &str) -> Pki {
let ca = AuraCa::generate("Aura Test Root CA").expect("generate CA");
let server = ca
.issue_server_cert(server_name)
.expect("issue server cert");
let client = ca.issue_client_cert(client_id).expect("issue client cert");
Pki {
ca_cert_pem: ca.ca_cert_pem(),
server_cert_pem: server.cert_pem,
server_key_pem: server.key_pem,
client_cert_pem: client.cert_pem,
client_key_pem: client.key_pem,
server_name: server_name.to_string(),
client_id: client_id.to_string(),
}
}
impl Pki {
/// Build a matching [`ClientConfig`] from this fixture.
pub fn client_config(&self) -> ClientConfig {
ClientConfig {
ca_cert_pem: self.ca_cert_pem.clone(),
client_cert_pem: self.client_cert_pem.clone(),
client_key_pem: self.client_key_pem.clone(),
server_name: self.server_name.clone(),
}
}
/// Build a matching [`ServerConfig`] from this fixture.
pub fn server_config(&self) -> ServerConfig {
ServerConfig {
ca_cert_pem: self.ca_cert_pem.clone(),
server_cert_pem: self.server_cert_pem.clone(),
server_key_pem: self.server_key_pem.clone(),
}
}
}
+136
View File
@@ -0,0 +1,136 @@
//! `test_data_exchange_1000pkts` — after the handshake, exchange 1000 Data frames in each
//! direction and assert payload integrity and ordering.
mod common;
use aura_proto::{client_handshake, server_handshake, Frame};
use bytes::Bytes;
use tokio::io::split;
const N: u32 = 1000;
/// Build the deterministic payload for frame `i` from `who`.
fn payload_for(who: &str, i: u32) -> Bytes {
Bytes::from(format!(
"{who}-packet-{i}-{}",
"x".repeat((i % 37) as usize)
))
}
#[tokio::test]
async fn test_data_exchange_1000pkts() {
let pki = common::mint_pki("vpn.aura.example", "client-alpha");
let client_cfg = pki.client_config();
let server_cfg = pki.server_config();
let (client_end, server_end) = tokio::io::duplex(64 * 1024);
let (c_read, c_write) = split(client_end);
let (s_read, s_write) = split(server_end);
let client = tokio::spawn(async move {
let mut sess = client_handshake(c_read, c_write, &client_cfg)
.await
.expect("client handshake");
// Interleave send + recv in lockstep to avoid filling the duplex buffer.
for i in 0..N {
sess.send_frame(Frame::Data {
stream_id: 1,
payload: payload_for("client", i),
})
.await
.expect("client send");
match sess.recv_frame().await.expect("client recv") {
Frame::Data { stream_id, payload } => {
assert_eq!(stream_id, 2, "wrong stream id at i={i}");
assert_eq!(
payload,
payload_for("server", i),
"payload mismatch at i={i}"
);
}
other => panic!("client expected Data, got {other:?}"),
}
}
});
let server = tokio::spawn(async move {
let mut sess = server_handshake(s_read, s_write, &server_cfg)
.await
.expect("server handshake");
for i in 0..N {
// Receive the client's i-th packet first, then reply, mirroring the client's lockstep.
match sess.recv_frame().await.expect("server recv") {
Frame::Data { stream_id, payload } => {
assert_eq!(stream_id, 1, "wrong stream id at i={i}");
assert_eq!(
payload,
payload_for("client", i),
"payload mismatch at i={i}"
);
}
other => panic!("server expected Data, got {other:?}"),
}
sess.send_frame(Frame::Data {
stream_id: 2,
payload: payload_for("server", i),
})
.await
.expect("server send");
}
});
let (c, s) = tokio::join!(client, server);
c.expect("client task");
s.expect("server task");
}
#[tokio::test]
async fn ping_pong_and_close_frames_roundtrip() {
let pki = common::mint_pki("vpn.aura.example", "c1");
let client_cfg = pki.client_config();
let server_cfg = pki.server_config();
let (client_end, server_end) = tokio::io::duplex(64 * 1024);
let (c_read, c_write) = split(client_end);
let (s_read, s_write) = split(server_end);
let client = tokio::spawn(async move {
let mut sess = client_handshake(c_read, c_write, &client_cfg)
.await
.unwrap();
sess.send_frame(Frame::Ping { seq: 7 }).await.unwrap();
match sess.recv_frame().await.unwrap() {
Frame::Pong { seq } => assert_eq!(seq, 7),
other => panic!("expected Pong, got {other:?}"),
}
sess.send_frame(Frame::Close {
code: 0,
reason: "bye".into(),
})
.await
.unwrap();
});
let server = tokio::spawn(async move {
let mut sess = server_handshake(s_read, s_write, &server_cfg)
.await
.unwrap();
match sess.recv_frame().await.unwrap() {
Frame::Ping { seq } => sess.send_frame(Frame::Pong { seq }).await.unwrap(),
other => panic!("expected Ping, got {other:?}"),
}
match sess.recv_frame().await.unwrap() {
Frame::Close { code, reason } => {
assert_eq!(code, 0);
assert_eq!(reason, "bye");
}
other => panic!("expected Close, got {other:?}"),
}
});
let (c, s) = tokio::join!(client, server);
c.unwrap();
s.unwrap();
}
@@ -0,0 +1,43 @@
//! `test_full_handshake_loopback` — a full client+server handshake over an in-memory duplex.
mod common;
use aura_proto::{client_handshake, server_handshake};
use tokio::io::split;
#[tokio::test]
async fn test_full_handshake_loopback() {
let pki = common::mint_pki("vpn.aura.example", "client-alpha");
let client_cfg = pki.client_config();
let server_cfg = pki.server_config();
// Connected in-memory transport; split each end into independent read/write halves so the
// handshake can use separate reader + writer (matching quinn's split streams).
let (client_end, server_end) = tokio::io::duplex(64 * 1024);
let (c_read, c_write) = split(client_end);
let (s_read, s_write) = split(server_end);
let client = tokio::spawn(async move {
client_handshake(c_read, c_write, &client_cfg)
.await
.map(|s| s.peer_id().map(str::to_string))
});
let server = tokio::spawn(async move {
server_handshake(s_read, s_write, &server_cfg)
.await
.map(|s| s.peer_id().map(str::to_string))
});
let (client_res, server_res) = tokio::join!(client, server);
let client_peer = client_res
.expect("client task")
.expect("client handshake ok");
let server_peer = server_res
.expect("server task")
.expect("server handshake ok");
// Server learned the client id from the verified client certificate.
assert_eq!(server_peer.as_deref(), Some("client-alpha"));
// Client recorded the server name it authenticated.
assert_eq!(client_peer.as_deref(), Some("vpn.aura.example"));
}
@@ -0,0 +1,86 @@
//! `test_pki_mutual_auth` — the server must reject a client whose certificate was issued by a
//! different CA, and must reject a client that presents a valid certificate but a forged signature
//! (one made with a key that does not match the certificate).
mod common;
use aura_pki::AuraCa;
use aura_proto::{client_handshake, server_handshake, ClientConfig, ProtoError};
use tokio::io::split;
/// Run a handshake and return both sides' results.
async fn run(
client_cfg: ClientConfig,
server_cfg: aura_proto::ServerConfig,
) -> (Result<(), ProtoError>, Result<Option<String>, ProtoError>) {
let (client_end, server_end) = tokio::io::duplex(64 * 1024);
let (c_read, c_write) = split(client_end);
let (s_read, s_write) = split(server_end);
let client = tokio::spawn(async move {
client_handshake(c_read, c_write, &client_cfg)
.await
.map(|_| ())
});
let server = tokio::spawn(async move {
server_handshake(s_read, s_write, &server_cfg)
.await
.map(|s| s.peer_id().map(str::to_string))
});
let (c, s) = tokio::join!(client, server);
(c.expect("client task"), s.expect("server task"))
}
#[tokio::test]
async fn wrong_ca_client_cert_is_rejected() {
// The legitimate server-side PKI.
let pki = common::mint_pki("vpn.aura.example", "client-alpha");
// An attacker CA issues a client cert with a plausible CN, but it does NOT chain to the
// server's trusted CA.
let rogue_ca = AuraCa::generate("Rogue CA").expect("rogue CA");
let rogue_client = rogue_ca
.issue_client_cert("client-alpha")
.expect("rogue client cert");
let client_cfg = ClientConfig {
ca_cert_pem: pki.ca_cert_pem.clone(),
client_cert_pem: rogue_client.cert_pem,
client_key_pem: rogue_client.key_pem,
server_name: pki.server_name.clone(),
};
let (_client_res, server_res) = run(client_cfg, pki.server_config()).await;
// The server must fail verifying the client chain against its trusted CA.
assert!(
matches!(server_res, Err(ProtoError::Pki(_))),
"expected a PKI verification failure, got {server_res:?}"
);
}
#[tokio::test]
async fn forged_client_signature_is_rejected() {
let pki = common::mint_pki("vpn.aura.example", "client-alpha");
// Mint an unrelated P-256 keypair (via a throwaway issued cert) to use as the WRONG signing
// key. We pair the legitimate client's certificate with this mismatched private key: the chain
// verifies fine, but the signature over the transcript is made with a key that does not match
// the certificate's public key, so signature verification must fail.
let throwaway_ca = AuraCa::generate("throwaway").expect("throwaway CA");
let mismatched = throwaway_ca
.issue_client_cert("mismatched")
.expect("throwaway cert");
let client_cfg = ClientConfig {
ca_cert_pem: pki.ca_cert_pem.clone(),
client_cert_pem: pki.client_cert_pem.clone(), // valid cert (chains to trusted CA)
client_key_pem: mismatched.key_pem, // WRONG key -> forged signature
server_name: pki.server_name.clone(),
};
let (_client_res, server_res) = run(client_cfg, pki.server_config()).await;
assert!(
matches!(server_res, Err(ProtoError::Signature(_))),
"expected a signature verification failure, got {server_res:?}"
);
}
@@ -0,0 +1,122 @@
//! `test_replay_protection` — a Data record that was already delivered, replayed verbatim, must be
//! rejected by the receiver's sliding replay window.
//!
//! Topology: the client and server each talk to one end of their own duplex. A relay task in the
//! middle forwards bytes between the two. For the client->server direction the relay parses whole
//! frames (using the crate's public framing helpers) so that, once the client has sent its data
//! packet, the relay can forward that exact record to the server a SECOND time — a verbatim replay.
mod common;
use aura_proto::frame::{read_frame, write_frame, MsgType, RawFrame};
use aura_proto::{client_handshake, server_handshake, Frame, ProtoError};
use bytes::Bytes;
use tokio::io::{split, AsyncWriteExt};
use tokio::sync::oneshot;
#[tokio::test]
async fn test_replay_protection() {
let pki = common::mint_pki("vpn.aura.example", "client-alpha");
let client_cfg = pki.client_config();
let server_cfg = pki.server_config();
// Two duplexes with a relay in the middle.
let (client_io, relay_a) = tokio::io::duplex(64 * 1024);
let (relay_b, server_io) = tokio::io::duplex(64 * 1024);
let (c_read, c_write) = split(client_io);
let (s_read, s_write) = split(server_io);
let (ra_read, ra_write) = split(relay_a); // faces the client
let (rb_read, rb_write) = split(relay_b); // faces the server
// Signal so the relay forwards the replay only after the server has consumed the genuine copy.
let (genuine_done_tx, genuine_done_rx) = oneshot::channel::<()>();
// ---- Relay: client -> server, with a one-shot verbatim replay of the first Data record ----
let relay_c2s = tokio::spawn(async move {
let mut ra_read = ra_read;
let mut rb_write = rb_write;
let mut genuine_done = Some(genuine_done_rx);
let mut replayed = false;
loop {
let frame: RawFrame = match read_frame(&mut ra_read).await {
Ok(f) => f,
Err(_) => break, // EOF when the client side closes
};
// Forward the frame unchanged.
write_frame(&mut rb_write, frame.msg_type, &frame.payload)
.await
.expect("relay forward c->s");
// On the first Data record, wait until the server has accepted it, then replay it once.
if frame.msg_type == MsgType::Data && !replayed {
replayed = true;
if let Some(rx) = genuine_done.take() {
let _ = rx.await; // server signals after it accepted the genuine record
}
write_frame(&mut rb_write, frame.msg_type, &frame.payload)
.await
.expect("relay replay c->s");
rb_write.flush().await.expect("flush replay");
}
}
});
// ---- Relay: server -> client (straight byte copy) ----
let relay_s2c = tokio::spawn(async move {
let mut rb_read = rb_read;
let mut ra_write = ra_write;
let _ = tokio::io::copy(&mut rb_read, &mut ra_write).await;
let _ = ra_write.shutdown().await;
});
// ---- Client: handshake, then send exactly one Data frame ----
let client = tokio::spawn(async move {
let mut sess = client_handshake(c_read, c_write, &client_cfg)
.await
.expect("client handshake");
sess.send_frame(Frame::Data {
stream_id: 9,
payload: Bytes::from_static(b"the one and only payload"),
})
.await
.expect("client send");
// Keep the session (and thus the transport) alive until the test signals completion.
sess
});
// ---- Server: handshake, accept the genuine record, then expect the replay to be rejected ----
let server = tokio::spawn(async move {
let mut sess = server_handshake(s_read, s_write, &server_cfg)
.await
.expect("server handshake");
// 1) Genuine record is accepted.
let first = sess.recv_frame().await.expect("genuine recv");
match first {
Frame::Data { stream_id, payload } => {
assert_eq!(stream_id, 9);
assert_eq!(&payload[..], b"the one and only payload");
}
other => panic!("expected Data, got {other:?}"),
}
// Tell the relay it may now inject the verbatim replay.
genuine_done_tx.send(()).expect("signal genuine done");
// 2) The replayed record must be rejected by the replay window.
let replay_result = sess.recv_frame().await;
assert!(
matches!(replay_result, Err(ProtoError::Replay(_))),
"expected ProtoError::Replay, got {replay_result:?}"
);
sess
});
let (_client_sess, server_outcome) = tokio::join!(client, server);
server_outcome.expect("server task");
drop(_client_sess); // closes the client side -> relays drain and exit
let _ = relay_c2s.await;
let _ = relay_s2c.await;
}