package frame import ( "bytes" "errors" "testing" ) func TestHeaderRoundtripAllTypes(t *testing.T) { cases := []struct { ty MsgType b byte }{ {MsgClientHello, 0x01}, {MsgServerHello, 0x02}, {MsgClientAuth, 0x03}, {MsgServerAuth, 0x04}, {MsgFinished, 0x05}, {MsgData, 0x06}, {MsgAlert, 0xFF}, } for _, c := range cases { h, err := EncodeHeader(c.ty, 0x00123456) if err != nil { t.Fatalf("encode %s: %v", c.ty, err) } if h[0] != c.b { t.Fatalf("type byte for %s: got 0x%02X want 0x%02X", c.ty, h[0], c.b) } if h[4] != ProtocolVersion { t.Fatalf("version byte: got 0x%02X want 0x01", h[4]) } mt, plen, err := DecodeHeader(h) if err != nil { t.Fatalf("decode %s: %v", c.ty, err) } if mt != c.ty || plen != 0x00123456 { t.Fatalf("roundtrip mismatch: got (%s, %d)", mt, plen) } } } func TestHeaderRejectsOversizeAndBadVersion(t *testing.T) { if _, err := EncodeHeader(MsgData, MaxPayloadLen+1); !errors.Is(err, ErrFrameTooLarge) { t.Fatalf("oversize: want ErrFrameTooLarge, got %v", err) } h, err := EncodeHeader(MsgData, 1) if err != nil { t.Fatal(err) } h[4] = 0x02 if _, _, err := DecodeHeader(h); !errors.Is(err, ErrBadVersion) { t.Fatalf("bad version: want ErrBadVersion, got %v", err) } // Reset the version so the unknown-type check actually exercises the type branch. h[4] = ProtocolVersion h[0] = 0x77 if _, _, err := DecodeHeader(h); !errors.Is(err, ErrUnknownMsgType) { t.Fatalf("unknown type: want ErrUnknownMsgType, got %v", err) } } func TestFrameRoundtrip(t *testing.T) { frames := []*Frame{ {Kind: FrameData, StreamID: 0xDEADBEEF, Payload: []byte("hello world")}, {Kind: FrameData, StreamID: 0, Payload: nil}, {Kind: FramePing, Seq: 42}, {Kind: FramePong, Seq: 0xFFFFFFFF}, {Kind: FrameClose, Code: 7, Reason: "going away \U0001F44B"}, {Kind: FrameClose, Code: 0, Reason: ""}, } for _, f := range frames { enc := EncodeFrame(f) got, err := DecodeFrame(enc) if err != nil { t.Fatalf("decode %v: %v", f.Kind, err) } if got.Kind != f.Kind || got.StreamID != f.StreamID || got.Seq != f.Seq || got.Code != f.Code || got.Reason != f.Reason || !bytes.Equal(got.Payload, f.Payload) { t.Fatalf("roundtrip mismatch: %+v vs %+v", f, got) } } } func TestFrameDecodeRejectsGarbage(t *testing.T) { if _, err := DecodeFrame(nil); err == nil { t.Fatal("nil: want error") } if _, err := DecodeFrame([]byte{0x99}); err == nil { t.Fatal("unknown tag: want error") } if _, err := DecodeFrame([]byte{byte(FramePing), 0x00}); err == nil { t.Fatal("truncated ping: want error") } if _, err := DecodeFrame([]byte{byte(FrameClose)}); err == nil { t.Fatal("missing close code: want error") } } func TestControlEnvelopeRoundtrip(t *testing.T) { env := EncodeControlEnvelope(ControlCrlPush, []byte("hello")) if !bytes.Equal(env[:4], ControlEnvelopeMagic[:]) { t.Fatalf("magic mismatch: %x", env[:4]) } kind, payload, ok, err := DecodeControlEnvelope(env) if err != nil || !ok { t.Fatalf("decode: ok=%v err=%v", ok, err) } if kind != ControlCrlPush || string(payload) != "hello" { t.Fatalf("decode mismatch: kind=%v payload=%q", kind, payload) } } func TestControlEnvelopeSkipsNormalIPPackets(t *testing.T) { cases := [][]byte{ {0x45, 0x00, 0x00, 0x14}, // IPv4 packet {0x60, 0x00, 0x00, 0x00}, // IPv6 packet {0xAA, 0xAA, 0xC0, 0x02}, // wrong magic last byte {0xAA, 0xAA}, // shorter than magic } for _, c := range cases { _, _, ok, err := DecodeControlEnvelope(c) if ok || err != nil { t.Fatalf("expected pass-through (ok=false, err=nil): got ok=%v err=%v on %x", ok, err, c) } } } func TestControlEnvelopeRejectsTruncatedPayload(t *testing.T) { env := EncodeControlEnvelope(ControlCrlPush, []byte("payload-bytes")) env = env[:len(env)-3] if _, _, _, err := DecodeControlEnvelope(env); err == nil { t.Fatal("want truncated payload error") } } func TestWriteAndReadFrameRoundtrip(t *testing.T) { var buf bytes.Buffer payload := []byte{1, 2, 3, 4, 5} if err := WriteFrame(&buf, MsgData, payload); err != nil { t.Fatal(err) } raw, err := ReadFrame(&buf) if err != nil { t.Fatal(err) } if raw.MsgType != MsgData || !bytes.Equal(raw.Payload, payload) { t.Fatalf("roundtrip mismatch: %+v", raw) } if got := raw.WireBytes(); len(got) != HeaderLen+len(payload) { t.Fatalf("wire bytes wrong length: %d", len(got)) } }