Merge branch 'copilot/add-streaming-base64-encoder' of https://github.com/kovidgoyal/kitty

This commit is contained in:
Kovid Goyal
2026-05-02 11:14:02 +05:30
2 changed files with 423 additions and 0 deletions

View File

@@ -119,3 +119,89 @@ func (s *StreamingBase64Decoder) Finish() ([]byte, error) {
n, err := base64.StdEncoding.Decode(output[:3], s.leftover[:4])
return output[:n], wrap_error(err, s.total_read-int64(s.num_leftover))
}
type StreamingBase64Encoder struct {
leftover [3]byte
num_leftover int
total_read int64
}
// The size of output buffer needed to encode the provided number of input bytes.
func (s *StreamingBase64Encoder) NeededOutputLen(input_len int) int {
return ((input_len + s.num_leftover) / 3) * 4
}
// Encode provided input, iterating in chunks. Each chunk is a slice from the
// provided output buffer, which must be at least s.NeededOutputLen() in size.
// The only error returned is when the output slice is too small.
func (s *StreamingBase64Encoder) Encode(input []byte, output []byte) iter.Seq2[[]byte, error] {
maxPossibleOutput := s.NeededOutputLen(len(input))
return func(yield func([]byte, error) bool) {
if len(output) < maxPossibleOutput {
yield(nil, fmt.Errorf("output slice too small: need at least %d, got %d", maxPossibleOutput, len(output)))
return
}
currIn := input
outOffset := 0
// 1. Handle leftover bytes from previous call
if s.num_leftover > 0 {
need := 3 - s.num_leftover
if len(currIn) >= need {
copy(s.leftover[s.num_leftover:], currIn[:need])
// Encode the bridge block (3 bytes -> 4 chars)
base64.RawStdEncoding.Encode(output[outOffset:], s.leftover[:3])
if !yield(output[outOffset:outOffset+4], nil) {
return
}
outOffset += 4
currIn = currIn[need:]
s.total_read += int64(need)
s.num_leftover = 0
} else {
// Still not enough to complete a group of 3
copy(s.leftover[s.num_leftover:], currIn)
s.num_leftover += len(currIn)
s.total_read += int64(len(currIn))
return
}
}
// 2. Encode the bulk of the current chunk without copying
processableLen := (len(currIn) / 3) * 3
if processableLen > 0 {
encodedLen := (processableLen / 3) * 4
base64.RawStdEncoding.Encode(output[outOffset:], currIn[:processableLen])
if !yield(output[outOffset:outOffset+encodedLen], nil) {
return
}
outOffset += encodedLen
currIn = currIn[processableLen:]
s.total_read += int64(processableLen)
}
// 3. Buffer remaining bytes (1-2) for the next Encode call
if len(currIn) > 0 {
copy(s.leftover[:], currIn)
s.num_leftover = len(currIn)
s.total_read += int64(len(currIn))
}
}
}
// Finish encoding the stream. Resets the encoder. Returned slice can be nil
// if no leftover bytes are present.
func (s *StreamingBase64Encoder) Finish() ([]byte, error) {
defer func() {
s.num_leftover = 0
s.total_read = 0
}()
if s.num_leftover == 0 {
return nil, nil
}
encodedLen := base64.RawStdEncoding.EncodedLen(s.num_leftover)
output := make([]byte, encodedLen)
base64.RawStdEncoding.Encode(output, s.leftover[:s.num_leftover])
return output, nil
}

View File

