feat(proto): implement Wave 2 — hybrid PKI handshake + session

aura-proto: 5-byte wire header + Frame codec (§6.1/§6.3); transport-agnostic
handshake state machine (§6.2) over split tokio AsyncRead/AsyncWrite —
hybrid X25519+ML-KEM-768 KEM, SHA-256 transcript, mutual X.509 auth with
ECDSA-P256 transcript signatures (ring), constant-time HMAC Finished;
Session with sliding-window replay protection. 13 tests green, clippy clean.

Handshake message order pinned (resolves spec diagram ambiguity); reader/writer
taken by value since Session owns both halves.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xah30
2026-05-25 18:05:11 +03:00
parent b8ce58ddf0
commit bb835e4ca7
11 changed files with 1710 additions and 1 deletions
+371
View File
@@ -0,0 +1,371 @@
//! Wire format: the 5-byte protocol header (§6.1) and the application [`Frame`] enum (§6.3).
//!
//! Every Aura protocol message on the wire is a 5-byte header followed by a payload:
//!
//! ```text
//! byte 0 : msg_type (u8)
//! bytes 1..4 : length (u24, big-endian) = payload length in bytes
//! byte 4 : version = 0x01
//! bytes 5.. : payload (length bytes)
//! ```
//!
//! [`Frame`] is the post-handshake application payload. Each `Frame` is serialized with
//! [`Frame::encode`], AEAD-sealed, and shipped inside a [`MsgType::Data`] record (see
//! [`crate::session`]).
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::ProtoError;
/// Length in bytes of the protocol frame header.
pub const HEADER_LEN: usize = 5;
/// Protocol version carried in byte 4 of every header.
pub const PROTOCOL_VERSION: u8 = 0x01;
/// Largest payload expressible by the u24 length field.
pub const MAX_PAYLOAD_LEN: usize = 0x00FF_FFFF;
/// Message types carried in byte 0 of the header (§6.1).
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum MsgType {
/// Handshake message 1 (C->S): hybrid public key + client nonce.
ClientHello = 0x01,
/// Handshake message 2 (S->C): hybrid ciphertext + server nonce.
ServerHello = 0x02,
/// Handshake message 4 (C->S, encrypted): client cert + signature.
ClientAuth = 0x03,
/// Handshake message 3 (S->C, encrypted): server cert + signature.
ServerAuth = 0x04,
/// Handshake Finished (encrypted): HMAC over the handshake hash.
Finished = 0x05,
/// Application data (encrypted): an AEAD-sealed [`Frame`].
Data = 0x06,
/// Fatal alert / error notification.
Alert = 0xFF,
}
impl MsgType {
/// Map the on-wire byte to a [`MsgType`].
///
/// # Errors
/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized byte.
pub fn from_u8(b: u8) -> Result<Self, ProtoError> {
Ok(match b {
0x01 => Self::ClientHello,
0x02 => Self::ServerHello,
0x03 => Self::ClientAuth,
0x04 => Self::ServerAuth,
0x05 => Self::Finished,
0x06 => Self::Data,
0xFF => Self::Alert,
other => return Err(ProtoError::UnknownMsgType(other)),
})
}
}
/// Build a 5-byte header for `msg_type` carrying a payload of `payload_len` bytes.
///
/// # Errors
/// Returns [`ProtoError::FrameTooLarge`] if `payload_len` does not fit in the u24 length field.
pub fn encode_header(
msg_type: MsgType,
payload_len: usize,
) -> Result<[u8; HEADER_LEN], ProtoError> {
if payload_len > MAX_PAYLOAD_LEN {
return Err(ProtoError::FrameTooLarge(payload_len));
}
let len = payload_len as u32;
Ok([
msg_type as u8,
((len >> 16) & 0xFF) as u8,
((len >> 8) & 0xFF) as u8,
(len & 0xFF) as u8,
PROTOCOL_VERSION,
])
}
/// Parse a 5-byte header into `(msg_type, payload_len)`.
///
/// # Errors
/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized type byte or
/// [`ProtoError::BadVersion`] if byte 4 is not [`PROTOCOL_VERSION`].
pub fn decode_header(header: &[u8; HEADER_LEN]) -> Result<(MsgType, usize), ProtoError> {
let msg_type = MsgType::from_u8(header[0])?;
let version = header[4];
if version != PROTOCOL_VERSION {
return Err(ProtoError::BadVersion(version));
}
let len = ((header[1] as usize) << 16) | ((header[2] as usize) << 8) | (header[3] as usize);
Ok((msg_type, len))
}
/// Write one full frame (`header || payload`) to `writer`.
///
/// # Errors
/// Returns [`ProtoError::FrameTooLarge`] if the payload is too long, or [`ProtoError::Io`] on a
/// write failure.
pub async fn write_frame<W>(
writer: &mut W,
msg_type: MsgType,
payload: &[u8],
) -> Result<(), ProtoError>
where
W: AsyncWrite + Unpin,
{
let header = encode_header(msg_type, payload.len())?;
writer.write_all(&header).await?;
writer.write_all(payload).await?;
writer.flush().await?;
Ok(())
}
/// A frame read off the wire: its type, the raw header bytes (useful as AEAD AAD and for the
/// handshake transcript hash), and the payload.
#[derive(Clone, Debug)]
pub struct RawFrame {
/// The decoded message type.
pub msg_type: MsgType,
/// The 5 header bytes exactly as transmitted.
pub header: [u8; HEADER_LEN],
/// The payload bytes.
pub payload: Vec<u8>,
}
impl RawFrame {
/// The full serialized frame (`header || payload`) exactly as it appeared on the wire.
///
/// Used to feed the handshake transcript hash, which must hash the bytes as transmitted.
#[must_use]
pub fn wire_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(HEADER_LEN + self.payload.len());
out.extend_from_slice(&self.header);
out.extend_from_slice(&self.payload);
out
}
}
/// Read one full frame (`header || payload`) from `reader`.
///
/// # Errors
/// Returns [`ProtoError::Io`] on a read failure (including a truncated frame / EOF), or a header
/// decode error from [`decode_header`].
pub async fn read_frame<R>(reader: &mut R) -> Result<RawFrame, ProtoError>
where
R: AsyncRead + Unpin,
{
let mut header = [0u8; HEADER_LEN];
reader.read_exact(&mut header).await?;
let (msg_type, len) = decode_header(&header)?;
let mut payload = vec![0u8; len];
reader.read_exact(&mut payload).await?;
Ok(RawFrame {
msg_type,
header,
payload,
})
}
/// Frame type tags used in the application [`Frame`] encoding (§6.3).
mod frame_tag {
pub const DATA: u8 = 0x01;
pub const PING: u8 = 0x02;
pub const PONG: u8 = 0x03;
pub const CLOSE: u8 = 0x04;
}
/// Application-level frames carried inside encrypted [`MsgType::Data`] records (§6.3).
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Frame {
/// A stream data payload.
Data {
/// Logical stream identifier.
stream_id: u32,
/// Opaque application bytes.
payload: Bytes,
},
/// Liveness probe.
Ping {
/// Monotonic sequence number echoed back in the matching [`Frame::Pong`].
seq: u32,
},
/// Reply to a [`Frame::Ping`].
Pong {
/// Sequence number copied from the [`Frame::Ping`].
seq: u32,
},
/// Orderly shutdown of the logical connection.
Close {
/// Application-defined close code.
code: u8,
/// Human-readable reason (UTF-8).
reason: String,
},
}
impl Frame {
/// Serialize this frame to its compact byte encoding.
///
/// Layout (all multi-byte integers big-endian):
/// * `Data` : `0x01 || stream_id(u32) || payload`
/// * `Ping` : `0x02 || seq(u32)`
/// * `Pong` : `0x03 || seq(u32)`
/// * `Close` : `0x04 || code(u8) || reason_len(u32) || reason_utf8`
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let mut out = Vec::new();
match self {
Frame::Data { stream_id, payload } => {
out.push(frame_tag::DATA);
out.extend_from_slice(&stream_id.to_be_bytes());
out.extend_from_slice(payload);
}
Frame::Ping { seq } => {
out.push(frame_tag::PING);
out.extend_from_slice(&seq.to_be_bytes());
}
Frame::Pong { seq } => {
out.push(frame_tag::PONG);
out.extend_from_slice(&seq.to_be_bytes());
}
Frame::Close { code, reason } => {
out.push(frame_tag::CLOSE);
out.push(*code);
let bytes = reason.as_bytes();
out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
out.extend_from_slice(bytes);
}
}
out
}
/// Parse a frame from its byte encoding (the inverse of [`Frame::encode`]).
///
/// # Errors
/// Returns [`ProtoError::MalformedFrame`] if the buffer is truncated, has an unknown tag, or
/// (for `Close`) does not contain valid UTF-8.
pub fn decode(buf: &[u8]) -> Result<Self, ProtoError> {
let (&tag, rest) = buf
.split_first()
.ok_or(ProtoError::MalformedFrame("empty frame"))?;
match tag {
frame_tag::DATA => {
let stream_id = read_u32(rest, "Data.stream_id")?;
let payload = Bytes::copy_from_slice(&rest[4..]);
Ok(Frame::Data { stream_id, payload })
}
frame_tag::PING => Ok(Frame::Ping {
seq: read_u32(rest, "Ping.seq")?,
}),
frame_tag::PONG => Ok(Frame::Pong {
seq: read_u32(rest, "Pong.seq")?,
}),
frame_tag::CLOSE => {
let code = *rest
.first()
.ok_or(ProtoError::MalformedFrame("Close: missing code"))?;
let reason_len = read_u32(&rest[1..], "Close.reason_len")? as usize;
let reason_bytes = rest
.get(5..5 + reason_len)
.ok_or(ProtoError::MalformedFrame("Close: truncated reason"))?;
let reason = String::from_utf8(reason_bytes.to_vec())
.map_err(|_| ProtoError::MalformedFrame("Close: reason not UTF-8"))?;
Ok(Frame::Close { code, reason })
}
_ => Err(ProtoError::MalformedFrame("unknown frame tag")),
}
}
}
/// Read a big-endian u32 from the start of `buf`, erroring if it is too short.
fn read_u32(buf: &[u8], what: &'static str) -> Result<u32, ProtoError> {
let bytes: [u8; 4] = buf
.get(..4)
.ok_or(ProtoError::MalformedFrame(what))?
.try_into()
.expect("slice of length 4 converts to [u8; 4]");
Ok(u32::from_be_bytes(bytes))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_roundtrip_all_types() {
for (ty, byte) in [
(MsgType::ClientHello, 0x01u8),
(MsgType::ServerHello, 0x02),
(MsgType::ClientAuth, 0x03),
(MsgType::ServerAuth, 0x04),
(MsgType::Finished, 0x05),
(MsgType::Data, 0x06),
(MsgType::Alert, 0xFF),
] {
let h = encode_header(ty, 0x0012_3456).unwrap();
assert_eq!(h[0], byte);
assert_eq!(h[4], PROTOCOL_VERSION);
let (got_ty, got_len) = decode_header(&h).unwrap();
assert_eq!(got_ty, ty);
assert_eq!(got_len, 0x0012_3456);
}
}
#[test]
fn header_rejects_oversize_and_bad_version() {
assert!(matches!(
encode_header(MsgType::Data, MAX_PAYLOAD_LEN + 1),
Err(ProtoError::FrameTooLarge(_))
));
let mut h = encode_header(MsgType::Data, 1).unwrap();
h[4] = 0x02;
assert!(matches!(
decode_header(&h),
Err(ProtoError::BadVersion(0x02))
));
h[0] = 0x77;
assert!(matches!(
decode_header(&h),
Err(ProtoError::UnknownMsgType(0x77))
));
}
#[test]
fn frame_roundtrip() {
let frames = vec![
Frame::Data {
stream_id: 0xDEAD_BEEF,
payload: Bytes::from_static(b"hello world"),
},
Frame::Data {
stream_id: 0,
payload: Bytes::new(),
},
Frame::Ping { seq: 42 },
Frame::Pong { seq: 0xFFFF_FFFF },
Frame::Close {
code: 7,
reason: "going away \u{1f44b}".to_string(),
},
Frame::Close {
code: 0,
reason: String::new(),
},
];
for f in frames {
let encoded = f.encode();
let decoded = Frame::decode(&encoded).unwrap();
assert_eq!(f, decoded);
}
}
#[test]
fn frame_decode_rejects_garbage() {
assert!(Frame::decode(&[]).is_err());
assert!(Frame::decode(&[0x99]).is_err());
assert!(Frame::decode(&[frame_tag::PING, 0x00]).is_err()); // truncated u32
assert!(Frame::decode(&[frame_tag::CLOSE]).is_err()); // missing code
}
}