// License: GPLv3 Copyright: 2024, Kovid Goyal, package streaming_base64 import ( "bytes" "fmt" "testing" "github.com/emmansun/base64" ) var _ = fmt.Print // collectDecode feeds input to the decoder in chunks of chunkSize bytes and // accumulates all decoded output into a single slice. It uses a generously- // sized output buffer so buffer-size issues never mask correctness bugs. func collectDecode(t *testing.T, d *StreamingBase64Decoder, encoded []byte, chunkSize int) ([]byte, error) { t.Helper() var result []byte for len(encoded) > 0 { end := min(chunkSize, len(encoded)) chunk := encoded[:end] encoded = encoded[end:] // Output buffer: worst-case 3 bytes per 4 input bytes, plus the // 3 bytes for a buffered leftover block. outBuf := make([]byte, (len(chunk)+4)*3) for decoded, err := range d.Decode(chunk, outBuf) { if err != nil { return nil, err } result = append(result, decoded...) } } return result, nil } // roundtrip encodes plaintext with standard base64, then decodes it through // the streaming decoder in chunkSize pieces and verifies the result. func roundtrip(t *testing.T, plaintext []byte, chunkSize int) { t.Helper() encoded := []byte(base64.StdEncoding.EncodeToString(plaintext)) var d StreamingBase64Decoder got, err := collectDecode(t, &d, encoded, chunkSize) if err != nil { t.Fatalf("chunkSize=%d: unexpected decode error: %v", chunkSize, err) } tail, err := d.Finish() if err != nil { t.Fatalf("chunkSize=%d: unexpected Finish error: %v", chunkSize, err) } got = append(got, tail...) if !bytes.Equal(got, plaintext) { t.Fatalf("chunkSize=%d: roundtrip mismatch:\n want %q\n got %q", chunkSize, plaintext, got) } } // roundtripNoPadding is like roundtrip but strips the trailing '=' padding // characters from the encoded string. The Finish method must add them back. func roundtripNoPadding(t *testing.T, plaintext []byte, chunkSize int) { t.Helper() encoded := []byte(base64.RawStdEncoding.EncodeToString(plaintext)) // no padding var d StreamingBase64Decoder got, err := collectDecode(t, &d, encoded, chunkSize) if err != nil { t.Fatalf("noPad chunkSize=%d: unexpected decode error: %v", chunkSize, err) } tail, err := d.Finish() if err != nil { t.Fatalf("noPad chunkSize=%d: unexpected Finish error: %v", chunkSize, err) } got = append(got, tail...) if !bytes.Equal(got, plaintext) { t.Fatalf("noPad chunkSize=%d: roundtrip mismatch:\n want %q\n got %q", chunkSize, plaintext, got) } } // TestRoundtripAllChunkSizes exercises every chunk size from 1 to 7 with // plaintexts whose lengths produce all possible num_leftover values (0-3) // after encoding: 0 mod 3 → 0 leftover, 1 mod 3 → 2 leftover, 2 mod 3 → 1 // leftover after decoding (but 2 base64 chars left when unpadded). func TestRoundtripAllChunkSizes(t *testing.T) { plaintexts := [][]byte{ {}, // 0 bytes → 0 encoded → num_leftover=0 []byte("a"), // 1 byte → 4 encoded → no leftover (padded) []byte("ab"), // 2 bytes → 4 encoded → no leftover (padded) []byte("abc"), // 3 bytes → 4 encoded → no leftover []byte("abcd"), // 4 bytes → 8 encoded → no leftover []byte("abcde"), // 5 bytes → 8 encoded → no leftover (padded) []byte("abcdef"), // 6 bytes → 8 encoded → no leftover (padded) []byte("Hello, World!"), // 13 bytes → 20 encoded []byte("The quick brown fox jumps over the"), // 34 bytes → 48 encoded bytes.Repeat([]byte{0x00, 0xff, 0x80}, 17), // binary data } for _, plain := range plaintexts { for chunkSize := 1; chunkSize <= 7; chunkSize++ { if len(plain) == 0 && chunkSize > 1 { continue // only test once for empty input } roundtrip(t, plain, chunkSize) } } } // TestRoundtripNoPaddingAllChunkSizes tests decoding without trailing '=' // padding bytes for all relevant chunk sizes. func TestRoundtripNoPaddingAllChunkSizes(t *testing.T) { plaintexts := [][]byte{ []byte("a"), // 1 byte → "YQ" (2 base64 chars, no pad) []byte("ab"), // 2 bytes → "YWI" (3 base64 chars, no pad) []byte("abc"), // 3 bytes → "YWJj" (4 chars, no leftover) []byte("abcd"), // 4 bytes → "YWJjZA" (6 chars) []byte("Hello, World!"), // mixed bytes.Repeat([]byte{0xde}, 10), // binary, 1 mod 3 remainder bytes.Repeat([]byte{0xbe}, 11), // binary, 2 mod 3 remainder bytes.Repeat([]byte{0xef}, 12), // binary, 0 mod 3 remainder } for _, plain := range plaintexts { for chunkSize := 1; chunkSize <= 7; chunkSize++ { roundtripNoPadding(t, plain, chunkSize) } } } // TestNumLeftoverInDecode checks that the bridge-block path in Decode (the // path that fires when num_leftover > 0 at the start of a new call) produces // correct results. We achieve specific num_leftover values by feeding exactly // that many bytes in the first call. func TestNumLeftoverInDecode(t *testing.T) { // plaintext whose encoding is long enough to exercise all leftover values plain := []byte("abcdefghijklmnopqrstuvwxyz") // 26 bytes → 36 encoded chars for firstCallLen := 1; firstCallLen <= 3; firstCallLen++ { encoded := []byte(base64.StdEncoding.EncodeToString(plain)) var d StreamingBase64Decoder outBuf := make([]byte, 64) // Feed exactly firstCallLen bytes → num_leftover = firstCallLen % 4 var got []byte for dec, err := range d.Decode(encoded[:firstCallLen], outBuf) { if err != nil { t.Fatalf("firstCallLen=%d first Decode error: %v", firstCallLen, err) } got = append(got, dec...) } // Feed the rest rest, err := collectDecode(t, &d, encoded[firstCallLen:], 4) if err != nil { t.Fatalf("firstCallLen=%d rest Decode error: %v", firstCallLen, err) } got = append(got, rest...) tail, err := d.Finish() if err != nil { t.Fatalf("firstCallLen=%d Finish error: %v", firstCallLen, err) } got = append(got, tail...) if !bytes.Equal(got, plain) { t.Fatalf("firstCallLen=%d mismatch:\n want %q\n got %q", firstCallLen, plain, got) } } } // TestFinishNumLeftover directly exercises all four branches of Finish. // // - num_leftover=0 → (nil, nil) // - num_leftover=1 → error pointing to correct byte offset // - num_leftover=2 → decode 1 byte, pad "==" // - num_leftover=3 → decode 2 bytes, pad "=" func TestFinishNumLeftover(t *testing.T) { // num_leftover=0: empty decoder, Finish must return (nil, nil). t.Run("leftover=0", func(t *testing.T) { var d StreamingBase64Decoder out, err := d.Finish() if err != nil || out != nil { t.Fatalf("expected (nil,nil), got (%v,%v)", out, err) } }) // num_leftover=0 after a complete stream: also (nil, nil). t.Run("leftover=0_after_stream", func(t *testing.T) { // "abc" encodes to "YWJj" — exactly 4 chars, no leftover var d StreamingBase64Decoder outBuf := make([]byte, 16) for _, err := range d.Decode([]byte("YWJj"), outBuf) { if err != nil { t.Fatal(err) } } out, err := d.Finish() if err != nil || out != nil { t.Fatalf("expected (nil,nil), got (%v,%v)", out, err) } }) // num_leftover=1 → CorruptInputError pointing to total_read - 1. t.Run("leftover=1", func(t *testing.T) { // Feed 5 base64 chars: 4 will be consumed, 1 leftover. encoded := []byte(base64.StdEncoding.EncodeToString([]byte("abc"))) // "YWJj" (4) encoded = append(encoded, 'Y') // + 1 → total 5 var d StreamingBase64Decoder outBuf := make([]byte, 16) for _, err := range d.Decode(encoded, outBuf) { if err != nil { t.Fatalf("unexpected Decode error: %v", err) } } // num_leftover should now be 1; total_read = 5 _, err := d.Finish() if err == nil { t.Fatal("expected error for leftover=1, got nil") } ce, ok := err.(base64.CorruptInputError) if !ok { t.Fatalf("expected base64.CorruptInputError, got %T: %v", err, err) } // total_read=5, num_leftover=1 → offset should be 4 if int64(ce) != 4 { t.Fatalf("wrong error offset: want 4, got %d", int64(ce)) } }) // num_leftover=2 → Finish pads "==" and decodes 1 byte. t.Run("leftover=2", func(t *testing.T) { // "a" encodes to "YQ==" with padding, or "YQ" without. // Feed "YQ" (2 chars) so num_leftover=2. var d StreamingBase64Decoder outBuf := make([]byte, 16) for _, err := range d.Decode([]byte("YQ"), outBuf) { if err != nil { t.Fatalf("unexpected Decode error: %v", err) } } out, err := d.Finish() if err != nil { t.Fatalf("unexpected Finish error: %v", err) } if !bytes.Equal(out, []byte("a")) { t.Fatalf("want %q, got %q", "a", out) } }) // num_leftover=3 → Finish pads "=" and decodes 2 bytes. t.Run("leftover=3", func(t *testing.T) { // "ab" encodes to "YWI=" with padding, or "YWI" without. // Feed "YWI" (3 chars) so num_leftover=3. var d StreamingBase64Decoder outBuf := make([]byte, 16) for _, err := range d.Decode([]byte("YWI"), outBuf) { if err != nil { t.Fatalf("unexpected Decode error: %v", err) } } out, err := d.Finish() if err != nil { t.Fatalf("unexpected Finish error: %v", err) } if !bytes.Equal(out, []byte("ab")) { t.Fatalf("want %q, got %q", "ab", out) } }) } // TestErrorOffsetInDecode checks that when Decode encounters an invalid byte // the returned error is a base64.CorruptInputError with the correct absolute // byte offset within the full stream. func TestErrorOffsetInDecode(t *testing.T) { tests := []struct { name string chunks []string // successive calls to Decode wantOffset int64 }{ { // Error in the very first block (no leftovers involved). name: "error_in_first_block", chunks: []string{"YWJj", "YW!j"}, wantOffset: 4 + 2, // 4 bytes of good data, then offset 2 within chunk }, { // Error in the bridge block built from leftover + new bytes. // Feed 3 chars (leftover=3), then 1 invalid char. // The bridge block is "YWJ!" (positions 0–3 in the stream). // base64 reports the '!' as relative byte 3 within "YWJ!"; // wrap_error adds chunk_offset = total_read - num_leftover = 3 - 3 = 0, // so the absolute error offset is 3. name: "error_in_bridge_block", chunks: []string{"YWJ", "!"}, wantOffset: 3, }, { // Error well into the stream after multiple complete blocks. // "YWJj" × 3 = 12 valid chars, then invalid at position 0 of next chunk. name: "error_after_multiple_blocks", chunks: []string{"YWJjYWJjYWJj", "!WJj"}, wantOffset: 12, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var d StreamingBase64Decoder var gotErr error outBuf := make([]byte, 64) outer: for _, chunk := range tc.chunks { for _, err := range d.Decode([]byte(chunk), outBuf) { if err != nil { gotErr = err break outer } } } if gotErr == nil { t.Fatal("expected a decode error, got nil") } ce, ok := gotErr.(base64.CorruptInputError) if !ok { t.Fatalf("expected base64.CorruptInputError, got %T: %v", gotErr, gotErr) } if int64(ce) != tc.wantOffset { t.Fatalf("wrong error offset: want %d, got %d", tc.wantOffset, int64(ce)) } }) } } // TestOutputBufferTooSmall verifies that Decode returns an error (not a panic) // when the supplied output buffer is too small. func TestOutputBufferTooSmall(t *testing.T) { var d StreamingBase64Decoder tinyBuf := make([]byte, 0) // definitely too small for any decoded output encoded := []byte("YWJj") // decodes to 3 bytes var gotErr error for _, err := range d.Decode(encoded, tinyBuf) { if err != nil { gotErr = err break } } if gotErr == nil { t.Fatal("expected an error for too-small output buffer, got nil") } } // TestEmptyInput verifies that decoding an empty input followed by Finish // produces no output and no error. func TestEmptyInput(t *testing.T) { var d StreamingBase64Decoder outBuf := make([]byte, 16) for _, err := range d.Decode([]byte{}, outBuf) { if err != nil { t.Fatalf("unexpected error on empty input: %v", err) } } out, err := d.Finish() if err != nil || out != nil { t.Fatalf("expected (nil,nil) for empty input, got (%v,%v)", out, err) } } // TestLargeInput stress-tests the decoder with a large binary payload across // every chunk size from 1 to 13 to catch off-by-one errors in the leftover // accumulation logic. func TestLargeInput(t *testing.T) { plain := make([]byte, 1000) for i := range plain { plain[i] = byte(i * 7) } for chunkSize := 1; chunkSize <= 13; chunkSize++ { roundtrip(t, plain, chunkSize) roundtripNoPadding(t, plain, chunkSize) } } // collectEncode feeds input to the encoder in chunks of chunkSize bytes and // accumulates all encoded output into a single slice. func collectEncode(t *testing.T, e *StreamingBase64Encoder, plain []byte, chunkSize int) ([]byte, error) { t.Helper() var result []byte for len(plain) > 0 { end := min(chunkSize, len(plain)) chunk := plain[:end] plain = plain[end:] outBuf := make([]byte, e.NeededOutputLen(len(chunk))+8) for encoded, err := range e.Encode(chunk, outBuf) { if err != nil { return nil, err } result = append(result, encoded...) } } return result, nil } // encodeRoundtrip encodes plaintext through the streaming encoder in chunkSize // pieces, appends the Finish output, and verifies the result matches // base64.RawStdEncoding. func encodeRoundtrip(t *testing.T, plaintext []byte, chunkSize int) { t.Helper() var e StreamingBase64Encoder got, err := collectEncode(t, &e, plaintext, chunkSize) if err != nil { t.Fatalf("chunkSize=%d: unexpected Encode error: %v", chunkSize, err) } tail := e.Finish() got = append(got, tail...) want := []byte(base64.RawStdEncoding.EncodeToString(plaintext)) if !bytes.Equal(got, want) { t.Fatalf("chunkSize=%d: encode mismatch:\n want %q\n got %q", chunkSize, want, got) } } // TestEncoderRoundtripAllChunkSizes exercises every chunk size from 1 to 7 with // plaintexts whose lengths cover all possible num_leftover values (0, 1, 2). func TestEncoderRoundtripAllChunkSizes(t *testing.T) { plaintexts := [][]byte{ {}, // 0 bytes → 0 leftover []byte("a"), // 1 byte → 1 leftover []byte("ab"), // 2 bytes → 2 leftover []byte("abc"), // 3 bytes → 0 leftover []byte("abcd"), // 4 bytes → 1 leftover []byte("abcde"), // 5 bytes → 2 leftover []byte("abcdef"), // 6 bytes → 0 leftover []byte("Hello, World!"), // 13 bytes []byte("The quick brown fox jumps over the"), // 34 bytes bytes.Repeat([]byte{0x00, 0xff, 0x80}, 17), // binary data } for _, plain := range plaintexts { for chunkSize := 1; chunkSize <= 7; chunkSize++ { if len(plain) == 0 && chunkSize > 1 { continue } encodeRoundtrip(t, plain, chunkSize) } } } // TestEncoderFinishLeftover directly exercises all three branches of Finish. func TestEncoderFinishLeftover(t *testing.T) { // num_leftover=0: empty encoder, Finish must return (nil, nil). t.Run("leftover=0", func(t *testing.T) { var e StreamingBase64Encoder out := e.Finish() if out != nil { t.Fatalf("expected (nil,nil), got (%v)", out) } }) // num_leftover=0 after a complete stream: also (nil, nil). t.Run("leftover=0_after_stream", func(t *testing.T) { // "abc" is 3 bytes, encodes with no leftover var e StreamingBase64Encoder outBuf := make([]byte, 16) for _, err := range e.Encode([]byte("abc"), outBuf) { if err != nil { t.Fatal(err) } } out := e.Finish() if out != nil { t.Fatalf("expected (nil,nil), got (%v)", out) } }) // num_leftover=1 → Finish encodes 1 byte → 2 chars. t.Run("leftover=1", func(t *testing.T) { // Feed 4 bytes: 3 processed, 1 leftover var e StreamingBase64Encoder outBuf := make([]byte, 16) for _, err := range e.Encode([]byte("abcd"), outBuf) { if err != nil { t.Fatalf("unexpected Encode error: %v", err) } } tail := e.Finish() want := []byte(base64.RawStdEncoding.EncodeToString([]byte("d"))) if !bytes.Equal(tail, want) { t.Fatalf("want %q, got %q", want, tail) } }) // num_leftover=2 → Finish encodes 2 bytes → 3 chars. t.Run("leftover=2", func(t *testing.T) { // Feed 5 bytes: 3 processed, 2 leftover var e StreamingBase64Encoder outBuf := make([]byte, 16) for _, err := range e.Encode([]byte("abcde"), outBuf) { if err != nil { t.Fatalf("unexpected Encode error: %v", err) } } tail := e.Finish() want := []byte(base64.RawStdEncoding.EncodeToString([]byte("de"))) if !bytes.Equal(tail, want) { t.Fatalf("want %q, got %q", want, tail) } }) } // TestEncoderOutputBufferTooSmall verifies that Encode returns an error (not a // panic) when the supplied output buffer is too small. func TestEncoderOutputBufferTooSmall(t *testing.T) { var e StreamingBase64Encoder tinyBuf := make([]byte, 0) // too small for any output input := []byte("abc") // 3 bytes → 4 chars needed var gotErr error for _, err := range e.Encode(input, tinyBuf) { if err != nil { gotErr = err break } } if gotErr == nil { t.Fatal("expected an error for too-small output buffer, got nil") } } // TestEncoderEmptyInput verifies that encoding an empty input followed by // Finish produces no output and no error. func TestEncoderEmptyInput(t *testing.T) { var e StreamingBase64Encoder outBuf := make([]byte, 16) for _, err := range e.Encode([]byte{}, outBuf) { if err != nil { t.Fatalf("unexpected error on empty input: %v", err) } } out := e.Finish() if out != nil { t.Fatalf("expected (nil,nil) for empty input, got (%v)", out) } } // TestEncoderLargeInput stress-tests the encoder with a large binary payload // across every chunk size from 1 to 13 to catch off-by-one errors. func TestEncoderLargeInput(t *testing.T) { plain := make([]byte, 1000) for i := range plain { plain[i] = byte(i * 7) } for chunkSize := 1; chunkSize <= 13; chunkSize++ { encodeRoundtrip(t, plain, chunkSize) } } // TestEncoderNumLeftoverInEncode checks that the bridge-block path in Encode // (when num_leftover > 0 at the start of a new call) produces correct results. func TestEncoderNumLeftoverInEncode(t *testing.T) { plain := []byte("abcdefghijklmnopqrstuvwxyz") // 26 bytes for firstCallLen := 1; firstCallLen <= 2; firstCallLen++ { var e StreamingBase64Encoder outBuf := make([]byte, 64) var got []byte for enc, err := range e.Encode(plain[:firstCallLen], outBuf) { if err != nil { t.Fatalf("firstCallLen=%d first Encode error: %v", firstCallLen, err) } got = append(got, enc...) } rest, err := collectEncode(t, &e, plain[firstCallLen:], 3) if err != nil { t.Fatalf("firstCallLen=%d rest Encode error: %v", firstCallLen, err) } got = append(got, rest...) tail := e.Finish() got = append(got, tail...) want := []byte(base64.RawStdEncoding.EncodeToString(plain)) if !bytes.Equal(got, want) { t.Fatalf("firstCallLen=%d mismatch:\n want %q\n got %q", firstCallLen, want, got) } } } // TestEncoderFinishResetsState verifies that calling Finish resets the encoder // so it can be reused for a new stream. func TestEncoderFinishResetsState(t *testing.T) { var e StreamingBase64Encoder outBuf := make([]byte, 64) plain1 := []byte("hello") var got1 []byte for enc, err := range e.Encode(plain1, outBuf) { if err != nil { t.Fatal(err) } got1 = append(got1, enc...) } tail1 := e.Finish() got1 = append(got1, tail1...) want1 := []byte(base64.RawStdEncoding.EncodeToString(plain1)) if !bytes.Equal(got1, want1) { t.Fatalf("first stream: want %q, got %q", want1, got1) } // Reuse encoder for a second stream plain2 := []byte("world!") var got2 []byte for enc, err := range e.Encode(plain2, outBuf) { if err != nil { t.Fatal(err) } got2 = append(got2, enc...) } tail2 := e.Finish() got2 = append(got2, tail2...) want2 := []byte(base64.RawStdEncoding.EncodeToString(plain2)) if !bytes.Equal(got2, want2) { t.Fatalf("second stream: want %q, got %q", want2, got2) } } // TestEncoderNeededOutputLen verifies that NeededOutputLen returns the correct // minimum buffer size for various input lengths and leftover states. func TestEncoderNeededOutputLen(t *testing.T) { tests := []struct { leftover int inputLen int want int }{ {0, 0, 0}, {0, 1, 0}, // 1 byte → leftover, no output yet {0, 2, 0}, // 2 bytes → leftover, no output yet {0, 3, 4}, // 3 bytes → 4 chars {0, 4, 4}, // 4 bytes → 3+1, only 3 processed → 4 chars {0, 6, 8}, // 6 bytes → 8 chars {1, 2, 4}, // 1+2=3 bytes → 4 chars {1, 3, 4}, // 1+3=4, floor(4/3)*4 = 4 {2, 1, 4}, // 2+1=3 bytes → 4 chars {2, 4, 8}, // 2+4=6, floor(6/3)*4 = 8 } for _, tt := range tests { e := StreamingBase64Encoder{num_leftover: tt.leftover} got := e.NeededOutputLen(tt.inputLen) if got != tt.want { t.Errorf("leftover=%d inputLen=%d: NeededOutputLen=%d, want %d", tt.leftover, tt.inputLen, got, tt.want) } } } // TestEncoderDecoderRoundtrip verifies that streaming-encode then // streaming-decode recovers the original data for a variety of inputs and // chunk sizes. func TestEncoderDecoderRoundtrip(t *testing.T) { plaintexts := [][]byte{ {}, []byte("a"), []byte("ab"), []byte("abc"), []byte("Hello, World!"), bytes.Repeat([]byte{0x00, 0xfe, 0x80}, 20), } for _, plain := range plaintexts { for chunkSize := 1; chunkSize <= 7; chunkSize++ { if len(plain) == 0 && chunkSize > 1 { continue } // Encode var e StreamingBase64Encoder encoded, err := collectEncode(t, &e, plain, chunkSize) if err != nil { t.Fatalf("Encode error: %v", err) } tail := e.Finish() encoded = append(encoded, tail...) // Decode using RawStdEncoding (no padding) var d StreamingBase64Decoder decoded, err := collectDecode(t, &d, encoded, chunkSize) if err != nil { t.Fatalf("Decode error: %v", err) } dtail, err := d.Finish() if err != nil { t.Fatalf("Decode Finish error: %v", err) } decoded = append(decoded, dtail...) if !bytes.Equal(decoded, plain) { t.Fatalf("chunkSize=%d roundtrip mismatch:\n want %q\n got %q", chunkSize, plain, decoded) } } } }