Proper parsing of safe_string fields

This commit is contained in:
Kovid Goyal
2023-05-28 13:18:31 +05:30
parent 425e993ab7
commit 6c79ae4443
6 changed files with 47 additions and 55 deletions

View File

@@ -7,13 +7,14 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/fs" "io/fs"
"kitty"
"kitty/tools/utils"
"kitty/tools/wcswidth"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"kitty"
"kitty/tools/utils"
) )
var _ = fmt.Print var _ = fmt.Print
@@ -154,6 +155,14 @@ var ftc_field_map = utils.Once(func() map[string]reflect.StructField {
return ans return ans
}) })
var safe_string_pat = utils.Once(func() *regexp.Regexp {
return regexp.MustCompile(`[^0-9a-zA-Z_:.,/!@#$%^&*()[\]{}~` + "`" + `?"'\\|=+-]`)
})
func safe_string(x string) string {
return safe_string_pat().ReplaceAllLiteralString(x, ``)
}
func (self *FileTransmissionCommand) Serialize(prefix_with_osc_code ...bool) string { func (self *FileTransmissionCommand) Serialize(prefix_with_osc_code ...bool) string {
ans := strings.Builder{} ans := strings.Builder{}
v := reflect.ValueOf(*self) v := reflect.ValueOf(*self)
@@ -173,7 +182,7 @@ func (self *FileTransmissionCommand) Serialize(prefix_with_osc_code ...bool) str
case "base64": case "base64":
encoded_val = base64.RawStdEncoding.EncodeToString(utils.UnsafeStringToBytes(sval)) encoded_val = base64.RawStdEncoding.EncodeToString(utils.UnsafeStringToBytes(sval))
default: default:
encoded_val = escape_semicolons(wcswidth.StripEscapeCodes(sval)) encoded_val = safe_string(sval)
} }
} }
case reflect.Slice: case reflect.Slice:
@@ -229,11 +238,11 @@ func (self FileTransmissionCommand) String() string {
func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand, err error) { func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand, err error) {
ans = &FileTransmissionCommand{} ans = &FileTransmissionCommand{}
key_length, key_start, val_start, val_length := 0, 0, 0, 0 key_length, key_start, val_start, val_length := 0, 0, 0, 0
has_semicolons := false
field_map := ftc_field_map() field_map := ftc_field_map()
v := reflect.Indirect(reflect.ValueOf(ans)) v := reflect.Indirect(reflect.ValueOf(ans))
handle_value := func(key, serialized_val string, has_semicolons bool) error { handle_value := func(key, serialized_val string) error {
key = strings.TrimLeft(key, `;;`)
if field, ok := field_map[key]; ok { if field, ok := field_map[key]; ok {
val := v.FieldByIndex(field.Index) val := v.FieldByIndex(field.Index)
switch val.Kind() { switch val.Kind() {
@@ -246,10 +255,7 @@ func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand
} }
val.SetString(utils.UnsafeBytesToString(b)) val.SetString(utils.UnsafeBytesToString(b))
default: default:
if has_semicolons { val.SetString(safe_string(serialized_val))
serialized_val = strings.ReplaceAll(serialized_val, `;;`, `;`)
}
val.SetString(serialized_val)
} }
case reflect.Slice: case reflect.Slice:
switch val.Type().Elem().Kind() { switch val.Type().Elem().Kind() {
@@ -302,17 +308,12 @@ func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand
if ch == '=' { if ch == '=' {
key_length = i - key_start key_length = i - key_start
val_start = i + 1 val_start = i + 1
has_semicolons = false
} }
} else { } else {
if ch == ';' { if ch == ';' {
if i+1 < len(serialized) && serialized[i+1] == ';' {
has_semicolons = true
i++
} else {
val_length = i - val_start val_length = i - val_start
if key_length > 0 && val_start > 0 { if key_length > 0 && val_start > 0 {
err = handle_value(serialized[key_start:key_start+key_length], serialized[val_start:val_start+val_length], has_semicolons) err = handle_value(serialized[key_start:key_start+key_length], serialized[val_start:val_start+val_length])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -324,9 +325,8 @@ func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand
} }
} }
} }
}
if key_length > 0 && val_start > 0 { if key_length > 0 && val_start > 0 {
err = handle_value(serialized[key_start:key_start+key_length], serialized[val_start:], has_semicolons) err = handle_value(serialized[key_start:key_start+key_length], serialized[val_start:])
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -43,4 +43,9 @@ func TestFTCSerialization(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
q(n.Serialize()) q(n.Serialize())
unsafe := "moo\x1b;;[?*.-se1"
if safe_string(unsafe) != "moo[?*.-se1" {
t.Fatalf("safe_string() failed for %#v yielding: %#v", unsafe, safe_string(unsafe))
}
} }

View File

@@ -162,12 +162,13 @@ begin_patch(PyObject *self UNUSED, PyObject *callback) {
} }
static bool static bool
call_ftc_callback(PyObject *callback, char *src, Py_ssize_t key_start, Py_ssize_t key_length, Py_ssize_t val_start, Py_ssize_t val_length, PyObject *has_semicolons) { call_ftc_callback(PyObject *callback, char *src, Py_ssize_t key_start, Py_ssize_t key_length, Py_ssize_t val_start, Py_ssize_t val_length) {
while(src[key_start] == ';' && key_length > 0 ) { key_start++; key_length--; }
DECREF_AFTER_FUNCTION PyObject *k = PyMemoryView_FromMemory(src + key_start, key_length, PyBUF_READ); DECREF_AFTER_FUNCTION PyObject *k = PyMemoryView_FromMemory(src + key_start, key_length, PyBUF_READ);
if (!k) return false; if (!k) return false;
DECREF_AFTER_FUNCTION PyObject *v = PyMemoryView_FromMemory(src + val_start, val_length, PyBUF_READ); DECREF_AFTER_FUNCTION PyObject *v = PyMemoryView_FromMemory(src + val_start, val_length, PyBUF_READ);
if (!v) return false; if (!v) return false;
DECREF_AFTER_FUNCTION PyObject *ret = PyObject_CallFunctionObjArgs(callback, k, v, has_semicolons, NULL); DECREF_AFTER_FUNCTION PyObject *ret = PyObject_CallFunctionObjArgs(callback, k, v, NULL);
return ret != NULL; return ret != NULL;
} }
@@ -187,31 +188,24 @@ parse_ftc(PyObject *self UNUSED, PyObject *args) {
char *src = buf.buf; char *src = buf.buf;
size_t sz = buf.len; size_t sz = buf.len;
if (!PyCallable_Check(callback)) { PyErr_SetString(PyExc_TypeError, "callback must be callable"); return NULL; } if (!PyCallable_Check(callback)) { PyErr_SetString(PyExc_TypeError, "callback must be callable"); return NULL; }
PyObject *has_semicolons = Py_False;
for (i = 0; i < sz; i++) { for (i = 0; i < sz; i++) {
char ch = src[i]; char ch = src[i];
if (key_length == 0) { if (key_length == 0) {
if (ch == '=') { if (ch == '=') {
key_length = i - key_start; key_length = i - key_start;
val_start = i + 1; val_start = i + 1;
has_semicolons = Py_False;
} }
} else { } else {
if (ch == ';') { if (ch == ';') {
if (i + 1 < sz && src[i + 1] == ';') {
has_semicolons = Py_True;
i++;
} else {
val_length = i - val_start; val_length = i - val_start;
if (!call_ftc_callback(callback, src, key_start, key_length, val_start, val_length, has_semicolons)) return NULL; if (!call_ftc_callback(callback, src, key_start, key_length, val_start, val_length)) return NULL;
key_length = 0; key_start = i + 1; val_start = 0; key_length = 0; key_start = i + 1; val_start = 0;
} }
} }
} }
}
if (key_length && val_start) { if (key_length && val_start) {
val_length = sz - val_start; val_length = sz - val_start;
if (!call_ftc_callback(callback, src, key_start, key_length, val_start, val_length, has_semicolons)) return NULL; if (!call_ftc_callback(callback, src, key_start, key_length, val_start, val_length)) return NULL;
} }
Py_RETURN_NONE; Py_RETURN_NONE;
} }

View File

@@ -39,7 +39,7 @@ def iter_job(job_capsule: JobCapsule, input_data: bytes, output_buf: bytearray)
pass pass
def parse_ftc(src: Union[str, bytes, memoryview], callback: Callable[[memoryview, memoryview, bool], None]) -> None: def parse_ftc(src: Union[str, bytes, memoryview], callback: Callable[[memoryview, memoryview], None]) -> None:
pass pass

View File

@@ -29,10 +29,7 @@ MAX_ACTIVE_RECEIVES = MAX_ACTIVE_SENDS = 10
ftc_prefix = str(FILE_TRANSFER_CODE) ftc_prefix = str(FILE_TRANSFER_CODE)
def escape_semicolons(x: str) -> str: @run_once
return x.replace(';', ';;')
def safe_string_pat() -> 're.Pattern[str]': def safe_string_pat() -> 're.Pattern[str]':
return re.compile(r'[^0-9a-zA-Z_:./@-]') return re.compile(r'[^0-9a-zA-Z_:./@-]')
@@ -316,7 +313,7 @@ class FileTransmissionCommand:
if k.metadata.get('base64'): if k.metadata.get('base64'):
yield standard_b64encode(val.encode('utf-8')) yield standard_b64encode(val.encode('utf-8'))
else: else:
yield escape_semicolons(safe_string(val)) yield safe_string(val)
elif k.type is int: elif k.type is int:
yield str(val) yield str(val)
else: else:
@@ -331,7 +328,7 @@ class FileTransmissionCommand:
fmap = serialized_to_field_map() fmap = serialized_to_field_map()
from kittens.transfer.rsync import decode_utf8_buffer, parse_ftc from kittens.transfer.rsync import decode_utf8_buffer, parse_ftc
def handle_item(key: memoryview, val: memoryview, has_semicolons: bool) -> None: def handle_item(key: memoryview, val: memoryview) -> None:
field = fmap.get(key) field = fmap.get(key)
if field is None: if field is None:
return return
@@ -345,9 +342,7 @@ class FileTransmissionCommand:
if field.metadata.get('base64'): if field.metadata.get('base64'):
sval = standard_b64decode(val).decode('utf-8') sval = standard_b64decode(val).decode('utf-8')
else: else:
sval = decode_utf8_buffer(val) sval = safe_string(decode_utf8_buffer(val))
if has_semicolons:
sval = sval.replace(';;', ';')
setattr(ans, field.name, safe_string(sval)) setattr(ans, field.name, safe_string(sval))
parse_ftc(data, handle_item) parse_ftc(data, handle_item)

View File

@@ -353,10 +353,8 @@ class TestFileTransmission(BaseTest):
def t(raw, *expected): def t(raw, *expected):
a = [] a = []
def c(k, v, has_semicolons): def c(k, v):
a.append(decode_utf8_buffer(k)) a.append(decode_utf8_buffer(k))
if has_semicolons:
v = bytes(v).replace(b';;', b';')
a.append(decode_utf8_buffer(v)) a.append(decode_utf8_buffer(v))
parse_ftc(raw, c) parse_ftc(raw, c)
@@ -364,9 +362,9 @@ class TestFileTransmission(BaseTest):
t('a=b', 'a', 'b') t('a=b', 'a', 'b')
t('a=b;', 'a', 'b') t('a=b;', 'a', 'b')
t('a1=b1;c=d;;', 'a1', 'b1', 'c', 'd;') t('a1=b1;c=d;;', 'a1', 'b1', 'c', 'd')
t('a1=b1;c=d;;e', 'a1', 'b1', 'c', 'd;e') t('a1=b1;c=d;;e', 'a1', 'b1', 'c', 'd')
t('a1=b1;c=d;;;1=1', 'a1', 'b1', 'c', 'd;', '1', '1') t('a1=b1;c=d;;;1=1', 'a1', 'b1', 'c', 'd', '1', '1')
def test_path_mapping_receive(self): def test_path_mapping_receive(self):
opts = parse_transfer_args([])[0] opts = parse_transfer_args([])[0]