feat(proto): implement Wave 2 — hybrid PKI handshake + session
aura-proto: 5-byte wire header + Frame codec (§6.1/§6.3); transport-agnostic handshake state machine (§6.2) over split tokio AsyncRead/AsyncWrite — hybrid X25519+ML-KEM-768 KEM, SHA-256 transcript, mutual X.509 auth with ECDSA-P256 transcript signatures (ring), constant-time HMAC Finished; Session with sliding-window replay protection. 13 tests green, clippy clean. Handshake message order pinned (resolves spec diagram ambiguity); reader/writer taken by value since Session owns both halves. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,371 @@
|
||||
//! Wire format: the 5-byte protocol header (§6.1) and the application [`Frame`] enum (§6.3).
|
||||
//!
|
||||
//! Every Aura protocol message on the wire is a 5-byte header followed by a payload:
|
||||
//!
|
||||
//! ```text
|
||||
//! byte 0 : msg_type (u8)
|
||||
//! bytes 1..4 : length (u24, big-endian) = payload length in bytes
|
||||
//! byte 4 : version = 0x01
|
||||
//! bytes 5.. : payload (length bytes)
|
||||
//! ```
|
||||
//!
|
||||
//! [`Frame`] is the post-handshake application payload. Each `Frame` is serialized with
|
||||
//! [`Frame::encode`], AEAD-sealed, and shipped inside a [`MsgType::Data`] record (see
|
||||
//! [`crate::session`]).
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::ProtoError;
|
||||
|
||||
/// Length in bytes of the protocol frame header.
|
||||
pub const HEADER_LEN: usize = 5;
|
||||
|
||||
/// Protocol version carried in byte 4 of every header.
|
||||
pub const PROTOCOL_VERSION: u8 = 0x01;
|
||||
|
||||
/// Largest payload expressible by the u24 length field.
|
||||
pub const MAX_PAYLOAD_LEN: usize = 0x00FF_FFFF;
|
||||
|
||||
/// Message types carried in byte 0 of the header (§6.1).
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum MsgType {
|
||||
/// Handshake message 1 (C->S): hybrid public key + client nonce.
|
||||
ClientHello = 0x01,
|
||||
/// Handshake message 2 (S->C): hybrid ciphertext + server nonce.
|
||||
ServerHello = 0x02,
|
||||
/// Handshake message 4 (C->S, encrypted): client cert + signature.
|
||||
ClientAuth = 0x03,
|
||||
/// Handshake message 3 (S->C, encrypted): server cert + signature.
|
||||
ServerAuth = 0x04,
|
||||
/// Handshake Finished (encrypted): HMAC over the handshake hash.
|
||||
Finished = 0x05,
|
||||
/// Application data (encrypted): an AEAD-sealed [`Frame`].
|
||||
Data = 0x06,
|
||||
/// Fatal alert / error notification.
|
||||
Alert = 0xFF,
|
||||
}
|
||||
|
||||
impl MsgType {
|
||||
/// Map the on-wire byte to a [`MsgType`].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized byte.
|
||||
pub fn from_u8(b: u8) -> Result<Self, ProtoError> {
|
||||
Ok(match b {
|
||||
0x01 => Self::ClientHello,
|
||||
0x02 => Self::ServerHello,
|
||||
0x03 => Self::ClientAuth,
|
||||
0x04 => Self::ServerAuth,
|
||||
0x05 => Self::Finished,
|
||||
0x06 => Self::Data,
|
||||
0xFF => Self::Alert,
|
||||
other => return Err(ProtoError::UnknownMsgType(other)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a 5-byte header for `msg_type` carrying a payload of `payload_len` bytes.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::FrameTooLarge`] if `payload_len` does not fit in the u24 length field.
|
||||
pub fn encode_header(
|
||||
msg_type: MsgType,
|
||||
payload_len: usize,
|
||||
) -> Result<[u8; HEADER_LEN], ProtoError> {
|
||||
if payload_len > MAX_PAYLOAD_LEN {
|
||||
return Err(ProtoError::FrameTooLarge(payload_len));
|
||||
}
|
||||
let len = payload_len as u32;
|
||||
Ok([
|
||||
msg_type as u8,
|
||||
((len >> 16) & 0xFF) as u8,
|
||||
((len >> 8) & 0xFF) as u8,
|
||||
(len & 0xFF) as u8,
|
||||
PROTOCOL_VERSION,
|
||||
])
|
||||
}
|
||||
|
||||
/// Parse a 5-byte header into `(msg_type, payload_len)`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized type byte or
|
||||
/// [`ProtoError::BadVersion`] if byte 4 is not [`PROTOCOL_VERSION`].
|
||||
pub fn decode_header(header: &[u8; HEADER_LEN]) -> Result<(MsgType, usize), ProtoError> {
|
||||
let msg_type = MsgType::from_u8(header[0])?;
|
||||
let version = header[4];
|
||||
if version != PROTOCOL_VERSION {
|
||||
return Err(ProtoError::BadVersion(version));
|
||||
}
|
||||
let len = ((header[1] as usize) << 16) | ((header[2] as usize) << 8) | (header[3] as usize);
|
||||
Ok((msg_type, len))
|
||||
}
|
||||
|
||||
/// Write one full frame (`header || payload`) to `writer`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::FrameTooLarge`] if the payload is too long, or [`ProtoError::Io`] on a
|
||||
/// write failure.
|
||||
pub async fn write_frame<W>(
|
||||
writer: &mut W,
|
||||
msg_type: MsgType,
|
||||
payload: &[u8],
|
||||
) -> Result<(), ProtoError>
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let header = encode_header(msg_type, payload.len())?;
|
||||
writer.write_all(&header).await?;
|
||||
writer.write_all(payload).await?;
|
||||
writer.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A frame read off the wire: its type, the raw header bytes (useful as AEAD AAD and for the
|
||||
/// handshake transcript hash), and the payload.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RawFrame {
|
||||
/// The decoded message type.
|
||||
pub msg_type: MsgType,
|
||||
/// The 5 header bytes exactly as transmitted.
|
||||
pub header: [u8; HEADER_LEN],
|
||||
/// The payload bytes.
|
||||
pub payload: Vec<u8>,
|
||||
}
|
||||
|
||||
impl RawFrame {
|
||||
/// The full serialized frame (`header || payload`) exactly as it appeared on the wire.
|
||||
///
|
||||
/// Used to feed the handshake transcript hash, which must hash the bytes as transmitted.
|
||||
#[must_use]
|
||||
pub fn wire_bytes(&self) -> Vec<u8> {
|
||||
let mut out = Vec::with_capacity(HEADER_LEN + self.payload.len());
|
||||
out.extend_from_slice(&self.header);
|
||||
out.extend_from_slice(&self.payload);
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
/// Read one full frame (`header || payload`) from `reader`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::Io`] on a read failure (including a truncated frame / EOF), or a header
|
||||
/// decode error from [`decode_header`].
|
||||
pub async fn read_frame<R>(reader: &mut R) -> Result<RawFrame, ProtoError>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
let mut header = [0u8; HEADER_LEN];
|
||||
reader.read_exact(&mut header).await?;
|
||||
let (msg_type, len) = decode_header(&header)?;
|
||||
let mut payload = vec![0u8; len];
|
||||
reader.read_exact(&mut payload).await?;
|
||||
Ok(RawFrame {
|
||||
msg_type,
|
||||
header,
|
||||
payload,
|
||||
})
|
||||
}
|
||||
|
||||
/// Frame type tags used in the application [`Frame`] encoding (§6.3).
|
||||
mod frame_tag {
|
||||
pub const DATA: u8 = 0x01;
|
||||
pub const PING: u8 = 0x02;
|
||||
pub const PONG: u8 = 0x03;
|
||||
pub const CLOSE: u8 = 0x04;
|
||||
}
|
||||
|
||||
/// Application-level frames carried inside encrypted [`MsgType::Data`] records (§6.3).
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum Frame {
|
||||
/// A stream data payload.
|
||||
Data {
|
||||
/// Logical stream identifier.
|
||||
stream_id: u32,
|
||||
/// Opaque application bytes.
|
||||
payload: Bytes,
|
||||
},
|
||||
/// Liveness probe.
|
||||
Ping {
|
||||
/// Monotonic sequence number echoed back in the matching [`Frame::Pong`].
|
||||
seq: u32,
|
||||
},
|
||||
/// Reply to a [`Frame::Ping`].
|
||||
Pong {
|
||||
/// Sequence number copied from the [`Frame::Ping`].
|
||||
seq: u32,
|
||||
},
|
||||
/// Orderly shutdown of the logical connection.
|
||||
Close {
|
||||
/// Application-defined close code.
|
||||
code: u8,
|
||||
/// Human-readable reason (UTF-8).
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Frame {
|
||||
/// Serialize this frame to its compact byte encoding.
|
||||
///
|
||||
/// Layout (all multi-byte integers big-endian):
|
||||
/// * `Data` : `0x01 || stream_id(u32) || payload`
|
||||
/// * `Ping` : `0x02 || seq(u32)`
|
||||
/// * `Pong` : `0x03 || seq(u32)`
|
||||
/// * `Close` : `0x04 || code(u8) || reason_len(u32) || reason_utf8`
|
||||
#[must_use]
|
||||
pub fn encode(&self) -> Vec<u8> {
|
||||
let mut out = Vec::new();
|
||||
match self {
|
||||
Frame::Data { stream_id, payload } => {
|
||||
out.push(frame_tag::DATA);
|
||||
out.extend_from_slice(&stream_id.to_be_bytes());
|
||||
out.extend_from_slice(payload);
|
||||
}
|
||||
Frame::Ping { seq } => {
|
||||
out.push(frame_tag::PING);
|
||||
out.extend_from_slice(&seq.to_be_bytes());
|
||||
}
|
||||
Frame::Pong { seq } => {
|
||||
out.push(frame_tag::PONG);
|
||||
out.extend_from_slice(&seq.to_be_bytes());
|
||||
}
|
||||
Frame::Close { code, reason } => {
|
||||
out.push(frame_tag::CLOSE);
|
||||
out.push(*code);
|
||||
let bytes = reason.as_bytes();
|
||||
out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
|
||||
out.extend_from_slice(bytes);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Parse a frame from its byte encoding (the inverse of [`Frame::encode`]).
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::MalformedFrame`] if the buffer is truncated, has an unknown tag, or
|
||||
/// (for `Close`) does not contain valid UTF-8.
|
||||
pub fn decode(buf: &[u8]) -> Result<Self, ProtoError> {
|
||||
let (&tag, rest) = buf
|
||||
.split_first()
|
||||
.ok_or(ProtoError::MalformedFrame("empty frame"))?;
|
||||
match tag {
|
||||
frame_tag::DATA => {
|
||||
let stream_id = read_u32(rest, "Data.stream_id")?;
|
||||
let payload = Bytes::copy_from_slice(&rest[4..]);
|
||||
Ok(Frame::Data { stream_id, payload })
|
||||
}
|
||||
frame_tag::PING => Ok(Frame::Ping {
|
||||
seq: read_u32(rest, "Ping.seq")?,
|
||||
}),
|
||||
frame_tag::PONG => Ok(Frame::Pong {
|
||||
seq: read_u32(rest, "Pong.seq")?,
|
||||
}),
|
||||
frame_tag::CLOSE => {
|
||||
let code = *rest
|
||||
.first()
|
||||
.ok_or(ProtoError::MalformedFrame("Close: missing code"))?;
|
||||
let reason_len = read_u32(&rest[1..], "Close.reason_len")? as usize;
|
||||
let reason_bytes = rest
|
||||
.get(5..5 + reason_len)
|
||||
.ok_or(ProtoError::MalformedFrame("Close: truncated reason"))?;
|
||||
let reason = String::from_utf8(reason_bytes.to_vec())
|
||||
.map_err(|_| ProtoError::MalformedFrame("Close: reason not UTF-8"))?;
|
||||
Ok(Frame::Close { code, reason })
|
||||
}
|
||||
_ => Err(ProtoError::MalformedFrame("unknown frame tag")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a big-endian u32 from the start of `buf`, erroring if it is too short.
|
||||
fn read_u32(buf: &[u8], what: &'static str) -> Result<u32, ProtoError> {
|
||||
let bytes: [u8; 4] = buf
|
||||
.get(..4)
|
||||
.ok_or(ProtoError::MalformedFrame(what))?
|
||||
.try_into()
|
||||
.expect("slice of length 4 converts to [u8; 4]");
|
||||
Ok(u32::from_be_bytes(bytes))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn header_roundtrip_all_types() {
|
||||
for (ty, byte) in [
|
||||
(MsgType::ClientHello, 0x01u8),
|
||||
(MsgType::ServerHello, 0x02),
|
||||
(MsgType::ClientAuth, 0x03),
|
||||
(MsgType::ServerAuth, 0x04),
|
||||
(MsgType::Finished, 0x05),
|
||||
(MsgType::Data, 0x06),
|
||||
(MsgType::Alert, 0xFF),
|
||||
] {
|
||||
let h = encode_header(ty, 0x0012_3456).unwrap();
|
||||
assert_eq!(h[0], byte);
|
||||
assert_eq!(h[4], PROTOCOL_VERSION);
|
||||
let (got_ty, got_len) = decode_header(&h).unwrap();
|
||||
assert_eq!(got_ty, ty);
|
||||
assert_eq!(got_len, 0x0012_3456);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_rejects_oversize_and_bad_version() {
|
||||
assert!(matches!(
|
||||
encode_header(MsgType::Data, MAX_PAYLOAD_LEN + 1),
|
||||
Err(ProtoError::FrameTooLarge(_))
|
||||
));
|
||||
let mut h = encode_header(MsgType::Data, 1).unwrap();
|
||||
h[4] = 0x02;
|
||||
assert!(matches!(
|
||||
decode_header(&h),
|
||||
Err(ProtoError::BadVersion(0x02))
|
||||
));
|
||||
h[0] = 0x77;
|
||||
assert!(matches!(
|
||||
decode_header(&h),
|
||||
Err(ProtoError::UnknownMsgType(0x77))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_roundtrip() {
|
||||
let frames = vec![
|
||||
Frame::Data {
|
||||
stream_id: 0xDEAD_BEEF,
|
||||
payload: Bytes::from_static(b"hello world"),
|
||||
},
|
||||
Frame::Data {
|
||||
stream_id: 0,
|
||||
payload: Bytes::new(),
|
||||
},
|
||||
Frame::Ping { seq: 42 },
|
||||
Frame::Pong { seq: 0xFFFF_FFFF },
|
||||
Frame::Close {
|
||||
code: 7,
|
||||
reason: "going away \u{1f44b}".to_string(),
|
||||
},
|
||||
Frame::Close {
|
||||
code: 0,
|
||||
reason: String::new(),
|
||||
},
|
||||
];
|
||||
for f in frames {
|
||||
let encoded = f.encode();
|
||||
let decoded = Frame::decode(&encoded).unwrap();
|
||||
assert_eq!(f, decoded);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_rejects_garbage() {
|
||||
assert!(Frame::decode(&[]).is_err());
|
||||
assert!(Frame::decode(&[0x99]).is_err());
|
||||
assert!(Frame::decode(&[frame_tag::PING, 0x00]).is_err()); // truncated u32
|
||||
assert!(Frame::decode(&[frame_tag::CLOSE]).is_err()); // missing code
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
//! The Aura handshake state machine (§6.2): a hybrid X25519 + ML-KEM-768 key exchange with mutual
|
||||
//! X.509 authentication and a Finished-MAC transcript binding.
|
||||
//!
|
||||
//! Both entry points — [`client_handshake`] and [`server_handshake`] — are generic over a separate
|
||||
//! [`tokio::io::AsyncRead`] reader and [`tokio::io::AsyncWrite`] writer so they drive an in-memory
|
||||
//! duplex pipe (tests) or quinn's split `RecvStream` / `SendStream` (the QUIC transport)
|
||||
//! identically. On success each returns an established [`Session`].
|
||||
//!
|
||||
//! ## Message order (resolves the spec diagram ambiguity)
|
||||
//!
|
||||
//! ```text
|
||||
//! 1. C->S ClientHello (plaintext): x25519_pub[32] || kyber_ek[1184] || client_nonce[32]
|
||||
//! 2. S->C ServerHello (plaintext): x25519_ephemeral[32] || kyber_ct[1088] || server_nonce[32]
|
||||
//! -- both sides derive the hybrid shared secret and the directional SessionKeys --
|
||||
//! 3. S->C ServerAuth (encrypted): u16(cert_der_len) || server_leaf_cert_der || sig(transcript)
|
||||
//! 4. C->S ClientAuth (encrypted): u16(cert_der_len) || client_leaf_cert_der || sig(transcript)
|
||||
//! 5. C->S Finished (encrypted): HMAC-SHA256(key = key_c2s, transcript)
|
||||
//! 6. S->C Finished (encrypted): HMAC-SHA256(key = key_s2c, transcript)
|
||||
//! ```
|
||||
//!
|
||||
//! `transcript = SHA-256(ClientHello_frame_bytes || ServerHello_frame_bytes)` over the FULL
|
||||
//! serialized frames (header + payload) exactly as transmitted. The same two [`AeadSession`]s
|
||||
//! protect messages 3–6 and all subsequent application Data — their counters stay continuous.
|
||||
|
||||
use aura_crypto::{
|
||||
derive_session_keys, AeadSession, HybridCiphertext, HybridPrivateKey, HybridPublicKey,
|
||||
SessionKeys,
|
||||
};
|
||||
use aura_pki::AuraCertVerifier;
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::RngCore;
|
||||
use rustls_pki_types::CertificateDer;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::frame::{self, MsgType, RawFrame, HEADER_LEN};
|
||||
use crate::session::Session;
|
||||
use crate::{ClientConfig, ProtoError, ServerConfig};
|
||||
|
||||
/// X25519 public key / ephemeral / shared-secret length.
|
||||
const X25519_LEN: usize = 32;
|
||||
/// ML-KEM-768 encapsulation (public) key length.
|
||||
const KYBER_EK_LEN: usize = 1184;
|
||||
/// ML-KEM-768 ciphertext length.
|
||||
const KYBER_CT_LEN: usize = 1088;
|
||||
/// Handshake nonce length.
|
||||
const NONCE_LEN: usize = 32;
|
||||
|
||||
/// Exact ClientHello payload length: x25519_pub || kyber_ek || client_nonce.
|
||||
const CLIENT_HELLO_LEN: usize = X25519_LEN + KYBER_EK_LEN + NONCE_LEN;
|
||||
/// Exact ServerHello payload length: x25519_ephemeral || kyber_ct || server_nonce.
|
||||
const SERVER_HELLO_LEN: usize = X25519_LEN + KYBER_CT_LEN + NONCE_LEN;
|
||||
|
||||
/// HMAC-SHA256 alias for the Finished MAC.
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Each direction's AEAD seals exactly two encrypted handshake messages (an Auth and a Finished)
|
||||
/// before application Data begins, so both AEAD counters start Data at this value.
|
||||
const POST_HANDSHAKE_COUNTER: u64 = 2;
|
||||
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
// Client
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Drive the client side of the Aura handshake to completion.
|
||||
///
|
||||
/// Generates a hybrid keypair, exchanges hello messages, derives session keys, authenticates the
|
||||
/// server, proves possession of the client key, and exchanges Finished MACs. On success returns an
|
||||
/// established [`Session`] whose AEAD counters continue from the handshake.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns a [`ProtoError`] for any transport, cryptographic, certificate, signature, or
|
||||
/// transcript-MAC failure (see the variants of [`ProtoError`]).
|
||||
pub async fn client_handshake<R, W>(
|
||||
mut reader: R,
|
||||
mut writer: W,
|
||||
cfg: &ClientConfig,
|
||||
) -> Result<Session<R, W>, ProtoError>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
// (1) C->S ClientHello: generate our hybrid keypair; send public half + a fresh nonce.
|
||||
let (priv_key, pub_key): (HybridPrivateKey, HybridPublicKey) = HybridPrivateKey::generate();
|
||||
let client_nonce = random_nonce();
|
||||
|
||||
let mut ch_payload = Vec::with_capacity(CLIENT_HELLO_LEN);
|
||||
ch_payload.extend_from_slice(&pub_key.x25519);
|
||||
ch_payload.extend_from_slice(&pub_key.kyber);
|
||||
ch_payload.extend_from_slice(&client_nonce);
|
||||
debug_assert_eq!(ch_payload.len(), CLIENT_HELLO_LEN);
|
||||
|
||||
let ch_header = frame::encode_header(MsgType::ClientHello, ch_payload.len())?;
|
||||
frame::write_frame(&mut writer, MsgType::ClientHello, &ch_payload).await?;
|
||||
let client_hello_wire = concat_frame(&ch_header, &ch_payload);
|
||||
|
||||
// (2) S->C ServerHello: ciphertext + server nonce. Capture raw bytes for the transcript.
|
||||
let sh = read_expect(&mut reader, MsgType::ServerHello).await?;
|
||||
if sh.payload.len() != SERVER_HELLO_LEN {
|
||||
return Err(ProtoError::MalformedHandshake("ServerHello: wrong length"));
|
||||
}
|
||||
let (sh_x_eph, rest) = sh.payload.split_at(X25519_LEN);
|
||||
let (sh_kyber_ct, sh_nonce) = rest.split_at(KYBER_CT_LEN);
|
||||
let mut server_nonce = [0u8; NONCE_LEN];
|
||||
server_nonce.copy_from_slice(sh_nonce);
|
||||
|
||||
let ciphertext = HybridCiphertext {
|
||||
x25519_ephemeral: sh_x_eph.try_into().expect("32-byte x25519 ephemeral"),
|
||||
kyber_ciphertext: sh_kyber_ct.to_vec(),
|
||||
};
|
||||
let shared = priv_key.decapsulate(&ciphertext)?;
|
||||
let keys: SessionKeys = derive_session_keys(&shared, &client_nonce, &server_nonce);
|
||||
|
||||
// Transcript hash over the two hello frames exactly as transmitted.
|
||||
let transcript = transcript_hash(&client_hello_wire, &sh.wire_bytes());
|
||||
|
||||
// Two AEADs: client seals c2s, opens s2c.
|
||||
let mut aead_c2s = AeadSession::new(keys.client_to_server);
|
||||
let mut aead_s2c = AeadSession::new(keys.server_to_client);
|
||||
|
||||
// (3) S->C ServerAuth (encrypted under s2c): verify cert chain + signature over transcript.
|
||||
let server_auth = open_handshake_msg(&mut reader, MsgType::ServerAuth, &mut aead_s2c).await?;
|
||||
let (server_cert_der, server_sig) = split_cert_and_sig(&server_auth)?;
|
||||
let verifier = AuraCertVerifier::new(&cfg.ca_cert_pem)?;
|
||||
let chain = [CertificateDer::from(server_cert_der.to_vec())];
|
||||
verifier.verify_server_cert(&chain, &cfg.server_name)?;
|
||||
verify_signature(server_cert_der, &transcript, server_sig)?;
|
||||
|
||||
// (4) C->S ClientAuth (encrypted under c2s): our leaf + our signature over transcript.
|
||||
let client_cert_der = pem_cert_to_der(&cfg.client_cert_pem)?;
|
||||
let client_sig = sign_transcript(&cfg.client_key_pem, &transcript)?;
|
||||
let client_auth = build_cert_and_sig(&client_cert_der, &client_sig);
|
||||
seal_handshake_msg(
|
||||
&mut writer,
|
||||
MsgType::ClientAuth,
|
||||
&mut aead_c2s,
|
||||
&client_auth,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// (5) C->S Finished (encrypted under c2s): HMAC(key_c2s, transcript).
|
||||
let client_finished = finished_mac(&keys.client_to_server, &transcript);
|
||||
seal_handshake_msg(
|
||||
&mut writer,
|
||||
MsgType::Finished,
|
||||
&mut aead_c2s,
|
||||
&client_finished,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// (6) S->C Finished (encrypted under s2c): verify HMAC(key_s2c, transcript).
|
||||
let server_finished = open_handshake_msg(&mut reader, MsgType::Finished, &mut aead_s2c).await?;
|
||||
verify_finished(&keys.server_to_client, &transcript, &server_finished)?;
|
||||
|
||||
Ok(Session::new(
|
||||
reader,
|
||||
writer,
|
||||
aead_c2s,
|
||||
aead_s2c,
|
||||
POST_HANDSHAKE_COUNTER,
|
||||
Some(cfg.server_name.clone()),
|
||||
))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
// Server
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Drive the server side of the Aura handshake to completion.
|
||||
///
|
||||
/// Receives the client hello, encapsulates to the client's hybrid public key, derives session
|
||||
/// keys, authenticates itself, verifies the client certificate (capturing the client id) and its
|
||||
/// signature, and exchanges Finished MACs. On success returns an established [`Session`] whose
|
||||
/// [`Session::peer_id`] is the verified client id.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns a [`ProtoError`] for any transport, cryptographic, certificate, signature, or
|
||||
/// transcript-MAC failure.
|
||||
pub async fn server_handshake<R, W>(
|
||||
mut reader: R,
|
||||
mut writer: W,
|
||||
cfg: &ServerConfig,
|
||||
) -> Result<Session<R, W>, ProtoError>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
// (1) S receives ClientHello: reconstruct the client's hybrid public key + nonce.
|
||||
let ch = read_expect(&mut reader, MsgType::ClientHello).await?;
|
||||
if ch.payload.len() != CLIENT_HELLO_LEN {
|
||||
return Err(ProtoError::MalformedHandshake("ClientHello: wrong length"));
|
||||
}
|
||||
let (ch_x, rest) = ch.payload.split_at(X25519_LEN);
|
||||
let (ch_kyber_ek, ch_nonce) = rest.split_at(KYBER_EK_LEN);
|
||||
let mut client_nonce = [0u8; NONCE_LEN];
|
||||
client_nonce.copy_from_slice(ch_nonce);
|
||||
|
||||
let client_pub = HybridPublicKey {
|
||||
x25519: ch_x.try_into().expect("32-byte x25519 public key"),
|
||||
kyber: ch_kyber_ek.to_vec(),
|
||||
};
|
||||
|
||||
// (2) S->C ServerHello: encapsulate to the client's public key; send ciphertext + nonce.
|
||||
let (ciphertext, shared) = client_pub.encapsulate();
|
||||
let server_nonce = random_nonce();
|
||||
let mut sh_payload = Vec::with_capacity(SERVER_HELLO_LEN);
|
||||
sh_payload.extend_from_slice(&ciphertext.x25519_ephemeral);
|
||||
sh_payload.extend_from_slice(&ciphertext.kyber_ciphertext);
|
||||
sh_payload.extend_from_slice(&server_nonce);
|
||||
debug_assert_eq!(sh_payload.len(), SERVER_HELLO_LEN);
|
||||
|
||||
let sh_header = frame::encode_header(MsgType::ServerHello, sh_payload.len())?;
|
||||
frame::write_frame(&mut writer, MsgType::ServerHello, &sh_payload).await?;
|
||||
let server_hello_wire = concat_frame(&sh_header, &sh_payload);
|
||||
|
||||
let keys: SessionKeys = derive_session_keys(&shared, &client_nonce, &server_nonce);
|
||||
let transcript = transcript_hash(&ch.wire_bytes(), &server_hello_wire);
|
||||
|
||||
// Two AEADs: server seals s2c, opens c2s.
|
||||
let mut aead_c2s = AeadSession::new(keys.client_to_server);
|
||||
let mut aead_s2c = AeadSession::new(keys.server_to_client);
|
||||
|
||||
// (3) S->C ServerAuth (encrypted under s2c): our leaf + our signature over transcript.
|
||||
let server_cert_der = pem_cert_to_der(&cfg.server_cert_pem)?;
|
||||
let server_sig = sign_transcript(&cfg.server_key_pem, &transcript)?;
|
||||
let server_auth = build_cert_and_sig(&server_cert_der, &server_sig);
|
||||
seal_handshake_msg(
|
||||
&mut writer,
|
||||
MsgType::ServerAuth,
|
||||
&mut aead_s2c,
|
||||
&server_auth,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// (4) C->S ClientAuth (encrypted under c2s): verify cert chain (=> client_id) + signature.
|
||||
let client_auth = open_handshake_msg(&mut reader, MsgType::ClientAuth, &mut aead_c2s).await?;
|
||||
let (client_cert_der, client_sig) = split_cert_and_sig(&client_auth)?;
|
||||
let verifier = AuraCertVerifier::new(&cfg.ca_cert_pem)?;
|
||||
let chain = [CertificateDer::from(client_cert_der.to_vec())];
|
||||
let client_id = verifier.verify_client_cert(&chain)?;
|
||||
verify_signature(client_cert_der, &transcript, client_sig)?;
|
||||
|
||||
// (5) C->S Finished (encrypted under c2s): verify HMAC(key_c2s, transcript).
|
||||
let client_finished = open_handshake_msg(&mut reader, MsgType::Finished, &mut aead_c2s).await?;
|
||||
verify_finished(&keys.client_to_server, &transcript, &client_finished)?;
|
||||
|
||||
// (6) S->C Finished (encrypted under s2c): HMAC(key_s2c, transcript).
|
||||
let server_finished = finished_mac(&keys.server_to_client, &transcript);
|
||||
seal_handshake_msg(
|
||||
&mut writer,
|
||||
MsgType::Finished,
|
||||
&mut aead_s2c,
|
||||
&server_finished,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Session::new(
|
||||
reader,
|
||||
writer,
|
||||
aead_s2c,
|
||||
aead_c2s,
|
||||
POST_HANDSHAKE_COUNTER,
|
||||
Some(client_id),
|
||||
))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Generate a fresh 32-byte handshake nonce from the OS RNG.
|
||||
fn random_nonce() -> [u8; NONCE_LEN] {
|
||||
let mut n = [0u8; NONCE_LEN];
|
||||
rand::thread_rng().fill_bytes(&mut n);
|
||||
n
|
||||
}
|
||||
|
||||
/// Concatenate a header and payload into the full serialized frame bytes.
|
||||
fn concat_frame(header: &[u8; HEADER_LEN], payload: &[u8]) -> Vec<u8> {
|
||||
let mut out = Vec::with_capacity(HEADER_LEN + payload.len());
|
||||
out.extend_from_slice(header);
|
||||
out.extend_from_slice(payload);
|
||||
out
|
||||
}
|
||||
|
||||
/// `transcript = SHA-256(client_hello_frame || server_hello_frame)`.
|
||||
fn transcript_hash(client_hello_frame: &[u8], server_hello_frame: &[u8]) -> [u8; 32] {
|
||||
let mut h = Sha256::new();
|
||||
h.update(client_hello_frame);
|
||||
h.update(server_hello_frame);
|
||||
h.finalize().into()
|
||||
}
|
||||
|
||||
/// Read a frame and require it to be `expected`, mapping an Alert to [`ProtoError::Alert`].
|
||||
async fn read_expect<R>(reader: &mut R, expected: MsgType) -> Result<RawFrame, ProtoError>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
let raw = frame::read_frame(reader).await?;
|
||||
if raw.msg_type == MsgType::Alert {
|
||||
let code = raw.payload.first().copied().unwrap_or(0);
|
||||
return Err(ProtoError::Alert(code));
|
||||
}
|
||||
if raw.msg_type != expected {
|
||||
return Err(ProtoError::UnexpectedMsg {
|
||||
expected,
|
||||
got: raw.msg_type,
|
||||
});
|
||||
}
|
||||
Ok(raw)
|
||||
}
|
||||
|
||||
/// Seal `plaintext` and write it as an encrypted handshake frame of type `msg_type`.
|
||||
///
|
||||
/// The AAD is the 5-byte frame header (binding type + length), matching the Data-record convention.
|
||||
async fn seal_handshake_msg<W>(
|
||||
writer: &mut W,
|
||||
msg_type: MsgType,
|
||||
aead: &mut AeadSession,
|
||||
plaintext: &[u8],
|
||||
) -> Result<(), ProtoError>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let sealed_len = plaintext.len() + 16; // Poly1305 tag
|
||||
let header = frame::encode_header(msg_type, sealed_len)?;
|
||||
let ciphertext = aead.seal(plaintext, &header);
|
||||
debug_assert_eq!(ciphertext.len(), sealed_len);
|
||||
frame::write_frame(writer, msg_type, &ciphertext).await
|
||||
}
|
||||
|
||||
/// Read an encrypted handshake frame of type `msg_type` and AEAD-open it, returning the plaintext.
|
||||
async fn open_handshake_msg<R>(
|
||||
reader: &mut R,
|
||||
msg_type: MsgType,
|
||||
aead: &mut AeadSession,
|
||||
) -> Result<Vec<u8>, ProtoError>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
let raw = read_expect(reader, msg_type).await?;
|
||||
let plaintext = aead.open(&raw.payload, &raw.header)?;
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
/// Build an Auth payload: `u16_be(cert_der_len) || cert_der || signature`.
|
||||
fn build_cert_and_sig(cert_der: &[u8], sig: &[u8]) -> Vec<u8> {
|
||||
let mut out = Vec::with_capacity(2 + cert_der.len() + sig.len());
|
||||
out.extend_from_slice(&(cert_der.len() as u16).to_be_bytes());
|
||||
out.extend_from_slice(cert_der);
|
||||
out.extend_from_slice(sig);
|
||||
out
|
||||
}
|
||||
|
||||
/// Parse an Auth payload into `(cert_der, signature)`.
|
||||
fn split_cert_and_sig(buf: &[u8]) -> Result<(&[u8], &[u8]), ProtoError> {
|
||||
let len_bytes: [u8; 2] = buf
|
||||
.get(..2)
|
||||
.ok_or(ProtoError::MalformedHandshake("Auth: missing cert length"))?
|
||||
.try_into()
|
||||
.expect("2-byte length prefix");
|
||||
let cert_len = u16::from_be_bytes(len_bytes) as usize;
|
||||
let cert_der = buf
|
||||
.get(2..2 + cert_len)
|
||||
.ok_or(ProtoError::MalformedHandshake("Auth: truncated cert"))?;
|
||||
let sig = buf
|
||||
.get(2 + cert_len..)
|
||||
.ok_or(ProtoError::MalformedHandshake("Auth: missing signature"))?;
|
||||
if sig.is_empty() {
|
||||
return Err(ProtoError::MalformedHandshake("Auth: empty signature"));
|
||||
}
|
||||
Ok((cert_der, sig))
|
||||
}
|
||||
|
||||
/// Compute the Finished MAC: `HMAC-SHA256(key, transcript)`.
|
||||
fn finished_mac(key: &[u8; 32], transcript: &[u8; 32]) -> Vec<u8> {
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts a 32-byte key");
|
||||
mac.update(transcript);
|
||||
mac.finalize().into_bytes().to_vec()
|
||||
}
|
||||
|
||||
/// Verify a received Finished MAC in constant time.
|
||||
fn verify_finished(
|
||||
key: &[u8; 32],
|
||||
transcript: &[u8; 32],
|
||||
received: &[u8],
|
||||
) -> Result<(), ProtoError> {
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts a 32-byte key");
|
||||
mac.update(transcript);
|
||||
// `verify_slice` is constant-time and also checks the length.
|
||||
mac.verify_slice(received)
|
||||
.map_err(|_| ProtoError::FinishedMismatch)
|
||||
}
|
||||
|
||||
/// Sign the transcript with a PKCS#8 PEM key (ECDSA P-256 / SHA-256, ASN.1 DER signature).
|
||||
fn sign_transcript(key_pem: &str, transcript: &[u8; 32]) -> Result<Vec<u8>, ProtoError> {
|
||||
let pkcs8_der = pem_key_to_der(key_pem)?;
|
||||
let rng = ring::rand::SystemRandom::new();
|
||||
let key_pair = ring::signature::EcdsaKeyPair::from_pkcs8(
|
||||
&ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING,
|
||||
&pkcs8_der,
|
||||
&rng,
|
||||
)
|
||||
.map_err(|_| ProtoError::Signature("invalid PKCS#8 signing key"))?;
|
||||
let sig = key_pair
|
||||
.sign(&rng, transcript)
|
||||
.map_err(|_| ProtoError::Signature("ECDSA signing failed"))?;
|
||||
Ok(sig.as_ref().to_vec())
|
||||
}
|
||||
|
||||
/// Convenience wrapper used by the client/server symmetric calls.
|
||||
fn verify_signature(
|
||||
cert_der: &[u8],
|
||||
transcript: &[u8; 32],
|
||||
signature: &[u8],
|
||||
) -> Result<(), ProtoError> {
|
||||
let ec_point = ec_public_key_from_cert(cert_der)?;
|
||||
ring::signature::UnparsedPublicKey::new(&ring::signature::ECDSA_P256_SHA256_ASN1, ec_point)
|
||||
.verify(transcript, signature)
|
||||
.map_err(|_| ProtoError::Signature("handshake signature did not verify"))
|
||||
}
|
||||
|
||||
/// Extract the uncompressed EC public-key point (`04 || X || Y`) from a leaf certificate DER.
|
||||
fn ec_public_key_from_cert(cert_der: &[u8]) -> Result<Vec<u8>, ProtoError> {
|
||||
use x509_parser::prelude::FromDer;
|
||||
let (_, cert) = x509_parser::certificate::X509Certificate::from_der(cert_der)
|
||||
.map_err(|_| ProtoError::Signature("could not parse peer certificate DER"))?;
|
||||
Ok(cert.public_key().subject_public_key.data.to_vec())
|
||||
}
|
||||
|
||||
/// Decode the first `CERTIFICATE` PEM block to DER.
|
||||
fn pem_cert_to_der(pem: &str) -> Result<Vec<u8>, ProtoError> {
|
||||
pem_block_to_der(pem, &["CERTIFICATE"])
|
||||
.ok_or(ProtoError::Signature("no CERTIFICATE block in PEM"))
|
||||
}
|
||||
|
||||
/// Decode the first private-key PEM block (PKCS#8 `PRIVATE KEY`, or the EC/RSA variants) to DER.
|
||||
fn pem_key_to_der(pem: &str) -> Result<Vec<u8>, ProtoError> {
|
||||
pem_block_to_der(pem, &["PRIVATE KEY", "EC PRIVATE KEY", "RSA PRIVATE KEY"])
|
||||
.ok_or(ProtoError::Signature("no private-key block in PEM"))
|
||||
}
|
||||
|
||||
/// Decode the first PEM block whose label is one of `labels` into its DER contents.
|
||||
///
|
||||
/// Uses `x509-parser`'s generic PEM container reader, so we do not need `rustls-pemfile`.
|
||||
fn pem_block_to_der(pem: &str, labels: &[&str]) -> Option<Vec<u8>> {
|
||||
for item in x509_parser::pem::Pem::iter_from_buffer(pem.as_bytes()) {
|
||||
let item = item.ok()?;
|
||||
if labels.contains(&item.label.as_str()) {
|
||||
return Some(item.contents);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -1 +1,138 @@
|
||||
//! aura-proto — protocol wire format and handshake (skeleton; implemented in Wave 2).
|
||||
//! aura-proto — the Aura VPN protocol: wire format, hybrid-KEM + mutual-X.509 handshake, and the
|
||||
//! post-handshake encrypted session.
|
||||
//!
|
||||
//! This crate sits on top of [`aura_crypto`] (hybrid KEM, HKDF, AEAD) and [`aura_pki`] (mutual
|
||||
//! X.509 verification) and implements project §6:
|
||||
//!
|
||||
//! * [`frame`] — the 5-byte protocol header (§6.1) and the application [`Frame`] enum (§6.3).
|
||||
//! * [`handshake`] — [`client_handshake`] / [`server_handshake`], the hybrid-KEM + mutual-auth
|
||||
//! state machine (§6.2).
|
||||
//! * [`session`] — [`Session`], post-handshake encrypted [`Frame`] send/recv with replay
|
||||
//! protection.
|
||||
//!
|
||||
//! ## Transport-agnostic by design
|
||||
//!
|
||||
//! The handshake is generic over any [`tokio::io::AsyncRead`] reader and [`tokio::io::AsyncWrite`]
|
||||
//! writer, supplied as **separate** halves. That matches both an in-memory duplex pipe (tests) and
|
||||
//! quinn's split `RecvStream` / `SendStream` (the Wave 3 QUIC transport), so the same code drives
|
||||
//! both. See [`client_handshake`] / [`server_handshake`] and [`Session`].
|
||||
//!
|
||||
//! The handshake takes the reader and writer **by value** (`reader: R, writer: W`) rather than the
|
||||
//! `&mut R, &mut W` sketched in the brief: the returned [`Session`] *owns* both halves (so it can
|
||||
//! keep `send_frame` / `recv_frame` going), and an owning return cannot be built from borrows
|
||||
//! without a `Default` placeholder that generic `R`/`W` lack. Callers pass owned halves anyway —
|
||||
//! `tokio::io::split(...)` and quinn's `(SendStream, RecvStream)` both yield owned values — so this
|
||||
//! is ergonomically identical while being sound.
|
||||
//!
|
||||
//! ## Handshake message order (resolving the spec diagram)
|
||||
//!
|
||||
//! The spec diagram is ambiguous about the ordering of the encrypted auth/Finished messages. This
|
||||
//! implementation fixes the exact order below and both peers follow it lock-step:
|
||||
//!
|
||||
//! ```text
|
||||
//! 1. C->S ClientHello (plaintext)
|
||||
//! 2. S->C ServerHello (plaintext) -- both derive SessionKeys here --
|
||||
//! 3. S->C ServerAuth (encrypted)
|
||||
//! 4. C->S ClientAuth (encrypted)
|
||||
//! 5. C->S Finished (encrypted)
|
||||
//! 6. S->C Finished (encrypted) -- encrypted Data channel open both ways --
|
||||
//! ```
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
pub mod frame;
|
||||
pub mod handshake;
|
||||
pub mod session;
|
||||
|
||||
pub use frame::{Frame, MsgType};
|
||||
pub use handshake::{client_handshake, server_handshake};
|
||||
pub use session::Session;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors produced by the Aura protocol layer.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProtoError {
|
||||
/// An I/O error on the underlying transport (read/write/EOF).
|
||||
#[error("transport I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// An error from the cryptographic core (KEM decapsulation, AEAD open).
|
||||
#[error("crypto error: {0}")]
|
||||
Crypto(#[from] aura_crypto::CryptoError),
|
||||
|
||||
/// An error from the PKI layer (certificate chain or name verification).
|
||||
#[error("pki error: {0}")]
|
||||
Pki(#[from] aura_pki::PkiError),
|
||||
|
||||
/// The header carried an unknown message-type byte.
|
||||
#[error("unknown message type byte: 0x{0:02x}")]
|
||||
UnknownMsgType(u8),
|
||||
|
||||
/// The header carried an unsupported protocol version.
|
||||
#[error("unsupported protocol version: 0x{0:02x}")]
|
||||
BadVersion(u8),
|
||||
|
||||
/// A payload exceeded the u24 length field of the frame header.
|
||||
#[error("frame payload too large: {0} bytes")]
|
||||
FrameTooLarge(usize),
|
||||
|
||||
/// A received frame had the wrong type for the current handshake/session step.
|
||||
#[error("unexpected message type: expected {expected:?}, got {got:?}")]
|
||||
UnexpectedMsg {
|
||||
/// The message type the state machine required.
|
||||
expected: MsgType,
|
||||
/// The message type actually received.
|
||||
got: MsgType,
|
||||
},
|
||||
|
||||
/// A handshake message had an invalid length or internal structure.
|
||||
#[error("malformed handshake message: {0}")]
|
||||
MalformedHandshake(&'static str),
|
||||
|
||||
/// An application [`Frame`] could not be decoded.
|
||||
#[error("malformed frame: {0}")]
|
||||
MalformedFrame(&'static str),
|
||||
|
||||
/// A digital signature (over the handshake hash) failed to verify, or a key/cert could not be
|
||||
/// parsed for signing or verification.
|
||||
#[error("signature verification failed: {0}")]
|
||||
Signature(&'static str),
|
||||
|
||||
/// The peer's Finished HMAC did not match the locally computed value.
|
||||
#[error("handshake Finished verification failed (HMAC mismatch)")]
|
||||
FinishedMismatch,
|
||||
|
||||
/// A Data record was rejected by the replay-protection window (duplicate or too old).
|
||||
#[error("replay detected: record counter {0} is a duplicate or outside the window")]
|
||||
Replay(u64),
|
||||
|
||||
/// The peer sent a fatal Alert.
|
||||
#[error("peer sent fatal alert (code {0})")]
|
||||
Alert(u8),
|
||||
}
|
||||
|
||||
/// Client-side handshake configuration.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClientConfig {
|
||||
/// PEM of the Aura CA certificate, used as the trust anchor for the server cert.
|
||||
pub ca_cert_pem: String,
|
||||
/// PEM of this client's leaf certificate (sent to the server during ClientAuth).
|
||||
pub client_cert_pem: String,
|
||||
/// PEM of this client's PKCS#8 private key (used to sign the handshake hash).
|
||||
pub client_key_pem: String,
|
||||
/// The DNS name the client expects in the server certificate's SAN.
|
||||
pub server_name: String,
|
||||
}
|
||||
|
||||
/// Server-side handshake configuration.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ServerConfig {
|
||||
/// PEM of the Aura CA certificate, used as the trust anchor for client certs.
|
||||
pub ca_cert_pem: String,
|
||||
/// PEM of this server's leaf certificate (sent to the client during ServerAuth).
|
||||
pub server_cert_pem: String,
|
||||
/// PEM of this server's PKCS#8 private key (used to sign the handshake hash).
|
||||
pub server_key_pem: String,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,294 @@
|
||||
//! Post-handshake encrypted session: AEAD-protected [`Frame`] exchange with replay protection.
|
||||
//!
|
||||
//! A [`Session`] owns the transport reader + writer and the two directional [`AeadSession`]s
|
||||
//! produced by the handshake. It exposes [`Session::send_frame`] / [`Session::recv_frame`], which
|
||||
//! serialize a [`Frame`], AEAD-seal/open it, and ship it inside a [`MsgType::Data`] record framed
|
||||
//! by the 5-byte protocol header.
|
||||
//!
|
||||
//! ## Record format and replay protection
|
||||
//!
|
||||
//! Each Data record's payload is:
|
||||
//!
|
||||
//! ```text
|
||||
//! seq (u64, big-endian) || AEAD_seal(frame_bytes, aad = header || seq)
|
||||
//! ```
|
||||
//!
|
||||
//! The 8-byte `seq` is the AEAD nonce counter for that record. Because `aura_crypto::AeadSession`
|
||||
//! advances its internal counter in lock-step on the sealing and opening sides, the transmitted
|
||||
//! `seq` always equals the receiver's expected AEAD counter on the happy path. Carrying it
|
||||
//! explicitly lets the receiver run a **sliding-window** replay check (window size
|
||||
//! [`REPLAY_WINDOW`]) *before* touching the AEAD: a duplicate or too-old `seq` is rejected with
|
||||
//! [`ProtoError::Replay`] without disturbing the AEAD's counter, so the session stays usable.
|
||||
//!
|
||||
//! The `seq` is also folded into the AEAD AAD (alongside the frame header), cryptographically
|
||||
//! binding the record to its claimed position.
|
||||
|
||||
use aura_crypto::AeadSession;
|
||||
|
||||
use crate::frame::{self, Frame, MsgType, RawFrame, HEADER_LEN};
|
||||
use crate::ProtoError;
|
||||
|
||||
/// Width of the sliding replay window, in records.
|
||||
pub const REPLAY_WINDOW: u64 = 64;
|
||||
|
||||
/// Width in bytes of the per-record sequence-number prefix.
|
||||
const SEQ_LEN: usize = 8;
|
||||
|
||||
/// Sliding-window replay detector over a monotonically increasing 64-bit counter.
|
||||
///
|
||||
/// Tracks the highest accepted sequence number and a bitmap of the [`REPLAY_WINDOW`] positions
|
||||
/// below it. A sequence number is accepted iff it is strictly newer than everything seen, or falls
|
||||
/// within the window and has not been seen before.
|
||||
#[derive(Debug)]
|
||||
struct ReplayWindow {
|
||||
/// Highest sequence number accepted so far.
|
||||
highest: u64,
|
||||
/// Bitmap of accepted positions in `(highest - REPLAY_WINDOW, highest]`; bit `i` set means
|
||||
/// `highest - 1 - i` was accepted.
|
||||
bitmap: u64,
|
||||
/// Whether any record has been accepted yet (so seq 0 can itself be accepted exactly once).
|
||||
seeded: bool,
|
||||
}
|
||||
|
||||
impl ReplayWindow {
|
||||
/// Create a window primed so the first expected record is `start` (the AEAD counter at the end
|
||||
/// of the handshake). Anything strictly below `start` is treated as already-consumed.
|
||||
fn new(start: u64) -> Self {
|
||||
// Treat (start - 1) as the highest already-accepted slot so the first Data record at `seq
|
||||
// == start` is accepted as "newer". `seeded = true` keeps seq 0 handling uniform even when
|
||||
// start == 0 (the handshake always leaves start >= 1, but be defensive).
|
||||
Self {
|
||||
highest: start.saturating_sub(1),
|
||||
bitmap: 0,
|
||||
seeded: start > 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check and record `seq`. Returns `Ok(())` if it is fresh; [`ProtoError::Replay`] otherwise.
|
||||
fn check_and_set(&mut self, seq: u64) -> Result<(), ProtoError> {
|
||||
if !self.seeded {
|
||||
// First ever record (only reachable if started at 0): accept and seed.
|
||||
self.seeded = true;
|
||||
self.highest = seq;
|
||||
self.bitmap = 0;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if seq > self.highest {
|
||||
// New high-water mark: shift the bitmap up by the gap, marking the old highest as seen.
|
||||
let shift = seq - self.highest;
|
||||
if shift >= 64 {
|
||||
self.bitmap = 0;
|
||||
} else {
|
||||
self.bitmap = (self.bitmap << shift) | (1u64 << (shift - 1));
|
||||
}
|
||||
self.highest = seq;
|
||||
Ok(())
|
||||
} else {
|
||||
// seq <= highest: must be inside the window and previously unseen.
|
||||
let offset = self.highest - seq; // 0 == highest itself (already accepted)
|
||||
if offset >= REPLAY_WINDOW {
|
||||
return Err(ProtoError::Replay(seq));
|
||||
}
|
||||
if offset == 0 {
|
||||
// Exactly the current highest: already accepted.
|
||||
return Err(ProtoError::Replay(seq));
|
||||
}
|
||||
let mask = 1u64 << (offset - 1);
|
||||
if self.bitmap & mask != 0 {
|
||||
return Err(ProtoError::Replay(seq));
|
||||
}
|
||||
self.bitmap |= mask;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An established, encrypted Aura session over a transport reader `R` and writer `W`.
|
||||
///
|
||||
/// Created by [`crate::client_handshake`] / [`crate::server_handshake`]. Use
|
||||
/// [`Session::send_frame`] and [`Session::recv_frame`] for application traffic.
|
||||
pub struct Session<R, W> {
|
||||
reader: R,
|
||||
writer: W,
|
||||
/// AEAD this endpoint seals outgoing Data with.
|
||||
send_aead: AeadSession,
|
||||
/// AEAD this endpoint opens incoming Data with.
|
||||
recv_aead: AeadSession,
|
||||
/// Next sequence number to stamp on an outgoing Data record (mirrors `send_aead`'s counter).
|
||||
send_seq: u64,
|
||||
/// Replay window over incoming Data sequence numbers.
|
||||
replay: ReplayWindow,
|
||||
/// The verified identity (Common Name) of the peer, learned during the handshake.
|
||||
peer_id: Option<String>,
|
||||
}
|
||||
|
||||
impl<R, W> Session<R, W>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
/// Assemble a session from the handshake outputs.
|
||||
///
|
||||
/// `start_counter` is the AEAD nonce counter both directions have reached after the encrypted
|
||||
/// handshake messages (so the first Data record stamps `seq == start_counter`). `peer_id` is
|
||||
/// the verified peer Common Name (the server learns the client id; the client may store the
|
||||
/// server name).
|
||||
pub(crate) fn new(
|
||||
reader: R,
|
||||
writer: W,
|
||||
send_aead: AeadSession,
|
||||
recv_aead: AeadSession,
|
||||
start_counter: u64,
|
||||
peer_id: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
reader,
|
||||
writer,
|
||||
send_aead,
|
||||
recv_aead,
|
||||
send_seq: start_counter,
|
||||
replay: ReplayWindow::new(start_counter),
|
||||
peer_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// The verified identity (Common Name) of the peer, if one was captured during the handshake.
|
||||
#[must_use]
|
||||
pub fn peer_id(&self) -> Option<&str> {
|
||||
self.peer_id.as_deref()
|
||||
}
|
||||
|
||||
/// Serialize, seal, and send a single application [`Frame`].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::Io`] on a write failure or [`ProtoError::FrameTooLarge`] if the sealed
|
||||
/// record exceeds the frame header's length field.
|
||||
pub async fn send_frame(&mut self, frame: Frame) -> Result<(), ProtoError> {
|
||||
let frame_bytes = frame.encode();
|
||||
let seq = self.send_seq;
|
||||
|
||||
// Build the record payload: seq(8) || ciphertext. The frame header binds (type, length),
|
||||
// and the seq binds the record position; both go into the AAD.
|
||||
let sealed_len = SEQ_LEN + frame_bytes.len() + 16 /* Poly1305 tag */;
|
||||
let header = frame::encode_header(MsgType::Data, sealed_len)?;
|
||||
|
||||
let mut aad = [0u8; HEADER_LEN + SEQ_LEN];
|
||||
aad[..HEADER_LEN].copy_from_slice(&header);
|
||||
aad[HEADER_LEN..].copy_from_slice(&seq.to_be_bytes());
|
||||
|
||||
let ciphertext = self.send_aead.seal(&frame_bytes, &aad);
|
||||
|
||||
let mut payload = Vec::with_capacity(SEQ_LEN + ciphertext.len());
|
||||
payload.extend_from_slice(&seq.to_be_bytes());
|
||||
payload.extend_from_slice(&ciphertext);
|
||||
debug_assert_eq!(payload.len(), sealed_len);
|
||||
|
||||
// Write header + payload. (We already know the exact header for AAD; reuse it.)
|
||||
use tokio::io::AsyncWriteExt;
|
||||
self.writer.write_all(&header).await?;
|
||||
self.writer.write_all(&payload).await?;
|
||||
self.writer.flush().await?;
|
||||
|
||||
self.send_seq = self.send_seq.wrapping_add(1);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive, open, and decode a single application [`Frame`].
|
||||
///
|
||||
/// # Errors
|
||||
/// * [`ProtoError::Replay`] — the record's sequence number was a duplicate or too old.
|
||||
/// * [`ProtoError::Crypto`] — AEAD authentication failed.
|
||||
/// * [`ProtoError::UnexpectedMsg`] — a non-Data record arrived.
|
||||
/// * [`ProtoError::Alert`] — the peer sent a fatal alert.
|
||||
/// * [`ProtoError::Io`] / [`ProtoError::MalformedFrame`] on transport / decode failure.
|
||||
pub async fn recv_frame(&mut self) -> Result<Frame, ProtoError> {
|
||||
let raw: RawFrame = frame::read_frame(&mut self.reader).await?;
|
||||
match raw.msg_type {
|
||||
MsgType::Data => {}
|
||||
MsgType::Alert => {
|
||||
let code = raw.payload.first().copied().unwrap_or(0);
|
||||
return Err(ProtoError::Alert(code));
|
||||
}
|
||||
other => {
|
||||
return Err(ProtoError::UnexpectedMsg {
|
||||
expected: MsgType::Data,
|
||||
got: other,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if raw.payload.len() < SEQ_LEN {
|
||||
return Err(ProtoError::MalformedFrame(
|
||||
"Data record shorter than seq prefix",
|
||||
));
|
||||
}
|
||||
let mut seq_bytes = [0u8; SEQ_LEN];
|
||||
seq_bytes.copy_from_slice(&raw.payload[..SEQ_LEN]);
|
||||
let seq = u64::from_be_bytes(seq_bytes);
|
||||
let ciphertext = &raw.payload[SEQ_LEN..];
|
||||
|
||||
// Replay check FIRST — a duplicate/old record must not advance the AEAD counter.
|
||||
self.replay.check_and_set(seq)?;
|
||||
|
||||
let mut aad = [0u8; HEADER_LEN + SEQ_LEN];
|
||||
aad[..HEADER_LEN].copy_from_slice(&raw.header);
|
||||
aad[HEADER_LEN..].copy_from_slice(&seq_bytes);
|
||||
|
||||
let plaintext = self.recv_aead.open(ciphertext, &aad)?;
|
||||
Frame::decode(&plaintext)
|
||||
}
|
||||
|
||||
/// Consume the session, returning its transport halves (for clean shutdown / reuse).
|
||||
#[must_use]
|
||||
pub fn into_inner(self) -> (R, W) {
|
||||
(self.reader, self.writer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn replay_window_basic_monotonic() {
|
||||
let mut w = ReplayWindow::new(2);
|
||||
// First legitimate Data records.
|
||||
assert!(w.check_and_set(2).is_ok());
|
||||
assert!(w.check_and_set(3).is_ok());
|
||||
assert!(w.check_and_set(4).is_ok());
|
||||
// Replays of each are rejected.
|
||||
assert!(matches!(w.check_and_set(2), Err(ProtoError::Replay(2))));
|
||||
assert!(matches!(w.check_and_set(3), Err(ProtoError::Replay(3))));
|
||||
assert!(matches!(w.check_and_set(4), Err(ProtoError::Replay(4))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replay_window_out_of_order_within_window() {
|
||||
let mut w = ReplayWindow::new(0);
|
||||
assert!(w.check_and_set(0).is_ok());
|
||||
assert!(w.check_and_set(10).is_ok());
|
||||
// Older-but-unseen inside the window is accepted exactly once.
|
||||
assert!(w.check_and_set(5).is_ok());
|
||||
assert!(matches!(w.check_and_set(5), Err(ProtoError::Replay(5))));
|
||||
// The new-high record (10) is a replay.
|
||||
assert!(matches!(w.check_and_set(10), Err(ProtoError::Replay(10))));
|
||||
// Brand-new high still works.
|
||||
assert!(w.check_and_set(11).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replay_window_rejects_too_old() {
|
||||
let mut w = ReplayWindow::new(0);
|
||||
assert!(w.check_and_set(0).is_ok());
|
||||
assert!(w.check_and_set(200).is_ok());
|
||||
// Far below the window edge => rejected as too old.
|
||||
assert!(matches!(w.check_and_set(1), Err(ProtoError::Replay(1))));
|
||||
assert!(matches!(
|
||||
w.check_and_set(200 - REPLAY_WINDOW),
|
||||
Err(ProtoError::Replay(_))
|
||||
));
|
||||
// Just inside the window edge => still acceptable.
|
||||
assert!(w.check_and_set(200 - REPLAY_WINDOW + 1).is_ok());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user