feat(proto): implement Wave 2 — hybrid PKI handshake + session
aura-proto: 5-byte wire header + Frame codec (§6.1/§6.3); transport-agnostic handshake state machine (§6.2) over split tokio AsyncRead/AsyncWrite — hybrid X25519+ML-KEM-768 KEM, SHA-256 transcript, mutual X.509 auth with ECDSA-P256 transcript signatures (ring), constant-time HMAC Finished; Session with sliding-window replay protection. 13 tests green, clippy clean. Handshake message order pinned (resolves spec diagram ambiguity); reader/writer taken by value since Session owns both halves. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,371 @@
|
||||
//! Wire format: the 5-byte protocol header (§6.1) and the application [`Frame`] enum (§6.3).
|
||||
//!
|
||||
//! Every Aura protocol message on the wire is a 5-byte header followed by a payload:
|
||||
//!
|
||||
//! ```text
|
||||
//! byte 0 : msg_type (u8)
|
||||
//! bytes 1..4 : length (u24, big-endian) = payload length in bytes
|
||||
//! byte 4 : version = 0x01
|
||||
//! bytes 5.. : payload (length bytes)
|
||||
//! ```
|
||||
//!
|
||||
//! [`Frame`] is the post-handshake application payload. Each `Frame` is serialized with
|
||||
//! [`Frame::encode`], AEAD-sealed, and shipped inside a [`MsgType::Data`] record (see
|
||||
//! [`crate::session`]).
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::ProtoError;
|
||||
|
||||
/// Length in bytes of the protocol frame header.
|
||||
pub const HEADER_LEN: usize = 5;
|
||||
|
||||
/// Protocol version carried in byte 4 of every header.
|
||||
pub const PROTOCOL_VERSION: u8 = 0x01;
|
||||
|
||||
/// Largest payload expressible by the u24 length field.
|
||||
pub const MAX_PAYLOAD_LEN: usize = 0x00FF_FFFF;
|
||||
|
||||
/// Message types carried in byte 0 of the header (§6.1).
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum MsgType {
|
||||
/// Handshake message 1 (C->S): hybrid public key + client nonce.
|
||||
ClientHello = 0x01,
|
||||
/// Handshake message 2 (S->C): hybrid ciphertext + server nonce.
|
||||
ServerHello = 0x02,
|
||||
/// Handshake message 4 (C->S, encrypted): client cert + signature.
|
||||
ClientAuth = 0x03,
|
||||
/// Handshake message 3 (S->C, encrypted): server cert + signature.
|
||||
ServerAuth = 0x04,
|
||||
/// Handshake Finished (encrypted): HMAC over the handshake hash.
|
||||
Finished = 0x05,
|
||||
/// Application data (encrypted): an AEAD-sealed [`Frame`].
|
||||
Data = 0x06,
|
||||
/// Fatal alert / error notification.
|
||||
Alert = 0xFF,
|
||||
}
|
||||
|
||||
impl MsgType {
|
||||
/// Map the on-wire byte to a [`MsgType`].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized byte.
|
||||
pub fn from_u8(b: u8) -> Result<Self, ProtoError> {
|
||||
Ok(match b {
|
||||
0x01 => Self::ClientHello,
|
||||
0x02 => Self::ServerHello,
|
||||
0x03 => Self::ClientAuth,
|
||||
0x04 => Self::ServerAuth,
|
||||
0x05 => Self::Finished,
|
||||
0x06 => Self::Data,
|
||||
0xFF => Self::Alert,
|
||||
other => return Err(ProtoError::UnknownMsgType(other)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a 5-byte header for `msg_type` carrying a payload of `payload_len` bytes.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::FrameTooLarge`] if `payload_len` does not fit in the u24 length field.
|
||||
pub fn encode_header(
|
||||
msg_type: MsgType,
|
||||
payload_len: usize,
|
||||
) -> Result<[u8; HEADER_LEN], ProtoError> {
|
||||
if payload_len > MAX_PAYLOAD_LEN {
|
||||
return Err(ProtoError::FrameTooLarge(payload_len));
|
||||
}
|
||||
let len = payload_len as u32;
|
||||
Ok([
|
||||
msg_type as u8,
|
||||
((len >> 16) & 0xFF) as u8,
|
||||
((len >> 8) & 0xFF) as u8,
|
||||
(len & 0xFF) as u8,
|
||||
PROTOCOL_VERSION,
|
||||
])
|
||||
}
|
||||
|
||||
/// Parse a 5-byte header into `(msg_type, payload_len)`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::UnknownMsgType`] for an unrecognized type byte or
|
||||
/// [`ProtoError::BadVersion`] if byte 4 is not [`PROTOCOL_VERSION`].
|
||||
pub fn decode_header(header: &[u8; HEADER_LEN]) -> Result<(MsgType, usize), ProtoError> {
|
||||
let msg_type = MsgType::from_u8(header[0])?;
|
||||
let version = header[4];
|
||||
if version != PROTOCOL_VERSION {
|
||||
return Err(ProtoError::BadVersion(version));
|
||||
}
|
||||
let len = ((header[1] as usize) << 16) | ((header[2] as usize) << 8) | (header[3] as usize);
|
||||
Ok((msg_type, len))
|
||||
}
|
||||
|
||||
/// Write one full frame (`header || payload`) to `writer`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::FrameTooLarge`] if the payload is too long, or [`ProtoError::Io`] on a
|
||||
/// write failure.
|
||||
pub async fn write_frame<W>(
|
||||
writer: &mut W,
|
||||
msg_type: MsgType,
|
||||
payload: &[u8],
|
||||
) -> Result<(), ProtoError>
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let header = encode_header(msg_type, payload.len())?;
|
||||
writer.write_all(&header).await?;
|
||||
writer.write_all(payload).await?;
|
||||
writer.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A frame read off the wire: its type, the raw header bytes (useful as AEAD AAD and for the
|
||||
/// handshake transcript hash), and the payload.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RawFrame {
|
||||
/// The decoded message type.
|
||||
pub msg_type: MsgType,
|
||||
/// The 5 header bytes exactly as transmitted.
|
||||
pub header: [u8; HEADER_LEN],
|
||||
/// The payload bytes.
|
||||
pub payload: Vec<u8>,
|
||||
}
|
||||
|
||||
impl RawFrame {
|
||||
/// The full serialized frame (`header || payload`) exactly as it appeared on the wire.
|
||||
///
|
||||
/// Used to feed the handshake transcript hash, which must hash the bytes as transmitted.
|
||||
#[must_use]
|
||||
pub fn wire_bytes(&self) -> Vec<u8> {
|
||||
let mut out = Vec::with_capacity(HEADER_LEN + self.payload.len());
|
||||
out.extend_from_slice(&self.header);
|
||||
out.extend_from_slice(&self.payload);
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
/// Read one full frame (`header || payload`) from `reader`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::Io`] on a read failure (including a truncated frame / EOF), or a header
|
||||
/// decode error from [`decode_header`].
|
||||
pub async fn read_frame<R>(reader: &mut R) -> Result<RawFrame, ProtoError>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
let mut header = [0u8; HEADER_LEN];
|
||||
reader.read_exact(&mut header).await?;
|
||||
let (msg_type, len) = decode_header(&header)?;
|
||||
let mut payload = vec![0u8; len];
|
||||
reader.read_exact(&mut payload).await?;
|
||||
Ok(RawFrame {
|
||||
msg_type,
|
||||
header,
|
||||
payload,
|
||||
})
|
||||
}
|
||||
|
||||
/// Frame type tags used in the application [`Frame`] encoding (§6.3).
|
||||
mod frame_tag {
|
||||
pub const DATA: u8 = 0x01;
|
||||
pub const PING: u8 = 0x02;
|
||||
pub const PONG: u8 = 0x03;
|
||||
pub const CLOSE: u8 = 0x04;
|
||||
}
|
||||
|
||||
/// Application-level frames carried inside encrypted [`MsgType::Data`] records (§6.3).
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum Frame {
|
||||
/// A stream data payload.
|
||||
Data {
|
||||
/// Logical stream identifier.
|
||||
stream_id: u32,
|
||||
/// Opaque application bytes.
|
||||
payload: Bytes,
|
||||
},
|
||||
/// Liveness probe.
|
||||
Ping {
|
||||
/// Monotonic sequence number echoed back in the matching [`Frame::Pong`].
|
||||
seq: u32,
|
||||
},
|
||||
/// Reply to a [`Frame::Ping`].
|
||||
Pong {
|
||||
/// Sequence number copied from the [`Frame::Ping`].
|
||||
seq: u32,
|
||||
},
|
||||
/// Orderly shutdown of the logical connection.
|
||||
Close {
|
||||
/// Application-defined close code.
|
||||
code: u8,
|
||||
/// Human-readable reason (UTF-8).
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Frame {
|
||||
/// Serialize this frame to its compact byte encoding.
|
||||
///
|
||||
/// Layout (all multi-byte integers big-endian):
|
||||
/// * `Data` : `0x01 || stream_id(u32) || payload`
|
||||
/// * `Ping` : `0x02 || seq(u32)`
|
||||
/// * `Pong` : `0x03 || seq(u32)`
|
||||
/// * `Close` : `0x04 || code(u8) || reason_len(u32) || reason_utf8`
|
||||
#[must_use]
|
||||
pub fn encode(&self) -> Vec<u8> {
|
||||
let mut out = Vec::new();
|
||||
match self {
|
||||
Frame::Data { stream_id, payload } => {
|
||||
out.push(frame_tag::DATA);
|
||||
out.extend_from_slice(&stream_id.to_be_bytes());
|
||||
out.extend_from_slice(payload);
|
||||
}
|
||||
Frame::Ping { seq } => {
|
||||
out.push(frame_tag::PING);
|
||||
out.extend_from_slice(&seq.to_be_bytes());
|
||||
}
|
||||
Frame::Pong { seq } => {
|
||||
out.push(frame_tag::PONG);
|
||||
out.extend_from_slice(&seq.to_be_bytes());
|
||||
}
|
||||
Frame::Close { code, reason } => {
|
||||
out.push(frame_tag::CLOSE);
|
||||
out.push(*code);
|
||||
let bytes = reason.as_bytes();
|
||||
out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
|
||||
out.extend_from_slice(bytes);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Parse a frame from its byte encoding (the inverse of [`Frame::encode`]).
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns [`ProtoError::MalformedFrame`] if the buffer is truncated, has an unknown tag, or
|
||||
/// (for `Close`) does not contain valid UTF-8.
|
||||
pub fn decode(buf: &[u8]) -> Result<Self, ProtoError> {
|
||||
let (&tag, rest) = buf
|
||||
.split_first()
|
||||
.ok_or(ProtoError::MalformedFrame("empty frame"))?;
|
||||
match tag {
|
||||
frame_tag::DATA => {
|
||||
let stream_id = read_u32(rest, "Data.stream_id")?;
|
||||
let payload = Bytes::copy_from_slice(&rest[4..]);
|
||||
Ok(Frame::Data { stream_id, payload })
|
||||
}
|
||||
frame_tag::PING => Ok(Frame::Ping {
|
||||
seq: read_u32(rest, "Ping.seq")?,
|
||||
}),
|
||||
frame_tag::PONG => Ok(Frame::Pong {
|
||||
seq: read_u32(rest, "Pong.seq")?,
|
||||
}),
|
||||
frame_tag::CLOSE => {
|
||||
let code = *rest
|
||||
.first()
|
||||
.ok_or(ProtoError::MalformedFrame("Close: missing code"))?;
|
||||
let reason_len = read_u32(&rest[1..], "Close.reason_len")? as usize;
|
||||
let reason_bytes = rest
|
||||
.get(5..5 + reason_len)
|
||||
.ok_or(ProtoError::MalformedFrame("Close: truncated reason"))?;
|
||||
let reason = String::from_utf8(reason_bytes.to_vec())
|
||||
.map_err(|_| ProtoError::MalformedFrame("Close: reason not UTF-8"))?;
|
||||
Ok(Frame::Close { code, reason })
|
||||
}
|
||||
_ => Err(ProtoError::MalformedFrame("unknown frame tag")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a big-endian u32 from the start of `buf`, erroring if it is too short.
|
||||
fn read_u32(buf: &[u8], what: &'static str) -> Result<u32, ProtoError> {
|
||||
let bytes: [u8; 4] = buf
|
||||
.get(..4)
|
||||
.ok_or(ProtoError::MalformedFrame(what))?
|
||||
.try_into()
|
||||
.expect("slice of length 4 converts to [u8; 4]");
|
||||
Ok(u32::from_be_bytes(bytes))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn header_roundtrip_all_types() {
|
||||
for (ty, byte) in [
|
||||
(MsgType::ClientHello, 0x01u8),
|
||||
(MsgType::ServerHello, 0x02),
|
||||
(MsgType::ClientAuth, 0x03),
|
||||
(MsgType::ServerAuth, 0x04),
|
||||
(MsgType::Finished, 0x05),
|
||||
(MsgType::Data, 0x06),
|
||||
(MsgType::Alert, 0xFF),
|
||||
] {
|
||||
let h = encode_header(ty, 0x0012_3456).unwrap();
|
||||
assert_eq!(h[0], byte);
|
||||
assert_eq!(h[4], PROTOCOL_VERSION);
|
||||
let (got_ty, got_len) = decode_header(&h).unwrap();
|
||||
assert_eq!(got_ty, ty);
|
||||
assert_eq!(got_len, 0x0012_3456);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_rejects_oversize_and_bad_version() {
|
||||
assert!(matches!(
|
||||
encode_header(MsgType::Data, MAX_PAYLOAD_LEN + 1),
|
||||
Err(ProtoError::FrameTooLarge(_))
|
||||
));
|
||||
let mut h = encode_header(MsgType::Data, 1).unwrap();
|
||||
h[4] = 0x02;
|
||||
assert!(matches!(
|
||||
decode_header(&h),
|
||||
Err(ProtoError::BadVersion(0x02))
|
||||
));
|
||||
h[0] = 0x77;
|
||||
assert!(matches!(
|
||||
decode_header(&h),
|
||||
Err(ProtoError::UnknownMsgType(0x77))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_roundtrip() {
|
||||
let frames = vec![
|
||||
Frame::Data {
|
||||
stream_id: 0xDEAD_BEEF,
|
||||
payload: Bytes::from_static(b"hello world"),
|
||||
},
|
||||
Frame::Data {
|
||||
stream_id: 0,
|
||||
payload: Bytes::new(),
|
||||
},
|
||||
Frame::Ping { seq: 42 },
|
||||
Frame::Pong { seq: 0xFFFF_FFFF },
|
||||
Frame::Close {
|
||||
code: 7,
|
||||
reason: "going away \u{1f44b}".to_string(),
|
||||
},
|
||||
Frame::Close {
|
||||
code: 0,
|
||||
reason: String::new(),
|
||||
},
|
||||
];
|
||||
for f in frames {
|
||||
let encoded = f.encode();
|
||||
let decoded = Frame::decode(&encoded).unwrap();
|
||||
assert_eq!(f, decoded);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_decode_rejects_garbage() {
|
||||
assert!(Frame::decode(&[]).is_err());
|
||||
assert!(Frame::decode(&[0x99]).is_err());
|
||||
assert!(Frame::decode(&[frame_tag::PING, 0x00]).is_err()); // truncated u32
|
||||
assert!(Frame::decode(&[frame_tag::CLOSE]).is_err()); // missing code
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user