@@ -371,3 +371,340 @@ func TestLargeInput(t *testing.T) {
roundtripNoPadding(t, plain, chunkSize)
}
}
// collectEncode feeds input to the encoder in chunks of chunkSize bytes and
// accumulates all encoded output into a single slice.
func collectEncode(t *testing.T, e *StreamingBase64Encoder, plain []byte, chunkSize int) ([]byte, error) {
t.Helper()
var result []byte
for len(plain) > 0 {
end := min(chunkSize, len(plain))
chunk := plain[:end]
plain = plain[end:]
outBuf := make([]byte, e.NeededOutputLen(len(chunk))+8)
for encoded, err := range e.Encode(chunk, outBuf) {
if err != nil {
return nil, err
}
result = append(result, encoded...)
}
}
return result, nil
}
// encodeRoundtrip encodes plaintext through the streaming encoder in chunkSize
// pieces, appends the Finish output, and verifies the result matches
// base64.RawStdEncoding.
func encodeRoundtrip(t *testing.T, plaintext []byte, chunkSize int) {
t.Helper()
var e StreamingBase64Encoder
got, err := collectEncode(t, &e, plaintext, chunkSize)
if err != nil {
t.Fatalf("chunkSize=%d: unexpected Encode error: %v", chunkSize, err)
}
tail, err := e.Finish()
if err != nil {
t.Fatalf("chunkSize=%d: unexpected Finish error: %v", chunkSize, err)
}
got = append(got, tail...)
want := []byte(base64.RawStdEncoding.EncodeToString(plaintext))
if !bytes.Equal(got, want) {
t.Fatalf("chunkSize=%d: encode mismatch:\n want %q\n got %q", chunkSize, want, got)
}
}
// TestEncoderRoundtripAllChunkSizes exercises every chunk size from 1 to 7 with
// plaintexts whose lengths cover all possible num_leftover values (0, 1, 2).
func TestEncoderRoundtripAllChunkSizes(t *testing.T) {
plaintexts := [][]byte{
{}, // 0 bytes → 0 leftover
[]byte("a"), // 1 byte → 1 leftover
[]byte("ab"), // 2 bytes → 2 leftover
[]byte("abc"), // 3 bytes → 0 leftover
[]byte("abcd"), // 4 bytes → 1 leftover
[]byte("abcde"), // 5 bytes → 2 leftover
[]byte("abcdef"), // 6 bytes → 0 leftover
[]byte("Hello, World!"), // 13 bytes
[]byte("The quick brown fox jumps over the"), // 34 bytes
bytes.Repeat([]byte{0x00, 0xff, 0x80}, 17), // binary data
}
for _, plain := range plaintexts {
for chunkSize := 1; chunkSize <= 7; chunkSize++ {
if len(plain) == 0 && chunkSize > 1 {
continue
}
encodeRoundtrip(t, plain, chunkSize)
}
}
}
// TestEncoderFinishLeftover directly exercises all three branches of Finish.
func TestEncoderFinishLeftover(t *testing.T) {
// num_leftover=0: empty encoder, Finish must return (nil, nil).
t.Run("leftover=0", func(t *testing.T) {
var e StreamingBase64Encoder
out, err := e.Finish()
if err != nil || out != nil {
t.Fatalf("expected (nil,nil), got (%v,%v)", out, err)
}
})
// num_leftover=0 after a complete stream: also (nil, nil).
t.Run("leftover=0_after_stream", func(t *testing.T) {
// "abc" is 3 bytes, encodes with no leftover
var e StreamingBase64Encoder
outBuf := make([]byte, 16)
for _, err := range e.Encode([]byte("abc"), outBuf) {
if err != nil {
t.Fatal(err)
}
}
out, err := e.Finish()
if err != nil || out != nil {
t.Fatalf("expected (nil,nil), got (%v,%v)", out, err)
}
})
// num_leftover=1 → Finish encodes 1 byte → 2 chars.
t.Run("leftover=1", func(t *testing.T) {
// Feed 4 bytes: 3 processed, 1 leftover
var e StreamingBase64Encoder
outBuf := make([]byte, 16)
for _, err := range e.Encode([]byte("abcd"), outBuf) {
if err != nil {
t.Fatalf("unexpected Encode error: %v", err)
}
}
tail, err := e.Finish()
if err != nil {
t.Fatalf("unexpected Finish error: %v", err)
}
want := []byte(base64.RawStdEncoding.EncodeToString([]byte("d")))
if !bytes.Equal(tail, want) {
t.Fatalf("want %q, got %q", want, tail)
}
})
// num_leftover=2 → Finish encodes 2 bytes → 3 chars.
t.Run("leftover=2", func(t *testing.T) {
// Feed 5 bytes: 3 processed, 2 leftover
var e StreamingBase64Encoder
outBuf := make([]byte, 16)
for _, err := range e.Encode([]byte("abcde"), outBuf) {
if err != nil {
t.Fatalf("unexpected Encode error: %v", err)
}
}
tail, err := e.Finish()
if err != nil {
t.Fatalf("unexpected Finish error: %v", err)
}
want := []byte(base64.RawStdEncoding.EncodeToString([]byte("de")))
if !bytes.Equal(tail, want) {
t.Fatalf("want %q, got %q", want, tail)
}
})
}
// TestEncoderOutputBufferTooSmall verifies that Encode returns an error (not a
// panic) when the supplied output buffer is too small.
func TestEncoderOutputBufferTooSmall(t *testing.T) {
var e StreamingBase64Encoder
tinyBuf := make([]byte, 0) // too small for any output
input := []byte("abc") // 3 bytes → 4 chars needed
var gotErr error
for _, err := range e.Encode(input, tinyBuf) {
if err != nil {
gotErr = err
break
}
}
if gotErr == nil {
t.Fatal("expected an error for too-small output buffer, got nil")
}
}
// TestEncoderEmptyInput verifies that encoding an empty input followed by
// Finish produces no output and no error.
func TestEncoderEmptyInput(t *testing.T) {
var e StreamingBase64Encoder
outBuf := make([]byte, 16)
for _, err := range e.Encode([]byte{}, outBuf) {
if err != nil {
t.Fatalf("unexpected error on empty input: %v", err)
}
}
out, err := e.Finish()
if err != nil || out != nil {
t.Fatalf("expected (nil,nil) for empty input, got (%v,%v)", out, err)
}
}
// TestEncoderLargeInput stress-tests the encoder with a large binary payload
// across every chunk size from 1 to 13 to catch off-by-one errors.
func TestEncoderLargeInput(t *testing.T) {
plain := make([]byte, 1000)
for i := range plain {
plain[i] = byte(i * 7)
}
for chunkSize := 1; chunkSize <= 13; chunkSize++ {
encodeRoundtrip(t, plain, chunkSize)
}
}
// TestEncoderNumLeftoverInEncode checks that the bridge-block path in Encode
// (when num_leftover > 0 at the start of a new call) produces correct results.
func TestEncoderNumLeftoverInEncode(t *testing.T) {
plain := []byte("abcdefghijklmnopqrstuvwxyz") // 26 bytes
for firstCallLen := 1; firstCallLen <= 2; firstCallLen++ {
var e StreamingBase64Encoder
outBuf := make([]byte, 64)
var got []byte
for enc, err := range e.Encode(plain[:firstCallLen], outBuf) {
if err != nil {
t.Fatalf("firstCallLen=%d first Encode error: %v", firstCallLen, err)
}
got = append(got, enc...)
}
rest, err := collectEncode(t, &e, plain[firstCallLen:], 3)
if err != nil {
t.Fatalf("firstCallLen=%d rest Encode error: %v", firstCallLen, err)
}
got = append(got, rest...)
tail, err := e.Finish()
if err != nil {
t.Fatalf("firstCallLen=%d Finish error: %v", firstCallLen, err)
}
got = append(got, tail...)
want := []byte(base64.RawStdEncoding.EncodeToString(plain))
if !bytes.Equal(got, want) {
t.Fatalf("firstCallLen=%d mismatch:\n want %q\n got %q", firstCallLen, want, got)
}
}
}
// TestEncoderFinishResetsState verifies that calling Finish resets the encoder
// so it can be reused for a new stream.
func TestEncoderFinishResetsState(t *testing.T) {
var e StreamingBase64Encoder
outBuf := make([]byte, 64)
plain1 := []byte("hello")
var got1 []byte
for enc, err := range e.Encode(plain1, outBuf) {
if err != nil {
t.Fatal(err)
}
got1 = append(got1, enc...)
}
tail1, err := e.Finish()
if err != nil {
t.Fatal(err)
}
got1 = append(got1, tail1...)
want1 := []byte(base64.RawStdEncoding.EncodeToString(plain1))
if !bytes.Equal(got1, want1) {
t.Fatalf("first stream: want %q, got %q", want1, got1)
}
// Reuse encoder for a second stream
plain2 := []byte("world!")
var got2 []byte
for enc, err := range e.Encode(plain2, outBuf) {
if err != nil {
t.Fatal(err)
}
got2 = append(got2, enc...)
}
tail2, err := e.Finish()
if err != nil {
t.Fatal(err)
}
got2 = append(got2, tail2...)
want2 := []byte(base64.RawStdEncoding.EncodeToString(plain2))
if !bytes.Equal(got2, want2) {
t.Fatalf("second stream: want %q, got %q", want2, got2)
}
}
// TestEncoderNeededOutputLen verifies that NeededOutputLen returns the correct
// minimum buffer size for various input lengths and leftover states.
func TestEncoderNeededOutputLen(t *testing.T) {
tests := []struct {
leftover int
inputLen int
want int
}{
{0, 0, 0},
{0, 1, 0}, // 1 byte → leftover, no output yet
{0, 2, 0}, // 2 bytes → leftover, no output yet
{0, 3, 4}, // 3 bytes → 4 chars
{0, 4, 4}, // 4 bytes → 3+1, only 3 processed → 4 chars
{0, 6, 8}, // 6 bytes → 8 chars
{1, 2, 4}, // 1+2=3 bytes → 4 chars
{1, 3, 4}, // 1+3=4, floor(4/3)*4 = 4
{2, 1, 4}, // 2+1=3 bytes → 4 chars
{2, 4, 8}, // 2+4=6, floor(6/3)*4 = 8
}
for _, tt := range tests {
e := StreamingBase64Encoder{num_leftover: tt.leftover}
got := e.NeededOutputLen(tt.inputLen)
if got != tt.want {
t.Errorf("leftover=%d inputLen=%d: NeededOutputLen=%d, want %d", tt.leftover, tt.inputLen, got, tt.want)
}
}
}
// TestEncoderDecoderRoundtrip verifies that streaming-encode then
// streaming-decode recovers the original data for a variety of inputs and
// chunk sizes.
func TestEncoderDecoderRoundtrip(t *testing.T) {
plaintexts := [][]byte{
{},
[]byte("a"),
[]byte("ab"),
[]byte("abc"),
[]byte("Hello, World!"),
bytes.Repeat([]byte{0x00, 0xfe, 0x80}, 20),
}
for _, plain := range plaintexts {
for chunkSize := 1; chunkSize <= 7; chunkSize++ {
if len(plain) == 0 && chunkSize > 1 {
continue
}
// Encode
var e StreamingBase64Encoder
encoded, err := collectEncode(t, &e, plain, chunkSize)
if err != nil {
t.Fatalf("Encode error: %v", err)
}
tail, err := e.Finish()
if err != nil {
t.Fatalf("Encode Finish error: %v", err)
}
encoded = append(encoded, tail...)
// Decode using RawStdEncoding (no padding)
var d StreamingBase64Decoder
decoded, err := collectDecode(t, &d, encoded, chunkSize)
if err != nil {
t.Fatalf("Decode error: %v", err)
}
dtail, err := d.Finish()
if err != nil {
t.Fatalf("Decode Finish error: %v", err)
}
decoded = append(decoded, dtail...)
if !bytes.Equal(decoded, plain) {
t.Fatalf("chunkSize=%d roundtrip mismatch:\n want %q\n got %q", chunkSize, plain, decoded)
}
}
}
}