mirror of
https://github.com/kovidgoyal/kitty
synced 2026-06-06 01:05:48 +02:00
Merge branch 'copilot/add-streaming-base64-encoder' of https://github.com/kovidgoyal/kitty
This commit is contained in:
@@ -119,3 +119,89 @@ func (s *StreamingBase64Decoder) Finish() ([]byte, error) {
|
||||
n, err := base64.StdEncoding.Decode(output[:3], s.leftover[:4])
|
||||
return output[:n], wrap_error(err, s.total_read-int64(s.num_leftover))
|
||||
}
|
||||
|
||||
type StreamingBase64Encoder struct {
|
||||
leftover [3]byte
|
||||
num_leftover int
|
||||
total_read int64
|
||||
}
|
||||
|
||||
// The size of output buffer needed to encode the provided number of input bytes.
|
||||
func (s *StreamingBase64Encoder) NeededOutputLen(input_len int) int {
|
||||
return ((input_len + s.num_leftover) / 3) * 4
|
||||
}
|
||||
|
||||
// Encode provided input, iterating in chunks. Each chunk is a slice from the
|
||||
// provided output buffer, which must be at least s.NeededOutputLen() in size.
|
||||
// The only error returned is when the output slice is too small.
|
||||
func (s *StreamingBase64Encoder) Encode(input []byte, output []byte) iter.Seq2[[]byte, error] {
|
||||
maxPossibleOutput := s.NeededOutputLen(len(input))
|
||||
return func(yield func([]byte, error) bool) {
|
||||
if len(output) < maxPossibleOutput {
|
||||
yield(nil, fmt.Errorf("output slice too small: need at least %d, got %d", maxPossibleOutput, len(output)))
|
||||
return
|
||||
}
|
||||
currIn := input
|
||||
outOffset := 0
|
||||
|
||||
// 1. Handle leftover bytes from previous call
|
||||
if s.num_leftover > 0 {
|
||||
need := 3 - s.num_leftover
|
||||
if len(currIn) >= need {
|
||||
copy(s.leftover[s.num_leftover:], currIn[:need])
|
||||
|
||||
// Encode the bridge block (3 bytes -> 4 chars)
|
||||
base64.RawStdEncoding.Encode(output[outOffset:], s.leftover[:3])
|
||||
if !yield(output[outOffset:outOffset+4], nil) {
|
||||
return
|
||||
}
|
||||
outOffset += 4
|
||||
currIn = currIn[need:]
|
||||
s.total_read += int64(need)
|
||||
s.num_leftover = 0
|
||||
} else {
|
||||
// Still not enough to complete a group of 3
|
||||
copy(s.leftover[s.num_leftover:], currIn)
|
||||
s.num_leftover += len(currIn)
|
||||
s.total_read += int64(len(currIn))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Encode the bulk of the current chunk without copying
|
||||
processableLen := (len(currIn) / 3) * 3
|
||||
if processableLen > 0 {
|
||||
encodedLen := (processableLen / 3) * 4
|
||||
base64.RawStdEncoding.Encode(output[outOffset:], currIn[:processableLen])
|
||||
if !yield(output[outOffset:outOffset+encodedLen], nil) {
|
||||
return
|
||||
}
|
||||
outOffset += encodedLen
|
||||
currIn = currIn[processableLen:]
|
||||
s.total_read += int64(processableLen)
|
||||
}
|
||||
|
||||
// 3. Buffer remaining bytes (1-2) for the next Encode call
|
||||
if len(currIn) > 0 {
|
||||
copy(s.leftover[:], currIn)
|
||||
s.num_leftover = len(currIn)
|
||||
s.total_read += int64(len(currIn))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finish encoding the stream. Resets the encoder. Returned slice can be nil
|
||||
// if no leftover bytes are present.
|
||||
func (s *StreamingBase64Encoder) Finish() ([]byte, error) {
|
||||
defer func() {
|
||||
s.num_leftover = 0
|
||||
s.total_read = 0
|
||||
}()
|
||||
if s.num_leftover == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
encodedLen := base64.RawStdEncoding.EncodedLen(s.num_leftover)
|
||||
output := make([]byte, encodedLen)
|
||||
base64.RawStdEncoding.Encode(output, s.leftover[:s.num_leftover])
|
||||
return output, nil
|
||||
}
|
||||
|
||||
@@ -371,3 +371,340 @@ func TestLargeInput(t *testing.T) {
|
||||
roundtripNoPadding(t, plain, chunkSize)
|
||||
}
|
||||
}
|
||||
|
||||
// collectEncode feeds input to the encoder in chunks of chunkSize bytes and
|
||||
// accumulates all encoded output into a single slice.
|
||||
func collectEncode(t *testing.T, e *StreamingBase64Encoder, plain []byte, chunkSize int) ([]byte, error) {
|
||||
t.Helper()
|
||||
var result []byte
|
||||
for len(plain) > 0 {
|
||||
end := min(chunkSize, len(plain))
|
||||
chunk := plain[:end]
|
||||
plain = plain[end:]
|
||||
|
||||
outBuf := make([]byte, e.NeededOutputLen(len(chunk))+8)
|
||||
for encoded, err := range e.Encode(chunk, outBuf) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, encoded...)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// encodeRoundtrip encodes plaintext through the streaming encoder in chunkSize
|
||||
// pieces, appends the Finish output, and verifies the result matches
|
||||
// base64.RawStdEncoding.
|
||||
func encodeRoundtrip(t *testing.T, plaintext []byte, chunkSize int) {
|
||||
t.Helper()
|
||||
var e StreamingBase64Encoder
|
||||
got, err := collectEncode(t, &e, plaintext, chunkSize)
|
||||
if err != nil {
|
||||
t.Fatalf("chunkSize=%d: unexpected Encode error: %v", chunkSize, err)
|
||||
}
|
||||
tail, err := e.Finish()
|
||||
if err != nil {
|
||||
t.Fatalf("chunkSize=%d: unexpected Finish error: %v", chunkSize, err)
|
||||
}
|
||||
got = append(got, tail...)
|
||||
|
||||
want := []byte(base64.RawStdEncoding.EncodeToString(plaintext))
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("chunkSize=%d: encode mismatch:\n want %q\n got %q", chunkSize, want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncoderRoundtripAllChunkSizes exercises every chunk size from 1 to 7 with
|
||||
// plaintexts whose lengths cover all possible num_leftover values (0, 1, 2).
|
||||
func TestEncoderRoundtripAllChunkSizes(t *testing.T) {
|
||||
plaintexts := [][]byte{
|
||||
{}, // 0 bytes → 0 leftover
|
||||
[]byte("a"), // 1 byte → 1 leftover
|
||||
[]byte("ab"), // 2 bytes → 2 leftover
|
||||
[]byte("abc"), // 3 bytes → 0 leftover
|
||||
[]byte("abcd"), // 4 bytes → 1 leftover
|
||||
[]byte("abcde"), // 5 bytes → 2 leftover
|
||||
[]byte("abcdef"), // 6 bytes → 0 leftover
|
||||
[]byte("Hello, World!"), // 13 bytes
|
||||
[]byte("The quick brown fox jumps over the"), // 34 bytes
|
||||
bytes.Repeat([]byte{0x00, 0xff, 0x80}, 17), // binary data
|
||||
}
|
||||
for _, plain := range plaintexts {
|
||||
for chunkSize := 1; chunkSize <= 7; chunkSize++ {
|
||||
if len(plain) == 0 && chunkSize > 1 {
|
||||
continue
|
||||
}
|
||||
encodeRoundtrip(t, plain, chunkSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncoderFinishLeftover directly exercises all three branches of Finish.
|
||||
func TestEncoderFinishLeftover(t *testing.T) {
|
||||
// num_leftover=0: empty encoder, Finish must return (nil, nil).
|
||||
t.Run("leftover=0", func(t *testing.T) {
|
||||
var e StreamingBase64Encoder
|
||||
out, err := e.Finish()
|
||||
if err != nil || out != nil {
|
||||
t.Fatalf("expected (nil,nil), got (%v,%v)", out, err)
|
||||
}
|
||||
})
|
||||
|
||||
// num_leftover=0 after a complete stream: also (nil, nil).
|
||||
t.Run("leftover=0_after_stream", func(t *testing.T) {
|
||||
// "abc" is 3 bytes, encodes with no leftover
|
||||
var e StreamingBase64Encoder
|
||||
outBuf := make([]byte, 16)
|
||||
for _, err := range e.Encode([]byte("abc"), outBuf) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
out, err := e.Finish()
|
||||
if err != nil || out != nil {
|
||||
t.Fatalf("expected (nil,nil), got (%v,%v)", out, err)
|
||||
}
|
||||
})
|
||||
|
||||
// num_leftover=1 → Finish encodes 1 byte → 2 chars.
|
||||
t.Run("leftover=1", func(t *testing.T) {
|
||||
// Feed 4 bytes: 3 processed, 1 leftover
|
||||
var e StreamingBase64Encoder
|
||||
outBuf := make([]byte, 16)
|
||||
for _, err := range e.Encode([]byte("abcd"), outBuf) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Encode error: %v", err)
|
||||
}
|
||||
}
|
||||
tail, err := e.Finish()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Finish error: %v", err)
|
||||
}
|
||||
want := []byte(base64.RawStdEncoding.EncodeToString([]byte("d")))
|
||||
if !bytes.Equal(tail, want) {
|
||||
t.Fatalf("want %q, got %q", want, tail)
|
||||
}
|
||||
})
|
||||
|
||||
// num_leftover=2 → Finish encodes 2 bytes → 3 chars.
|
||||
t.Run("leftover=2", func(t *testing.T) {
|
||||
// Feed 5 bytes: 3 processed, 2 leftover
|
||||
var e StreamingBase64Encoder
|
||||
outBuf := make([]byte, 16)
|
||||
for _, err := range e.Encode([]byte("abcde"), outBuf) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Encode error: %v", err)
|
||||
}
|
||||
}
|
||||
tail, err := e.Finish()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Finish error: %v", err)
|
||||
}
|
||||
want := []byte(base64.RawStdEncoding.EncodeToString([]byte("de")))
|
||||
if !bytes.Equal(tail, want) {
|
||||
t.Fatalf("want %q, got %q", want, tail)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestEncoderOutputBufferTooSmall verifies that Encode returns an error (not a
|
||||
// panic) when the supplied output buffer is too small.
|
||||
func TestEncoderOutputBufferTooSmall(t *testing.T) {
|
||||
var e StreamingBase64Encoder
|
||||
tinyBuf := make([]byte, 0) // too small for any output
|
||||
input := []byte("abc") // 3 bytes → 4 chars needed
|
||||
var gotErr error
|
||||
for _, err := range e.Encode(input, tinyBuf) {
|
||||
if err != nil {
|
||||
gotErr = err
|
||||
break
|
||||
}
|
||||
}
|
||||
if gotErr == nil {
|
||||
t.Fatal("expected an error for too-small output buffer, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncoderEmptyInput verifies that encoding an empty input followed by
|
||||
// Finish produces no output and no error.
|
||||
func TestEncoderEmptyInput(t *testing.T) {
|
||||
var e StreamingBase64Encoder
|
||||
outBuf := make([]byte, 16)
|
||||
for _, err := range e.Encode([]byte{}, outBuf) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on empty input: %v", err)
|
||||
}
|
||||
}
|
||||
out, err := e.Finish()
|
||||
if err != nil || out != nil {
|
||||
t.Fatalf("expected (nil,nil) for empty input, got (%v,%v)", out, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncoderLargeInput stress-tests the encoder with a large binary payload
|
||||
// across every chunk size from 1 to 13 to catch off-by-one errors.
|
||||
func TestEncoderLargeInput(t *testing.T) {
|
||||
plain := make([]byte, 1000)
|
||||
for i := range plain {
|
||||
plain[i] = byte(i * 7)
|
||||
}
|
||||
for chunkSize := 1; chunkSize <= 13; chunkSize++ {
|
||||
encodeRoundtrip(t, plain, chunkSize)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncoderNumLeftoverInEncode checks that the bridge-block path in Encode
|
||||
// (when num_leftover > 0 at the start of a new call) produces correct results.
|
||||
func TestEncoderNumLeftoverInEncode(t *testing.T) {
|
||||
plain := []byte("abcdefghijklmnopqrstuvwxyz") // 26 bytes
|
||||
|
||||
for firstCallLen := 1; firstCallLen <= 2; firstCallLen++ {
|
||||
var e StreamingBase64Encoder
|
||||
outBuf := make([]byte, 64)
|
||||
|
||||
var got []byte
|
||||
for enc, err := range e.Encode(plain[:firstCallLen], outBuf) {
|
||||
if err != nil {
|
||||
t.Fatalf("firstCallLen=%d first Encode error: %v", firstCallLen, err)
|
||||
}
|
||||
got = append(got, enc...)
|
||||
}
|
||||
|
||||
rest, err := collectEncode(t, &e, plain[firstCallLen:], 3)
|
||||
if err != nil {
|
||||
t.Fatalf("firstCallLen=%d rest Encode error: %v", firstCallLen, err)
|
||||
}
|
||||
got = append(got, rest...)
|
||||
tail, err := e.Finish()
|
||||
if err != nil {
|
||||
t.Fatalf("firstCallLen=%d Finish error: %v", firstCallLen, err)
|
||||
}
|
||||
got = append(got, tail...)
|
||||
|
||||
want := []byte(base64.RawStdEncoding.EncodeToString(plain))
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("firstCallLen=%d mismatch:\n want %q\n got %q", firstCallLen, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncoderFinishResetsState verifies that calling Finish resets the encoder
|
||||
// so it can be reused for a new stream.
|
||||
func TestEncoderFinishResetsState(t *testing.T) {
|
||||
var e StreamingBase64Encoder
|
||||
outBuf := make([]byte, 64)
|
||||
|
||||
plain1 := []byte("hello")
|
||||
var got1 []byte
|
||||
for enc, err := range e.Encode(plain1, outBuf) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got1 = append(got1, enc...)
|
||||
}
|
||||
tail1, err := e.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got1 = append(got1, tail1...)
|
||||
want1 := []byte(base64.RawStdEncoding.EncodeToString(plain1))
|
||||
if !bytes.Equal(got1, want1) {
|
||||
t.Fatalf("first stream: want %q, got %q", want1, got1)
|
||||
}
|
||||
|
||||
// Reuse encoder for a second stream
|
||||
plain2 := []byte("world!")
|
||||
var got2 []byte
|
||||
for enc, err := range e.Encode(plain2, outBuf) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got2 = append(got2, enc...)
|
||||
}
|
||||
tail2, err := e.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got2 = append(got2, tail2...)
|
||||
want2 := []byte(base64.RawStdEncoding.EncodeToString(plain2))
|
||||
if !bytes.Equal(got2, want2) {
|
||||
t.Fatalf("second stream: want %q, got %q", want2, got2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncoderNeededOutputLen verifies that NeededOutputLen returns the correct
|
||||
// minimum buffer size for various input lengths and leftover states.
|
||||
func TestEncoderNeededOutputLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
leftover int
|
||||
inputLen int
|
||||
want int
|
||||
}{
|
||||
{0, 0, 0},
|
||||
{0, 1, 0}, // 1 byte → leftover, no output yet
|
||||
{0, 2, 0}, // 2 bytes → leftover, no output yet
|
||||
{0, 3, 4}, // 3 bytes → 4 chars
|
||||
{0, 4, 4}, // 4 bytes → 3+1, only 3 processed → 4 chars
|
||||
{0, 6, 8}, // 6 bytes → 8 chars
|
||||
{1, 2, 4}, // 1+2=3 bytes → 4 chars
|
||||
{1, 3, 4}, // 1+3=4, floor(4/3)*4 = 4
|
||||
{2, 1, 4}, // 2+1=3 bytes → 4 chars
|
||||
{2, 4, 8}, // 2+4=6, floor(6/3)*4 = 8
|
||||
}
|
||||
for _, tt := range tests {
|
||||
e := StreamingBase64Encoder{num_leftover: tt.leftover}
|
||||
got := e.NeededOutputLen(tt.inputLen)
|
||||
if got != tt.want {
|
||||
t.Errorf("leftover=%d inputLen=%d: NeededOutputLen=%d, want %d", tt.leftover, tt.inputLen, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncoderDecoderRoundtrip verifies that streaming-encode then
|
||||
// streaming-decode recovers the original data for a variety of inputs and
|
||||
// chunk sizes.
|
||||
func TestEncoderDecoderRoundtrip(t *testing.T) {
|
||||
plaintexts := [][]byte{
|
||||
{},
|
||||
[]byte("a"),
|
||||
[]byte("ab"),
|
||||
[]byte("abc"),
|
||||
[]byte("Hello, World!"),
|
||||
bytes.Repeat([]byte{0x00, 0xfe, 0x80}, 20),
|
||||
}
|
||||
for _, plain := range plaintexts {
|
||||
for chunkSize := 1; chunkSize <= 7; chunkSize++ {
|
||||
if len(plain) == 0 && chunkSize > 1 {
|
||||
continue
|
||||
}
|
||||
// Encode
|
||||
var e StreamingBase64Encoder
|
||||
encoded, err := collectEncode(t, &e, plain, chunkSize)
|
||||
if err != nil {
|
||||
t.Fatalf("Encode error: %v", err)
|
||||
}
|
||||
tail, err := e.Finish()
|
||||
if err != nil {
|
||||
t.Fatalf("Encode Finish error: %v", err)
|
||||
}
|
||||
encoded = append(encoded, tail...)
|
||||
|
||||
// Decode using RawStdEncoding (no padding)
|
||||
var d StreamingBase64Decoder
|
||||
decoded, err := collectDecode(t, &d, encoded, chunkSize)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode error: %v", err)
|
||||
}
|
||||
dtail, err := d.Finish()
|
||||
if err != nil {
|
||||
t.Fatalf("Decode Finish error: %v", err)
|
||||
}
|
||||
decoded = append(decoded, dtail...)
|
||||
|
||||
if !bytes.Equal(decoded, plain) {
|
||||
t.Fatalf("chunkSize=%d roundtrip mismatch:\n want %q\n got %q", chunkSize, plain, decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user