mirror of
https://github.com/kovidgoyal/kitty
synced 2026-06-08 22:28:24 +02:00
Code to parse socket addresses
This commit is contained in:
@@ -29,20 +29,12 @@ func add_bool_set(cmd *cobra.Command, name string, short string, usage string) *
|
||||
}
|
||||
|
||||
type GlobalOptions struct {
|
||||
to_address, password string
|
||||
to_address_is_from_env_var bool
|
||||
to_network, to_address, password string
|
||||
to_address_is_from_env_var bool
|
||||
}
|
||||
|
||||
var global_options GlobalOptions
|
||||
|
||||
func cut(a string, sep string) (string, string, bool) {
|
||||
idx := strings.Index(a, sep)
|
||||
if idx < 0 {
|
||||
return "", "", false
|
||||
}
|
||||
return a[:idx], a[idx+len(sep):], true
|
||||
}
|
||||
|
||||
func get_pubkey(encoded_key string) (encryption_version string, pubkey []byte, err error) {
|
||||
if encoded_key == "" {
|
||||
encoded_key = os.Getenv("KITTY_PUBLIC_KEY")
|
||||
@@ -51,7 +43,7 @@ func get_pubkey(encoded_key string) (encryption_version string, pubkey []byte, e
|
||||
return
|
||||
}
|
||||
}
|
||||
encryption_version, encoded_key, found := cut(encoded_key, ":")
|
||||
encryption_version, encoded_key, found := utils.Cut(encoded_key, ":")
|
||||
if !found {
|
||||
err = fmt.Errorf("KITTY_PUBLIC_KEY environment variable does not have a : in it")
|
||||
return
|
||||
@@ -77,23 +69,28 @@ type serializer_func func(rc *utils.RemoteControlCmd) ([]byte, error)
|
||||
|
||||
var serializer serializer_func = simple_serializer
|
||||
|
||||
func create_serializer(password string, encoded_pubkey string) (ans serializer_func, err error) {
|
||||
func create_serializer(password string, encoded_pubkey string, response_timeout float64) (ans serializer_func, timeout float64, err error) {
|
||||
timeout = response_timeout
|
||||
if password != "" {
|
||||
encryption_version, pubkey, err := get_pubkey(encoded_pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, timeout, err
|
||||
}
|
||||
ans = func(rc *utils.RemoteControlCmd) (ans []byte, err error) {
|
||||
ec, err := crypto.Encrypt_cmd(rc, global_options.password, pubkey, encryption_version)
|
||||
ans, err = json.Marshal(ec)
|
||||
return
|
||||
}
|
||||
if timeout < 120 {
|
||||
timeout = 120
|
||||
}
|
||||
return ans, timeout, nil
|
||||
}
|
||||
return simple_serializer, nil
|
||||
return simple_serializer, timeout, nil
|
||||
}
|
||||
|
||||
func send_rc_command(rc *utils.RemoteControlCmd, timeout float64) (err error) {
|
||||
serializer, err = create_serializer(global_options.password, "")
|
||||
serializer, timeout, err = create_serializer(global_options.password, "", timeout)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -165,7 +162,14 @@ func EntryPoint(tool_root *cobra.Command) *cobra.Command {
|
||||
*to = os.Getenv("KITTY_LISTEN_ON")
|
||||
global_options.to_address_is_from_env_var = true
|
||||
}
|
||||
global_options.to_address = *to
|
||||
if *to != "" {
|
||||
network, address, err := utils.ParseSocketAddress(*to)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
global_options.to_network = network
|
||||
global_options.to_address = address
|
||||
}
|
||||
q, err := get_password(*password, *password_file, *password_env, use_password.Choice)
|
||||
global_options.password = q
|
||||
return err
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestCommandToJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRCSerialization(t *testing.T) {
|
||||
serializer, err := create_serializer("", "")
|
||||
serializer, _, err := create_serializer("", "", 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func TestRCSerialization(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serializer, err = create_serializer("tpw", pubkey)
|
||||
serializer, _, err = create_serializer("tpw", pubkey, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,12 @@ func run_CMD_NAME(cmd *cobra.Command, args []string) (err error) {
|
||||
if err == nil {
|
||||
rc.NoResponse = nrv
|
||||
}
|
||||
err = send_rc_command(rc, WAIT_TIMEOUT)
|
||||
var timeout float64 = WAIT_TIMEOUT
|
||||
rt, err := cmd.Flags().GetFloat64("response-timeout")
|
||||
if err == nil {
|
||||
timeout = rt
|
||||
}
|
||||
err = send_rc_command(rc, timeout)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
46
tools/utils/sockets.go
Normal file
46
tools/utils/sockets.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/seancfoley/ipaddress-go/ipaddr"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Cut(s string, sep string) (string, string, bool) {
|
||||
if i := strings.Index(s, sep); i >= 0 {
|
||||
return s[:i], s[i+len(sep):], true
|
||||
}
|
||||
return s, "", false
|
||||
}
|
||||
|
||||
func ParseSocketAddress(spec string) (network string, addr string, err error) {
|
||||
network, addr, found := Cut(spec, ":")
|
||||
if !found {
|
||||
err = fmt.Errorf("Invalid socket address: %s", spec)
|
||||
return
|
||||
}
|
||||
if network == "unix" {
|
||||
if strings.HasSuffix(addr, "@") && runtime.GOOS != "linux" {
|
||||
err = fmt.Errorf("Abstract UNIX sockets are only supported on Linux. Cannot use: %s", spec)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if network == "tcp" || network == "tcp6" || network == "tcp4" {
|
||||
host := ipaddr.NewHostName(addr)
|
||||
if host.IsAddress() {
|
||||
network = "ip"
|
||||
}
|
||||
return
|
||||
}
|
||||
if network == "ip" || network == "ip6" || network == "ip4" {
|
||||
host := ipaddr.NewHostName(addr)
|
||||
if !host.IsAddress() {
|
||||
err = fmt.Errorf("Not a valid IP address: %#v. Cannot use: %s", addr, spec)
|
||||
}
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("Unknown network type: %#v in socket address: %s", network, spec)
|
||||
return
|
||||
}
|
||||
57
tools/utils/sockets_test.go
Normal file
57
tools/utils/sockets_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseSocketAddress(t *testing.T) {
|
||||
en := "unix"
|
||||
ea := "/tmp/test"
|
||||
var eerr error = nil
|
||||
|
||||
test := func(spec string) {
|
||||
n, a, err := ParseSocketAddress(spec)
|
||||
if err != eerr {
|
||||
if eerr == nil {
|
||||
t.Fatalf("Parsing of %s failed with unexpected error: %s", spec, err)
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatalf("Parsing of %s did not fail, unexpectedly", spec)
|
||||
}
|
||||
return
|
||||
}
|
||||
if a != ea {
|
||||
t.Fatalf("actual != expected, %s != %s, when parsing %s", a, ea, spec)
|
||||
}
|
||||
if n != en {
|
||||
t.Fatalf("actual != expected, %s != %s, when parsing %s", n, en, spec)
|
||||
}
|
||||
}
|
||||
|
||||
testf := func(spec string, netw string, addr string) {
|
||||
eerr = nil
|
||||
en = netw
|
||||
ea = addr
|
||||
test(spec)
|
||||
}
|
||||
teste := func(spec string, e string) {
|
||||
eerr = fmt.Errorf(e)
|
||||
test(spec)
|
||||
}
|
||||
|
||||
test("unix:/tmp/test")
|
||||
if runtime.GOOS == "linux" {
|
||||
ea = "@test"
|
||||
} else {
|
||||
eerr = fmt.Errorf("bad kitty")
|
||||
}
|
||||
test("unix:@test")
|
||||
testf("tcp:localhost:123", "tcp", "localhost:123")
|
||||
testf("tcp:1.1.1.1:123", "ip", "1.1.1.1:123")
|
||||
testf("tcp:fe80::1", "ip", "fe80::1")
|
||||
teste("xxx", "bad kitty")
|
||||
teste("xxx:yyy", "bad kitty")
|
||||
teste(":yyy", "bad kitty")
|
||||
}
|
||||
Reference in New Issue
Block a user