From ff7c6425e649e594643fd1f47796972a498645c1 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sat, 2 May 2026 12:16:42 +0530 Subject: [PATCH] More work on dnd kitten --- docs/dnd-protocol.rst | 4 +- kittens/dnd/drag.go | 119 ++++++++++++++++++++++- kittens/dnd/main.go | 18 ++-- tools/utils/streaming_base64/api.go | 10 +- tools/utils/streaming_base64/api_test.go | 53 +++------- 5 files changed, 151 insertions(+), 53 deletions(-) diff --git a/docs/dnd-protocol.rst b/docs/dnd-protocol.rst index bca63c2da..5e6c942b0 100644 --- a/docs/dnd-protocol.rst +++ b/docs/dnd-protocol.rst @@ -386,9 +386,7 @@ These represent possibly chunked data for files, symlinks and directories, as denoted by the ``X`` key. As above, end of data for an individual entry is indicated by an escape code with ``m=0`` and no payload. ``idx`` is the one based index into the list of entries in the ``text/uri-list`` MIME type. -``file://`` URLs that point to symlinks must be resolved to files or -directories and sent. So actual symlinks will appear only when recursing -through directories as described below. Only regular files should be sent. +Only regular files, symlinks and directories should be sent. Terminals should write the transmitted data into a temporary directory and replace the entries in the ``text/uri-list`` data with the transmitted diff --git a/kittens/dnd/drag.go b/kittens/dnd/drag.go index 42fc124ca..ee1df9fea 100644 --- a/kittens/dnd/drag.go +++ b/kittens/dnd/drag.go @@ -1,16 +1,30 @@ package dnd import ( + "errors" "fmt" + "io" "maps" + "os" "slices" "strings" + "github.com/emmansun/base64" + "github.com/kovidgoyal/kitty/tools/tui/loop" "github.com/kovidgoyal/kitty/tools/utils" + "github.com/kovidgoyal/kitty/tools/utils/streaming_base64" ) var _ = fmt.Print +type data_request struct { + drag_source *drag_source + send_remote_data bool + index int + write_id loop.IdType + base64 streaming_base64.StreamingBase64Encoder +} + type drag_status struct { active bool terminal_accepted_drag bool @@ -18,6 +32,7 @@ type drag_status struct { accepted_mime int accepted_operation int dropped bool + data_requests []*data_request } func (dnd *dnd) on_potential_drag_start(cell_x, cell_y int) (err error) { @@ -64,10 +79,16 @@ func (dnd *dnd) on_drag_error(cmd DC) (err error) { } func (dnd *dnd) reset_drag() { + for _, dr := range dnd.drag_status.data_requests { + if dr.drag_source.file != nil { + dr.drag_source.file.Close() + dr.drag_source.file = nil + } + } dnd.drag_status = drag_status{} } -func (dnd *dnd) on_drag_event(x, y, operation int) (err error) { +func (dnd *dnd) on_drag_event(x, y, operation, Y int) (err error) { switch x { case 1: dnd.drag_status.accepted_mime = y @@ -77,6 +98,102 @@ func (dnd *dnd) on_drag_event(x, y, operation int) (err error) { dnd.drag_status.dropped = true case 4: dnd.reset_drag() + case 5: + if err = dnd.handle_data_request(y, Y == 1); err != nil { + return err + } } return dnd.render_screen() } + +func (dnd *dnd) finish_drag(errname string) { + if errname == "" { // cancel drag + dnd.lp.QueueDnDData(DC{Type: 'E', Y: -1}) + } else { + dnd.lp.QueueDnDData(DC{Type: 'E', Payload: []byte(errname)}) + } + dnd.reset_drag() +} + +func (dnd *dnd) handle_data_request(idx int, send_remote_data bool) (err error) { + if idx < 0 || idx >= len(dnd.drag_status.offered_mimes) { + dnd.finish_drag("EINVAL") + return fmt.Errorf("terminal asked for drag data from MIME list with out of bounds index: %d", idx) + } + mime := dnd.drag_status.offered_mimes[idx] + ds := dnd.drag_sources[mime] + send_remote_data = send_remote_data && mime == "text/uri-list" && len(ds.uri_list) > 0 + dr := &data_request{drag_source: ds, send_remote_data: send_remote_data, index: idx} + if ds.path == "" { + dnd.lp.QueueDnDData(DC{Type: 'e', Y: idx, Payload: utils.UnsafeStringToBytes(base64.RawStdEncoding.EncodeToString(ds.data))}) + if !dr.send_remote_data { + return + } + return dnd.start_remote_data_send(ds) + } else { + if ds.file != nil { + ds.file.Close() + } + if ds.file, err = os.Open(ds.path); err != nil { + dnd.finish_drag("EIO") + return err + } + } + dnd.drag_status.data_requests = append(dnd.drag_status.data_requests, dr) + return dnd.send_data_for_data_request(len(dnd.drag_status.data_requests) - 1) +} + +var read_buf [64 * 1024]byte +var encode_buf [128 * 1024]byte + +func (dnd *dnd) send_data_for_data_request(i int) (err error) { + dr := dnd.drag_status.data_requests[i] + n, err := dr.drag_source.file.Read(read_buf[:]) + if n > 0 { + for chunk := range dr.base64.Encode(read_buf[:n], encode_buf[:]) { + dr.write_id = dnd.lp.QueueDnDData(DC{Type: 'e', Y: dr.index, Payload: chunk}) + } + } + if err == nil { + return nil + } + if errors.Is(err, io.EOF) { + chunk := dr.base64.Finish() + if len(chunk) > 0 { + dr.write_id = dnd.lp.QueueDnDData(DC{Type: 'e', Y: dr.index, Payload: chunk}) + } + dr.write_id = dnd.lp.QueueDnDData(DC{Type: 'e', Y: dr.index}) // EOF + return dnd.on_data_request_finished(i) + } + dnd.finish_drag("EIO") + return err +} + +func (dnd *dnd) on_send_done(id loop.IdType) (err error) { + for i, dr := range dnd.drag_status.data_requests { + if dr.write_id == id { + return dnd.send_data_for_data_request(i) + } + } + return +} + +func (dnd *dnd) on_data_request_finished(i int) (err error) { + dr := dnd.drag_status.data_requests[i] + if dr.drag_source.file != nil { + dr.drag_source.file.Close() + dr.drag_source.file = nil + } + dnd.drag_status.data_requests = slices.Delete(dnd.drag_status.data_requests, i, i+1) + if dr.send_remote_data { + err = dnd.start_remote_data_send(dr.drag_source) + } else if len(dnd.drag_status.data_requests) > 0 { + err = dnd.send_data_for_data_request(0) + } + return +} + +func (dnd *dnd) start_remote_data_send(ds *drag_source) (err error) { + // TODO: Implement this + return +} diff --git a/kittens/dnd/main.go b/kittens/dnd/main.go index e151d87e8..b05d33c4a 100644 --- a/kittens/dnd/main.go +++ b/kittens/dnd/main.go @@ -76,7 +76,7 @@ func (d *dir_handle) unref() *dir_handle { type dnd struct { opts *Options drop_dests map[string]*drop_dest - drag_sources map[string]drag_source + drag_sources map[string]*drag_source allow_drops, allow_drags bool lp *loop.Loop @@ -253,11 +253,15 @@ func (dnd *dnd) run_loop() (err error) { case 'E': return dnd.on_drag_error(cmd) case 'e': - return dnd.on_drag_event(cmd.X, cmd.Y, cmd.Operation) + return dnd.on_drag_event(cmd.X, cmd.Y, cmd.Operation, cmd.Yp) } return nil } + dnd.lp.OnWriteComplete = func(msg_id loop.IdType, has_pending_writes bool) (err error) { + return dnd.on_send_done(msg_id) + } + dnd.lp.OnKeyEvent = func(e *loop.KeyEvent) (err error) { e.Handled = true if len(dnd.confirm_drop.overwrites) > 0 { @@ -311,20 +315,20 @@ func dnd_main(cmd *cli.Command, opts *Options, args []string) (rc int, err error drop_dests[mime] = &drop_dest{human_name: dest, path: path, mime_type: mime} } } - drag_sources := make(map[string]drag_source) + drag_sources := make(map[string]*drag_source) for _, spec := range opts.Drag { mime, src, found := strings.Cut(spec, ":") if !found { return 1, fmt.Errorf("invalid drag source %s, must be of the form mime-type:path", spec) } - s := drag_source{human_name: src, mime_type: mime} + s := &drag_source{human_name: src, mime_type: mime} if src == "-" || src == "/dev/stdin" { data, err := io.ReadAll(os.Stdin) if err != nil { return 1, err } if len(data) > 0 { - drag_sources["text/plain"] = drag_source{human_name: "STDIN", mime_type: "text/plain", data: data} + drag_sources["text/plain"] = &drag_source{human_name: "STDIN", mime_type: "text/plain", data: data} } } else { path, err := filepath.Abs(src) @@ -342,7 +346,7 @@ func dnd_main(cmd *cli.Command, opts *Options, args []string) (rc int, err error return 1, err } if len(data) > 0 { - drag_sources["text/plain"] = drag_source{human_name: "STDIN", mime_type: "text/plain", data: data} + drag_sources["text/plain"] = &drag_source{human_name: "STDIN", mime_type: "text/plain", data: data} } } var uri_list []uri_list_item @@ -372,7 +376,7 @@ func dnd_main(cmd *cli.Command, opts *Options, args []string) (rc int, err error uris[i] = u.uri } payload := strings.Join(uris, "\r\n") + "\r\n" - drag_sources["text/uri-list"] = drag_source{ + drag_sources["text/uri-list"] = &drag_source{ human_name: "Files", mime_type: "text/uri-list", uri_list: uri_list, data: utils.UnsafeStringToBytes(payload), } } diff --git a/tools/utils/streaming_base64/api.go b/tools/utils/streaming_base64/api.go index 6189dc00f..371c7345b 100644 --- a/tools/utils/streaming_base64/api.go +++ b/tools/utils/streaming_base64/api.go @@ -192,16 +192,16 @@ func (s *StreamingBase64Encoder) Encode(input []byte, output []byte) iter.Seq2[[ // Finish encoding the stream. Resets the encoder. Returned slice can be nil // if no leftover bytes are present. -func (s *StreamingBase64Encoder) Finish() ([]byte, error) { +func (s *StreamingBase64Encoder) Finish() []byte { defer func() { s.num_leftover = 0 s.total_read = 0 }() if s.num_leftover == 0 { - return nil, nil + return nil } encodedLen := base64.RawStdEncoding.EncodedLen(s.num_leftover) - output := make([]byte, encodedLen) - base64.RawStdEncoding.Encode(output, s.leftover[:s.num_leftover]) - return output, nil + output := [4]byte{} + base64.RawStdEncoding.Encode(output[:encodedLen], s.leftover[:s.num_leftover]) + return output[:encodedLen] } diff --git a/tools/utils/streaming_base64/api_test.go b/tools/utils/streaming_base64/api_test.go index cf8c68587..70c750c4d 100644 --- a/tools/utils/streaming_base64/api_test.go +++ b/tools/utils/streaming_base64/api_test.go @@ -403,10 +403,7 @@ func encodeRoundtrip(t *testing.T, plaintext []byte, chunkSize int) { if err != nil { t.Fatalf("chunkSize=%d: unexpected Encode error: %v", chunkSize, err) } - tail, err := e.Finish() - if err != nil { - t.Fatalf("chunkSize=%d: unexpected Finish error: %v", chunkSize, err) - } + tail := e.Finish() got = append(got, tail...) want := []byte(base64.RawStdEncoding.EncodeToString(plaintext)) @@ -445,9 +442,9 @@ func TestEncoderFinishLeftover(t *testing.T) { // num_leftover=0: empty encoder, Finish must return (nil, nil). t.Run("leftover=0", func(t *testing.T) { var e StreamingBase64Encoder - out, err := e.Finish() - if err != nil || out != nil { - t.Fatalf("expected (nil,nil), got (%v,%v)", out, err) + out := e.Finish() + if out != nil { + t.Fatalf("expected (nil,nil), got (%v)", out) } }) @@ -461,9 +458,9 @@ func TestEncoderFinishLeftover(t *testing.T) { t.Fatal(err) } } - out, err := e.Finish() - if err != nil || out != nil { - t.Fatalf("expected (nil,nil), got (%v,%v)", out, err) + out := e.Finish() + if out != nil { + t.Fatalf("expected (nil,nil), got (%v)", out) } }) @@ -477,10 +474,7 @@ func TestEncoderFinishLeftover(t *testing.T) { t.Fatalf("unexpected Encode error: %v", err) } } - tail, err := e.Finish() - if err != nil { - t.Fatalf("unexpected Finish error: %v", err) - } + tail := e.Finish() want := []byte(base64.RawStdEncoding.EncodeToString([]byte("d"))) if !bytes.Equal(tail, want) { t.Fatalf("want %q, got %q", want, tail) @@ -497,10 +491,7 @@ func TestEncoderFinishLeftover(t *testing.T) { t.Fatalf("unexpected Encode error: %v", err) } } - tail, err := e.Finish() - if err != nil { - t.Fatalf("unexpected Finish error: %v", err) - } + tail := e.Finish() want := []byte(base64.RawStdEncoding.EncodeToString([]byte("de"))) if !bytes.Equal(tail, want) { t.Fatalf("want %q, got %q", want, tail) @@ -536,9 +527,9 @@ func TestEncoderEmptyInput(t *testing.T) { t.Fatalf("unexpected error on empty input: %v", err) } } - out, err := e.Finish() - if err != nil || out != nil { - t.Fatalf("expected (nil,nil) for empty input, got (%v,%v)", out, err) + out := e.Finish() + if out != nil { + t.Fatalf("expected (nil,nil) for empty input, got (%v)", out) } } @@ -576,10 +567,7 @@ func TestEncoderNumLeftoverInEncode(t *testing.T) { t.Fatalf("firstCallLen=%d rest Encode error: %v", firstCallLen, err) } got = append(got, rest...) - tail, err := e.Finish() - if err != nil { - t.Fatalf("firstCallLen=%d Finish error: %v", firstCallLen, err) - } + tail := e.Finish() got = append(got, tail...) want := []byte(base64.RawStdEncoding.EncodeToString(plain)) @@ -603,10 +591,7 @@ func TestEncoderFinishResetsState(t *testing.T) { } got1 = append(got1, enc...) } - tail1, err := e.Finish() - if err != nil { - t.Fatal(err) - } + tail1 := e.Finish() got1 = append(got1, tail1...) want1 := []byte(base64.RawStdEncoding.EncodeToString(plain1)) if !bytes.Equal(got1, want1) { @@ -622,10 +607,7 @@ func TestEncoderFinishResetsState(t *testing.T) { } got2 = append(got2, enc...) } - tail2, err := e.Finish() - if err != nil { - t.Fatal(err) - } + tail2 := e.Finish() got2 = append(got2, tail2...) want2 := []byte(base64.RawStdEncoding.EncodeToString(plain2)) if !bytes.Equal(got2, want2) { @@ -684,10 +666,7 @@ func TestEncoderDecoderRoundtrip(t *testing.T) { if err != nil { t.Fatalf("Encode error: %v", err) } - tail, err := e.Finish() - if err != nil { - t.Fatalf("Encode Finish error: %v", err) - } + tail := e.Finish() encoded = append(encoded, tail...) // Decode using RawStdEncoding (no padding)