feat(singbox-aura,tools): Go port of Aura UDP client + KAT bridge to Rust
Lays the foundation for sing-box mobile clients (Option B from
docs/sing-box.md): an independent Go module that speaks the AuraVPN wire
protocol byte-for-byte. Proof of equivalence is in KAT tests cross-loaded
from a Rust-side deterministic vector exporter.
- tools/export-kat (new Rust bin in workspace): captures a handshake +
derived keys + a sealed datagram record + a knock token using seeded
RNGs (rand::rngs::StdRng + ml-kem's *_deterministic public API), emits
JSON. Reproducible byte-for-byte.
- singbox-aura/ (new Go module, ~3000 LOC, 22 files):
- aura/frame: 5-byte protocol header + Frame{Data,Ping,Pong,Close,
Control} + magic envelope (0xAA,0xAA,0xC0,0x01) — encode/decode
matching aura-proto::frame.
- aura/crypto: hybrid X25519 + ML-KEM-768 (stdlib crypto/ecdh +
crypto/mlkem on Go 1.24+; falls back to circl on older Go via a
documented swap), HKDF-SHA256 derive_session_keys, ChaCha20-Poly1305
with the **LE(u64 counter) || [0;4]** nonce scheme that matches
aura-crypto::AeadKey/AeadSession.
- aura/handshake: client_handshake state machine reproducing protocol.md
§6.2 exactly (CH→SH→ServerAuth→ClientAuth→Finished×2; transcript hash;
ECDSA-P256 transcript signature; HMAC-SHA256 Finished).
- aura/session: DatagramSender/Receiver + 64-wide sliding replay window.
- aura/transport: reliable HS-adapter (DTLS-flight retransmit) + UDP
datagram data path + 16-byte HMAC port-knock with ±1-minute window.
- aura/outbound: sing-box-shaped shim (interface signatures only — sing-
box upstream registration is one more step, documented in README).
- cmd/aura-client: standalone Go binary; reads client.toml via
pelletier/go-toml/v2 and connects to a real aura server. Validates
end-to-end interop with the Rust side.
- KAT: 6 comparisons against Rust vectors — session_keys (HKDF), hybrid
KEM ek/encaps roundtrip, c2s + s2c Finished HMAC, sealed datagram
record at seq=2 (incl. 16-byte Poly1305 tag), knock token. All byte-
for-byte.
Go: 29 tests across 5 packages, all green. Only deps: golang.org/x/crypto
and pelletier/go-toml/v2. Rust: 293 tests still green; tools/export-kat
added to workspace members.
v1 limits documented in singbox-aura/README.md: UDP-only (no TCP/QUIC
fallback yet), no cell padding / cover traffic, no relay/exit role, no
multi-hop, sing-box upstream-registration sketch (vendor sagernet/sing-box +
init() RegisterOutbound) for follow-up.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,102 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// NonceLen is the AEAD nonce length (96 bits for ChaCha20-Poly1305).
|
||||
const NonceLen = 12
|
||||
|
||||
// NonceFor reproduces the AeadSession::nonce_for layout exactly:
|
||||
//
|
||||
// nonce[0..8] = LE(u64) counter
|
||||
// nonce[8..12] = 0
|
||||
//
|
||||
// Both stream- and datagram-mode AEADs share this nonce derivation; the only difference is
|
||||
// whether the counter is advanced lock-step (stream) or carried on the wire (datagram).
|
||||
func NonceFor(counter uint64) [NonceLen]byte {
|
||||
var n [NonceLen]byte
|
||||
binary.LittleEndian.PutUint64(n[0:8], counter)
|
||||
return n
|
||||
}
|
||||
|
||||
// AeadKey wraps a 32-byte ChaCha20-Poly1305 key for explicit-nonce datagram use. The caller owns
|
||||
// nonce uniqueness — Aura's datagram codec carries the counter on the wire as `seq`.
|
||||
type AeadKey struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
// NewAeadKey builds an AeadKey from a 32-byte key. Returns an error if the key is the wrong
|
||||
// size; ChaCha20-Poly1305 always wants 32.
|
||||
func NewAeadKey(key []byte) (*AeadKey, error) {
|
||||
if len(key) != SessionKeyLen {
|
||||
return nil, fmt.Errorf("aead key must be %d bytes, got %d", SessionKeyLen, len(key))
|
||||
}
|
||||
a, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chacha20poly1305.New: %w", err)
|
||||
}
|
||||
return &AeadKey{aead: a}, nil
|
||||
}
|
||||
|
||||
// Seal encrypts plaintext under the nonce derived from counter, returning ciphertext||tag.
|
||||
func (k *AeadKey) Seal(counter uint64, plaintext, aad []byte) []byte {
|
||||
nonce := NonceFor(counter)
|
||||
return k.aead.Seal(nil, nonce[:], plaintext, aad)
|
||||
}
|
||||
|
||||
// Open authenticates and decrypts ciphertext (which must include the 16-byte Poly1305 tag).
|
||||
// Returns the plaintext, or an error on authentication failure.
|
||||
func (k *AeadKey) Open(counter uint64, ciphertext, aad []byte) ([]byte, error) {
|
||||
nonce := NonceFor(counter)
|
||||
out, err := k.aead.Open(nil, nonce[:], ciphertext, aad)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aead open: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// AeadSession is the stream-mode counterpart: it holds the key plus a monotonically increasing
|
||||
// 64-bit counter that advances on every Seal and Open. Used by the handshake's encrypted
|
||||
// messages (ServerAuth, ClientAuth, Finished) so the two sides stay in lockstep without putting
|
||||
// the counter on the wire.
|
||||
type AeadSession struct {
|
||||
key *AeadKey
|
||||
counter uint64
|
||||
}
|
||||
|
||||
// NewAeadSession starts a session at counter 0.
|
||||
func NewAeadSession(rawKey []byte) (*AeadSession, error) {
|
||||
k, err := NewAeadKey(rawKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AeadSession{key: k, counter: 0}, nil
|
||||
}
|
||||
|
||||
// Counter is the current counter (the nonce that the next Seal/Open will use). Test-only and
|
||||
// used by Session.IntoDatagramParts to hand off the explicit-nonce key.
|
||||
func (s *AeadSession) Counter() uint64 { return s.counter }
|
||||
|
||||
// Seal seals plaintext at the current counter then advances it.
|
||||
func (s *AeadSession) Seal(plaintext, aad []byte) []byte {
|
||||
ct := s.key.Seal(s.counter, plaintext, aad)
|
||||
s.counter++
|
||||
return ct
|
||||
}
|
||||
|
||||
// Open verifies+decrypts ciphertext at the current counter then advances it (symmetric to Seal
|
||||
// so a failed decrypt keeps the two ends aligned).
|
||||
func (s *AeadSession) Open(ciphertext, aad []byte) ([]byte, error) {
|
||||
pt, err := s.key.Open(s.counter, ciphertext, aad)
|
||||
s.counter++
|
||||
return pt, err
|
||||
}
|
||||
|
||||
// IntoKey returns the underlying AeadKey so datagram-mode codecs can continue at the same
|
||||
// counter without re-deriving anything (matches Rust's into_parts).
|
||||
func (s *AeadSession) IntoKey() *AeadKey { return s.key }
|
||||
@@ -0,0 +1,279 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// vectorsJSON mirrors the JSON written by tools/export-kat (in Rust). Every field is hex.
|
||||
type vectorsJSON struct {
|
||||
CAFingerprint string `json:"ca_fingerprint"`
|
||||
ClientX25519Priv string `json:"client_x25519_priv"`
|
||||
ClientX25519Pub string `json:"client_x25519_pub"`
|
||||
ClientKyberPriv string `json:"client_kyber_priv"`
|
||||
ClientKyberPub string `json:"client_kyber_pub"`
|
||||
ServerX25519EphPriv string `json:"server_x25519_eph_priv"`
|
||||
ServerX25519EphPub string `json:"server_x25519_eph_pub"`
|
||||
ServerKyberCt string `json:"server_kyber_ct"`
|
||||
ClientNonce string `json:"client_nonce"`
|
||||
ServerNonce string `json:"server_nonce"`
|
||||
X25519SS string `json:"x25519_ss"`
|
||||
KyberSS string `json:"kyber_ss"`
|
||||
SessionKeys struct {
|
||||
C2S string `json:"c2s"`
|
||||
S2C string `json:"s2c"`
|
||||
} `json:"session_keys"`
|
||||
TranscriptHash string `json:"transcript_hash"`
|
||||
ClientFinishedHmac string `json:"client_finished_hmac"`
|
||||
ServerFinishedHmac string `json:"server_finished_hmac"`
|
||||
DatagramTest struct {
|
||||
Seq uint64 `json:"seq"`
|
||||
Frame string `json:"frame"`
|
||||
Key string `json:"key"`
|
||||
SealedRecord string `json:"sealed_record"`
|
||||
} `json:"datagram_test"`
|
||||
KnockTest struct {
|
||||
CAFingerprint string `json:"ca_fingerprint"`
|
||||
UnixMinute uint64 `json:"unix_minute"`
|
||||
Knock string `json:"knock"`
|
||||
} `json:"knock_test"`
|
||||
}
|
||||
|
||||
// loadVectors finds the vectors file at <module>/kat/vectors.json. The file is created by
|
||||
//
|
||||
// cargo run -p export-kat
|
||||
//
|
||||
// from the workspace root.
|
||||
func loadVectors(t *testing.T) *vectorsJSON {
|
||||
t.Helper()
|
||||
// crypto_test.go is at singbox-aura/aura/crypto/. The KAT lives at singbox-aura/kat/.
|
||||
_, thisFile, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("runtime.Caller failed")
|
||||
}
|
||||
path := filepath.Join(filepath.Dir(thisFile), "..", "..", "kat", "vectors.json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Skipf("KAT vectors.json not present at %s — run `cargo run -p export-kat` first: %v", path, err)
|
||||
return nil
|
||||
}
|
||||
var v vectorsJSON
|
||||
if err := json.Unmarshal(data, &v); err != nil {
|
||||
t.Fatalf("parse vectors.json: %v", err)
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
func mustHex(t *testing.T, s string) []byte {
|
||||
t.Helper()
|
||||
b, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
t.Fatalf("hex decode %q: %v", s, err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func mustHex32(t *testing.T, s string) [32]byte {
|
||||
b := mustHex(t, s)
|
||||
if len(b) != 32 {
|
||||
t.Fatalf("want 32 bytes, got %d", len(b))
|
||||
}
|
||||
var out [32]byte
|
||||
copy(out[:], b)
|
||||
return out
|
||||
}
|
||||
|
||||
// TestKAT_SessionKeys: HKDF-derive from the shared secrets in the vector reproduces the
|
||||
// session_keys.{c2s,s2c} byte-for-byte.
|
||||
func TestKAT_SessionKeys(t *testing.T) {
|
||||
v := loadVectors(t)
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
xss := mustHex32(t, v.X25519SS)
|
||||
kss := mustHex32(t, v.KyberSS)
|
||||
cn := mustHex32(t, v.ClientNonce)
|
||||
sn := mustHex32(t, v.ServerNonce)
|
||||
wantC2S := mustHex(t, v.SessionKeys.C2S)
|
||||
wantS2C := mustHex(t, v.SessionKeys.S2C)
|
||||
|
||||
shared := &HybridSharedSecret{X25519SS: xss, MLKEMSS: kss}
|
||||
keys := DeriveSessionKeys(shared, cn, sn)
|
||||
if !bytes.Equal(keys.ClientToServer[:], wantC2S) {
|
||||
t.Fatalf("c2s mismatch:\n got %x\nwant %x", keys.ClientToServer, wantC2S)
|
||||
}
|
||||
if !bytes.Equal(keys.ServerToClient[:], wantS2C) {
|
||||
t.Fatalf("s2c mismatch:\n got %x\nwant %x", keys.ServerToClient, wantS2C)
|
||||
}
|
||||
}
|
||||
|
||||
// TestKAT_HybridDecapsulateRoundtrip: load the client's deterministic hybrid key from the
|
||||
// vector, then run Decapsulate against the server's ciphertext. The derived shared secrets must
|
||||
// match x25519_ss / kyber_ss in the vector.
|
||||
func TestKAT_HybridDecapsulateRoundtrip(t *testing.T) {
|
||||
v := loadVectors(t)
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
xPriv := mustHex32(t, v.ClientX25519Priv)
|
||||
// We don't ship the ml-kem seed in the JSON directly (the export tool uses a fixed seed and
|
||||
// stores only the expanded private key for diagnostics). Instead, reconstruct from the seed
|
||||
// the export tool documents — match the literal bytes in tools/export-kat/src/main.rs.
|
||||
var seed [64]byte
|
||||
copy(seed[:32], []byte("AURA-MLKEM-DSEED-CLIENT--FIXED32"))
|
||||
copy(seed[32:], []byte("AURA-MLKEM-ZSEED-CLIENT--FIXED32"))
|
||||
priv, pub, err := NewHybridPrivateFromBytes(xPriv, seed)
|
||||
if err != nil {
|
||||
t.Fatalf("rebuild hybrid: %v", err)
|
||||
}
|
||||
// Sanity: the recomputed encapsulation key must match what the Rust side emitted.
|
||||
if !bytes.Equal(pub.MLKEM, mustHex(t, v.ClientKyberPub)) {
|
||||
t.Fatalf("ml-kem ek mismatch: Go and Rust derive different bytes from the same seed")
|
||||
}
|
||||
if !bytes.Equal(pub.X25519[:], mustHex(t, v.ClientX25519Pub)) {
|
||||
t.Fatalf("x25519 pub mismatch")
|
||||
}
|
||||
// Decapsulate.
|
||||
ct := &HybridCiphertext{MLKEMCT: mustHex(t, v.ServerKyberCt)}
|
||||
copy(ct.X25519Eph[:], mustHex(t, v.ServerX25519EphPub))
|
||||
ss, err := priv.Decapsulate(ct)
|
||||
if err != nil {
|
||||
t.Fatalf("decapsulate: %v", err)
|
||||
}
|
||||
if !bytes.Equal(ss.X25519SS[:], mustHex(t, v.X25519SS)) {
|
||||
t.Fatalf("x25519_ss mismatch:\n got %x\nwant %s", ss.X25519SS, v.X25519SS)
|
||||
}
|
||||
if !bytes.Equal(ss.MLKEMSS[:], mustHex(t, v.KyberSS)) {
|
||||
t.Fatalf("kyber_ss mismatch:\n got %x\nwant %s", ss.MLKEMSS, v.KyberSS)
|
||||
}
|
||||
}
|
||||
|
||||
// TestKAT_ClientFinishedHMAC: HMAC-SHA256(c2s, transcript_hash) reproduces the Rust value.
|
||||
func TestKAT_ClientFinishedHMAC(t *testing.T) {
|
||||
v := loadVectors(t)
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
key := mustHex(t, v.SessionKeys.C2S)
|
||||
transcript := mustHex(t, v.TranscriptHash)
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(transcript)
|
||||
got := mac.Sum(nil)
|
||||
want := mustHex(t, v.ClientFinishedHmac)
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("client finished mismatch:\n got %x\nwant %x", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestKAT_ServerFinishedHMAC: HMAC-SHA256(s2c, transcript_hash) reproduces the Rust value.
|
||||
func TestKAT_ServerFinishedHMAC(t *testing.T) {
|
||||
v := loadVectors(t)
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
key := mustHex(t, v.SessionKeys.S2C)
|
||||
transcript := mustHex(t, v.TranscriptHash)
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(transcript)
|
||||
got := mac.Sum(nil)
|
||||
want := mustHex(t, v.ServerFinishedHmac)
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("server finished mismatch:\n got %x\nwant %x", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestKAT_SealedDatagramRecord: ChaCha20-Poly1305.Seal under the c2s key at seq 2 with
|
||||
// aad=seq_be reproduces the exact sealed_record bytes (seq_be || ciphertext).
|
||||
func TestKAT_SealedDatagramRecord(t *testing.T) {
|
||||
v := loadVectors(t)
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
key, err := NewAeadKey(mustHex(t, v.DatagramTest.Key))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
frameBytes := mustHex(t, v.DatagramTest.Frame)
|
||||
seq := v.DatagramTest.Seq
|
||||
var seqBE [8]byte
|
||||
binary.BigEndian.PutUint64(seqBE[:], seq)
|
||||
ct := key.Seal(seq, frameBytes, seqBE[:])
|
||||
got := append(append([]byte{}, seqBE[:]...), ct...)
|
||||
want := mustHex(t, v.DatagramTest.SealedRecord)
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("sealed datagram mismatch:\n got %x\nwant %x", got, want)
|
||||
}
|
||||
// Round-trip: opening at the same seq must return the original frame bytes.
|
||||
pt, err := key.Open(seq, ct, seqBE[:])
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
if !bytes.Equal(pt, frameBytes) {
|
||||
t.Fatal("open returned different plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKAT_KnockToken: HMAC-SHA256(ca_fp, u64_be(minute))[:16] matches the Rust knock value.
|
||||
func TestKAT_KnockToken(t *testing.T) {
|
||||
v := loadVectors(t)
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
key := mustHex(t, v.KnockTest.CAFingerprint)
|
||||
var mb [8]byte
|
||||
binary.BigEndian.PutUint64(mb[:], v.KnockTest.UnixMinute)
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(mb[:])
|
||||
tag := mac.Sum(nil)
|
||||
if len(tag) < 16 {
|
||||
t.Fatalf("hmac too short: %d", len(tag))
|
||||
}
|
||||
got := tag[:16]
|
||||
want := mustHex(t, v.KnockTest.Knock)
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("knock mismatch:\n got %x\nwant %x", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonceLayout: explicit sanity that NonceFor matches the documented LE(u64) || 0x00000000.
|
||||
func TestNonceLayout(t *testing.T) {
|
||||
if got := NonceFor(0); got != ([NonceLen]byte{}) {
|
||||
t.Fatalf("counter 0: want zero, got %x", got)
|
||||
}
|
||||
n := NonceFor(0x0807060504030201)
|
||||
if !bytes.Equal(n[:8], []byte{1, 2, 3, 4, 5, 6, 7, 8}) {
|
||||
t.Fatalf("LE layout wrong: %x", n[:8])
|
||||
}
|
||||
if !bytes.Equal(n[8:], []byte{0, 0, 0, 0}) {
|
||||
t.Fatalf("upper 4 bytes not zero: %x", n[8:])
|
||||
}
|
||||
}
|
||||
|
||||
// TestAeadSessionCounterMonotonic: Seal/Open lock-step advances the counter by exactly 1.
|
||||
func TestAeadSessionCounterMonotonic(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
s, err := NewAeadSession(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if s.Counter() != 0 {
|
||||
t.Fatalf("initial counter %d", s.Counter())
|
||||
}
|
||||
for want := uint64(1); want <= 5; want++ {
|
||||
_ = s.Seal([]byte("x"), nil)
|
||||
if s.Counter() != want {
|
||||
t.Fatalf("after %d seals: counter %d", want, s.Counter())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// HKDFInfo is the domain-separation string bound into the HKDF expansion.
|
||||
// MUST match HKDF_INFO in crates/aura-crypto/src/kdf.rs.
|
||||
var HKDFInfo = []byte("aura-v1-session")
|
||||
|
||||
// SessionKeyLen is the size of one directional AEAD key.
|
||||
const SessionKeyLen = 32
|
||||
|
||||
// SessionKeys is the pair of directional 256-bit keys produced by the HKDF expansion.
|
||||
type SessionKeys struct {
|
||||
ClientToServer [SessionKeyLen]byte
|
||||
ServerToClient [SessionKeyLen]byte
|
||||
}
|
||||
|
||||
// DeriveSessionKeys runs HKDF-SHA256 with
|
||||
//
|
||||
// salt = client_nonce || server_nonce (64 bytes)
|
||||
// IKM = x25519_ss || mlkem_ss (64 bytes)
|
||||
// info = "aura-v1-session", OKM 64 bytes -> (c2s, s2c)
|
||||
//
|
||||
// matching the production helper in crates/aura-crypto/src/kdf.rs byte-for-byte.
|
||||
func DeriveSessionKeys(shared *HybridSharedSecret, clientNonce, serverNonce [32]byte) *SessionKeys {
|
||||
salt := make([]byte, 64)
|
||||
copy(salt[:32], clientNonce[:])
|
||||
copy(salt[32:], serverNonce[:])
|
||||
|
||||
ikm := shared.Concat()
|
||||
hk := hkdf.New(func() hash.Hash { return sha256.New() }, ikm, salt, HKDFInfo)
|
||||
okm := make([]byte, 64)
|
||||
if _, err := hk.Read(okm); err != nil {
|
||||
// HKDF-Read for 64 bytes from SHA-256 is infallible; treat any error as a bug.
|
||||
panic(err)
|
||||
}
|
||||
var keys SessionKeys
|
||||
copy(keys.ClientToServer[:], okm[:32])
|
||||
copy(keys.ServerToClient[:], okm[32:])
|
||||
return &keys
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
// Package crypto implements the Aura primitives the Go client side needs: hybrid X25519 +
|
||||
// ML-KEM-768 KEM, HKDF-SHA256 session-key derivation, ChaCha20-Poly1305 AEAD using the same
|
||||
// LE(u64)||[0;4] nonce scheme the Rust side uses, and the HMAC-SHA256 port-knock token.
|
||||
//
|
||||
// All exported sizes match the on-wire constants in crates/aura-crypto and aura-proto:
|
||||
//
|
||||
// X25519 public / shared secret 32 bytes
|
||||
// ML-KEM-768 encapsulation key 1184 bytes
|
||||
// ML-KEM-768 ciphertext 1088 bytes
|
||||
// ML-KEM-768 shared secret 32 bytes
|
||||
//
|
||||
// We use crypto/mlkem (Go 1.24+ stdlib) for the post-quantum half. The Rust side uses the
|
||||
// `ml_kem` 0.3 crate; both are FIPS 203 ML-KEM-768. The shared secrets agree byte-for-byte —
|
||||
// asserted in crypto_test.go against the KAT vector emitted by `tools/export-kat`.
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/mlkem"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Sizes of the hybrid KEM building blocks, all in bytes.
|
||||
const (
|
||||
X25519Len = 32
|
||||
MLKEMEKLen = 1184
|
||||
MLKEMCTLen = 1088
|
||||
MLKEMSSLen = 32
|
||||
HybridSSLen = X25519Len + MLKEMSSLen
|
||||
)
|
||||
|
||||
// HybridPublicKey is the client's public half: a 32-byte X25519 public key plus a 1184-byte
|
||||
// ML-KEM-768 encapsulation key.
|
||||
type HybridPublicKey struct {
|
||||
X25519 [X25519Len]byte
|
||||
MLKEM []byte // 1184 bytes
|
||||
}
|
||||
|
||||
// HybridPrivateKey is the client's secret half. We hold the high-level keys so encapsulate /
|
||||
// decapsulate are simple method calls.
|
||||
type HybridPrivateKey struct {
|
||||
x25519Priv *ecdh.PrivateKey
|
||||
mlkemDk *mlkem.DecapsulationKey768
|
||||
}
|
||||
|
||||
// HybridCiphertext is the server's response: its ephemeral X25519 public key plus the ML-KEM
|
||||
// ciphertext.
|
||||
type HybridCiphertext struct {
|
||||
X25519Eph [X25519Len]byte
|
||||
MLKEMCT []byte // 1088 bytes
|
||||
}
|
||||
|
||||
// HybridSharedSecret is the 64-byte concatenation x25519_ss || kyber_ss.
|
||||
type HybridSharedSecret struct {
|
||||
X25519SS [X25519Len]byte
|
||||
MLKEMSS [MLKEMSSLen]byte
|
||||
}
|
||||
|
||||
// Concat returns x25519_ss || mlkem_ss in one slice (the IKM HKDF consumes).
|
||||
func (h *HybridSharedSecret) Concat() []byte {
|
||||
out := make([]byte, HybridSSLen)
|
||||
copy(out[:X25519Len], h.X25519SS[:])
|
||||
copy(out[X25519Len:], h.MLKEMSS[:])
|
||||
return out
|
||||
}
|
||||
|
||||
// GenerateHybridKeypair produces a fresh client hybrid keypair using the OS RNG. Used by the
|
||||
// standalone CLI; tests that need determinism instead call NewHybridPrivateFromSeeds or
|
||||
// reconstruct from explicit bytes.
|
||||
func GenerateHybridKeypair() (*HybridPrivateKey, *HybridPublicKey, error) {
|
||||
x, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("x25519 keygen: %w", err)
|
||||
}
|
||||
dk, err := mlkem.GenerateKey768()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("ml-kem keygen: %w", err)
|
||||
}
|
||||
return buildHybrid(x, dk)
|
||||
}
|
||||
|
||||
// NewHybridPrivateFromBytes reconstructs a hybrid private key from raw 32-byte X25519 seed and
|
||||
// the 64-byte ML-KEM seed (d || z). Mirrors the deterministic constructor the export-kat tool
|
||||
// uses so the Go side can drive a handshake against the same KAT vector.
|
||||
func NewHybridPrivateFromBytes(x25519Priv [X25519Len]byte, mlkemSeed [64]byte) (*HybridPrivateKey, *HybridPublicKey, error) {
|
||||
// x25519: NewPrivateKey requires a 32-byte scalar. Go enforces clamping inside the curve.
|
||||
x, err := ecdh.X25519().NewPrivateKey(x25519Priv[:])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("x25519 from bytes: %w", err)
|
||||
}
|
||||
dk, err := mlkem.NewDecapsulationKey768(mlkemSeed[:])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("ml-kem from seed: %w", err)
|
||||
}
|
||||
return buildHybrid(x, dk)
|
||||
}
|
||||
|
||||
func buildHybrid(x *ecdh.PrivateKey, dk *mlkem.DecapsulationKey768) (*HybridPrivateKey, *HybridPublicKey, error) {
|
||||
priv := &HybridPrivateKey{x25519Priv: x, mlkemDk: dk}
|
||||
pub := &HybridPublicKey{MLKEM: dk.EncapsulationKey().Bytes()}
|
||||
if len(pub.MLKEM) != MLKEMEKLen {
|
||||
return nil, nil, fmt.Errorf("ml-kem ek wrong length: %d", len(pub.MLKEM))
|
||||
}
|
||||
xPub := x.PublicKey().Bytes()
|
||||
if len(xPub) != X25519Len {
|
||||
return nil, nil, fmt.Errorf("x25519 pub wrong length: %d", len(xPub))
|
||||
}
|
||||
copy(pub.X25519[:], xPub)
|
||||
return priv, pub, nil
|
||||
}
|
||||
|
||||
// Decapsulate runs the client-side decapsulation: ECDH against the server's ephemeral X25519
|
||||
// plus ML-KEM-768 decapsulation under the stored secret key.
|
||||
func (h *HybridPrivateKey) Decapsulate(ct *HybridCiphertext) (*HybridSharedSecret, error) {
|
||||
if len(ct.MLKEMCT) != MLKEMCTLen {
|
||||
return nil, fmt.Errorf("ml-kem ct wrong length: %d", len(ct.MLKEMCT))
|
||||
}
|
||||
peerPub, err := ecdh.X25519().NewPublicKey(ct.X25519Eph[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("x25519 peer pub: %w", err)
|
||||
}
|
||||
xss, err := h.x25519Priv.ECDH(peerPub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("x25519 ecdh: %w", err)
|
||||
}
|
||||
if len(xss) != X25519Len {
|
||||
return nil, fmt.Errorf("x25519 ss wrong length: %d", len(xss))
|
||||
}
|
||||
kss, err := h.mlkemDk.Decapsulate(ct.MLKEMCT)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ml-kem decapsulate: %w", err)
|
||||
}
|
||||
if len(kss) != MLKEMSSLen {
|
||||
return nil, fmt.Errorf("ml-kem ss wrong length: %d", len(kss))
|
||||
}
|
||||
out := &HybridSharedSecret{}
|
||||
copy(out.X25519SS[:], xss)
|
||||
copy(out.MLKEMSS[:], kss)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Encapsulate is the server side of the handshake. Provided here purely so a Go-side end-to-end
|
||||
// test can drive both halves in-process. The standalone client never calls this.
|
||||
func (p *HybridPublicKey) Encapsulate() (*HybridCiphertext, *HybridSharedSecret, error) {
|
||||
if len(p.MLKEM) != MLKEMEKLen {
|
||||
return nil, nil, errors.New("hybrid pub: invalid ml-kem ek length")
|
||||
}
|
||||
eph, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("x25519 eph keygen: %w", err)
|
||||
}
|
||||
peer, err := ecdh.X25519().NewPublicKey(p.X25519[:])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("x25519 peer: %w", err)
|
||||
}
|
||||
xss, err := eph.ECDH(peer)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("x25519 ecdh: %w", err)
|
||||
}
|
||||
ek, err := mlkem.NewEncapsulationKey768(p.MLKEM)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("ml-kem ek parse: %w", err)
|
||||
}
|
||||
kss, kct := ek.Encapsulate()
|
||||
|
||||
ct := &HybridCiphertext{MLKEMCT: kct}
|
||||
copy(ct.X25519Eph[:], eph.PublicKey().Bytes())
|
||||
ss := &HybridSharedSecret{}
|
||||
copy(ss.X25519SS[:], xss)
|
||||
copy(ss.MLKEMSS[:], kss)
|
||||
return ct, ss, nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package frame
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ControlEnvelopeMagic is the 4-byte prefix marking a v2 control message multiplexed through the
|
||||
// PacketConnection's send_packet path. An IPv4 packet's first byte is 0x4X and an IPv6 packet's
|
||||
// first byte is 0x6X, so this magic (starting with 0xAA) never collides with a real IP packet.
|
||||
var ControlEnvelopeMagic = [4]byte{0xAA, 0xAA, 0xC0, 0x01}
|
||||
|
||||
// ControlKind is the on-wire byte selector inside a control envelope.
|
||||
type ControlKind byte
|
||||
|
||||
// Known control kinds (must match crates/aura-proto/src/frame.rs ControlKind).
|
||||
const (
|
||||
ControlCrlPush ControlKind = 0x01
|
||||
ControlCrlAck ControlKind = 0x02
|
||||
ControlExtendBridge ControlKind = 0x03
|
||||
ControlCircuitReady ControlKind = 0x04
|
||||
ControlCircuitFailed ControlKind = 0x05
|
||||
)
|
||||
|
||||
// EncodeControlEnvelope wraps (kind, payload) as
|
||||
//
|
||||
// MAGIC(4) || kind(u8) || u32_be(payload_len) || payload
|
||||
//
|
||||
// suitable for shipping through PacketConnection.SendPacket.
|
||||
func EncodeControlEnvelope(kind ControlKind, payload []byte) []byte {
|
||||
out := make([]byte, 0, len(ControlEnvelopeMagic)+1+4+len(payload))
|
||||
out = append(out, ControlEnvelopeMagic[:]...)
|
||||
out = append(out, byte(kind))
|
||||
var lb [4]byte
|
||||
binary.BigEndian.PutUint32(lb[:], uint32(len(payload)))
|
||||
out = append(out, lb[:]...)
|
||||
out = append(out, payload...)
|
||||
return out
|
||||
}
|
||||
|
||||
// DecodeControlEnvelope returns (kind, payload, true, nil) if buf starts with the magic and
|
||||
// parses cleanly. If buf does NOT start with the magic (i.e. it is a normal IP packet) the third
|
||||
// return is false and the error is nil. A malformed envelope (truncated) returns an error.
|
||||
func DecodeControlEnvelope(buf []byte) (ControlKind, []byte, bool, error) {
|
||||
if len(buf) < len(ControlEnvelopeMagic) {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
for i, b := range ControlEnvelopeMagic {
|
||||
if buf[i] != b {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
}
|
||||
rest := buf[len(ControlEnvelopeMagic):]
|
||||
if len(rest) < 1 {
|
||||
return 0, nil, true, fmt.Errorf("%w: control envelope: missing kind", ErrMalformedFrame)
|
||||
}
|
||||
kind := ControlKind(rest[0])
|
||||
if len(rest) < 5 {
|
||||
return 0, nil, true, fmt.Errorf("%w: control envelope: missing payload length", ErrMalformedFrame)
|
||||
}
|
||||
plen := int(binary.BigEndian.Uint32(rest[1:5]))
|
||||
if len(rest) < 5+plen {
|
||||
return 0, nil, true, fmt.Errorf("%w: control envelope: truncated payload", ErrMalformedFrame)
|
||||
}
|
||||
payload := make([]byte, plen)
|
||||
copy(payload, rest[5:5+plen])
|
||||
return kind, payload, true, nil
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
// Package frame implements Aura's wire framing: the 5-byte protocol header and the
|
||||
// application-level Frame{Data,Ping,Pong,Close}.
|
||||
//
|
||||
// This is a byte-for-byte port of crates/aura-proto/src/frame.rs. The Rust unit tests in that
|
||||
// file are the wire spec; matching them here keeps the Go port interoperable with the Rust
|
||||
// server.
|
||||
//
|
||||
// Wire layout (from docs/protocol.md §6.1):
|
||||
//
|
||||
// byte 0 : msg_type (u8)
|
||||
// bytes 1..4 : length (u24, big-endian) — payload length in bytes
|
||||
// byte 4 : version = 0x01
|
||||
// bytes 5.. : payload (length bytes)
|
||||
package frame
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// HeaderLen is the size of the protocol header in bytes.
|
||||
const HeaderLen = 5
|
||||
|
||||
// ProtocolVersion is the constant carried in byte 4 of every header.
|
||||
const ProtocolVersion byte = 0x01
|
||||
|
||||
// MaxPayloadLen is the largest payload expressible by the u24 length field.
|
||||
const MaxPayloadLen = 0x00FF_FFFF
|
||||
|
||||
// MsgType is the on-wire message-type discriminant carried in byte 0 of the header.
|
||||
type MsgType byte
|
||||
|
||||
// Message-type bytes (must match the Rust MsgType repr in aura-proto/frame.rs).
|
||||
const (
|
||||
MsgClientHello MsgType = 0x01
|
||||
MsgServerHello MsgType = 0x02
|
||||
MsgClientAuth MsgType = 0x03
|
||||
MsgServerAuth MsgType = 0x04
|
||||
MsgFinished MsgType = 0x05
|
||||
MsgData MsgType = 0x06
|
||||
MsgAlert MsgType = 0xFF
|
||||
)
|
||||
|
||||
// String returns the short name of the message type, for logs.
|
||||
func (m MsgType) String() string {
|
||||
switch m {
|
||||
case MsgClientHello:
|
||||
return "ClientHello"
|
||||
case MsgServerHello:
|
||||
return "ServerHello"
|
||||
case MsgClientAuth:
|
||||
return "ClientAuth"
|
||||
case MsgServerAuth:
|
||||
return "ServerAuth"
|
||||
case MsgFinished:
|
||||
return "Finished"
|
||||
case MsgData:
|
||||
return "Data"
|
||||
case MsgAlert:
|
||||
return "Alert"
|
||||
default:
|
||||
return fmt.Sprintf("MsgType(0x%02X)", byte(m))
|
||||
}
|
||||
}
|
||||
|
||||
// Errors returned by the codec. They mirror the ProtoError variants the Rust side returns so
|
||||
// callers can map them onto identical wire alerts.
|
||||
var (
|
||||
ErrFrameTooLarge = errors.New("aura/frame: payload exceeds 16 MiB u24 length field")
|
||||
ErrBadVersion = errors.New("aura/frame: header byte 4 is not protocol version 0x01")
|
||||
ErrUnknownMsgType = errors.New("aura/frame: unknown message-type byte")
|
||||
ErrMalformedFrame = errors.New("aura/frame: malformed application frame")
|
||||
)
|
||||
|
||||
// EncodeHeader builds a 5-byte header for msgType carrying a payload of payloadLen bytes.
|
||||
func EncodeHeader(msgType MsgType, payloadLen int) ([HeaderLen]byte, error) {
|
||||
var h [HeaderLen]byte
|
||||
if payloadLen < 0 || payloadLen > MaxPayloadLen {
|
||||
return h, fmt.Errorf("%w: len=%d", ErrFrameTooLarge, payloadLen)
|
||||
}
|
||||
h[0] = byte(msgType)
|
||||
// u24 big-endian.
|
||||
h[1] = byte((payloadLen >> 16) & 0xFF)
|
||||
h[2] = byte((payloadLen >> 8) & 0xFF)
|
||||
h[3] = byte(payloadLen & 0xFF)
|
||||
h[4] = ProtocolVersion
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// DecodeHeader parses a 5-byte header into (msgType, payloadLen).
|
||||
func DecodeHeader(h [HeaderLen]byte) (MsgType, int, error) {
|
||||
if h[4] != ProtocolVersion {
|
||||
return 0, 0, fmt.Errorf("%w: got 0x%02X", ErrBadVersion, h[4])
|
||||
}
|
||||
mt := MsgType(h[0])
|
||||
switch mt {
|
||||
case MsgClientHello, MsgServerHello, MsgClientAuth, MsgServerAuth, MsgFinished, MsgData, MsgAlert:
|
||||
// recognized
|
||||
default:
|
||||
return 0, 0, fmt.Errorf("%w: got 0x%02X", ErrUnknownMsgType, h[0])
|
||||
}
|
||||
plen := int(h[1])<<16 | int(h[2])<<8 | int(h[3])
|
||||
return mt, plen, nil
|
||||
}
|
||||
|
||||
// RawFrame is a frame as it was on the wire: type, header bytes (useful for AEAD AAD and the
|
||||
// handshake transcript hash), and payload bytes.
|
||||
type RawFrame struct {
|
||||
MsgType MsgType
|
||||
Header [HeaderLen]byte
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
// WireBytes returns header || payload in a fresh slice — used to feed the transcript hash, which
|
||||
// hashes the bytes exactly as transmitted.
|
||||
func (rf *RawFrame) WireBytes() []byte {
|
||||
out := make([]byte, 0, HeaderLen+len(rf.Payload))
|
||||
out = append(out, rf.Header[:]...)
|
||||
out = append(out, rf.Payload...)
|
||||
return out
|
||||
}
|
||||
|
||||
// WriteFrame serializes header || payload and writes it to w. Single Write, so on a streaming
|
||||
// transport a single TCP segment is preferred.
|
||||
func WriteFrame(w io.Writer, msgType MsgType, payload []byte) error {
|
||||
h, err := EncodeHeader(msgType, len(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf := make([]byte, 0, HeaderLen+len(payload))
|
||||
buf = append(buf, h[:]...)
|
||||
buf = append(buf, payload...)
|
||||
_, err = w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// ReadFrame reads one full frame (header || payload) from r.
|
||||
func ReadFrame(r io.Reader) (*RawFrame, error) {
|
||||
var h [HeaderLen]byte
|
||||
if _, err := io.ReadFull(r, h[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mt, plen, err := DecodeHeader(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := make([]byte, plen)
|
||||
if plen > 0 {
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &RawFrame{MsgType: mt, Header: h, Payload: payload}, nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------------
|
||||
// Application frames (§6.3) — Data, Ping, Pong, Close, Control.
|
||||
// ----------------------------------------------------------------------------------------------
|
||||
|
||||
// FrameKind identifies the Application-frame variant.
|
||||
type FrameKind byte
|
||||
|
||||
// On-wire frame tags (must match crates/aura-proto/src/frame.rs frame_tag::*).
|
||||
const (
|
||||
FrameData FrameKind = 0x01
|
||||
FramePing FrameKind = 0x02
|
||||
FramePong FrameKind = 0x03
|
||||
FrameClose FrameKind = 0x04
|
||||
)
|
||||
|
||||
// Frame is the post-handshake application payload carried inside an AEAD-sealed MsgData record.
|
||||
// One Frame is mapped to one of the four variants by Kind.
|
||||
type Frame struct {
|
||||
Kind FrameKind
|
||||
StreamID uint32 // Data only
|
||||
Payload []byte // Data only
|
||||
Seq uint32 // Ping / Pong only
|
||||
Code byte // Close only
|
||||
Reason string // Close only
|
||||
}
|
||||
|
||||
// EncodeFrame serializes f into its compact byte encoding (all multi-byte ints big-endian):
|
||||
//
|
||||
// Data : 0x01 || stream_id(u32) || payload
|
||||
// Ping : 0x02 || seq(u32)
|
||||
// Pong : 0x03 || seq(u32)
|
||||
// Close : 0x04 || code(u8) || reason_len(u32) || reason_utf8
|
||||
func EncodeFrame(f *Frame) []byte {
|
||||
switch f.Kind {
|
||||
case FrameData:
|
||||
out := make([]byte, 1+4+len(f.Payload))
|
||||
out[0] = byte(FrameData)
|
||||
binary.BigEndian.PutUint32(out[1:5], f.StreamID)
|
||||
copy(out[5:], f.Payload)
|
||||
return out
|
||||
case FramePing:
|
||||
out := make([]byte, 1+4)
|
||||
out[0] = byte(FramePing)
|
||||
binary.BigEndian.PutUint32(out[1:5], f.Seq)
|
||||
return out
|
||||
case FramePong:
|
||||
out := make([]byte, 1+4)
|
||||
out[0] = byte(FramePong)
|
||||
binary.BigEndian.PutUint32(out[1:5], f.Seq)
|
||||
return out
|
||||
case FrameClose:
|
||||
reason := []byte(f.Reason)
|
||||
out := make([]byte, 1+1+4+len(reason))
|
||||
out[0] = byte(FrameClose)
|
||||
out[1] = f.Code
|
||||
binary.BigEndian.PutUint32(out[2:6], uint32(len(reason)))
|
||||
copy(out[6:], reason)
|
||||
return out
|
||||
default:
|
||||
// Programmer error — encode nothing rather than panic so call sites can defensively
|
||||
// inspect the returned length.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeFrame parses one byte-encoded Frame (the inverse of EncodeFrame).
|
||||
func DecodeFrame(b []byte) (*Frame, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, fmt.Errorf("%w: empty frame", ErrMalformedFrame)
|
||||
}
|
||||
tag := FrameKind(b[0])
|
||||
rest := b[1:]
|
||||
switch tag {
|
||||
case FrameData:
|
||||
if len(rest) < 4 {
|
||||
return nil, fmt.Errorf("%w: Data: missing stream_id", ErrMalformedFrame)
|
||||
}
|
||||
sid := binary.BigEndian.Uint32(rest[:4])
|
||||
// Payload is everything after the 4-byte stream_id.
|
||||
payload := make([]byte, len(rest)-4)
|
||||
copy(payload, rest[4:])
|
||||
return &Frame{Kind: FrameData, StreamID: sid, Payload: payload}, nil
|
||||
case FramePing:
|
||||
if len(rest) < 4 {
|
||||
return nil, fmt.Errorf("%w: Ping: truncated seq", ErrMalformedFrame)
|
||||
}
|
||||
return &Frame{Kind: FramePing, Seq: binary.BigEndian.Uint32(rest[:4])}, nil
|
||||
case FramePong:
|
||||
if len(rest) < 4 {
|
||||
return nil, fmt.Errorf("%w: Pong: truncated seq", ErrMalformedFrame)
|
||||
}
|
||||
return &Frame{Kind: FramePong, Seq: binary.BigEndian.Uint32(rest[:4])}, nil
|
||||
case FrameClose:
|
||||
if len(rest) < 1 {
|
||||
return nil, fmt.Errorf("%w: Close: missing code", ErrMalformedFrame)
|
||||
}
|
||||
code := rest[0]
|
||||
if len(rest) < 5 {
|
||||
return nil, fmt.Errorf("%w: Close: missing reason_len", ErrMalformedFrame)
|
||||
}
|
||||
rlen := int(binary.BigEndian.Uint32(rest[1:5]))
|
||||
if len(rest) < 5+rlen {
|
||||
return nil, fmt.Errorf("%w: Close: truncated reason", ErrMalformedFrame)
|
||||
}
|
||||
// We do not enforce strict UTF-8 here (Go strings can hold any bytes); the Rust side
|
||||
// rejects non-UTF-8 in this slot, so peers that follow the spec only ever send valid
|
||||
// strings.
|
||||
return &Frame{Kind: FrameClose, Code: code, Reason: string(rest[5 : 5+rlen])}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unknown frame tag 0x%02X", ErrMalformedFrame, byte(tag))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package frame
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHeaderRoundtripAllTypes(t *testing.T) {
|
||||
cases := []struct {
|
||||
ty MsgType
|
||||
b byte
|
||||
}{
|
||||
{MsgClientHello, 0x01},
|
||||
{MsgServerHello, 0x02},
|
||||
{MsgClientAuth, 0x03},
|
||||
{MsgServerAuth, 0x04},
|
||||
{MsgFinished, 0x05},
|
||||
{MsgData, 0x06},
|
||||
{MsgAlert, 0xFF},
|
||||
}
|
||||
for _, c := range cases {
|
||||
h, err := EncodeHeader(c.ty, 0x00123456)
|
||||
if err != nil {
|
||||
t.Fatalf("encode %s: %v", c.ty, err)
|
||||
}
|
||||
if h[0] != c.b {
|
||||
t.Fatalf("type byte for %s: got 0x%02X want 0x%02X", c.ty, h[0], c.b)
|
||||
}
|
||||
if h[4] != ProtocolVersion {
|
||||
t.Fatalf("version byte: got 0x%02X want 0x01", h[4])
|
||||
}
|
||||
mt, plen, err := DecodeHeader(h)
|
||||
if err != nil {
|
||||
t.Fatalf("decode %s: %v", c.ty, err)
|
||||
}
|
||||
if mt != c.ty || plen != 0x00123456 {
|
||||
t.Fatalf("roundtrip mismatch: got (%s, %d)", mt, plen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderRejectsOversizeAndBadVersion(t *testing.T) {
|
||||
if _, err := EncodeHeader(MsgData, MaxPayloadLen+1); !errors.Is(err, ErrFrameTooLarge) {
|
||||
t.Fatalf("oversize: want ErrFrameTooLarge, got %v", err)
|
||||
}
|
||||
h, err := EncodeHeader(MsgData, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h[4] = 0x02
|
||||
if _, _, err := DecodeHeader(h); !errors.Is(err, ErrBadVersion) {
|
||||
t.Fatalf("bad version: want ErrBadVersion, got %v", err)
|
||||
}
|
||||
// Reset the version so the unknown-type check actually exercises the type branch.
|
||||
h[4] = ProtocolVersion
|
||||
h[0] = 0x77
|
||||
if _, _, err := DecodeHeader(h); !errors.Is(err, ErrUnknownMsgType) {
|
||||
t.Fatalf("unknown type: want ErrUnknownMsgType, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameRoundtrip(t *testing.T) {
|
||||
frames := []*Frame{
|
||||
{Kind: FrameData, StreamID: 0xDEADBEEF, Payload: []byte("hello world")},
|
||||
{Kind: FrameData, StreamID: 0, Payload: nil},
|
||||
{Kind: FramePing, Seq: 42},
|
||||
{Kind: FramePong, Seq: 0xFFFFFFFF},
|
||||
{Kind: FrameClose, Code: 7, Reason: "going away \U0001F44B"},
|
||||
{Kind: FrameClose, Code: 0, Reason: ""},
|
||||
}
|
||||
for _, f := range frames {
|
||||
enc := EncodeFrame(f)
|
||||
got, err := DecodeFrame(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("decode %v: %v", f.Kind, err)
|
||||
}
|
||||
if got.Kind != f.Kind || got.StreamID != f.StreamID || got.Seq != f.Seq ||
|
||||
got.Code != f.Code || got.Reason != f.Reason || !bytes.Equal(got.Payload, f.Payload) {
|
||||
t.Fatalf("roundtrip mismatch: %+v vs %+v", f, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameDecodeRejectsGarbage(t *testing.T) {
|
||||
if _, err := DecodeFrame(nil); err == nil {
|
||||
t.Fatal("nil: want error")
|
||||
}
|
||||
if _, err := DecodeFrame([]byte{0x99}); err == nil {
|
||||
t.Fatal("unknown tag: want error")
|
||||
}
|
||||
if _, err := DecodeFrame([]byte{byte(FramePing), 0x00}); err == nil {
|
||||
t.Fatal("truncated ping: want error")
|
||||
}
|
||||
if _, err := DecodeFrame([]byte{byte(FrameClose)}); err == nil {
|
||||
t.Fatal("missing close code: want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlEnvelopeRoundtrip(t *testing.T) {
|
||||
env := EncodeControlEnvelope(ControlCrlPush, []byte("hello"))
|
||||
if !bytes.Equal(env[:4], ControlEnvelopeMagic[:]) {
|
||||
t.Fatalf("magic mismatch: %x", env[:4])
|
||||
}
|
||||
kind, payload, ok, err := DecodeControlEnvelope(env)
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("decode: ok=%v err=%v", ok, err)
|
||||
}
|
||||
if kind != ControlCrlPush || string(payload) != "hello" {
|
||||
t.Fatalf("decode mismatch: kind=%v payload=%q", kind, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlEnvelopeSkipsNormalIPPackets(t *testing.T) {
|
||||
cases := [][]byte{
|
||||
{0x45, 0x00, 0x00, 0x14}, // IPv4 packet
|
||||
{0x60, 0x00, 0x00, 0x00}, // IPv6 packet
|
||||
{0xAA, 0xAA, 0xC0, 0x02}, // wrong magic last byte
|
||||
{0xAA, 0xAA}, // shorter than magic
|
||||
}
|
||||
for _, c := range cases {
|
||||
_, _, ok, err := DecodeControlEnvelope(c)
|
||||
if ok || err != nil {
|
||||
t.Fatalf("expected pass-through (ok=false, err=nil): got ok=%v err=%v on %x", ok, err, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlEnvelopeRejectsTruncatedPayload(t *testing.T) {
|
||||
env := EncodeControlEnvelope(ControlCrlPush, []byte("payload-bytes"))
|
||||
env = env[:len(env)-3]
|
||||
if _, _, _, err := DecodeControlEnvelope(env); err == nil {
|
||||
t.Fatal("want truncated payload error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAndReadFrameRoundtrip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
payload := []byte{1, 2, 3, 4, 5}
|
||||
if err := WriteFrame(&buf, MsgData, payload); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
raw, err := ReadFrame(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if raw.MsgType != MsgData || !bytes.Equal(raw.Payload, payload) {
|
||||
t.Fatalf("roundtrip mismatch: %+v", raw)
|
||||
}
|
||||
if got := raw.WireBytes(); len(got) != HeaderLen+len(payload) {
|
||||
t.Fatalf("wire bytes wrong length: %d", len(got))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,363 @@
|
||||
// Package handshake implements the client side of the Aura handshake state machine — a direct
|
||||
// port of crates/aura-proto/src/handshake.rs::client_handshake.
|
||||
//
|
||||
// Order of messages (fixed by the Rust impl; see protocol.md §6.2):
|
||||
//
|
||||
// 1. C->S ClientHello (plaintext): x25519_pub(32) || mlkem_ek(1184) || client_nonce(32)
|
||||
// 2. S->C ServerHello (plaintext): x25519_ephemeral(32) || mlkem_ct(1088) || server_nonce(32)
|
||||
// -- both sides derive the hybrid shared secret + directional SessionKeys --
|
||||
// 3. S->C ServerAuth (encrypted under s2c): u16(cert_der_len) || server_leaf_cert_der || sig
|
||||
// 4. C->S ClientAuth (encrypted under c2s): u16(cert_der_len) || client_leaf_cert_der || sig
|
||||
// 5. C->S Finished (encrypted under c2s): HMAC-SHA256(key_c2s, transcript)
|
||||
// 6. S->C Finished (encrypted under s2c): HMAC-SHA256(key_s2c, transcript)
|
||||
//
|
||||
// transcript = SHA-256(ClientHello_frame || ServerHello_frame), over the full serialized frames
|
||||
// (header + payload) exactly as transmitted.
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/aura/singbox-aura/aura/crypto"
|
||||
"github.com/aura/singbox-aura/aura/frame"
|
||||
)
|
||||
|
||||
// ClientConfig is what the standalone CLI / sing-box outbound passes into Client.
|
||||
//
|
||||
// CAPEM, CertPEM, KeyPEM are PEM-encoded blobs (newlines, BEGIN/END lines and all). ServerName
|
||||
// is the DNS name we expect to find in the server cert's SAN — must match the cert the server
|
||||
// presents.
|
||||
type ClientConfig struct {
|
||||
CAPEM []byte
|
||||
CertPEM []byte
|
||||
KeyPEM []byte // PKCS#8 PEM, ECDSA P-256
|
||||
ServerName string
|
||||
}
|
||||
|
||||
// Client runs the client side of the handshake to completion.
|
||||
//
|
||||
// On success it returns:
|
||||
// - DerivedKeys: the (c2s, s2c) session keys to seed the datagram codecs.
|
||||
// - PeerID: the verified server name (the same string we passed in, on success).
|
||||
//
|
||||
// The caller wraps `r` / `w` over whatever transport is in use (the UDP reliability adapter
|
||||
// for plain UDP; a TCP stream for the TCP fallback; a paired pipe in tests).
|
||||
type Result struct {
|
||||
C2S [32]byte
|
||||
S2C [32]byte
|
||||
Transcript [32]byte
|
||||
PeerID string
|
||||
}
|
||||
|
||||
// Client drives the handshake state machine end-to-end.
|
||||
func Client(r io.Reader, w io.Writer, cfg *ClientConfig) (*Result, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("aura/handshake: nil config")
|
||||
}
|
||||
|
||||
// (1) Generate our hybrid keypair + nonce, send ClientHello.
|
||||
priv, pub, err := crypto.GenerateHybridKeypair()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hybrid keygen: %w", err)
|
||||
}
|
||||
var clientNonce [32]byte
|
||||
if _, err := rand.Read(clientNonce[:]); err != nil {
|
||||
return nil, fmt.Errorf("client nonce: %w", err)
|
||||
}
|
||||
|
||||
chPayload := make([]byte, 0, crypto.X25519Len+crypto.MLKEMEKLen+32)
|
||||
chPayload = append(chPayload, pub.X25519[:]...)
|
||||
chPayload = append(chPayload, pub.MLKEM...)
|
||||
chPayload = append(chPayload, clientNonce[:]...)
|
||||
if len(chPayload) != crypto.X25519Len+crypto.MLKEMEKLen+32 {
|
||||
return nil, fmt.Errorf("client hello wrong size: %d", len(chPayload))
|
||||
}
|
||||
chHeader, err := frame.EncodeHeader(frame.MsgClientHello, len(chPayload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := frame.WriteFrame(w, frame.MsgClientHello, chPayload); err != nil {
|
||||
return nil, fmt.Errorf("write ClientHello: %w", err)
|
||||
}
|
||||
chWire := append(append([]byte{}, chHeader[:]...), chPayload...)
|
||||
|
||||
// (2) Read ServerHello.
|
||||
sh, err := readExpect(r, frame.MsgServerHello)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
const expectSHLen = crypto.X25519Len + crypto.MLKEMCTLen + 32
|
||||
if len(sh.Payload) != expectSHLen {
|
||||
return nil, fmt.Errorf("ServerHello: wrong length %d (want %d)", len(sh.Payload), expectSHLen)
|
||||
}
|
||||
ct := &crypto.HybridCiphertext{MLKEMCT: append([]byte{}, sh.Payload[crypto.X25519Len:crypto.X25519Len+crypto.MLKEMCTLen]...)}
|
||||
copy(ct.X25519Eph[:], sh.Payload[:crypto.X25519Len])
|
||||
var serverNonce [32]byte
|
||||
copy(serverNonce[:], sh.Payload[crypto.X25519Len+crypto.MLKEMCTLen:])
|
||||
|
||||
shared, err := priv.Decapsulate(ct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decapsulate: %w", err)
|
||||
}
|
||||
keys := crypto.DeriveSessionKeys(shared, clientNonce, serverNonce)
|
||||
|
||||
// transcript = SHA-256(client_hello_wire || server_hello_wire) over the bytes as transmitted.
|
||||
hash := sha256.New()
|
||||
hash.Write(chWire)
|
||||
hash.Write(sh.WireBytes())
|
||||
var transcript [32]byte
|
||||
copy(transcript[:], hash.Sum(nil))
|
||||
|
||||
// Two AEAD sessions: client seals under c2s, opens under s2c. The counters continue across
|
||||
// the handshake/data boundary, so we must keep using the same instances.
|
||||
aeadC2S, err := crypto.NewAeadSession(keys.ClientToServer[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aeadS2C, err := crypto.NewAeadSession(keys.ServerToClient[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// (3) Server -> client ServerAuth (encrypted under s2c).
|
||||
serverAuth, err := openHandshakeMsg(r, frame.MsgServerAuth, aeadS2C)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ServerAuth: %w", err)
|
||||
}
|
||||
serverCertDER, serverSig, err := splitCertAndSig(serverAuth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := verifyServerCert(serverCertDER, cfg.CAPEM, cfg.ServerName); err != nil {
|
||||
return nil, fmt.Errorf("verify server cert: %w", err)
|
||||
}
|
||||
if err := verifySignature(serverCertDER, transcript[:], serverSig); err != nil {
|
||||
return nil, fmt.Errorf("verify server signature: %w", err)
|
||||
}
|
||||
|
||||
// (4) Client -> server ClientAuth (encrypted under c2s).
|
||||
clientCertDER, err := pemCertToDER(cfg.CertPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("client cert: %w", err)
|
||||
}
|
||||
clientSig, err := signTranscript(cfg.KeyPEM, transcript[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign transcript: %w", err)
|
||||
}
|
||||
clientAuth := buildCertAndSig(clientCertDER, clientSig)
|
||||
if err := sealHandshakeMsg(w, frame.MsgClientAuth, aeadC2S, clientAuth); err != nil {
|
||||
return nil, fmt.Errorf("write ClientAuth: %w", err)
|
||||
}
|
||||
|
||||
// (5) Client -> server Finished (encrypted under c2s).
|
||||
clientFinished := hmacSHA256(keys.ClientToServer[:], transcript[:])
|
||||
if err := sealHandshakeMsg(w, frame.MsgFinished, aeadC2S, clientFinished); err != nil {
|
||||
return nil, fmt.Errorf("write client Finished: %w", err)
|
||||
}
|
||||
|
||||
// (6) Server -> client Finished: verify against expected.
|
||||
serverFinished, err := openHandshakeMsg(r, frame.MsgFinished, aeadS2C)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server Finished: %w", err)
|
||||
}
|
||||
expectedServerFinished := hmacSHA256(keys.ServerToClient[:], transcript[:])
|
||||
if !hmac.Equal(serverFinished, expectedServerFinished) {
|
||||
return nil, errors.New("aura/handshake: server Finished MAC mismatch")
|
||||
}
|
||||
|
||||
return &Result{
|
||||
C2S: keys.ClientToServer,
|
||||
S2C: keys.ServerToClient,
|
||||
Transcript: transcript,
|
||||
PeerID: cfg.ServerName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// readExpect reads one frame from r and demands it be of type want. An Alert is converted into
|
||||
// a typed error.
|
||||
func readExpect(r io.Reader, want frame.MsgType) (*frame.RawFrame, error) {
|
||||
rf, err := frame.ReadFrame(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rf.MsgType == frame.MsgAlert {
|
||||
code := byte(0)
|
||||
if len(rf.Payload) > 0 {
|
||||
code = rf.Payload[0]
|
||||
}
|
||||
return nil, fmt.Errorf("aura/handshake: peer alert code %d", code)
|
||||
}
|
||||
if rf.MsgType != want {
|
||||
return nil, fmt.Errorf("aura/handshake: expected %s, got %s", want, rf.MsgType)
|
||||
}
|
||||
return rf, nil
|
||||
}
|
||||
|
||||
// sealHandshakeMsg seals plaintext under aead (advancing its counter) and writes one frame.
|
||||
// AAD is the 5-byte header — same convention as Data records.
|
||||
func sealHandshakeMsg(w io.Writer, msgType frame.MsgType, aead *crypto.AeadSession, plaintext []byte) error {
|
||||
sealedLen := len(plaintext) + 16 // Poly1305 tag
|
||||
hdr, err := frame.EncodeHeader(msgType, sealedLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ct := aead.Seal(plaintext, hdr[:])
|
||||
if len(ct) != sealedLen {
|
||||
return fmt.Errorf("aura/handshake: sealed wrong size %d (want %d)", len(ct), sealedLen)
|
||||
}
|
||||
return frame.WriteFrame(w, msgType, ct)
|
||||
}
|
||||
|
||||
// openHandshakeMsg reads one frame of type msgType and AEAD-opens it.
|
||||
func openHandshakeMsg(r io.Reader, msgType frame.MsgType, aead *crypto.AeadSession) ([]byte, error) {
|
||||
rf, err := readExpect(r, msgType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead.Open(rf.Payload, rf.Header[:])
|
||||
}
|
||||
|
||||
// buildCertAndSig: u16_be(cert_der_len) || cert_der || signature.
|
||||
func buildCertAndSig(certDER, sig []byte) []byte {
|
||||
out := make([]byte, 0, 2+len(certDER)+len(sig))
|
||||
var lb [2]byte
|
||||
binary.BigEndian.PutUint16(lb[:], uint16(len(certDER)))
|
||||
out = append(out, lb[:]...)
|
||||
out = append(out, certDER...)
|
||||
out = append(out, sig...)
|
||||
return out
|
||||
}
|
||||
|
||||
// splitCertAndSig is the inverse.
|
||||
func splitCertAndSig(buf []byte) (certDER, sig []byte, err error) {
|
||||
if len(buf) < 2 {
|
||||
return nil, nil, errors.New("aura/handshake: Auth: missing cert length")
|
||||
}
|
||||
certLen := int(binary.BigEndian.Uint16(buf[:2]))
|
||||
if len(buf) < 2+certLen {
|
||||
return nil, nil, errors.New("aura/handshake: Auth: truncated cert")
|
||||
}
|
||||
certDER = buf[2 : 2+certLen]
|
||||
sig = buf[2+certLen:]
|
||||
if len(sig) == 0 {
|
||||
return nil, nil, errors.New("aura/handshake: Auth: empty signature")
|
||||
}
|
||||
return certDER, sig, nil
|
||||
}
|
||||
|
||||
// hmacSHA256 returns HMAC-SHA256(key, msg).
|
||||
func hmacSHA256(key, msg []byte) []byte {
|
||||
m := hmac.New(sha256.New, key)
|
||||
m.Write(msg)
|
||||
return m.Sum(nil)
|
||||
}
|
||||
|
||||
// pemCertToDER decodes the first CERTIFICATE PEM block.
|
||||
func pemCertToDER(pemBytes []byte) ([]byte, error) {
|
||||
rest := pemBytes
|
||||
for {
|
||||
block, r := pem.Decode(rest)
|
||||
if block == nil {
|
||||
return nil, errors.New("aura/handshake: no CERTIFICATE block in PEM")
|
||||
}
|
||||
if block.Type == "CERTIFICATE" {
|
||||
return block.Bytes, nil
|
||||
}
|
||||
rest = r
|
||||
}
|
||||
}
|
||||
|
||||
// pemKeyToDER decodes the first PRIVATE KEY-style PEM block. ECDSA leaves typically use PKCS#8
|
||||
// ("PRIVATE KEY"); we also accept the old "EC PRIVATE KEY" form for compatibility.
|
||||
func pemKeyToDER(pemBytes []byte) ([]byte, error) {
|
||||
rest := pemBytes
|
||||
for {
|
||||
block, r := pem.Decode(rest)
|
||||
if block == nil {
|
||||
return nil, errors.New("aura/handshake: no private-key block in PEM")
|
||||
}
|
||||
switch block.Type {
|
||||
case "PRIVATE KEY", "EC PRIVATE KEY", "RSA PRIVATE KEY":
|
||||
return block.Bytes, nil
|
||||
}
|
||||
rest = r
|
||||
}
|
||||
}
|
||||
|
||||
// signTranscript signs a 32-byte transcript with the ECDSA P-256 PKCS#8 key in PEM form. The
|
||||
// signature is the ASN.1 DER encoding ring uses on the Rust side (ECDSA_P256_SHA256_ASN1).
|
||||
func signTranscript(keyPEM, transcript []byte) ([]byte, error) {
|
||||
if len(transcript) != 32 {
|
||||
return nil, fmt.Errorf("transcript must be 32 bytes, got %d", len(transcript))
|
||||
}
|
||||
der, err := pemKeyToDER(keyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := x509.ParsePKCS8PrivateKey(der)
|
||||
if err != nil {
|
||||
// Fall back to the old EC-specific encoding (rfc 5915).
|
||||
ec, err2 := x509.ParseECPrivateKey(der)
|
||||
if err2 != nil {
|
||||
return nil, fmt.Errorf("parse client key: pkcs8=%v ec=%v", err, err2)
|
||||
}
|
||||
parsed = ec
|
||||
}
|
||||
key, ok := parsed.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("aura/handshake: client key is %T, want *ecdsa.PrivateKey", parsed)
|
||||
}
|
||||
// ecdsa.SignASN1 returns the same ASN.1 DER (r,s) encoding ring produces.
|
||||
sig, err := ecdsa.SignASN1(rand.Reader, key, transcript)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ecdsa sign: %w", err)
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// verifySignature checks an ECDSA P-256/SHA-256 signature (ASN.1 DER) over the 32-byte transcript
|
||||
// against the leaf cert's public key.
|
||||
func verifySignature(certDER, transcript, sig []byte) error {
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse peer cert: %w", err)
|
||||
}
|
||||
pub, ok := cert.PublicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("peer key is %T, want *ecdsa.PublicKey", cert.PublicKey)
|
||||
}
|
||||
if !ecdsa.VerifyASN1(pub, transcript, sig) {
|
||||
return errors.New("aura/handshake: signature did not verify")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyServerCert validates the server leaf against the CA PEM and the expected DNS name.
|
||||
func verifyServerCert(certDER, caPEM []byte, serverName string) error {
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caPEM) {
|
||||
return errors.New("aura/handshake: CA PEM contains no certs")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse server cert: %w", err)
|
||||
}
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: pool,
|
||||
DNSName: serverName,
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
if _, err := cert.Verify(opts); err != nil {
|
||||
return fmt.Errorf("verify chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aura/singbox-aura/aura/frame"
|
||||
)
|
||||
|
||||
// TestSplitAndBuildCertAndSigRoundtrip: tiny but load-bearing — Auth payload layout must match
|
||||
// the Rust wire format byte-for-byte.
|
||||
func TestSplitAndBuildCertAndSigRoundtrip(t *testing.T) {
|
||||
cert := bytes.Repeat([]byte{0xAB}, 250)
|
||||
sig := []byte{0xCD, 0xEF, 0x01, 0x02}
|
||||
enc := buildCertAndSig(cert, sig)
|
||||
gotCert, gotSig, err := splitCertAndSig(enc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(gotCert, cert) || !bytes.Equal(gotSig, sig) {
|
||||
t.Fatalf("roundtrip mismatch")
|
||||
}
|
||||
// Empty signature must be rejected.
|
||||
if _, _, err := splitCertAndSig(enc[:2+len(cert)]); err == nil {
|
||||
t.Fatal("empty sig must error")
|
||||
}
|
||||
// Truncated cert must be rejected.
|
||||
if _, _, err := splitCertAndSig(enc[:3]); err == nil {
|
||||
t.Fatal("truncated cert must error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSignVerifyTranscriptRoundtrip: generate an ECDSA P-256 key + self-signed cert, sign a
|
||||
// 32-byte transcript with our helper, verify with our helper, asserting we match the Rust side
|
||||
// (ECDSA P-256 / SHA-256 / ASN.1 DER).
|
||||
func TestSignVerifyTranscriptRoundtrip(t *testing.T) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Self-signed cert wrapping this key.
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "test-leaf"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
SignatureAlgorithm: x509.ECDSAWithSHA256,
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Wrap our key in PKCS#8 PEM, as the production cert issuance does.
|
||||
keyDER, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
var transcript [32]byte
|
||||
for i := range transcript {
|
||||
transcript[i] = byte(i ^ 0x55)
|
||||
}
|
||||
sig, err := signTranscript(keyPEM, transcript[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := verifySignature(certDER, transcript[:], sig); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Tampered transcript: verification must fail.
|
||||
bad := transcript
|
||||
bad[0] ^= 1
|
||||
if err := verifySignature(certDER, bad[:], sig); err == nil {
|
||||
t.Fatal("tampered transcript must fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientHelloLayoutSize: sanity that we compute the expected hello payload size.
|
||||
func TestClientHelloLayoutSize(t *testing.T) {
|
||||
const expected = 32 + 1184 + 32 // X25519 + ML-KEM ek + nonce
|
||||
if expected != 1248 {
|
||||
t.Fatalf("ClientHello expected size 1248, got %d", expected)
|
||||
}
|
||||
// And the on-wire frame adds the 5-byte header.
|
||||
hdr, err := frame.EncodeHeader(frame.MsgClientHello, expected)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if hdr[0] != 0x01 || hdr[4] != 0x01 {
|
||||
t.Fatalf("header byte 0/4 mismatch: %x", hdr)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
# Aura outbound for sing-box
|
||||
|
||||
`outbound.Outbound` exposes a sing-box-shaped surface (`Network() / DialContext / ListenPacket`)
|
||||
without importing `github.com/sagernet/sing-box`. This keeps the build self-contained for v1;
|
||||
the next step is to vendor the sing-box module, register Aura via `init()` and add the JSON
|
||||
options struct.
|
||||
|
||||
## Integration sketch (Option B from `docs/sing-box.md`)
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/sagernet/sing-box"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
auraout "github.com/aura/singbox-aura/aura/outbound"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sing-box.RegisterOutbound(auraout.Tag, func(ctx context.Context, router adapter.Router, logger logger.ContextLogger, tag string, options option.Outbound) (adapter.Outbound, error) {
|
||||
// Translate option fields to handshake.ClientConfig + transport.Options.
|
||||
// Construct &auraout.Outbound{...} and adapt to adapter.Outbound (DialContext signature).
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
The exact `option.Outbound` schema is up to you — at minimum it needs:
|
||||
|
||||
* `server` (host:port)
|
||||
* `tls.ca_cert_path` (PEM)
|
||||
* `tls.cert_path`, `tls.key_path` (PEM, ECDSA P-256)
|
||||
* `tls.server_name` (DNS SAN to verify in the server leaf)
|
||||
* optional `knock_enabled`, `knock_secret_source = "ca_fingerprint"`
|
||||
|
||||
The packet path is **opaque IP** — Aura tunnels inner IP packets exactly as the existing Rust
|
||||
client does. The router writes IPv4/IPv6 packets to the returned `net.PacketConn`; the same
|
||||
conn yields incoming packets on `ReadFrom`. Multi-flow demultiplexing is the router's job, not
|
||||
ours.
|
||||
@@ -0,0 +1,99 @@
|
||||
// Package outbound is a thin sing-box-shaped wrapper around the Aura UDP client. It does NOT
|
||||
// import github.com/sagernet/sing-box — keeping the dependency footprint small and the build
|
||||
// self-contained for v1. The interface is shaped after sing-box's outbound (Network,
|
||||
// DialContext, ListenPacket) so a follow-up patch can register this as a real outbound by
|
||||
// vendoring the sing-box module + filling in the missing glue.
|
||||
//
|
||||
// See README.md for the concrete integration steps.
|
||||
package outbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/aura/singbox-aura/aura/handshake"
|
||||
"github.com/aura/singbox-aura/aura/transport"
|
||||
)
|
||||
|
||||
// Tag is the identifier this outbound advertises to a sing-box router. Real registration would
|
||||
// set it on the outbound options struct.
|
||||
const Tag = "aura"
|
||||
|
||||
// Network returns the sing-box network type. Aura is connection-oriented over UDP underneath
|
||||
// but the application-layer abstraction is reliable+ordered for streams (TCP-like) and
|
||||
// best-effort for datagrams (UDP-like), so we expose UDP here — matches how the QUIC outbound
|
||||
// is registered.
|
||||
func Network() []string { return []string{"udp"} }
|
||||
|
||||
// Outbound is the per-server configuration that a sing-box-style host instantiates once per
|
||||
// upstream. One Outbound can dial many short-lived connections.
|
||||
type Outbound struct {
|
||||
ServerAddr string // e.g. "203.0.113.10:443"
|
||||
HSConfig *handshake.ClientConfig // CA + leaf cert + leaf key + expected server SNI
|
||||
Opts *transport.Options // optional knock + handshake timers
|
||||
}
|
||||
|
||||
// DialContext opens an Aura UDP connection to the upstream and wraps it as a net.PacketConn
|
||||
// for the sing-box stack to write IP packets to. `network` must be "udp"/"udp4"/"udp6";
|
||||
// `destination` is the application target the sing-box router computed (unused by v1 — Aura
|
||||
// carries opaque IP packets, not per-flow destinations).
|
||||
func (o *Outbound) DialContext(ctx context.Context, network, destination string) (net.PacketConn, error) {
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
default:
|
||||
return nil, fmt.Errorf("aura/outbound: unsupported network %q", network)
|
||||
}
|
||||
if o.ServerAddr == "" {
|
||||
return nil, errors.New("aura/outbound: ServerAddr is empty")
|
||||
}
|
||||
if o.HSConfig == nil {
|
||||
return nil, errors.New("aura/outbound: HSConfig is nil")
|
||||
}
|
||||
conn, err := transport.Dial(ctx, o.ServerAddr, o.HSConfig, o.Opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &packetConnAdapter{conn: conn, dest: destination}, nil
|
||||
}
|
||||
|
||||
// ListenPacket is the same call shape sing-box uses for inbound-style transports; for an
|
||||
// outbound this is a convenience that delegates to DialContext.
|
||||
func (o *Outbound) ListenPacket(ctx context.Context, destination string) (net.PacketConn, error) {
|
||||
return o.DialContext(ctx, "udp", destination)
|
||||
}
|
||||
|
||||
// packetConnAdapter exposes a transport.Connection as net.PacketConn. ReadFrom returns the
|
||||
// next inner IP payload and a placeholder *net.UDPAddr (Aura tunnels opaque IP packets — the
|
||||
// concrete destination addr is decoded by the upper layer). WriteTo simply ships the payload.
|
||||
type packetConnAdapter struct {
|
||||
conn *transport.Connection
|
||||
dest string
|
||||
}
|
||||
|
||||
func (p *packetConnAdapter) ReadFrom(buf []byte) (int, net.Addr, error) {
|
||||
pkt, err := p.conn.Recv()
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
n := copy(buf, pkt)
|
||||
// We do not have a real source addr at this layer; report the peer's identity as a fake
|
||||
// UDP address so any sing-box code that logs addr.String() gets something sensible.
|
||||
addr, _ := net.ResolveUDPAddr("udp", p.dest)
|
||||
return n, addr, nil
|
||||
}
|
||||
|
||||
func (p *packetConnAdapter) WriteTo(buf []byte, _ net.Addr) (int, error) {
|
||||
if err := p.conn.Send(buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
func (p *packetConnAdapter) Close() error { return p.conn.Close() }
|
||||
func (p *packetConnAdapter) LocalAddr() net.Addr { return &net.UDPAddr{IP: net.IPv4zero} }
|
||||
func (p *packetConnAdapter) SetDeadline(_ time.Time) error { return nil }
|
||||
func (p *packetConnAdapter) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (p *packetConnAdapter) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
@@ -0,0 +1,70 @@
|
||||
// Package session provides the post-handshake AEAD-protected Frame exchange and the sliding
|
||||
// replay window — direct port of crates/aura-proto/src/session.rs.
|
||||
package session
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ReplayWindow is the width (in records) of the anti-replay sliding window.
|
||||
const ReplayWindow uint64 = 64
|
||||
|
||||
// ErrReplay is returned when a record's sequence number is a duplicate or too old.
|
||||
type ErrReplay struct{ Seq uint64 }
|
||||
|
||||
func (e *ErrReplay) Error() string { return fmt.Sprintf("aura/session: replay seq=%d", e.Seq) }
|
||||
|
||||
// Replay tracks the highest accepted sequence number and a 64-bit bitmap of the positions
|
||||
// below it that have already been accepted. A datagram is accepted iff its seq is strictly
|
||||
// newer than everything seen, or falls inside the window and was not previously seen.
|
||||
type Replay struct {
|
||||
highest uint64
|
||||
bitmap uint64
|
||||
seeded bool
|
||||
}
|
||||
|
||||
// NewReplay primes a window so the first expected record is `start` (the AEAD counter at the
|
||||
// end of the handshake). Anything strictly below `start` is treated as already-consumed.
|
||||
//
|
||||
// This mirrors ReplayWindow::new in the Rust impl: highest = start - 1 (saturating),
|
||||
// seeded = start > 0.
|
||||
func NewReplay(start uint64) *Replay {
|
||||
r := &Replay{seeded: start > 0}
|
||||
if start > 0 {
|
||||
r.highest = start - 1
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// CheckAndSet records a seen seq. Returns nil if it is fresh; *ErrReplay otherwise.
|
||||
func (r *Replay) CheckAndSet(seq uint64) error {
|
||||
if !r.seeded {
|
||||
// First-ever record (only reachable when started at 0): accept and seed.
|
||||
r.seeded = true
|
||||
r.highest = seq
|
||||
r.bitmap = 0
|
||||
return nil
|
||||
}
|
||||
if seq > r.highest {
|
||||
shift := seq - r.highest
|
||||
if shift >= 64 {
|
||||
r.bitmap = 0
|
||||
} else {
|
||||
r.bitmap = (r.bitmap << shift) | (1 << (shift - 1))
|
||||
}
|
||||
r.highest = seq
|
||||
return nil
|
||||
}
|
||||
// seq <= highest: must be inside the window and previously unseen.
|
||||
offset := r.highest - seq
|
||||
if offset >= ReplayWindow {
|
||||
return &ErrReplay{Seq: seq}
|
||||
}
|
||||
if offset == 0 {
|
||||
return &ErrReplay{Seq: seq}
|
||||
}
|
||||
mask := uint64(1) << (offset - 1)
|
||||
if r.bitmap&mask != 0 {
|
||||
return &ErrReplay{Seq: seq}
|
||||
}
|
||||
r.bitmap |= mask
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/aura/singbox-aura/aura/crypto"
|
||||
"github.com/aura/singbox-aura/aura/frame"
|
||||
)
|
||||
|
||||
// SeqLen is the size of the per-record sequence-number prefix.
|
||||
const SeqLen = 8
|
||||
|
||||
// PostHandshakeCounter is the AEAD counter at which the first application Data record starts,
|
||||
// because each direction sealed exactly two encrypted handshake messages before it.
|
||||
const PostHandshakeCounter uint64 = 2
|
||||
|
||||
// DatagramSender holds the outbound explicit-nonce AEAD plus the next sequence number to
|
||||
// stamp. Produced by Session.IntoDatagramParts() after the handshake completes.
|
||||
type DatagramSender struct {
|
||||
key *crypto.AeadKey
|
||||
seq uint64
|
||||
}
|
||||
|
||||
// NewDatagramSender wraps a 32-byte key starting at the given counter.
|
||||
func NewDatagramSender(rawKey []byte, startCounter uint64) (*DatagramSender, error) {
|
||||
k, err := crypto.NewAeadKey(rawKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DatagramSender{key: k, seq: startCounter}, nil
|
||||
}
|
||||
|
||||
// Seal encodes f, seals it under the next sequence number, and returns the on-wire datagram
|
||||
// payload: seq(8 BE) || ciphertext.
|
||||
func (s *DatagramSender) Seal(f *frame.Frame) []byte {
|
||||
seq := s.seq
|
||||
enc := frame.EncodeFrame(f)
|
||||
var seqBE [SeqLen]byte
|
||||
binary.BigEndian.PutUint64(seqBE[:], seq)
|
||||
ct := s.key.Seal(seq, enc, seqBE[:])
|
||||
out := make([]byte, 0, SeqLen+len(ct))
|
||||
out = append(out, seqBE[:]...)
|
||||
out = append(out, ct...)
|
||||
s.seq++
|
||||
return out
|
||||
}
|
||||
|
||||
// NextSeq is the sequence number the next Seal will use (test/diagnostic helper).
|
||||
func (s *DatagramSender) NextSeq() uint64 { return s.seq }
|
||||
|
||||
// DatagramReceiver authenticates, replay-checks, and decodes incoming datagram payloads.
|
||||
type DatagramReceiver struct {
|
||||
key *crypto.AeadKey
|
||||
replay *Replay
|
||||
}
|
||||
|
||||
// NewDatagramReceiver wraps a 32-byte key plus a replay window primed at startCounter.
|
||||
func NewDatagramReceiver(rawKey []byte, startCounter uint64) (*DatagramReceiver, error) {
|
||||
k, err := crypto.NewAeadKey(rawKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DatagramReceiver{key: k, replay: NewReplay(startCounter)}, nil
|
||||
}
|
||||
|
||||
// Open parses one datagram payload, runs the replay check first (so a duplicate cannot advance
|
||||
// the AEAD state), then verifies and decodes the inner Frame.
|
||||
func (r *DatagramReceiver) Open(datagram []byte) (*frame.Frame, error) {
|
||||
if len(datagram) < SeqLen {
|
||||
return nil, fmt.Errorf("aura/session: datagram shorter than seq prefix")
|
||||
}
|
||||
seqBE := datagram[:SeqLen]
|
||||
seq := binary.BigEndian.Uint64(seqBE)
|
||||
ct := datagram[SeqLen:]
|
||||
if err := r.replay.CheckAndSet(seq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pt, err := r.key.Open(seq, ct, seqBE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return frame.DecodeFrame(pt)
|
||||
}
|
||||
|
||||
// ErrUnexpectedMsg is returned by the stream half when the wire carries a non-Data record.
|
||||
var ErrUnexpectedMsg = errors.New("aura/session: unexpected message type")
|
||||
@@ -0,0 +1,123 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/aura/singbox-aura/aura/frame"
|
||||
)
|
||||
|
||||
func TestReplayWindowBasicMonotonic(t *testing.T) {
|
||||
w := NewReplay(2)
|
||||
for _, s := range []uint64{2, 3, 4} {
|
||||
if err := w.CheckAndSet(s); err != nil {
|
||||
t.Fatalf("seq %d: unexpected %v", s, err)
|
||||
}
|
||||
}
|
||||
for _, s := range []uint64{2, 3, 4} {
|
||||
var e *ErrReplay
|
||||
if err := w.CheckAndSet(s); !errors.As(err, &e) {
|
||||
t.Fatalf("seq %d: want replay, got %v", s, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayWindowOutOfOrderWithinWindow(t *testing.T) {
|
||||
w := NewReplay(0)
|
||||
if err := w.CheckAndSet(0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.CheckAndSet(10); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.CheckAndSet(5); err != nil {
|
||||
t.Fatalf("5 inside window: %v", err)
|
||||
}
|
||||
if err := w.CheckAndSet(5); err == nil {
|
||||
t.Fatal("replay of 5 must be rejected")
|
||||
}
|
||||
if err := w.CheckAndSet(10); err == nil {
|
||||
t.Fatal("replay of 10 must be rejected")
|
||||
}
|
||||
if err := w.CheckAndSet(11); err != nil {
|
||||
t.Fatalf("new high 11: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayWindowRejectsTooOld(t *testing.T) {
|
||||
w := NewReplay(0)
|
||||
if err := w.CheckAndSet(0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.CheckAndSet(200); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.CheckAndSet(1); err == nil {
|
||||
t.Fatal("far below window must be rejected")
|
||||
}
|
||||
if err := w.CheckAndSet(200 - ReplayWindow); err == nil {
|
||||
t.Fatal("at the floor of the window must be rejected")
|
||||
}
|
||||
if err := w.CheckAndSet(200 - ReplayWindow + 1); err != nil {
|
||||
t.Fatalf("just inside window: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatagramRoundtripReorderAndReplay(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = 11
|
||||
}
|
||||
tx, err := NewDatagramSender(key, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rx, err := NewDatagramReceiver(key, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
d0 := tx.Seal(&frame.Frame{Kind: frame.FrameData, StreamID: 0, Payload: []byte("pkt-a")})
|
||||
d1 := tx.Seal(&frame.Frame{Kind: frame.FrameData, StreamID: 0, Payload: []byte("pkt-b")})
|
||||
|
||||
// Out-of-order delivery within the window.
|
||||
gotB, err := rx.Open(d1)
|
||||
if err != nil {
|
||||
t.Fatalf("open d1: %v", err)
|
||||
}
|
||||
if gotB.Kind != frame.FrameData || string(gotB.Payload) != "pkt-b" {
|
||||
t.Fatalf("d1: %+v", gotB)
|
||||
}
|
||||
gotA, err := rx.Open(d0)
|
||||
if err != nil {
|
||||
t.Fatalf("open d0: %v", err)
|
||||
}
|
||||
if string(gotA.Payload) != "pkt-a" {
|
||||
t.Fatalf("d0: %+v", gotA)
|
||||
}
|
||||
|
||||
if _, err := rx.Open(d1); err == nil {
|
||||
t.Fatal("replay of d1 must be rejected")
|
||||
}
|
||||
|
||||
bad := tx.Seal(&frame.Frame{Kind: frame.FramePing, Seq: 7})
|
||||
bad[len(bad)-1] ^= 1
|
||||
if _, err := rx.Open(bad); err == nil {
|
||||
t.Fatal("tampered ciphertext must fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSenderNextSeqAdvances(t *testing.T) {
|
||||
tx, err := NewDatagramSender(bytes.Repeat([]byte{1}, 32), 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tx.NextSeq() != 2 {
|
||||
t.Fatalf("initial next seq %d", tx.NextSeq())
|
||||
}
|
||||
_ = tx.Seal(&frame.Frame{Kind: frame.FramePing, Seq: 1})
|
||||
if tx.NextSeq() != 3 {
|
||||
t.Fatalf("after seal: %d", tx.NextSeq())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KnockLen is the length in bytes of the truncated HMAC-SHA256 port-knock token.
|
||||
const KnockLen = 16
|
||||
|
||||
// KnockForMinute derives the 16-byte port-knock token for a given Unix minute under the shared
|
||||
// 32-byte key.
|
||||
//
|
||||
// Wire formula (mirrors aura-transport/src/udp.rs):
|
||||
//
|
||||
// HMAC-SHA256(key, u64_be(minute))[..16]
|
||||
//
|
||||
// The server validates against floor(now/60) and ±1 minute (~3-minute acceptance window).
|
||||
func KnockForMinute(key [32]byte, minute uint64) [KnockLen]byte {
|
||||
var mb [8]byte
|
||||
binary.BigEndian.PutUint64(mb[:], minute)
|
||||
m := hmac.New(sha256.New, key[:])
|
||||
m.Write(mb[:])
|
||||
tag := m.Sum(nil)
|
||||
var out [KnockLen]byte
|
||||
copy(out[:], tag[:KnockLen])
|
||||
return out
|
||||
}
|
||||
|
||||
// CurrentUnixMinute returns floor(now/60). Used by the client to compute the knock for "now".
|
||||
func CurrentUnixMinute() uint64 {
|
||||
return uint64(time.Now().Unix() / 60)
|
||||
}
|
||||
@@ -0,0 +1,548 @@
|
||||
// Package transport implements the Aura UDP client: a reliable HS-adapter wrapping the
|
||||
// handshake so it can run over lossy UDP, plus the post-handshake datagram data path.
|
||||
//
|
||||
// Wire layout (mirrors aura-transport/src/udp.rs):
|
||||
//
|
||||
// 0x01 HS : [optional 16-byte knock prefix] || 0x01 || hs_seq(u16 BE) || ack_upto(u16 BE) || msg_bytes
|
||||
// 0x02 DATA : 0x02 || rec_len(u16 BE) || sealed_record [|| random_padding]
|
||||
//
|
||||
// The HS phase is a DTLS-flight style reliability layer: every sent datagram is retransmitted
|
||||
// every `hs_rto` until either acked or the overall `hs_timeout` elapses; cumulative acks prune
|
||||
// the retransmit queue.
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aura/singbox-aura/aura/frame"
|
||||
"github.com/aura/singbox-aura/aura/handshake"
|
||||
"github.com/aura/singbox-aura/aura/session"
|
||||
)
|
||||
|
||||
// Wire-layer type bytes.
|
||||
const (
|
||||
typeHS byte = 0x01
|
||||
typeDATA byte = 0x02
|
||||
)
|
||||
|
||||
// HS header layout: type(1) || hs_seq(2 BE) || ack_upto(2 BE) || msg_bytes.
|
||||
const hsPrefixLen = 1 + 2 + 2
|
||||
|
||||
// DATA header layout: type(1) || rec_len(2 BE) || sealed_record.
|
||||
const dataPrefixLen = 1 + 2
|
||||
|
||||
// AckNone is the on-wire sentinel for "I have received nothing yet".
|
||||
const ackNone uint16 = 0xFFFF
|
||||
|
||||
// Default UDP read buffer — large enough for ClientHello (1253 bytes + headers) with slack.
|
||||
const recvBuf = 2048
|
||||
|
||||
// Options exposes the same knobs as Rust's UdpOpts. Defaults intentionally match.
|
||||
type Options struct {
|
||||
// Probe resistance (optional). When KnockEnabled is true, KnockKey must be 32 bytes.
|
||||
KnockEnabled bool
|
||||
KnockKey [32]byte
|
||||
|
||||
// Handshake retransmit timeout: every HsRTO, all unacked HS datagrams are resent.
|
||||
HsRTO time.Duration
|
||||
// Overall handshake deadline.
|
||||
HsTimeout time.Duration
|
||||
// Linger duration: after the handshake completes, the client briefly resends the final
|
||||
// flight to recover from a lost last message.
|
||||
HsLinger time.Duration
|
||||
}
|
||||
|
||||
// DefaultOptions matches Rust's UdpOpts::default sans obfuscation / cover-traffic (a TODO for v1
|
||||
// of the Go port).
|
||||
func DefaultOptions() *Options {
|
||||
return &Options{
|
||||
HsRTO: 250 * time.Millisecond,
|
||||
HsTimeout: 10 * time.Second,
|
||||
HsLinger: 2 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Connection is an established Aura UDP connection. After Dial succeeds, the caller uses Send
|
||||
// / Recv to ship application packets.
|
||||
type Connection struct {
|
||||
conn *net.UDPConn
|
||||
sender *session.DatagramSender
|
||||
recvr *session.DatagramReceiver
|
||||
peer string
|
||||
mu sync.Mutex // serializes sender access (Pong replies + user sends)
|
||||
}
|
||||
|
||||
// PeerID returns the verified peer identity (the server name).
|
||||
func (c *Connection) PeerID() string { return c.peer }
|
||||
|
||||
// Send seals one application packet as a Frame::Data on stream 0 and ships it in one DATA
|
||||
// datagram.
|
||||
func (c *Connection) Send(payload []byte) error {
|
||||
c.mu.Lock()
|
||||
rec := c.sender.Seal(&frame.Frame{Kind: frame.FrameData, StreamID: 0, Payload: payload})
|
||||
c.mu.Unlock()
|
||||
return c.writeDataDgram(rec)
|
||||
}
|
||||
|
||||
// Recv blocks until the next application packet arrives. Ping is answered with Pong
|
||||
// transparently; Pong is ignored; Close surfaces as an error (terminating the connection).
|
||||
func (c *Connection) Recv() ([]byte, error) {
|
||||
buf := make([]byte, recvBuf)
|
||||
for {
|
||||
n, err := c.conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dg := buf[:n]
|
||||
if len(dg) == 0 {
|
||||
continue
|
||||
}
|
||||
switch dg[0] {
|
||||
case typeDATA:
|
||||
if len(dg) < dataPrefixLen {
|
||||
continue
|
||||
}
|
||||
recLen := int(binary.BigEndian.Uint16(dg[1:3]))
|
||||
end := dataPrefixLen + recLen
|
||||
if len(dg) < end {
|
||||
continue
|
||||
}
|
||||
f, err := c.recvr.Open(dg[dataPrefixLen:end])
|
||||
if err != nil {
|
||||
continue // replay / tampered / out-of-window: defensive drop
|
||||
}
|
||||
switch f.Kind {
|
||||
case frame.FrameData:
|
||||
return f.Payload, nil
|
||||
case frame.FramePing:
|
||||
// Answer with Pong on the same datagram path.
|
||||
c.mu.Lock()
|
||||
rec := c.sender.Seal(&frame.Frame{Kind: frame.FramePong, Seq: f.Seq})
|
||||
c.mu.Unlock()
|
||||
if err := c.writeDataDgram(rec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case frame.FramePong:
|
||||
continue
|
||||
case frame.FrameClose:
|
||||
return nil, fmt.Errorf("aura/transport: peer closed (code=%d): %s", f.Code, f.Reason)
|
||||
}
|
||||
case typeHS:
|
||||
// Late HS retransmit on the data path: ignore.
|
||||
continue
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases the underlying socket.
|
||||
func (c *Connection) Close() error { return c.conn.Close() }
|
||||
|
||||
func (c *Connection) writeDataDgram(rec []byte) error {
|
||||
if len(rec) > 0xFFFF {
|
||||
return fmt.Errorf("aura/transport: sealed record too large: %d", len(rec))
|
||||
}
|
||||
dg := make([]byte, 0, dataPrefixLen+len(rec))
|
||||
dg = append(dg, typeDATA)
|
||||
var lb [2]byte
|
||||
binary.BigEndian.PutUint16(lb[:], uint16(len(rec)))
|
||||
dg = append(dg, lb[:]...)
|
||||
dg = append(dg, rec...)
|
||||
_, err := c.conn.Write(dg)
|
||||
return err
|
||||
}
|
||||
|
||||
// Dial connects to an Aura UDP server, performs the mutual-auth handshake over the reliable
|
||||
// adapter, and returns an established Connection.
|
||||
func Dial(ctx context.Context, addr string, hsCfg *handshake.ClientConfig, opts *Options) (*Connection, error) {
|
||||
if opts == nil {
|
||||
opts = DefaultOptions()
|
||||
}
|
||||
rAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve %s: %w", addr, err)
|
||||
}
|
||||
conn, err := net.DialUDP("udp", nil, rAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial udp: %w", err)
|
||||
}
|
||||
// The reliable adapter manages send/recv during the handshake; once we have a Connection
|
||||
// the user owns it.
|
||||
adapter := newHSAdapter(conn, opts)
|
||||
done := make(chan struct{})
|
||||
adapter.start(done)
|
||||
defer close(done) // stop the driver once Dial returns
|
||||
|
||||
res, err := handshake.Client(adapter, adapter, hsCfg)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("aura handshake: %w", err)
|
||||
}
|
||||
|
||||
// After the handshake, build datagram codecs starting at PostHandshakeCounter (2): both
|
||||
// directions sealed exactly two encrypted handshake messages (Auth + Finished), so the AEAD
|
||||
// counters resume from there.
|
||||
sender, err := session.NewDatagramSender(res.C2S[:], session.PostHandshakeCounter)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
recvr, err := session.NewDatagramReceiver(res.S2C[:], session.PostHandshakeCounter)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Linger: briefly resend the last unacked flight so a lost final message is recovered.
|
||||
adapter.linger(opts.HsRTO, opts.HsLinger)
|
||||
|
||||
return &Connection{
|
||||
conn: conn,
|
||||
sender: sender,
|
||||
recvr: recvr,
|
||||
peer: res.PeerID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================================================================
|
||||
// Reliable HS adapter
|
||||
// ============================================================================================
|
||||
|
||||
// hsAdapter wraps a *net.UDPConn with a small DTLS-flight reliability layer. It implements
|
||||
// io.Reader and io.Writer so handshake.Client can drive it like a stream — the adapter parses
|
||||
// the 5-byte Aura frame header in its outbound buffer to know each message's total length, so
|
||||
// each whole frame becomes exactly one HS datagram.
|
||||
type hsAdapter struct {
|
||||
conn *net.UDPConn
|
||||
opts *Options
|
||||
|
||||
mu sync.Mutex
|
||||
// Outbound: bytes the handshake wrote but not yet framed into an HS datagram.
|
||||
outPartial []byte
|
||||
// Outbound: hs_seq -> msg_bytes for retransmit.
|
||||
unacked map[uint16][]byte
|
||||
// Outbound: next hs_seq to stamp.
|
||||
nextSendSeq uint16
|
||||
|
||||
// Inbound: hs_seq -> received msg_bytes (reorder buffer).
|
||||
inBuf map[uint16][]byte
|
||||
// Inbound: next hs_seq we expect to deliver.
|
||||
nextDeliverSeq uint16
|
||||
// Inbound: bytes delivered in order but not yet read by the caller.
|
||||
ready []byte
|
||||
readyPos int
|
||||
|
||||
// Signals from the network goroutine to a parked reader.
|
||||
readCond *sync.Cond
|
||||
closed bool
|
||||
|
||||
// Network goroutine errors (only the first sticks).
|
||||
netErr error
|
||||
}
|
||||
|
||||
func newHSAdapter(conn *net.UDPConn, opts *Options) *hsAdapter {
|
||||
a := &hsAdapter{
|
||||
conn: conn,
|
||||
opts: opts,
|
||||
unacked: make(map[uint16][]byte),
|
||||
inBuf: make(map[uint16][]byte),
|
||||
}
|
||||
a.readCond = sync.NewCond(&a.mu)
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *hsAdapter) ackUpto() uint16 {
|
||||
if a.nextDeliverSeq == 0 {
|
||||
return ackNone
|
||||
}
|
||||
return a.nextDeliverSeq - 1
|
||||
}
|
||||
|
||||
// pruneAcked drops every entry with hs_seq <= ack_upto (cumulative ack).
|
||||
func (a *hsAdapter) pruneAcked(ackUpto uint16) {
|
||||
if ackUpto == ackNone {
|
||||
return
|
||||
}
|
||||
for k := range a.unacked {
|
||||
if k <= ackUpto {
|
||||
delete(a.unacked, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// acceptIncoming integrates a received HS payload at seq, advancing contiguous delivery.
|
||||
func (a *hsAdapter) acceptIncoming(seq uint16, msg []byte) {
|
||||
if seq < a.nextDeliverSeq {
|
||||
return // already delivered (a retransmit): drop
|
||||
}
|
||||
if _, ok := a.inBuf[seq]; !ok {
|
||||
a.inBuf[seq] = msg
|
||||
}
|
||||
before := len(a.ready)
|
||||
for {
|
||||
m, ok := a.inBuf[a.nextDeliverSeq]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
delete(a.inBuf, a.nextDeliverSeq)
|
||||
a.ready = append(a.ready, m...)
|
||||
a.nextDeliverSeq++ // wraps mod 2^16
|
||||
}
|
||||
if len(a.ready) > before {
|
||||
a.readCond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// sendHS builds and sends one HS datagram carrying msg at seq+ack.
|
||||
// Called under the lock or with the values already snapshotted.
|
||||
func (a *hsAdapter) sendHS(seq, ack uint16, msg []byte) error {
|
||||
prefix := 0
|
||||
if a.opts.KnockEnabled {
|
||||
prefix = KnockLen
|
||||
}
|
||||
dg := make([]byte, 0, prefix+hsPrefixLen+len(msg))
|
||||
if a.opts.KnockEnabled {
|
||||
tok := KnockForMinute(a.opts.KnockKey, CurrentUnixMinute())
|
||||
dg = append(dg, tok[:]...)
|
||||
}
|
||||
dg = append(dg, typeHS)
|
||||
var sb [2]byte
|
||||
binary.BigEndian.PutUint16(sb[:], seq)
|
||||
dg = append(dg, sb[:]...)
|
||||
binary.BigEndian.PutUint16(sb[:], ack)
|
||||
dg = append(dg, sb[:]...)
|
||||
dg = append(dg, msg...)
|
||||
_, err := a.conn.Write(dg)
|
||||
return err
|
||||
}
|
||||
|
||||
// flushOutgoing parses message boundaries out of outPartial and emits one HS datagram per
|
||||
// whole frame. Holds the lock internally; safe to call from any goroutine.
|
||||
func (a *hsAdapter) flushOutgoing() error {
|
||||
for {
|
||||
a.mu.Lock()
|
||||
if len(a.outPartial) < frame.HeaderLen {
|
||||
a.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
var hdr [frame.HeaderLen]byte
|
||||
copy(hdr[:], a.outPartial[:frame.HeaderLen])
|
||||
_, plen, err := frame.DecodeHeader(hdr)
|
||||
if err != nil {
|
||||
a.mu.Unlock()
|
||||
return nil // wait for more bytes
|
||||
}
|
||||
total := frame.HeaderLen + plen
|
||||
if len(a.outPartial) < total {
|
||||
a.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
msg := make([]byte, total)
|
||||
copy(msg, a.outPartial[:total])
|
||||
a.outPartial = a.outPartial[total:]
|
||||
seq := a.nextSendSeq
|
||||
a.nextSendSeq++
|
||||
ack := a.ackUpto()
|
||||
a.unacked[seq] = msg
|
||||
a.mu.Unlock()
|
||||
|
||||
if err := a.sendHS(seq, ack, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// maybeBareAck emits a zero-length HS datagram so the peer can prune its retransmit queue. The
|
||||
// bare ack does not consume a sequence number.
|
||||
func (a *hsAdapter) maybeBareAck() error {
|
||||
a.mu.Lock()
|
||||
should := a.nextDeliverSeq > 0 && len(a.outPartial) == 0
|
||||
seq := a.nextSendSeq
|
||||
ack := a.ackUpto()
|
||||
a.mu.Unlock()
|
||||
if !should {
|
||||
return nil
|
||||
}
|
||||
return a.sendHS(seq, ack, nil)
|
||||
}
|
||||
|
||||
// retransmitUnacked re-sends every still-unacked HS datagram. Called on the RTO timer.
|
||||
func (a *hsAdapter) retransmitUnacked() error {
|
||||
a.mu.Lock()
|
||||
ack := a.ackUpto()
|
||||
// Iterate in seq order for deterministic wire behaviour.
|
||||
seqs := make([]uint16, 0, len(a.unacked))
|
||||
for k := range a.unacked {
|
||||
seqs = append(seqs, k)
|
||||
}
|
||||
sort.Slice(seqs, func(i, j int) bool { return seqs[i] < seqs[j] })
|
||||
batch := make([][2]any, 0, len(seqs))
|
||||
for _, s := range seqs {
|
||||
batch = append(batch, [2]any{s, a.unacked[s]})
|
||||
}
|
||||
a.mu.Unlock()
|
||||
for _, e := range batch {
|
||||
seq := e[0].(uint16)
|
||||
msg := e[1].([]byte)
|
||||
if err := a.sendHS(seq, ack, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pumpOneIncoming reads and integrates exactly one HS datagram.
|
||||
func (a *hsAdapter) pumpOneIncoming() error {
|
||||
buf := make([]byte, recvBuf)
|
||||
n, err := a.conn.Read(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dg := buf[:n]
|
||||
if len(dg) == 0 || dg[0] != typeHS || len(dg) < hsPrefixLen {
|
||||
return nil
|
||||
}
|
||||
seq := binary.BigEndian.Uint16(dg[1:3])
|
||||
ack := binary.BigEndian.Uint16(dg[3:5])
|
||||
msg := append([]byte{}, dg[hsPrefixLen:]...)
|
||||
|
||||
a.mu.Lock()
|
||||
a.pruneAcked(ack)
|
||||
if len(msg) > 0 {
|
||||
a.acceptIncoming(seq, msg)
|
||||
}
|
||||
a.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// start launches the driver goroutine that interleaves I/O while the handshake future runs.
|
||||
// The driver stops when `done` is closed.
|
||||
func (a *hsAdapter) start(done chan struct{}) {
|
||||
// Reader goroutine.
|
||||
go func() {
|
||||
for {
|
||||
// Use a short read timeout so the goroutine can notice `done` promptly.
|
||||
_ = a.conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
|
||||
if err := a.pumpOneIncoming(); err != nil {
|
||||
if isTimeout(err) {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
a.setErr(err)
|
||||
return
|
||||
}
|
||||
// After pumping an incoming datagram, flush any replies + maybe a bare ack.
|
||||
_ = a.flushOutgoing()
|
||||
_ = a.maybeBareAck()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
// RTO + timeout driver.
|
||||
go func() {
|
||||
rto := time.NewTicker(a.opts.HsRTO)
|
||||
defer rto.Stop()
|
||||
dead := time.NewTimer(a.opts.HsTimeout)
|
||||
defer dead.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-rto.C:
|
||||
_ = a.flushOutgoing()
|
||||
_ = a.retransmitUnacked()
|
||||
case <-dead.C:
|
||||
a.setErr(fmt.Errorf("aura/transport: UDP handshake timed out after %s", a.opts.HsTimeout))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// linger briefly resends the last unacked flight after the handshake returns. Stops early if
|
||||
// nothing is unacked.
|
||||
func (a *hsAdapter) linger(rto, total time.Duration) {
|
||||
rounds := 3
|
||||
per := rto
|
||||
if total/time.Duration(rounds) < per {
|
||||
per = total / time.Duration(rounds)
|
||||
}
|
||||
for i := 0; i < rounds; i++ {
|
||||
a.mu.Lock()
|
||||
empty := len(a.unacked) == 0
|
||||
a.mu.Unlock()
|
||||
if empty {
|
||||
return
|
||||
}
|
||||
_ = a.retransmitUnacked()
|
||||
time.Sleep(per)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *hsAdapter) setErr(err error) {
|
||||
a.mu.Lock()
|
||||
if a.netErr == nil {
|
||||
a.netErr = err
|
||||
}
|
||||
a.closed = true
|
||||
a.readCond.Broadcast()
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
// Read implements io.Reader for the handshake driver: hand out already-delivered contiguous
|
||||
// bytes. Blocks (via Cond) until some bytes are ready or the adapter is closed/errored.
|
||||
func (a *hsAdapter) Read(p []byte) (int, error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
for {
|
||||
if a.readyPos < len(a.ready) {
|
||||
n := copy(p, a.ready[a.readyPos:])
|
||||
a.readyPos += n
|
||||
if a.readyPos == len(a.ready) {
|
||||
a.ready = a.ready[:0]
|
||||
a.readyPos = 0
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
if a.netErr != nil {
|
||||
return 0, a.netErr
|
||||
}
|
||||
if a.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
a.readCond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer: append to outPartial and flush any newly-complete messages.
|
||||
func (a *hsAdapter) Write(p []byte) (int, error) {
|
||||
a.mu.Lock()
|
||||
a.outPartial = append(a.outPartial, p...)
|
||||
a.mu.Unlock()
|
||||
if err := a.flushOutgoing(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// isTimeout returns true if err is a net.Error with Timeout()==true.
|
||||
func isTimeout(err error) bool {
|
||||
var ne net.Error
|
||||
return err != nil && errors.As(err, &ne) && ne.Timeout()
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestKnockForMinuteDeterministicAndMinuteSensitive(t *testing.T) {
|
||||
var k [32]byte
|
||||
for i := range k {
|
||||
k[i] = byte(i)
|
||||
}
|
||||
a := KnockForMinute(k, 1_000_000)
|
||||
b := KnockForMinute(k, 1_000_000)
|
||||
if a != b {
|
||||
t.Fatalf("same inputs gave different output: %x vs %x", a, b)
|
||||
}
|
||||
c := KnockForMinute(k, 1_000_001)
|
||||
if a == c {
|
||||
t.Fatalf("different minute gave same output: %x", c)
|
||||
}
|
||||
var k2 [32]byte
|
||||
copy(k2[:], k[:])
|
||||
k2[0] ^= 1
|
||||
d := KnockForMinute(k2, 1_000_000)
|
||||
if a == d {
|
||||
t.Fatalf("different key gave same output: %x", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReorderBufferDeliversInSequenceOrder(t *testing.T) {
|
||||
a := newHSAdapter(nil, DefaultOptions())
|
||||
// Direct manipulation of the adapter's reorder buffer mimicking the Rust unit test
|
||||
// `reorder_buffer_delivers_in_sequence_order`.
|
||||
a.acceptIncoming(2, []byte("ccc"))
|
||||
a.acceptIncoming(1, []byte("bbb"))
|
||||
if len(a.ready) != 0 {
|
||||
t.Fatalf("contiguous run unexpectedly emitted: %x", a.ready)
|
||||
}
|
||||
a.acceptIncoming(0, []byte("aaa"))
|
||||
if string(a.ready) != "aaabbbccc" {
|
||||
t.Fatalf("delivery order wrong: %s", a.ready)
|
||||
}
|
||||
if a.nextDeliverSeq != 3 {
|
||||
t.Fatalf("contig counter wrong: %d", a.nextDeliverSeq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateDatagramsAreDropped(t *testing.T) {
|
||||
a := newHSAdapter(nil, DefaultOptions())
|
||||
a.acceptIncoming(0, []byte("x"))
|
||||
a.acceptIncoming(0, []byte("x"))
|
||||
if string(a.ready) != "x" {
|
||||
t.Fatalf("duplicate retransmit double-counted: %s", a.ready)
|
||||
}
|
||||
a.acceptIncoming(2, []byte("z"))
|
||||
a.acceptIncoming(2, []byte("z"))
|
||||
a.acceptIncoming(1, []byte("y"))
|
||||
if string(a.ready) != "xyz" {
|
||||
t.Fatalf("delivery wrong with duplicates: %s", a.ready)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAckUptoReportsHighestContiguousOrSentinel(t *testing.T) {
|
||||
a := newHSAdapter(nil, DefaultOptions())
|
||||
if a.ackUpto() != ackNone {
|
||||
t.Fatalf("initial ack not sentinel: 0x%04X", a.ackUpto())
|
||||
}
|
||||
a.acceptIncoming(0, []byte("a"))
|
||||
if a.ackUpto() != 0 {
|
||||
t.Fatalf("after seq 0: %d", a.ackUpto())
|
||||
}
|
||||
a.acceptIncoming(2, []byte("c"))
|
||||
if a.ackUpto() != 0 {
|
||||
t.Fatalf("gap should not advance ack: %d", a.ackUpto())
|
||||
}
|
||||
a.acceptIncoming(1, []byte("b"))
|
||||
if a.ackUpto() != 2 {
|
||||
t.Fatalf("filling gap should advance: %d", a.ackUpto())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneAckedIsCumulativeAndRespectsSentinel(t *testing.T) {
|
||||
a := newHSAdapter(nil, DefaultOptions())
|
||||
a.unacked[0] = []byte{0}
|
||||
a.unacked[1] = []byte{1}
|
||||
a.unacked[2] = []byte{2}
|
||||
|
||||
a.pruneAcked(ackNone)
|
||||
if len(a.unacked) != 3 {
|
||||
t.Fatalf("sentinel should prune nothing, got %d", len(a.unacked))
|
||||
}
|
||||
a.pruneAcked(1)
|
||||
if _, ok := a.unacked[0]; ok {
|
||||
t.Fatal("seq 0 should be pruned")
|
||||
}
|
||||
if _, ok := a.unacked[1]; ok {
|
||||
t.Fatal("seq 1 should be pruned")
|
||||
}
|
||||
if _, ok := a.unacked[2]; !ok {
|
||||
t.Fatal("seq 2 should remain")
|
||||
}
|
||||
a.pruneAcked(2)
|
||||
if len(a.unacked) != 0 {
|
||||
t.Fatalf("should be empty: %d", len(a.unacked))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user