Add comprehensive tests for streaming_base64 decoder

Agent-Logs-Url: https://github.com/kovidgoyal/kitty/sessions/b5dbc339-6907-46fb-a7e7-9ee08b7ed20d

Co-authored-by: kovidgoyal <1308621+kovidgoyal@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-04-21 04:13:30 +00:00
committed by GitHub
parent 94651eaff2
commit 7b346d15c1

View File

@@ -0,0 +1,375 @@
// License: GPLv3 Copyright: 2024, Kovid Goyal, <kovid at kovidgoyal.net>
package streaming_base64
import (
"bytes"
"encoding/base64"
"fmt"
"testing"
)
var _ = fmt.Print
// collectDecode feeds input to the decoder in chunks of chunkSize bytes and
// accumulates all decoded output into a single slice. It uses a generously-
// sized output buffer so buffer-size issues never mask correctness bugs.
func collectDecode(t *testing.T, d *StreamingBase64Decoder, encoded []byte, chunkSize int) ([]byte, error) {
t.Helper()
var result []byte
for len(encoded) > 0 {
end := chunkSize
if end > len(encoded) {
end = len(encoded)
}
chunk := encoded[:end]
encoded = encoded[end:]
// Output buffer: worst-case 3 bytes per 4 input bytes, plus the
// 3 bytes for a buffered leftover block.
outBuf := make([]byte, (len(chunk)+4)*3)
for decoded, err := range d.Decode(chunk, outBuf) {
if err != nil {
return nil, err
}
result = append(result, decoded...)
}
}
return result, nil
}
// roundtrip encodes plaintext with standard base64, then decodes it through
// the streaming decoder in chunkSize pieces and verifies the result.
func roundtrip(t *testing.T, plaintext []byte, chunkSize int) {
t.Helper()
encoded := []byte(base64.StdEncoding.EncodeToString(plaintext))
var d StreamingBase64Decoder
got, err := collectDecode(t, &d, encoded, chunkSize)
if err != nil {
t.Fatalf("chunkSize=%d: unexpected decode error: %v", chunkSize, err)
}
tail, err := d.Finish()
if err != nil {
t.Fatalf("chunkSize=%d: unexpected Finish error: %v", chunkSize, err)
}
got = append(got, tail...)
if !bytes.Equal(got, plaintext) {
t.Fatalf("chunkSize=%d: roundtrip mismatch:\n want %q\n got %q", chunkSize, plaintext, got)
}
}
// roundtripNoPadding is like roundtrip but strips the trailing '=' padding
// characters from the encoded string. The Finish method must add them back.
func roundtripNoPadding(t *testing.T, plaintext []byte, chunkSize int) {
t.Helper()
encoded := []byte(base64.RawStdEncoding.EncodeToString(plaintext)) // no padding
var d StreamingBase64Decoder
got, err := collectDecode(t, &d, encoded, chunkSize)
if err != nil {
t.Fatalf("noPad chunkSize=%d: unexpected decode error: %v", chunkSize, err)
}
tail, err := d.Finish()
if err != nil {
t.Fatalf("noPad chunkSize=%d: unexpected Finish error: %v", chunkSize, err)
}
got = append(got, tail...)
if !bytes.Equal(got, plaintext) {
t.Fatalf("noPad chunkSize=%d: roundtrip mismatch:\n want %q\n got %q", chunkSize, plaintext, got)
}
}
// TestRoundtripAllChunkSizes exercises every chunk size from 1 to 7 with
// plaintexts whose lengths produce all possible num_leftover values (0-3)
// after encoding: 0 mod 3 → 0 leftover, 1 mod 3 → 2 leftover, 2 mod 3 → 1
// leftover after decoding (but 2 base64 chars left when unpadded).
func TestRoundtripAllChunkSizes(t *testing.T) {
plaintexts := [][]byte{
{}, // 0 bytes → 0 encoded → num_leftover=0
[]byte("a"), // 1 byte → 4 encoded → no leftover (padded)
[]byte("ab"), // 2 bytes → 4 encoded → no leftover (padded)
[]byte("abc"), // 3 bytes → 4 encoded → no leftover
[]byte("abcd"), // 4 bytes → 8 encoded → no leftover
[]byte("abcde"), // 5 bytes → 8 encoded → no leftover (padded)
[]byte("abcdef"), // 6 bytes → 8 encoded → no leftover (padded)
[]byte("Hello, World!"), // 13 bytes → 20 encoded
[]byte("The quick brown fox jumps over the"), // 34 bytes → 48 encoded
bytes.Repeat([]byte{0x00, 0xff, 0x80}, 17), // binary data
}
for _, plain := range plaintexts {
for chunkSize := 1; chunkSize <= 7; chunkSize++ {
if len(plain) == 0 && chunkSize > 1 {
continue // only test once for empty input
}
roundtrip(t, plain, chunkSize)
}
}
}
// TestRoundtripNoPaddingAllChunkSizes tests decoding without trailing '='
// padding bytes for all relevant chunk sizes.
func TestRoundtripNoPaddingAllChunkSizes(t *testing.T) {
plaintexts := [][]byte{
[]byte("a"), // 1 byte → "YQ" (2 base64 chars, no pad)
[]byte("ab"), // 2 bytes → "YWI" (3 base64 chars, no pad)
[]byte("abc"), // 3 bytes → "YWJj" (4 chars, no leftover)
[]byte("abcd"), // 4 bytes → "YWJjZA" (6 chars)
[]byte("Hello, World!"), // mixed
bytes.Repeat([]byte{0xde}, 10), // binary, 1 mod 3 remainder
bytes.Repeat([]byte{0xbe}, 11), // binary, 2 mod 3 remainder
bytes.Repeat([]byte{0xef}, 12), // binary, 0 mod 3 remainder
}
for _, plain := range plaintexts {
for chunkSize := 1; chunkSize <= 7; chunkSize++ {
roundtripNoPadding(t, plain, chunkSize)
}
}
}
// TestNumLeftoverInDecode checks that the bridge-block path in Decode (the
// path that fires when num_leftover > 0 at the start of a new call) produces
// correct results. We achieve specific num_leftover values by feeding exactly
// that many bytes in the first call.
func TestNumLeftoverInDecode(t *testing.T) {
// plaintext whose encoding is long enough to exercise all leftover values
plain := []byte("abcdefghijklmnopqrstuvwxyz") // 26 bytes → 36 encoded chars
for firstCallLen := 1; firstCallLen <= 3; firstCallLen++ {
encoded := []byte(base64.StdEncoding.EncodeToString(plain))
var d StreamingBase64Decoder
outBuf := make([]byte, 64)
// Feed exactly firstCallLen bytes → num_leftover = firstCallLen % 4
var got []byte
for dec, err := range d.Decode(encoded[:firstCallLen], outBuf) {
if err != nil {
t.Fatalf("firstCallLen=%d first Decode error: %v", firstCallLen, err)
}
got = append(got, dec...)
}
// Feed the rest
rest, err := collectDecode(t, &d, encoded[firstCallLen:], 4)
if err != nil {
t.Fatalf("firstCallLen=%d rest Decode error: %v", firstCallLen, err)
}
got = append(got, rest...)
tail, err := d.Finish()
if err != nil {
t.Fatalf("firstCallLen=%d Finish error: %v", firstCallLen, err)
}
got = append(got, tail...)
if !bytes.Equal(got, plain) {
t.Fatalf("firstCallLen=%d mismatch:\n want %q\n got %q", firstCallLen, plain, got)
}
}
}
// TestFinishNumLeftover directly exercises all four branches of Finish.
//
// - num_leftover=0 → (nil, nil)
// - num_leftover=1 → error pointing to correct byte offset
// - num_leftover=2 → decode 1 byte, pad "=="
// - num_leftover=3 → decode 2 bytes, pad "="
func TestFinishNumLeftover(t *testing.T) {
// num_leftover=0: empty decoder, Finish must return (nil, nil).
t.Run("leftover=0", func(t *testing.T) {
var d StreamingBase64Decoder
out, err := d.Finish()
if err != nil || out != nil {
t.Fatalf("expected (nil,nil), got (%v,%v)", out, err)
}
})
// num_leftover=0 after a complete stream: also (nil, nil).
t.Run("leftover=0_after_stream", func(t *testing.T) {
// "abc" encodes to "YWJj" — exactly 4 chars, no leftover
var d StreamingBase64Decoder
outBuf := make([]byte, 16)
for _, err := range d.Decode([]byte("YWJj"), outBuf) {
if err != nil {
t.Fatal(err)
}
}
out, err := d.Finish()
if err != nil || out != nil {
t.Fatalf("expected (nil,nil), got (%v,%v)", out, err)
}
})
// num_leftover=1 → CorruptInputError pointing to total_read - 1.
t.Run("leftover=1", func(t *testing.T) {
// Feed 5 base64 chars: 4 will be consumed, 1 leftover.
encoded := []byte(base64.StdEncoding.EncodeToString([]byte("abc"))) // "YWJj" (4)
encoded = append(encoded, 'Y') // + 1 → total 5
var d StreamingBase64Decoder
outBuf := make([]byte, 16)
for _, err := range d.Decode(encoded, outBuf) {
if err != nil {
t.Fatalf("unexpected Decode error: %v", err)
}
}
// num_leftover should now be 1; total_read = 5
_, err := d.Finish()
if err == nil {
t.Fatal("expected error for leftover=1, got nil")
}
ce, ok := err.(base64.CorruptInputError)
if !ok {
t.Fatalf("expected base64.CorruptInputError, got %T: %v", err, err)
}
// total_read=5, num_leftover=1 → offset should be 4
if int64(ce) != 4 {
t.Fatalf("wrong error offset: want 4, got %d", int64(ce))
}
})
// num_leftover=2 → Finish pads "==" and decodes 1 byte.
t.Run("leftover=2", func(t *testing.T) {
// "a" encodes to "YQ==" with padding, or "YQ" without.
// Feed "YQ" (2 chars) so num_leftover=2.
var d StreamingBase64Decoder
outBuf := make([]byte, 16)
for _, err := range d.Decode([]byte("YQ"), outBuf) {
if err != nil {
t.Fatalf("unexpected Decode error: %v", err)
}
}
out, err := d.Finish()
if err != nil {
t.Fatalf("unexpected Finish error: %v", err)
}
if !bytes.Equal(out, []byte("a")) {
t.Fatalf("want %q, got %q", "a", out)
}
})
// num_leftover=3 → Finish pads "=" and decodes 2 bytes.
t.Run("leftover=3", func(t *testing.T) {
// "ab" encodes to "YWI=" with padding, or "YWI" without.
// Feed "YWI" (3 chars) so num_leftover=3.
var d StreamingBase64Decoder
outBuf := make([]byte, 16)
for _, err := range d.Decode([]byte("YWI"), outBuf) {
if err != nil {
t.Fatalf("unexpected Decode error: %v", err)
}
}
out, err := d.Finish()
if err != nil {
t.Fatalf("unexpected Finish error: %v", err)
}
if !bytes.Equal(out, []byte("ab")) {
t.Fatalf("want %q, got %q", "ab", out)
}
})
}
// TestErrorOffsetInDecode checks that when Decode encounters an invalid byte
// the returned error is a base64.CorruptInputError with the correct absolute
// byte offset within the full stream.
func TestErrorOffsetInDecode(t *testing.T) {
tests := []struct {
name string
chunks []string // successive calls to Decode
wantOffset int64
}{
{
// Error in the very first block (no leftovers involved).
name: "error_in_first_block",
chunks: []string{"YWJj", "YW!j"},
wantOffset: 4 + 2, // 4 bytes of good data, then offset 2 within chunk
},
{
// Error in the bridge block built from leftover + new bytes.
// Feed 3 chars (leftover=3), then 1 invalid char.
// The bridge block is "YWJ!" (positions 03 in the stream).
// base64 reports the '!' as relative byte 3 within "YWJ!";
// wrap_error adds chunk_offset = total_read - num_leftover = 3 - 3 = 0,
// so the absolute error offset is 3.
name: "error_in_bridge_block",
chunks: []string{"YWJ", "!"},
wantOffset: 3,
},
{
// Error well into the stream after multiple complete blocks.
// "YWJj" × 3 = 12 valid chars, then invalid at position 0 of next chunk.
name: "error_after_multiple_blocks",
chunks: []string{"YWJjYWJjYWJj", "!WJj"},
wantOffset: 12,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var d StreamingBase64Decoder
var gotErr error
outBuf := make([]byte, 64)
outer:
for _, chunk := range tc.chunks {
for _, err := range d.Decode([]byte(chunk), outBuf) {
if err != nil {
gotErr = err
break outer
}
}
}
if gotErr == nil {
t.Fatal("expected a decode error, got nil")
}
ce, ok := gotErr.(base64.CorruptInputError)
if !ok {
t.Fatalf("expected base64.CorruptInputError, got %T: %v", gotErr, gotErr)
}
if int64(ce) != tc.wantOffset {
t.Fatalf("wrong error offset: want %d, got %d", tc.wantOffset, int64(ce))
}
})
}
}
// TestOutputBufferTooSmall verifies that Decode returns an error (not a panic)
// when the supplied output buffer is too small.
func TestOutputBufferTooSmall(t *testing.T) {
var d StreamingBase64Decoder
tinyBuf := make([]byte, 0) // definitely too small for any decoded output
encoded := []byte("YWJj") // decodes to 3 bytes
var gotErr error
for _, err := range d.Decode(encoded, tinyBuf) {
if err != nil {
gotErr = err
break
}
}
if gotErr == nil {
t.Fatal("expected an error for too-small output buffer, got nil")
}
}
// TestEmptyInput verifies that decoding an empty input followed by Finish
// produces no output and no error.
func TestEmptyInput(t *testing.T) {
var d StreamingBase64Decoder
outBuf := make([]byte, 16)
for _, err := range d.Decode([]byte{}, outBuf) {
if err != nil {
t.Fatalf("unexpected error on empty input: %v", err)
}
}
out, err := d.Finish()
if err != nil || out != nil {
t.Fatalf("expected (nil,nil) for empty input, got (%v,%v)", out, err)
}
}
// TestLargeInput stress-tests the decoder with a large binary payload across
// every chunk size from 1 to 13 to catch off-by-one errors in the leftover
// accumulation logic.
func TestLargeInput(t *testing.T) {
plain := make([]byte, 1000)
for i := range plain {
plain[i] = byte(i * 7)
}
for chunkSize := 1; chunkSize <= 13; chunkSize++ {
roundtrip(t, plain, chunkSize)
roundtripNoPadding(t, plain, chunkSize)
}
}