From 7b346d15c13adfa5820bdbf46481f1a811747464 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 21 Apr 2026 04:13:30 +0000 Subject: [PATCH] Add comprehensive tests for streaming_base64 decoder Agent-Logs-Url: https://github.com/kovidgoyal/kitty/sessions/b5dbc339-6907-46fb-a7e7-9ee08b7ed20d Co-authored-by: kovidgoyal <1308621+kovidgoyal@users.noreply.github.com> --- tools/utils/streaming_base64/api_test.go | 375 +++++++++++++++++++++++ 1 file changed, 375 insertions(+) create mode 100644 tools/utils/streaming_base64/api_test.go diff --git a/tools/utils/streaming_base64/api_test.go b/tools/utils/streaming_base64/api_test.go new file mode 100644 index 000000000..fdde19b27 --- /dev/null +++ b/tools/utils/streaming_base64/api_test.go @@ -0,0 +1,375 @@ +// License: GPLv3 Copyright: 2024, Kovid Goyal, + +package streaming_base64 + +import ( + "bytes" + "encoding/base64" + "fmt" + "testing" +) + +var _ = fmt.Print + +// collectDecode feeds input to the decoder in chunks of chunkSize bytes and +// accumulates all decoded output into a single slice. It uses a generously- +// sized output buffer so buffer-size issues never mask correctness bugs. +func collectDecode(t *testing.T, d *StreamingBase64Decoder, encoded []byte, chunkSize int) ([]byte, error) { + t.Helper() + var result []byte + for len(encoded) > 0 { + end := chunkSize + if end > len(encoded) { + end = len(encoded) + } + chunk := encoded[:end] + encoded = encoded[end:] + + // Output buffer: worst-case 3 bytes per 4 input bytes, plus the + // 3 bytes for a buffered leftover block. + outBuf := make([]byte, (len(chunk)+4)*3) + for decoded, err := range d.Decode(chunk, outBuf) { + if err != nil { + return nil, err + } + result = append(result, decoded...) + } + } + return result, nil +} + +// roundtrip encodes plaintext with standard base64, then decodes it through +// the streaming decoder in chunkSize pieces and verifies the result. +func roundtrip(t *testing.T, plaintext []byte, chunkSize int) { + t.Helper() + encoded := []byte(base64.StdEncoding.EncodeToString(plaintext)) + var d StreamingBase64Decoder + got, err := collectDecode(t, &d, encoded, chunkSize) + if err != nil { + t.Fatalf("chunkSize=%d: unexpected decode error: %v", chunkSize, err) + } + tail, err := d.Finish() + if err != nil { + t.Fatalf("chunkSize=%d: unexpected Finish error: %v", chunkSize, err) + } + got = append(got, tail...) + if !bytes.Equal(got, plaintext) { + t.Fatalf("chunkSize=%d: roundtrip mismatch:\n want %q\n got %q", chunkSize, plaintext, got) + } +} + +// roundtripNoPadding is like roundtrip but strips the trailing '=' padding +// characters from the encoded string. The Finish method must add them back. +func roundtripNoPadding(t *testing.T, plaintext []byte, chunkSize int) { + t.Helper() + encoded := []byte(base64.RawStdEncoding.EncodeToString(plaintext)) // no padding + var d StreamingBase64Decoder + got, err := collectDecode(t, &d, encoded, chunkSize) + if err != nil { + t.Fatalf("noPad chunkSize=%d: unexpected decode error: %v", chunkSize, err) + } + tail, err := d.Finish() + if err != nil { + t.Fatalf("noPad chunkSize=%d: unexpected Finish error: %v", chunkSize, err) + } + got = append(got, tail...) + if !bytes.Equal(got, plaintext) { + t.Fatalf("noPad chunkSize=%d: roundtrip mismatch:\n want %q\n got %q", chunkSize, plaintext, got) + } +} + +// TestRoundtripAllChunkSizes exercises every chunk size from 1 to 7 with +// plaintexts whose lengths produce all possible num_leftover values (0-3) +// after encoding: 0 mod 3 → 0 leftover, 1 mod 3 → 2 leftover, 2 mod 3 → 1 +// leftover after decoding (but 2 base64 chars left when unpadded). +func TestRoundtripAllChunkSizes(t *testing.T) { + plaintexts := [][]byte{ + {}, // 0 bytes → 0 encoded → num_leftover=0 + []byte("a"), // 1 byte → 4 encoded → no leftover (padded) + []byte("ab"), // 2 bytes → 4 encoded → no leftover (padded) + []byte("abc"), // 3 bytes → 4 encoded → no leftover + []byte("abcd"), // 4 bytes → 8 encoded → no leftover + []byte("abcde"), // 5 bytes → 8 encoded → no leftover (padded) + []byte("abcdef"), // 6 bytes → 8 encoded → no leftover (padded) + []byte("Hello, World!"), // 13 bytes → 20 encoded + []byte("The quick brown fox jumps over the"), // 34 bytes → 48 encoded + bytes.Repeat([]byte{0x00, 0xff, 0x80}, 17), // binary data + } + for _, plain := range plaintexts { + for chunkSize := 1; chunkSize <= 7; chunkSize++ { + if len(plain) == 0 && chunkSize > 1 { + continue // only test once for empty input + } + roundtrip(t, plain, chunkSize) + } + } +} + +// TestRoundtripNoPaddingAllChunkSizes tests decoding without trailing '=' +// padding bytes for all relevant chunk sizes. +func TestRoundtripNoPaddingAllChunkSizes(t *testing.T) { + plaintexts := [][]byte{ + []byte("a"), // 1 byte → "YQ" (2 base64 chars, no pad) + []byte("ab"), // 2 bytes → "YWI" (3 base64 chars, no pad) + []byte("abc"), // 3 bytes → "YWJj" (4 chars, no leftover) + []byte("abcd"), // 4 bytes → "YWJjZA" (6 chars) + []byte("Hello, World!"), // mixed + bytes.Repeat([]byte{0xde}, 10), // binary, 1 mod 3 remainder + bytes.Repeat([]byte{0xbe}, 11), // binary, 2 mod 3 remainder + bytes.Repeat([]byte{0xef}, 12), // binary, 0 mod 3 remainder + } + for _, plain := range plaintexts { + for chunkSize := 1; chunkSize <= 7; chunkSize++ { + roundtripNoPadding(t, plain, chunkSize) + } + } +} + +// TestNumLeftoverInDecode checks that the bridge-block path in Decode (the +// path that fires when num_leftover > 0 at the start of a new call) produces +// correct results. We achieve specific num_leftover values by feeding exactly +// that many bytes in the first call. +func TestNumLeftoverInDecode(t *testing.T) { + // plaintext whose encoding is long enough to exercise all leftover values + plain := []byte("abcdefghijklmnopqrstuvwxyz") // 26 bytes → 36 encoded chars + + for firstCallLen := 1; firstCallLen <= 3; firstCallLen++ { + encoded := []byte(base64.StdEncoding.EncodeToString(plain)) + var d StreamingBase64Decoder + outBuf := make([]byte, 64) + + // Feed exactly firstCallLen bytes → num_leftover = firstCallLen % 4 + var got []byte + for dec, err := range d.Decode(encoded[:firstCallLen], outBuf) { + if err != nil { + t.Fatalf("firstCallLen=%d first Decode error: %v", firstCallLen, err) + } + got = append(got, dec...) + } + // Feed the rest + rest, err := collectDecode(t, &d, encoded[firstCallLen:], 4) + if err != nil { + t.Fatalf("firstCallLen=%d rest Decode error: %v", firstCallLen, err) + } + got = append(got, rest...) + tail, err := d.Finish() + if err != nil { + t.Fatalf("firstCallLen=%d Finish error: %v", firstCallLen, err) + } + got = append(got, tail...) + if !bytes.Equal(got, plain) { + t.Fatalf("firstCallLen=%d mismatch:\n want %q\n got %q", firstCallLen, plain, got) + } + } +} + +// TestFinishNumLeftover directly exercises all four branches of Finish. +// +// - num_leftover=0 → (nil, nil) +// - num_leftover=1 → error pointing to correct byte offset +// - num_leftover=2 → decode 1 byte, pad "==" +// - num_leftover=3 → decode 2 bytes, pad "=" +func TestFinishNumLeftover(t *testing.T) { + // num_leftover=0: empty decoder, Finish must return (nil, nil). + t.Run("leftover=0", func(t *testing.T) { + var d StreamingBase64Decoder + out, err := d.Finish() + if err != nil || out != nil { + t.Fatalf("expected (nil,nil), got (%v,%v)", out, err) + } + }) + + // num_leftover=0 after a complete stream: also (nil, nil). + t.Run("leftover=0_after_stream", func(t *testing.T) { + // "abc" encodes to "YWJj" — exactly 4 chars, no leftover + var d StreamingBase64Decoder + outBuf := make([]byte, 16) + for _, err := range d.Decode([]byte("YWJj"), outBuf) { + if err != nil { + t.Fatal(err) + } + } + out, err := d.Finish() + if err != nil || out != nil { + t.Fatalf("expected (nil,nil), got (%v,%v)", out, err) + } + }) + + // num_leftover=1 → CorruptInputError pointing to total_read - 1. + t.Run("leftover=1", func(t *testing.T) { + // Feed 5 base64 chars: 4 will be consumed, 1 leftover. + encoded := []byte(base64.StdEncoding.EncodeToString([]byte("abc"))) // "YWJj" (4) + encoded = append(encoded, 'Y') // + 1 → total 5 + var d StreamingBase64Decoder + outBuf := make([]byte, 16) + for _, err := range d.Decode(encoded, outBuf) { + if err != nil { + t.Fatalf("unexpected Decode error: %v", err) + } + } + // num_leftover should now be 1; total_read = 5 + _, err := d.Finish() + if err == nil { + t.Fatal("expected error for leftover=1, got nil") + } + ce, ok := err.(base64.CorruptInputError) + if !ok { + t.Fatalf("expected base64.CorruptInputError, got %T: %v", err, err) + } + // total_read=5, num_leftover=1 → offset should be 4 + if int64(ce) != 4 { + t.Fatalf("wrong error offset: want 4, got %d", int64(ce)) + } + }) + + // num_leftover=2 → Finish pads "==" and decodes 1 byte. + t.Run("leftover=2", func(t *testing.T) { + // "a" encodes to "YQ==" with padding, or "YQ" without. + // Feed "YQ" (2 chars) so num_leftover=2. + var d StreamingBase64Decoder + outBuf := make([]byte, 16) + for _, err := range d.Decode([]byte("YQ"), outBuf) { + if err != nil { + t.Fatalf("unexpected Decode error: %v", err) + } + } + out, err := d.Finish() + if err != nil { + t.Fatalf("unexpected Finish error: %v", err) + } + if !bytes.Equal(out, []byte("a")) { + t.Fatalf("want %q, got %q", "a", out) + } + }) + + // num_leftover=3 → Finish pads "=" and decodes 2 bytes. + t.Run("leftover=3", func(t *testing.T) { + // "ab" encodes to "YWI=" with padding, or "YWI" without. + // Feed "YWI" (3 chars) so num_leftover=3. + var d StreamingBase64Decoder + outBuf := make([]byte, 16) + for _, err := range d.Decode([]byte("YWI"), outBuf) { + if err != nil { + t.Fatalf("unexpected Decode error: %v", err) + } + } + out, err := d.Finish() + if err != nil { + t.Fatalf("unexpected Finish error: %v", err) + } + if !bytes.Equal(out, []byte("ab")) { + t.Fatalf("want %q, got %q", "ab", out) + } + }) +} + +// TestErrorOffsetInDecode checks that when Decode encounters an invalid byte +// the returned error is a base64.CorruptInputError with the correct absolute +// byte offset within the full stream. +func TestErrorOffsetInDecode(t *testing.T) { + tests := []struct { + name string + chunks []string // successive calls to Decode + wantOffset int64 + }{ + { + // Error in the very first block (no leftovers involved). + name: "error_in_first_block", + chunks: []string{"YWJj", "YW!j"}, + wantOffset: 4 + 2, // 4 bytes of good data, then offset 2 within chunk + }, + { + // Error in the bridge block built from leftover + new bytes. + // Feed 3 chars (leftover=3), then 1 invalid char. + // The bridge block is "YWJ!" (positions 0–3 in the stream). + // base64 reports the '!' as relative byte 3 within "YWJ!"; + // wrap_error adds chunk_offset = total_read - num_leftover = 3 - 3 = 0, + // so the absolute error offset is 3. + name: "error_in_bridge_block", + chunks: []string{"YWJ", "!"}, + wantOffset: 3, + }, + { + // Error well into the stream after multiple complete blocks. + // "YWJj" × 3 = 12 valid chars, then invalid at position 0 of next chunk. + name: "error_after_multiple_blocks", + chunks: []string{"YWJjYWJjYWJj", "!WJj"}, + wantOffset: 12, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var d StreamingBase64Decoder + var gotErr error + outBuf := make([]byte, 64) + outer: + for _, chunk := range tc.chunks { + for _, err := range d.Decode([]byte(chunk), outBuf) { + if err != nil { + gotErr = err + break outer + } + } + } + if gotErr == nil { + t.Fatal("expected a decode error, got nil") + } + ce, ok := gotErr.(base64.CorruptInputError) + if !ok { + t.Fatalf("expected base64.CorruptInputError, got %T: %v", gotErr, gotErr) + } + if int64(ce) != tc.wantOffset { + t.Fatalf("wrong error offset: want %d, got %d", tc.wantOffset, int64(ce)) + } + }) + } +} + +// TestOutputBufferTooSmall verifies that Decode returns an error (not a panic) +// when the supplied output buffer is too small. +func TestOutputBufferTooSmall(t *testing.T) { + var d StreamingBase64Decoder + tinyBuf := make([]byte, 0) // definitely too small for any decoded output + encoded := []byte("YWJj") // decodes to 3 bytes + var gotErr error + for _, err := range d.Decode(encoded, tinyBuf) { + if err != nil { + gotErr = err + break + } + } + if gotErr == nil { + t.Fatal("expected an error for too-small output buffer, got nil") + } +} + +// TestEmptyInput verifies that decoding an empty input followed by Finish +// produces no output and no error. +func TestEmptyInput(t *testing.T) { + var d StreamingBase64Decoder + outBuf := make([]byte, 16) + for _, err := range d.Decode([]byte{}, outBuf) { + if err != nil { + t.Fatalf("unexpected error on empty input: %v", err) + } + } + out, err := d.Finish() + if err != nil || out != nil { + t.Fatalf("expected (nil,nil) for empty input, got (%v,%v)", out, err) + } +} + +// TestLargeInput stress-tests the decoder with a large binary payload across +// every chunk size from 1 to 13 to catch off-by-one errors in the leftover +// accumulation logic. +func TestLargeInput(t *testing.T) { + plain := make([]byte, 1000) + for i := range plain { + plain[i] = byte(i * 7) + } + for chunkSize := 1; chunkSize <= 13; chunkSize++ { + roundtrip(t, plain, chunkSize) + roundtripNoPadding(t, plain, chunkSize) + } +}