From 1b90c033045f0066ce903215b1e3990db1e9ef54 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Fri, 26 Aug 2022 12:47:58 +0530 Subject: [PATCH] Refactor loop code into its own package --- tools/cmd/at/tty_io.go | 37 +- tools/tui/loop.go | 789 ------------------------- tools/tui/loop/api.go | 184 ++++++ tools/tui/{ => loop}/key-encoding.go | 2 +- tools/tui/loop/read.go | 99 ++++ tools/tui/loop/run.go | 298 ++++++++++ tools/tui/{ => loop}/terminal-state.go | 2 +- tools/tui/loop/timers.go | 44 ++ tools/tui/loop/write.go | 211 +++++++ tools/tui/password.go | 33 +- tools/{tui => utils}/select.go | 40 +- 11 files changed, 893 insertions(+), 846 deletions(-) delete mode 100644 tools/tui/loop.go create mode 100644 tools/tui/loop/api.go rename tools/tui/{ => loop}/key-encoding.go (99%) create mode 100644 tools/tui/loop/read.go create mode 100644 tools/tui/loop/run.go rename tools/tui/{ => loop}/terminal-state.go (99%) create mode 100644 tools/tui/loop/timers.go create mode 100644 tools/tui/loop/write.go rename tools/{tui => utils}/select.go (60%) diff --git a/tools/cmd/at/tty_io.go b/tools/cmd/at/tty_io.go index f313f1a09..c8ac37390 100644 --- a/tools/cmd/at/tty_io.go +++ b/tools/cmd/at/tty_io.go @@ -3,58 +3,59 @@ package at import ( - "kitty/tools/tui" "os" "time" + + "kitty/tools/tui/loop" ) func do_chunked_io(io_data *rc_io_data) (serialized_response []byte, err error) { serialized_response = make([]byte, 0) - loop, err := tui.CreateLoop() - loop.NoAlternateScreen() + lp, err := loop.New() + lp.NoAlternateScreen() if err != nil { return } var last_received_data_at time.Time - var final_write_id tui.IdType - var check_for_timeout func(loop *tui.Loop, timer_id tui.IdType) error + var final_write_id loop.IdType + var check_for_timeout func(timer_id loop.IdType) error - check_for_timeout = func(loop *tui.Loop, timer_id tui.IdType) error { + check_for_timeout = func(timer_id loop.IdType) error { time_since_last_received_data := time.Now().Sub(last_received_data_at) if time_since_last_received_data >= io_data.timeout { return os.ErrDeadlineExceeded } - loop.AddTimer(io_data.timeout-time_since_last_received_data, false, check_for_timeout) + lp.AddTimer(io_data.timeout-time_since_last_received_data, false, check_for_timeout) return nil } transition_to_read := func() { if io_data.rc.NoResponse { - loop.Quit(0) + lp.Quit(0) } last_received_data_at = time.Now() - loop.AddTimer(io_data.timeout, false, check_for_timeout) + lp.AddTimer(io_data.timeout, false, check_for_timeout) } - loop.OnReceivedData = func(loop *tui.Loop, data []byte) error { + lp.OnReceivedData = func(data []byte) error { last_received_data_at = time.Now() return nil } - loop.OnInitialize = func(loop *tui.Loop) (string, error) { + lp.OnInitialize = func() (string, error) { chunk, err := io_data.next_chunk(true) if err != nil { return "", err } - write_id := loop.QueueWriteBytesDangerous(chunk) + write_id := lp.QueueWriteBytesDangerous(chunk) if len(chunk) == 0 { final_write_id = write_id } return "", nil } - loop.OnWriteComplete = func(loop *tui.Loop, completed_write_id tui.IdType) error { + lp.OnWriteComplete = func(completed_write_id loop.IdType) error { if completed_write_id == final_write_id { transition_to_read() return nil @@ -63,22 +64,22 @@ func do_chunked_io(io_data *rc_io_data) (serialized_response []byte, err error) if err != nil { return err } - write_id := loop.QueueWriteBytesDangerous(chunk) + write_id := lp.QueueWriteBytesDangerous(chunk) if len(chunk) == 0 { final_write_id = write_id } return nil } - loop.OnRCResponse = func(loop *tui.Loop, raw []byte) error { + lp.OnRCResponse = func(raw []byte) error { serialized_response = raw - loop.Quit(0) + lp.Quit(0) return nil } - err = loop.Run() + err = lp.Run() if err == nil { - loop.KillIfSignalled() + lp.KillIfSignalled() } return diff --git a/tools/tui/loop.go b/tools/tui/loop.go deleted file mode 100644 index 77f5657ac..000000000 --- a/tools/tui/loop.go +++ /dev/null @@ -1,789 +0,0 @@ -// License: GPLv3 Copyright: 2022, Kovid Goyal, - -package tui - -import ( - "bytes" - "errors" - "fmt" - "io" - "kitty/tools/tty" - "os" - "os/signal" - "runtime/debug" - "sort" - "time" - - "golang.org/x/sys/unix" - - "kitty/tools/utils" - "kitty/tools/wcswidth" -) - -func read_ignoring_temporary_errors(f *tty.Term, buf []byte) (int, error) { - n, err := f.Read(buf) - if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { - return 0, nil - } - if n == 0 { - return 0, io.EOF - } - return n, err -} - -func is_temporary_error(err error) bool { - return errors.Is(err, unix.EINTR) || errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, io.ErrShortWrite) -} - -func write_ignoring_temporary_errors(f *tty.Term, buf []byte) (int, error) { - n, err := f.Write(buf) - if err != nil { - if is_temporary_error(err) { - err = nil - } - return n, err - } - if n == 0 { - return 0, io.EOF - } - return n, err -} - -func writestring_ignoring_temporary_errors(f *tty.Term, buf string) (int, error) { - n, err := f.WriteString(buf) - if err != nil { - if is_temporary_error(err) { - err = nil - } - return n, err - } - if n == 0 { - return 0, io.EOF - } - return n, err -} - -type ScreenSize struct { - WidthCells, HeightCells, WidthPx, HeightPx, CellWidth, CellHeight uint - updated bool -} - -type IdType uint64 -type TimerCallback func(loop *Loop, timer_id IdType) error - -type timer struct { - interval time.Duration - deadline time.Time - repeats bool - id IdType - callback TimerCallback -} - -func (self *timer) update_deadline(now time.Time) { - self.deadline = now.Add(self.interval) -} - -var SIGNULL unix.Signal - -type write_msg struct { - id IdType - bytes []byte - str string -} - -func (self *write_msg) String() string { - return fmt.Sprintf("write_msg{%v %#v %#v}", self.id, string(self.bytes), self.str) -} - -type Loop struct { - controlling_term *tty.Term - terminal_options TerminalStateOptions - screen_size ScreenSize - escape_code_parser wcswidth.EscapeCodeParser - keep_going bool - death_signal unix.Signal - exit_code int - timers []*timer - timer_id_counter, write_msg_id_counter IdType - tty_read_channel chan []byte - tty_write_channel chan *write_msg - write_done_channel chan IdType - err_channel chan error - tty_writing_done_channel, tty_reading_done_channel, wakeup_channel chan byte - pending_writes []*write_msg - - // Callbacks - - // Called when the terminal has been fully setup. Any string returned is sent to - // the terminal on shutdown - OnInitialize func(loop *Loop) (string, error) - - // Called when a key event happens - OnKeyEvent func(loop *Loop, event *KeyEvent) error - - // Called when text is received either from a key event or directly from the terminal - OnText func(loop *Loop, text string, from_key_event bool, in_bracketed_paste bool) error - - // Called when the terminal is resize - OnResize func(loop *Loop, old_size ScreenSize, new_size ScreenSize) error - - // Called when writing is done - OnWriteComplete func(loop *Loop, msg_id IdType) error - - // Called when a response to an rc command is received - OnRCResponse func(loop *Loop, data []byte) error - - // Called when any input form tty is received - OnReceivedData func(loop *Loop, data []byte) error -} - -func (self *Loop) update_screen_size() error { - if self.controlling_term != nil { - return fmt.Errorf("No controlling terminal cannot update screen size") - } - ws, err := self.controlling_term.GetSize() - if err != nil { - return err - } - s := &self.screen_size - s.updated = true - s.HeightCells, s.WidthCells = uint(ws.Row), uint(ws.Col) - s.HeightPx, s.WidthPx = uint(ws.Ypixel), uint(ws.Xpixel) - s.CellWidth = s.WidthPx / s.WidthCells - s.CellHeight = s.HeightPx / s.HeightCells - return nil -} - -func (self *Loop) handle_csi(raw []byte) error { - csi := string(raw) - ke := KeyEventFromCSI(csi) - if ke != nil { - return self.handle_key_event(ke) - } - return nil -} - -func (self *Loop) handle_key_event(ev *KeyEvent) error { - // self.DebugPrintln(ev) - if self.OnKeyEvent != nil { - err := self.OnKeyEvent(self, ev) - if err != nil { - return err - } - if ev.Handled { - return nil - } - } - if ev.MatchesPressOrRepeat("ctrl+c") { - ev.Handled = true - return self.on_SIGINT() - } - if ev.MatchesPressOrRepeat("ctrl+z") { - ev.Handled = true - return self.on_SIGTSTP() - } - if ev.Text != "" && self.OnText != nil { - return self.OnText(self, ev.Text, true, false) - } - return nil -} - -func (self *Loop) handle_osc(raw []byte) error { - return nil -} - -func (self *Loop) handle_dcs(raw []byte) error { - if self.OnRCResponse != nil && bytes.HasPrefix(raw, []byte("@kitty-cmd")) { - return self.OnRCResponse(self, raw[len("@kitty-cmd"):]) - } - return nil -} - -func (self *Loop) handle_apc(raw []byte) error { - return nil -} - -func (self *Loop) handle_sos(raw []byte) error { - return nil -} - -func (self *Loop) handle_pm(raw []byte) error { - return nil -} - -func (self *Loop) handle_rune(raw rune) error { - if self.OnText != nil { - return self.OnText(self, string(raw), false, self.escape_code_parser.InBracketedPaste()) - } - return nil -} - -func (self *Loop) on_signal(s unix.Signal) error { - switch s { - case unix.SIGINT: - return self.on_SIGINT() - case unix.SIGPIPE: - return self.on_SIGPIPE() - case unix.SIGWINCH: - return self.on_SIGWINCH() - case unix.SIGTERM: - return self.on_SIGTERM() - case unix.SIGTSTP: - return self.on_SIGTSTP() - case unix.SIGHUP: - return self.on_SIGHUP() - default: - return nil - } -} - -func (self *Loop) on_SIGINT() error { - self.death_signal = unix.SIGINT - self.keep_going = false - return nil -} - -func (self *Loop) on_SIGPIPE() error { - return nil -} - -func (self *Loop) on_SIGWINCH() error { - self.screen_size.updated = false - if self.OnResize != nil { - old_size := self.screen_size - err := self.update_screen_size() - if err != nil { - return err - } - return self.OnResize(self, old_size, self.screen_size) - } - return nil -} - -func (self *Loop) on_SIGTERM() error { - self.death_signal = unix.SIGTERM - self.keep_going = false - return nil -} - -func (self *Loop) on_SIGTSTP() error { - return nil -} - -func (self *Loop) on_SIGHUP() error { - self.death_signal = unix.SIGHUP - self.keep_going = false - return nil -} - -func CreateLoop() (*Loop, error) { - l := Loop{controlling_term: nil, timers: make([]*timer, 0)} - l.terminal_options.alternate_screen = true - l.escape_code_parser.HandleCSI = l.handle_csi - l.escape_code_parser.HandleOSC = l.handle_osc - l.escape_code_parser.HandleDCS = l.handle_dcs - l.escape_code_parser.HandleAPC = l.handle_apc - l.escape_code_parser.HandleSOS = l.handle_sos - l.escape_code_parser.HandlePM = l.handle_pm - l.escape_code_parser.HandleRune = l.handle_rune - return &l, nil -} - -func (self *Loop) AddTimer(interval time.Duration, repeats bool, callback TimerCallback) IdType { - self.timer_id_counter++ - t := timer{interval: interval, repeats: repeats, callback: callback, id: self.timer_id_counter} - t.update_deadline(time.Now()) - self.timers = append(self.timers, &t) - self.sort_timers() - return t.id -} - -func (self *Loop) RemoveTimer(id IdType) bool { - for i := 0; i < len(self.timers); i++ { - if self.timers[i].id == id { - self.timers = append(self.timers[:i], self.timers[i+1:]...) - return true - } - } - return false -} - -func (self *Loop) NoAlternateScreen() { - self.terminal_options.alternate_screen = false -} - -func (self *Loop) MouseTracking(mt MouseTracking) { - self.terminal_options.mouse_tracking = mt -} - -func (self *Loop) DeathSignalName() string { - if self.death_signal != SIGNULL { - return self.death_signal.String() - } - return "" -} - -func (self *Loop) ScreenSize() (ScreenSize, error) { - if self.screen_size.updated { - return self.screen_size, nil - } - err := self.update_screen_size() - return self.screen_size, err -} - -func kill_self(sig unix.Signal) { - unix.Kill(os.Getpid(), sig) - // Give the signal time to be delivered - time.Sleep(20 * time.Millisecond) -} - -func (self *Loop) KillIfSignalled() { - if self.death_signal != SIGNULL { - kill_self(self.death_signal) - } -} - -func (self *Loop) DebugPrintln(args ...interface{}) { - if self.controlling_term != nil { - self.controlling_term.DebugPrintln(args...) - } -} - -func (self *Loop) Run() (err error) { - sigchnl := make(chan os.Signal, 256) - handled_signals := []os.Signal{unix.SIGINT, unix.SIGTERM, unix.SIGTSTP, unix.SIGHUP, unix.SIGWINCH, unix.SIGPIPE} - signal.Notify(sigchnl, handled_signals...) - defer signal.Reset(handled_signals...) - - controlling_term, err := tty.OpenControllingTerm() - if err != nil { - return err - } - self.controlling_term = controlling_term - defer func() { - self.controlling_term.RestoreAndClose() - self.controlling_term = nil - }() - err = self.controlling_term.ApplyOperations(tty.TCSANOW, tty.SetRaw) - if err != nil { - return nil - } - - self.keep_going = true - self.tty_read_channel = make(chan []byte) - self.tty_write_channel = make(chan *write_msg, 1) // buffered so there is no race between initial queueing and startup of writer thread - self.write_done_channel = make(chan IdType) - self.tty_writing_done_channel = make(chan byte) - self.tty_reading_done_channel = make(chan byte) - self.wakeup_channel = make(chan byte, 256) - self.pending_writes = make([]*write_msg, 0, 256) - self.err_channel = make(chan error, 8) - self.death_signal = SIGNULL - self.escape_code_parser.Reset() - self.exit_code = 0 - no_timeout_channel := make(<-chan time.Time) - finalizer := "" - - w_r, w_w, err := os.Pipe() - var r_r, r_w *os.File - if err == nil { - r_r, r_w, err = os.Pipe() - if err != nil { - w_r.Close() - w_w.Close() - return err - } - } else { - return err - } - self.QueueWriteBytesDangerous(self.terminal_options.SetStateEscapeCodes()) - - defer func() { - // notify tty reader that we are shutting down - r_w.Close() - close(self.tty_reading_done_channel) - - if finalizer != "" { - self.QueueWriteString(finalizer) - } - self.QueueWriteBytesDangerous(self.terminal_options.ResetStateEscapeCodes()) - // flush queued data and wait for it to be written for a timeout, then wait for writer to shutdown - flush_writer(w_w, self.tty_write_channel, self.tty_writing_done_channel, self.pending_writes, 2*time.Second) - self.pending_writes = nil - // wait for tty reader to exit cleanly - for more := true; more; _, more = <-self.tty_read_channel { - } - }() - - go write_to_tty(w_r, self.controlling_term, self.tty_write_channel, self.err_channel, self.write_done_channel, self.tty_writing_done_channel) - go read_from_tty(r_r, self.controlling_term, self.tty_read_channel, self.err_channel, self.tty_reading_done_channel) - - if self.OnInitialize != nil { - finalizer, err = self.OnInitialize(self) - if err != nil { - return err - } - } - - for self.keep_going { - self.queue_write_to_tty(nil) - timeout_chan := no_timeout_channel - if len(self.timers) > 0 { - now := time.Now() - err = self.dispatch_timers(now) - if err != nil { - return err - } - timeout := self.timers[0].deadline.Sub(now) - if timeout < 0 { - timeout = 0 - } - timeout_chan = time.After(timeout) - } - select { - case <-timeout_chan: - case <-self.wakeup_channel: - for len(self.wakeup_channel) > 0 { - <-self.wakeup_channel - } - case msg_id := <-self.write_done_channel: - self.queue_write_to_tty(nil) - if self.OnWriteComplete != nil { - err = self.OnWriteComplete(self, msg_id) - if err != nil { - return err - } - } - case s := <-sigchnl: - err = self.on_signal(s.(unix.Signal)) - if err != nil { - return err - } - case input_data, more := <-self.tty_read_channel: - if !more { - return io.EOF - } - err := self.dispatch_input_data(input_data) - if err != nil { - return err - } - - } - } - - return nil -} - -func (self *Loop) dispatch_input_data(data []byte) error { - if self.OnReceivedData != nil { - err := self.OnReceivedData(self, data) - if err != nil { - return err - } - } - err := self.escape_code_parser.Parse(data) - if err != nil { - return err - } - return nil -} - -func (self *Loop) print_stack() { - self.DebugPrintln(string(debug.Stack())) -} - -func (self *Loop) queue_write_to_tty(data *write_msg) { - for len(self.pending_writes) > 0 { - select { - case self.tty_write_channel <- self.pending_writes[0]: - n := copy(self.pending_writes, self.pending_writes[1:]) - self.pending_writes = self.pending_writes[:n] - default: - if data != nil { - self.pending_writes = append(self.pending_writes, data) - } - return - } - } - if data != nil { - select { - case self.tty_write_channel <- data: - default: - self.pending_writes = append(self.pending_writes, data) - } - } -} - -func (self *Loop) WakeupMainThread() { - self.wakeup_channel <- 1 -} - -func (self *Loop) QueueWriteString(data string) IdType { - self.write_msg_id_counter++ - msg := write_msg{str: data, id: self.write_msg_id_counter} - self.queue_write_to_tty(&msg) - return msg.id -} - -// This is dangerous as it is upto the calling code -// to ensure the data in the underlying array does not change -func (self *Loop) QueueWriteBytesDangerous(data []byte) IdType { - self.write_msg_id_counter++ - msg := write_msg{bytes: data, id: self.write_msg_id_counter} - self.queue_write_to_tty(&msg) - return msg.id -} - -func (self *Loop) QueueWriteBytesCopy(data []byte) IdType { - d := make([]byte, len(data)) - copy(d, data) - return self.QueueWriteBytesDangerous(d) -} - -func (self *Loop) ExitCode() int { - return self.exit_code -} - -func (self *Loop) Beep() { - self.QueueWriteString("\a") -} - -func (self *Loop) Quit(exit_code int) { - self.exit_code = exit_code - self.keep_going = false -} - -func read_from_tty(pipe_r *os.File, term *tty.Term, results_channel chan<- []byte, err_channel chan<- error, quit_channel <-chan byte) { - keep_going := true - pipe_fd := int(pipe_r.Fd()) - tty_fd := term.Fd() - selector := CreateSelect(2) - selector.RegisterRead(pipe_fd) - selector.RegisterRead(tty_fd) - - defer func() { - close(results_channel) - pipe_r.Close() - }() - - const bufsize = 2 * utils.DEFAULT_IO_BUFFER_SIZE - - wait_for_read_available := func() { - _, err := selector.WaitForever() - if err != nil { - err_channel <- err - keep_going = false - return - } - if selector.IsReadyToRead(pipe_fd) { - keep_going = false - return - } - if selector.IsReadyToRead(tty_fd) { - return - } - } - - buf := make([]byte, bufsize) - for keep_going { - if len(buf) == 0 { - buf = make([]byte, bufsize) - } - if wait_for_read_available(); !keep_going { - break - } - n, err := read_ignoring_temporary_errors(term, buf) - if err != nil { - err_channel <- err - keep_going = false - break - } - if n == 0 { - err_channel <- io.EOF - keep_going = false - break - } - send := buf[:n] - buf = buf[n:] - select { - case results_channel <- send: - case <-quit_channel: - keep_going = false - break - } - } -} - -type write_dispatcher struct { - str string - bytes []byte - is_string bool - is_empty bool -} - -func create_write_dispatcher(msg *write_msg) *write_dispatcher { - self := write_dispatcher{str: msg.str, bytes: msg.bytes, is_string: msg.bytes == nil} - if self.is_string { - self.is_empty = self.str == "" - } else { - self.is_empty = len(self.bytes) == 0 - } - return &self -} - -func (self *write_dispatcher) write(f *tty.Term) (int, error) { - if self.is_string { - return writestring_ignoring_temporary_errors(f, self.str) - } - return write_ignoring_temporary_errors(f, self.bytes) -} - -func (self *write_dispatcher) slice(n int) { - if self.is_string { - self.str = self.str[n:] - self.is_empty = self.str == "" - } else { - self.bytes = self.bytes[n:] - self.is_empty = len(self.bytes) == 0 - } -} - -func write_to_tty( - pipe_r *os.File, term *tty.Term, - job_channel <-chan *write_msg, err_channel chan<- error, write_done_channel chan<- IdType, completed_channel chan<- byte, -) { - keep_going := true - defer func() { - pipe_r.Close() - close(completed_channel) - }() - selector := CreateSelect(2) - pipe_fd := int(pipe_r.Fd()) - tty_fd := term.Fd() - selector.RegisterRead(pipe_fd) - selector.RegisterWrite(tty_fd) - - wait_for_write_available := func() { - _, err := selector.WaitForever() - if err != nil { - err_channel <- err - keep_going = false - return - } - if selector.IsReadyToWrite(tty_fd) { - return - } - if selector.IsReadyToRead(pipe_fd) { - keep_going = false - } - } - - write_data := func(msg *write_msg) { - data := create_write_dispatcher(msg) - for !data.is_empty { - wait_for_write_available() - if !keep_going { - return - } - n, err := data.write(term) - if err != nil { - err_channel <- err - keep_going = false - return - } - if n > 0 { - data.slice(n) - } - } - } - - for { - data, more := <-job_channel - if !more { - keep_going = false - break - } - write_data(data) - if keep_going { - write_done_channel <- data.id - } else { - break - } - } -} - -func flush_writer(pipe_w *os.File, tty_write_channel chan<- *write_msg, tty_writing_done_channel <-chan byte, pending_writes []*write_msg, timeout time.Duration) { - writer_quit := false - defer func() { - if tty_write_channel != nil { - close(tty_write_channel) - tty_write_channel = nil - } - pipe_w.Close() - if !writer_quit { - <-tty_writing_done_channel - writer_quit = true - } - }() - deadline := time.Now().Add(timeout) - for len(pending_writes) > 0 { - timeout = deadline.Sub(time.Now()) - if timeout <= 0 { - return - } - select { - case <-time.After(timeout): - return - case tty_write_channel <- pending_writes[0]: - pending_writes = pending_writes[1:] - } - } - close(tty_write_channel) - tty_write_channel = nil - timeout = deadline.Sub(time.Now()) - if timeout <= 0 { - return - } - select { - case <-tty_writing_done_channel: - writer_quit = true - case <-time.After(timeout): - } - return -} - -func (self *Loop) dispatch_timers(now time.Time) error { - updated := false - remove := make(map[IdType]bool, 0) - for _, t := range self.timers { - if now.After(t.deadline) { - err := t.callback(self, t.id) - if err != nil { - return err - } - if t.repeats { - t.update_deadline(now) - updated = true - } else { - remove[t.id] = true - } - } - } - if len(remove) > 0 { - timers := make([]*timer, len(self.timers)-len(remove)) - for _, t := range self.timers { - if !remove[t.id] { - timers = append(timers, t) - } - } - self.timers = timers - } - if updated { - self.sort_timers() - } - return nil -} - -func (self *Loop) sort_timers() { - sort.SliceStable(self.timers, func(a, b int) bool { return self.timers[a].deadline.Before(self.timers[b].deadline) }) -} diff --git a/tools/tui/loop/api.go b/tools/tui/loop/api.go new file mode 100644 index 000000000..5ef3b2c4a --- /dev/null +++ b/tools/tui/loop/api.go @@ -0,0 +1,184 @@ +// License: GPLv3 Copyright: 2022, Kovid Goyal, + +package loop + +import ( + "kitty/tools/tty" + "time" + + "golang.org/x/sys/unix" + + "kitty/tools/wcswidth" +) + +type ScreenSize struct { + WidthCells, HeightCells, WidthPx, HeightPx, CellWidth, CellHeight uint + updated bool +} + +type IdType uint64 +type TimerCallback func(timer_id IdType) error + +type timer struct { + interval time.Duration + deadline time.Time + repeats bool + id IdType + callback TimerCallback +} + +func (self *timer) update_deadline(now time.Time) { + self.deadline = now.Add(self.interval) +} + +type Loop struct { + controlling_term *tty.Term + terminal_options TerminalStateOptions + screen_size ScreenSize + escape_code_parser wcswidth.EscapeCodeParser + keep_going bool + death_signal unix.Signal + exit_code int + timers []*timer + timer_id_counter, write_msg_id_counter IdType + tty_read_channel chan []byte + tty_write_channel chan *write_msg + write_done_channel chan IdType + err_channel chan error + tty_writing_done_channel, tty_reading_done_channel, wakeup_channel chan byte + pending_writes []*write_msg + + // Callbacks + + // Called when the terminal has been fully setup. Any string returned is sent to + // the terminal on shutdown + OnInitialize func() (string, error) + + // Called when a key event happens + OnKeyEvent func(event *KeyEvent) error + + // Called when text is received either from a key event or directly from the terminal + OnText func(text string, from_key_event bool, in_bracketed_paste bool) error + + // Called when the terminal is resize + OnResize func(old_size ScreenSize, new_size ScreenSize) error + + // Called when writing is done + OnWriteComplete func(msg_id IdType) error + + // Called when a response to an rc command is received + OnRCResponse func(data []byte) error + + // Called when any input form tty is received + OnReceivedData func(data []byte) error +} + +func New() (*Loop, error) { + l := Loop{controlling_term: nil, timers: make([]*timer, 0)} + l.terminal_options.alternate_screen = true + l.escape_code_parser.HandleCSI = l.handle_csi + l.escape_code_parser.HandleOSC = l.handle_osc + l.escape_code_parser.HandleDCS = l.handle_dcs + l.escape_code_parser.HandleAPC = l.handle_apc + l.escape_code_parser.HandleSOS = l.handle_sos + l.escape_code_parser.HandlePM = l.handle_pm + l.escape_code_parser.HandleRune = l.handle_rune + return &l, nil +} + +func (self *Loop) AddTimer(interval time.Duration, repeats bool, callback TimerCallback) IdType { + self.timer_id_counter++ + t := timer{interval: interval, repeats: repeats, callback: callback, id: self.timer_id_counter} + t.update_deadline(time.Now()) + self.timers = append(self.timers, &t) + self.sort_timers() + return t.id +} + +func (self *Loop) RemoveTimer(id IdType) bool { + for i := 0; i < len(self.timers); i++ { + if self.timers[i].id == id { + self.timers = append(self.timers[:i], self.timers[i+1:]...) + return true + } + } + return false +} + +func (self *Loop) NoAlternateScreen() { + self.terminal_options.alternate_screen = false +} + +func (self *Loop) MouseTracking(mt MouseTracking) { + self.terminal_options.mouse_tracking = mt +} + +func (self *Loop) DeathSignalName() string { + if self.death_signal != SIGNULL { + return self.death_signal.String() + } + return "" +} + +func (self *Loop) ScreenSize() (ScreenSize, error) { + if self.screen_size.updated { + return self.screen_size, nil + } + err := self.update_screen_size() + return self.screen_size, err +} + +func (self *Loop) KillIfSignalled() { + if self.death_signal != SIGNULL { + kill_self(self.death_signal) + } +} + +func (self *Loop) DebugPrintln(args ...interface{}) { + if self.controlling_term != nil { + self.controlling_term.DebugPrintln(args...) + } +} + +func (self *Loop) Run() (err error) { + return self.run() +} + +func (self *Loop) WakeupMainThread() { + self.wakeup_channel <- 1 +} + +func (self *Loop) QueueWriteString(data string) IdType { + self.write_msg_id_counter++ + msg := write_msg{str: data, id: self.write_msg_id_counter} + self.queue_write_to_tty(&msg) + return msg.id +} + +// This is dangerous as it is upto the calling code +// to ensure the data in the underlying array does not change +func (self *Loop) QueueWriteBytesDangerous(data []byte) IdType { + self.write_msg_id_counter++ + msg := write_msg{bytes: data, id: self.write_msg_id_counter} + self.queue_write_to_tty(&msg) + return msg.id +} + +func (self *Loop) QueueWriteBytesCopy(data []byte) IdType { + d := make([]byte, len(data)) + copy(d, data) + return self.QueueWriteBytesDangerous(d) +} + +func (self *Loop) ExitCode() int { + return self.exit_code +} + +func (self *Loop) Beep() { + self.QueueWriteString("\a") +} + +func (self *Loop) Quit(exit_code int) { + self.exit_code = exit_code + self.keep_going = false +} diff --git a/tools/tui/key-encoding.go b/tools/tui/loop/key-encoding.go similarity index 99% rename from tools/tui/key-encoding.go rename to tools/tui/loop/key-encoding.go index 246ea7387..6afc44066 100644 --- a/tools/tui/key-encoding.go +++ b/tools/tui/loop/key-encoding.go @@ -1,6 +1,6 @@ // License: GPLv3 Copyright: 2022, Kovid Goyal, -package tui +package loop import ( "fmt" diff --git a/tools/tui/loop/read.go b/tools/tui/loop/read.go new file mode 100644 index 000000000..804f04f55 --- /dev/null +++ b/tools/tui/loop/read.go @@ -0,0 +1,99 @@ +// License: GPLv3 Copyright: 2022, Kovid Goyal, + +package loop + +import ( + "io" + "os" + + "golang.org/x/sys/unix" + + "kitty/tools/tty" + "kitty/tools/utils" +) + +func (self *Loop) dispatch_input_data(data []byte) error { + if self.OnReceivedData != nil { + err := self.OnReceivedData(data) + if err != nil { + return err + } + } + err := self.escape_code_parser.Parse(data) + if err != nil { + return err + } + return nil +} + +func read_ignoring_temporary_errors(f *tty.Term, buf []byte) (int, error) { + n, err := f.Read(buf) + if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { + return 0, nil + } + if n == 0 { + return 0, io.EOF + } + return n, err +} + +func read_from_tty(pipe_r *os.File, term *tty.Term, results_channel chan<- []byte, err_channel chan<- error, quit_channel <-chan byte) { + keep_going := true + pipe_fd := int(pipe_r.Fd()) + tty_fd := term.Fd() + selector := utils.CreateSelect(2) + selector.RegisterRead(pipe_fd) + selector.RegisterRead(tty_fd) + + defer func() { + close(results_channel) + pipe_r.Close() + }() + + const bufsize = 2 * utils.DEFAULT_IO_BUFFER_SIZE + + wait_for_read_available := func() { + _, err := selector.WaitForever() + if err != nil { + err_channel <- err + keep_going = false + return + } + if selector.IsReadyToRead(pipe_fd) { + keep_going = false + return + } + if selector.IsReadyToRead(tty_fd) { + return + } + } + + buf := make([]byte, bufsize) + for keep_going { + if len(buf) == 0 { + buf = make([]byte, bufsize) + } + if wait_for_read_available(); !keep_going { + break + } + n, err := read_ignoring_temporary_errors(term, buf) + if err != nil { + err_channel <- err + keep_going = false + break + } + if n == 0 { + err_channel <- io.EOF + keep_going = false + break + } + send := buf[:n] + buf = buf[n:] + select { + case results_channel <- send: + case <-quit_channel: + keep_going = false + break + } + } +} diff --git a/tools/tui/loop/run.go b/tools/tui/loop/run.go new file mode 100644 index 000000000..611db9c05 --- /dev/null +++ b/tools/tui/loop/run.go @@ -0,0 +1,298 @@ +// License: GPLv3 Copyright: 2022, Kovid Goyal, + +package loop + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "os/signal" + "runtime/debug" + "time" + + "golang.org/x/sys/unix" + + "kitty/tools/tty" +) + +var SIGNULL unix.Signal + +func is_temporary_error(err error) bool { + return errors.Is(err, unix.EINTR) || errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) || errors.Is(err, io.ErrShortWrite) +} + +func kill_self(sig unix.Signal) { + unix.Kill(os.Getpid(), sig) + // Give the signal time to be delivered + time.Sleep(20 * time.Millisecond) +} + +func (self *Loop) print_stack() { + self.DebugPrintln(string(debug.Stack())) +} + +func (self *Loop) update_screen_size() error { + if self.controlling_term != nil { + return fmt.Errorf("No controlling terminal cannot update screen size") + } + ws, err := self.controlling_term.GetSize() + if err != nil { + return err + } + s := &self.screen_size + s.updated = true + s.HeightCells, s.WidthCells = uint(ws.Row), uint(ws.Col) + s.HeightPx, s.WidthPx = uint(ws.Ypixel), uint(ws.Xpixel) + s.CellWidth = s.WidthPx / s.WidthCells + s.CellHeight = s.HeightPx / s.HeightCells + return nil +} + +func (self *Loop) handle_csi(raw []byte) error { + csi := string(raw) + ke := KeyEventFromCSI(csi) + if ke != nil { + return self.handle_key_event(ke) + } + return nil +} + +func (self *Loop) handle_key_event(ev *KeyEvent) error { + // self.DebugPrintln(ev) + if self.OnKeyEvent != nil { + err := self.OnKeyEvent(ev) + if err != nil { + return err + } + if ev.Handled { + return nil + } + } + if ev.MatchesPressOrRepeat("ctrl+c") { + ev.Handled = true + return self.on_SIGINT() + } + if ev.MatchesPressOrRepeat("ctrl+z") { + ev.Handled = true + return self.on_SIGTSTP() + } + if ev.Text != "" && self.OnText != nil { + return self.OnText(ev.Text, true, false) + } + return nil +} + +func (self *Loop) handle_osc(raw []byte) error { + return nil +} + +func (self *Loop) handle_dcs(raw []byte) error { + if self.OnRCResponse != nil && bytes.HasPrefix(raw, []byte("@kitty-cmd")) { + return self.OnRCResponse(raw[len("@kitty-cmd"):]) + } + return nil +} + +func (self *Loop) handle_apc(raw []byte) error { + return nil +} + +func (self *Loop) handle_sos(raw []byte) error { + return nil +} + +func (self *Loop) handle_pm(raw []byte) error { + return nil +} + +func (self *Loop) handle_rune(raw rune) error { + if self.OnText != nil { + return self.OnText(string(raw), false, self.escape_code_parser.InBracketedPaste()) + } + return nil +} + +func (self *Loop) on_signal(s unix.Signal) error { + switch s { + case unix.SIGINT: + return self.on_SIGINT() + case unix.SIGPIPE: + return self.on_SIGPIPE() + case unix.SIGWINCH: + return self.on_SIGWINCH() + case unix.SIGTERM: + return self.on_SIGTERM() + case unix.SIGTSTP: + return self.on_SIGTSTP() + case unix.SIGHUP: + return self.on_SIGHUP() + default: + return nil + } +} + +func (self *Loop) on_SIGINT() error { + self.death_signal = unix.SIGINT + self.keep_going = false + return nil +} + +func (self *Loop) on_SIGPIPE() error { + return nil +} + +func (self *Loop) on_SIGWINCH() error { + self.screen_size.updated = false + if self.OnResize != nil { + old_size := self.screen_size + err := self.update_screen_size() + if err != nil { + return err + } + return self.OnResize(old_size, self.screen_size) + } + return nil +} + +func (self *Loop) on_SIGTERM() error { + self.death_signal = unix.SIGTERM + self.keep_going = false + return nil +} + +func (self *Loop) on_SIGTSTP() error { + return nil +} + +func (self *Loop) on_SIGHUP() error { + self.death_signal = unix.SIGHUP + self.keep_going = false + return nil +} + +func (self *Loop) run() (err error) { + sigchnl := make(chan os.Signal, 256) + handled_signals := []os.Signal{unix.SIGINT, unix.SIGTERM, unix.SIGTSTP, unix.SIGHUP, unix.SIGWINCH, unix.SIGPIPE} + signal.Notify(sigchnl, handled_signals...) + defer signal.Reset(handled_signals...) + + controlling_term, err := tty.OpenControllingTerm() + if err != nil { + return err + } + self.controlling_term = controlling_term + defer func() { + self.controlling_term.RestoreAndClose() + self.controlling_term = nil + }() + err = self.controlling_term.ApplyOperations(tty.TCSANOW, tty.SetRaw) + if err != nil { + return nil + } + + self.keep_going = true + self.tty_read_channel = make(chan []byte) + self.tty_write_channel = make(chan *write_msg, 1) // buffered so there is no race between initial queueing and startup of writer thread + self.write_done_channel = make(chan IdType) + self.tty_writing_done_channel = make(chan byte) + self.tty_reading_done_channel = make(chan byte) + self.wakeup_channel = make(chan byte, 256) + self.pending_writes = make([]*write_msg, 0, 256) + self.err_channel = make(chan error, 8) + self.death_signal = SIGNULL + self.escape_code_parser.Reset() + self.exit_code = 0 + no_timeout_channel := make(<-chan time.Time) + finalizer := "" + + w_r, w_w, err := os.Pipe() + var r_r, r_w *os.File + if err == nil { + r_r, r_w, err = os.Pipe() + if err != nil { + w_r.Close() + w_w.Close() + return err + } + } else { + return err + } + self.QueueWriteBytesDangerous(self.terminal_options.SetStateEscapeCodes()) + + defer func() { + // notify tty reader that we are shutting down + r_w.Close() + close(self.tty_reading_done_channel) + + if finalizer != "" { + self.QueueWriteString(finalizer) + } + self.QueueWriteBytesDangerous(self.terminal_options.ResetStateEscapeCodes()) + // flush queued data and wait for it to be written for a timeout, then wait for writer to shutdown + flush_writer(w_w, self.tty_write_channel, self.tty_writing_done_channel, self.pending_writes, 2*time.Second) + self.pending_writes = nil + // wait for tty reader to exit cleanly + for more := true; more; _, more = <-self.tty_read_channel { + } + }() + + go write_to_tty(w_r, self.controlling_term, self.tty_write_channel, self.err_channel, self.write_done_channel, self.tty_writing_done_channel) + go read_from_tty(r_r, self.controlling_term, self.tty_read_channel, self.err_channel, self.tty_reading_done_channel) + + if self.OnInitialize != nil { + finalizer, err = self.OnInitialize() + if err != nil { + return err + } + } + + for self.keep_going { + self.queue_write_to_tty(nil) + timeout_chan := no_timeout_channel + if len(self.timers) > 0 { + now := time.Now() + err = self.dispatch_timers(now) + if err != nil { + return err + } + timeout := self.timers[0].deadline.Sub(now) + if timeout < 0 { + timeout = 0 + } + timeout_chan = time.After(timeout) + } + select { + case <-timeout_chan: + case <-self.wakeup_channel: + for len(self.wakeup_channel) > 0 { + <-self.wakeup_channel + } + case msg_id := <-self.write_done_channel: + self.queue_write_to_tty(nil) + if self.OnWriteComplete != nil { + err = self.OnWriteComplete(msg_id) + if err != nil { + return err + } + } + case s := <-sigchnl: + err = self.on_signal(s.(unix.Signal)) + if err != nil { + return err + } + case input_data, more := <-self.tty_read_channel: + if !more { + return io.EOF + } + err := self.dispatch_input_data(input_data) + if err != nil { + return err + } + + } + } + + return nil +} diff --git a/tools/tui/terminal-state.go b/tools/tui/loop/terminal-state.go similarity index 99% rename from tools/tui/terminal-state.go rename to tools/tui/loop/terminal-state.go index 03a842264..2ce05afef 100644 --- a/tools/tui/terminal-state.go +++ b/tools/tui/loop/terminal-state.go @@ -1,6 +1,6 @@ // License: GPLv3 Copyright: 2022, Kovid Goyal, -package tui +package loop import ( "fmt" diff --git a/tools/tui/loop/timers.go b/tools/tui/loop/timers.go new file mode 100644 index 000000000..449a42f36 --- /dev/null +++ b/tools/tui/loop/timers.go @@ -0,0 +1,44 @@ +// License: GPLv3 Copyright: 2022, Kovid Goyal, + +package loop + +import ( + "sort" + "time" +) + +func (self *Loop) dispatch_timers(now time.Time) error { + updated := false + remove := make(map[IdType]bool, 0) + for _, t := range self.timers { + if now.After(t.deadline) { + err := t.callback(t.id) + if err != nil { + return err + } + if t.repeats { + t.update_deadline(now) + updated = true + } else { + remove[t.id] = true + } + } + } + if len(remove) > 0 { + timers := make([]*timer, len(self.timers)-len(remove)) + for _, t := range self.timers { + if !remove[t.id] { + timers = append(timers, t) + } + } + self.timers = timers + } + if updated { + self.sort_timers() + } + return nil +} + +func (self *Loop) sort_timers() { + sort.SliceStable(self.timers, func(a, b int) bool { return self.timers[a].deadline.Before(self.timers[b].deadline) }) +} diff --git a/tools/tui/loop/write.go b/tools/tui/loop/write.go new file mode 100644 index 000000000..73321e8e9 --- /dev/null +++ b/tools/tui/loop/write.go @@ -0,0 +1,211 @@ +// License: GPLv3 Copyright: 2022, Kovid Goyal, + +package loop + +import ( + "fmt" + "io" + "os" + "time" + + "kitty/tools/tty" + "kitty/tools/utils" +) + +type write_msg struct { + id IdType + bytes []byte + str string +} + +func (self *write_msg) String() string { + return fmt.Sprintf("write_msg{%v %#v %#v}", self.id, string(self.bytes), self.str) +} + +type write_dispatcher struct { + str string + bytes []byte + is_string bool + is_empty bool +} + +func write_ignoring_temporary_errors(f *tty.Term, buf []byte) (int, error) { + n, err := f.Write(buf) + if err != nil { + if is_temporary_error(err) { + err = nil + } + return n, err + } + if n == 0 { + return 0, io.EOF + } + return n, err +} + +func writestring_ignoring_temporary_errors(f *tty.Term, buf string) (int, error) { + n, err := f.WriteString(buf) + if err != nil { + if is_temporary_error(err) { + err = nil + } + return n, err + } + if n == 0 { + return 0, io.EOF + } + return n, err +} + +func (self *Loop) queue_write_to_tty(data *write_msg) { + for len(self.pending_writes) > 0 { + select { + case self.tty_write_channel <- self.pending_writes[0]: + n := copy(self.pending_writes, self.pending_writes[1:]) + self.pending_writes = self.pending_writes[:n] + default: + if data != nil { + self.pending_writes = append(self.pending_writes, data) + } + return + } + } + if data != nil { + select { + case self.tty_write_channel <- data: + default: + self.pending_writes = append(self.pending_writes, data) + } + } +} + +func create_write_dispatcher(msg *write_msg) *write_dispatcher { + self := write_dispatcher{str: msg.str, bytes: msg.bytes, is_string: msg.bytes == nil} + if self.is_string { + self.is_empty = self.str == "" + } else { + self.is_empty = len(self.bytes) == 0 + } + return &self +} + +func (self *write_dispatcher) write(f *tty.Term) (int, error) { + if self.is_string { + return writestring_ignoring_temporary_errors(f, self.str) + } + return write_ignoring_temporary_errors(f, self.bytes) +} + +func (self *write_dispatcher) slice(n int) { + if self.is_string { + self.str = self.str[n:] + self.is_empty = self.str == "" + } else { + self.bytes = self.bytes[n:] + self.is_empty = len(self.bytes) == 0 + } +} + +func write_to_tty( + pipe_r *os.File, term *tty.Term, + job_channel <-chan *write_msg, err_channel chan<- error, write_done_channel chan<- IdType, completed_channel chan<- byte, +) { + keep_going := true + defer func() { + pipe_r.Close() + close(completed_channel) + }() + selector := utils.CreateSelect(2) + pipe_fd := int(pipe_r.Fd()) + tty_fd := term.Fd() + selector.RegisterRead(pipe_fd) + selector.RegisterWrite(tty_fd) + + wait_for_write_available := func() { + _, err := selector.WaitForever() + if err != nil { + err_channel <- err + keep_going = false + return + } + if selector.IsReadyToWrite(tty_fd) { + return + } + if selector.IsReadyToRead(pipe_fd) { + keep_going = false + } + } + + write_data := func(msg *write_msg) { + data := create_write_dispatcher(msg) + for !data.is_empty { + wait_for_write_available() + if !keep_going { + return + } + n, err := data.write(term) + if err != nil { + err_channel <- err + keep_going = false + return + } + if n > 0 { + data.slice(n) + } + } + } + + for { + data, more := <-job_channel + if !more { + keep_going = false + break + } + write_data(data) + if keep_going { + write_done_channel <- data.id + } else { + break + } + } +} + +func flush_writer(pipe_w *os.File, tty_write_channel chan<- *write_msg, tty_writing_done_channel <-chan byte, pending_writes []*write_msg, timeout time.Duration) { + writer_quit := false + defer func() { + if tty_write_channel != nil { + close(tty_write_channel) + tty_write_channel = nil + } + pipe_w.Close() + if !writer_quit { + <-tty_writing_done_channel + writer_quit = true + } + }() + deadline := time.Now().Add(timeout) + for len(pending_writes) > 0 { + timeout = deadline.Sub(time.Now()) + if timeout <= 0 { + return + } + select { + case <-time.After(timeout): + return + case tty_write_channel <- pending_writes[0]: + pending_writes = pending_writes[1:] + } + } + close(tty_write_channel) + tty_write_channel = nil + timeout = deadline.Sub(time.Now()) + if timeout <= 0 { + return + } + select { + case <-tty_writing_done_channel: + writer_quit = true + case <-time.After(timeout): + } + return +} diff --git a/tools/tui/password.go b/tools/tui/password.go index 9b0bad803..eb1da2e68 100644 --- a/tools/tui/password.go +++ b/tools/tui/password.go @@ -7,6 +7,7 @@ import ( "fmt" "strings" + "kitty/tools/tui/loop" "kitty/tools/wcswidth" ) @@ -20,31 +21,31 @@ func (self *KilledBySignal) Error() string { return self.Msg } var Canceled = errors.New("Canceled by user") func ReadPassword(prompt string, kill_if_signaled bool) (password string, err error) { - loop, err := CreateLoop() - loop.NoAlternateScreen() + lp, err := loop.New() + lp.NoAlternateScreen() shadow := "" if err != nil { return } - loop.OnInitialize = func(loop *Loop) (string, error) { - loop.QueueWriteString(prompt) + lp.OnInitialize = func() (string, error) { + lp.QueueWriteString(prompt) return "\r\n", nil } - loop.OnText = func(loop *Loop, text string, from_key_event bool, in_bracketed_paste bool) error { + lp.OnText = func(text string, from_key_event bool, in_bracketed_paste bool) error { old_width := wcswidth.Stringwidth(password) password += text new_width := wcswidth.Stringwidth(password) if new_width > old_width { extra := strings.Repeat("*", new_width-old_width) - loop.QueueWriteString(extra) + lp.QueueWriteString(extra) shadow += extra } return nil } - loop.OnKeyEvent = func(loop *Loop, event *KeyEvent) error { + lp.OnKeyEvent = func(event *loop.KeyEvent) error { if event.MatchesPressOrRepeat("backspace") || event.MatchesPressOrRepeat("delete") { event.Handled = true if len(password) > 0 { @@ -57,41 +58,41 @@ func ReadPassword(prompt string, kill_if_signaled bool) (password string, err er delta = len(shadow) } shadow = shadow[:len(shadow)-delta] - loop.QueueWriteString(strings.Repeat("\x08\x1b[P", delta)) + lp.QueueWriteString(strings.Repeat("\x08\x1b[P", delta)) } } else { - loop.Beep() + lp.Beep() } } if event.MatchesPressOrRepeat("enter") || event.MatchesPressOrRepeat("return") { event.Handled = true if password == "" { - loop.Quit(1) + lp.Quit(1) } else { - loop.Quit(0) + lp.Quit(0) } } if event.MatchesPressOrRepeat("esc") { event.Handled = true - loop.Quit(1) + lp.Quit(1) return Canceled } return nil } - err = loop.Run() + err = lp.Run() if err != nil { return } - ds := loop.DeathSignalName() + ds := lp.DeathSignalName() if ds != "" { if kill_if_signaled { - loop.KillIfSignalled() + lp.KillIfSignalled() return } return "", &KilledBySignal{Msg: fmt.Sprint("Killed by signal: ", ds), SignalName: ds} } - if loop.ExitCode() != 0 { + if lp.ExitCode() != 0 { password = "" } return password, nil diff --git a/tools/tui/select.go b/tools/utils/select.go similarity index 60% rename from tools/tui/select.go rename to tools/utils/select.go index 2ed23cec3..1f3822f40 100644 --- a/tools/tui/select.go +++ b/tools/utils/select.go @@ -1,61 +1,59 @@ // License: GPLv3 Copyright: 2022, Kovid Goyal, -package tui +package utils import ( "time" "golang.org/x/sys/unix" - - "kitty/tools/utils" ) -type Select struct { +type Selector struct { read_set, write_set, err_set unix.FdSet read_fds, write_fds, err_fds map[int]bool } -func CreateSelect(expected_number_of_fds int) *Select { - var ans Select +func CreateSelect(expected_number_of_fds int) *Selector { + var ans Selector ans.read_fds = make(map[int]bool, expected_number_of_fds) ans.write_fds = make(map[int]bool, expected_number_of_fds) ans.err_fds = make(map[int]bool, expected_number_of_fds) return &ans } -func (self *Select) register(fd int, fdset *map[int]bool) { +func (self *Selector) register(fd int, fdset *map[int]bool) { (*fdset)[fd] = true } -func (self *Select) RegisterRead(fd int) { +func (self *Selector) RegisterRead(fd int) { self.register(fd, &self.read_fds) } -func (self *Select) RegisterWrite(fd int) { +func (self *Selector) RegisterWrite(fd int) { self.register(fd, &self.write_fds) } -func (self *Select) RegisterError(fd int) { +func (self *Selector) RegisterError(fd int) { self.register(fd, &self.err_fds) } -func (self *Select) unregister(fd int, fdset *map[int]bool) { +func (self *Selector) unregister(fd int, fdset *map[int]bool) { (*fdset)[fd] = false } -func (self *Select) UnRegisterRead(fd int) { +func (self *Selector) UnRegisterRead(fd int) { self.unregister(fd, &self.read_fds) } -func (self *Select) UnRegisterWrite(fd int) { +func (self *Selector) UnRegisterWrite(fd int) { self.unregister(fd, &self.write_fds) } -func (self *Select) UnRegisterError(fd int) { +func (self *Selector) UnRegisterError(fd int) { self.unregister(fd, &self.err_fds) } -func (self *Select) Wait(timeout time.Duration) (num_ready int, err error) { +func (self *Selector) Wait(timeout time.Duration) (num_ready int, err error) { self.read_set.Zero() self.write_set.Zero() self.err_set.Zero() @@ -75,30 +73,30 @@ func (self *Select) Wait(timeout time.Duration) (num_ready int, err error) { init_set(&self.read_set, &self.read_fds) init_set(&self.write_set, &self.write_fds) init_set(&self.err_set, &self.err_fds) - num_ready, err = utils.Select(max_fd_num+1, &self.read_set, &self.write_set, &self.err_set, timeout) + num_ready, err = Select(max_fd_num+1, &self.read_set, &self.write_set, &self.err_set, timeout) if err == unix.EINTR { return 0, nil } return } -func (self *Select) WaitForever() (num_ready int, err error) { +func (self *Selector) WaitForever() (num_ready int, err error) { return self.Wait(-1) } -func (self *Select) IsReadyToRead(fd int) bool { +func (self *Selector) IsReadyToRead(fd int) bool { return fd > -1 && self.read_set.IsSet(fd) } -func (self *Select) IsReadyToWrite(fd int) bool { +func (self *Selector) IsReadyToWrite(fd int) bool { return fd > -1 && self.write_set.IsSet(fd) } -func (self *Select) IsErrored(fd int) bool { +func (self *Selector) IsErrored(fd int) bool { return fd > -1 && self.err_set.IsSet(fd) } -func (self *Select) UnregisterAll() { +func (self *Selector) UnregisterAll() { self.read_fds = make(map[int]bool) self.write_fds = make(map[int]bool) self.err_fds = make(map[int]bool)