Use a simple arena allocator for CLISpec

This commit is contained in:
Kovid Goyal
2025-04-27 12:55:50 +05:30
parent d548a6fcf4
commit 40f0f3d3eb

View File

@@ -7,7 +7,6 @@
#pragma once
#include "listobject.h"
#include <Python.h>
#include <stdbool.h>
#include <stddef.h>
@@ -21,6 +20,11 @@ static inline void cleanup_decref2(PyObject **p) { Py_CLEAR(*p); }
static inline void cleanup_free(void *p) { free(*(void**)p); }
#define RAII_ALLOC(type, name, initializer) __attribute__((cleanup(cleanup_free))) type *name = initializer
#define MAX(x, y) __extension__ ({ \
const __typeof__ (x) __a__ = (x); const __typeof__ (y) __b__ = (y); \
__a__ > __b__ ? __a__ : __b__;})
#endif
static inline void
@@ -40,7 +44,6 @@ typedef struct CLIValue {
struct {
const char* * items;
size_t count, capacity;
bool needs_free;
} listval;
} CLIValue;
@@ -63,7 +66,7 @@ typedef struct FlagSpec {
#define NAME flag_hash
#define KEY_TY const char*
#define VAL_TY const FlagSpec*
#define VAL_TY FlagSpec
#include "../kitty-verstable.h"
#define flag_map_for_loop(x) vt_create_for_loop(flag_hash_itr, itr, x)
@@ -73,7 +76,11 @@ typedef struct CLISpec {
alias_hash alias_map;
flag_hash flag_map;
char **argv; int argc; // leftover args
char err[1024];
const char* errmsg;
struct {
struct { char *buf; size_t capacity, used; } *items;
size_t count, capacity;
} blocks;
} CLISpec;
static void
@@ -83,12 +90,49 @@ out_of_memory(int line) {
}
#define OOM out_of_memory(__LINE__)
static void*
alloc_for_cli(CLISpec *spec, size_t sz) {
sz++;
if (!spec->blocks.capacity) {
spec->blocks.capacity = 8;
spec->blocks.items = calloc(spec->blocks.capacity, sizeof(spec->blocks.items[0]));
if (!spec->blocks.items) return NULL;
spec->blocks.count = 1;
}
#define block spec->blocks.items[spec->blocks.count-1]
if (block.used + sz >= block.capacity) {
if (block.capacity) { // need new block
spec->blocks.count++;
if (spec->blocks.count >= spec->blocks.capacity) {
spec->blocks.capacity *= 2;
spec->blocks.items = realloc(spec->blocks.items, spec->blocks.capacity * sizeof(spec->blocks.items[0]));
if (!spec->blocks.items) return NULL;
}
}
block.capacity = MAX(sz, 8192);
block.buf = malloc(block.capacity);
if (!block.buf) return NULL;
block.used = 0;
}
char *ans = block.buf + block.used;
block.used += sz;
ans[sz-1] = 0;
return ans;
#undef block
}
#define set_err(fmt, ...) { \
int sz = snprintf(NULL, 0, fmt, __VA_ARGS__); \
char *buf = alloc_for_cli(spec, sz); \
if (!buf) OOM; \
snprintf(buf, sz, fmt, __VA_ARGS__); spec->errmsg = buf; \
}
static const char*
dest_for_alias(CLISpec *spec, const char *alias) {
alias_hash_itr itr = vt_get(&spec->alias_map, alias);
if (vt_is_end(itr)) {
snprintf(spec->err, sizeof(spec->err), "Unknown flag: %s use --help", alias);
set_err("Unknown flag: %s use --help", alias);
return NULL;
}
return itr.data->val;
@@ -99,7 +143,7 @@ is_alias_bool(CLISpec* spec, const char *alias) {
const char *dest = dest_for_alias(spec, alias);
if (!dest) return false;
flag_hash_itr itr = vt_get(&spec->flag_map, dest);
return itr.data->val->defval.type == CLI_VALUE_BOOL;
return itr.data->val.defval.type == CLI_VALUE_BOOL;
}
static void
@@ -123,7 +167,7 @@ process_cli_arg(CLISpec* spec, const char *alias, const char *payload) {
const char *dest = dest_for_alias(spec, alias);
if (!dest) return false;
flag_hash_itr itr = vt_get(&spec->flag_map, dest);
const FlagSpec *flag = itr.data->val;
const FlagSpec *flag = &itr.data->val;
CLIValue val = {.type=flag->defval.type};
#define streq(q) (strcmp(payload, #q) == 0)
switch(val.type) {
@@ -133,7 +177,7 @@ process_cli_arg(CLISpec* spec, const char *alias, const char *payload) {
if (streq(y) || streq(yes) || streq(true)) val.boolval = true;
else if (streq(n) || streq(no) || streq(false)) val.boolval = false;
else {
snprintf(spec->err, sizeof(spec->err), "%s is an invalid value for %s. Valid values are: y, yes, true, n, no and false.",
set_err("%s is an invalid value for %s. Valid values are: y, yes, true, n, no and false.",
payload[0] ? payload : "<empty>", alias);
return false;
}
@@ -145,24 +189,28 @@ process_cli_arg(CLISpec* spec, const char *alias, const char *payload) {
if (strcmp(payload, flag->defval.listval.items[c]) == 0) { val.strval = payload; break; }
}
if (!val.strval) {
int n = snprintf(spec->err, sizeof(spec->err), "%s is an invalid value for %s. Valid values are:",
payload[0] ? payload : "<empty>", alias);
size_t bufsz = 128 + strlen(alias) + strlen(payload);
for (size_t c = 0; c < flag->defval.listval.count; c++) bufsz += strlen(flag->defval.listval.items[c]) + 8;
char *buf = alloc_for_cli(spec, bufsz);
int n = snprintf(buf, bufsz, "%s is an invalid value for %s. Valid values are:",
payload[0] ? payload : "<empty>", alias);
for (size_t c = 0; c < flag->defval.listval.count; c++)
n += snprintf(spec->err + n, sizeof(spec->err) - n, " %s,", flag->defval.listval.items[c]);
spec->err[n-1] = '.';
n += snprintf(buf + n, bufsz - n, " %s,", flag->defval.listval.items[c]);
buf[n-1] = '.';
spec->errmsg = buf;
return false;
}
break;
case CLI_VALUE_INT:
errno = 0; val.intval = strtoll(payload, NULL, 10);
if (errno) {
snprintf(spec->err, sizeof(spec->err), "%s is an invalid value for %s, it must be an integer number.", payload, alias);
set_err("%s is an invalid value for %s, it must be an integer number.", payload, alias);
return false;
} break;
case CLI_VALUE_FLOAT:
errno = 0; val.floatval = strtod(payload, NULL);
if (errno) {
snprintf(spec->err, sizeof(spec->err), "%s is an invalid value for %s, it must be a number.", payload, alias);
set_err("%s is an invalid value for %s, it must be a number.", payload, alias);
return false;
} break;
case CLI_VALUE_LIST: add_list_value(spec, flag->dest, payload); return true;
@@ -179,20 +227,11 @@ alloc_cli_spec(CLISpec *spec) {
vt_init(&spec->flag_map);
}
static void
dealloc_cli_value(CLIValue v) {
if (v.listval.needs_free) free((void*)v.listval.items);
}
static void
dealloc_cli_spec(void *v) {
CLISpec *spec = v;
value_map_for_loop(&spec->value_map) {
dealloc_cli_value(itr.data->val);
}
flag_map_for_loop(&spec->flag_map) {
dealloc_cli_value(itr.data->val->defval);
}
for (size_t i = 0; i < spec->blocks.count; i++) free(spec->blocks.items[i].buf);
free(spec->blocks.items);
vt_cleanup(&spec->value_map);
vt_cleanup(&spec->alias_map);
vt_cleanup(&spec->flag_map);
@@ -203,7 +242,7 @@ dealloc_cli_spec(void *v) {
static bool
parse_cli_loop(CLISpec *spec, int argc, char **argv) { // argv must contain arg1 and beyond
enum { NORMAL, EXPECTING_ARG } state = NORMAL;
spec->argc = 0; spec->argv = NULL; spec->err[0] = 0;
spec->argc = 0; spec->argv = NULL; spec->errmsg = NULL;
char flag[3] = {'-', 0, 0};
const char *current_option = NULL;
for (int i = 0; i < argc; i++) {
@@ -234,7 +273,7 @@ parse_cli_loop(CLISpec *spec, int argc, char **argv) { // argv must contain arg
current_option = arg;
}
}
if (spec->err[0]) return false;
if (spec->errmsg) return false;
} else {
for (const char *letter = arg + 1; *letter; letter++) {
flag[1] = *letter;
@@ -247,7 +286,7 @@ parse_cli_loop(CLISpec *spec, int argc, char **argv) { // argv must contain arg
state = EXPECTING_ARG;
current_option = arg;
}
if (spec->err[0]) return false;
if (spec->errmsg) return false;
}
}
}
@@ -263,19 +302,19 @@ parse_cli_loop(CLISpec *spec, int argc, char **argv) { // argv must contain arg
} break;
}
}
if (state == EXPECTING_ARG) snprintf(spec->err, sizeof(spec->err), "The %s flag must be followed by an argument.", current_option ? current_option : "");
return spec->err[0] == 0;
if (state == EXPECTING_ARG) set_err("The %s flag must be followed by an argument.", current_option ? current_option : "");
return spec->errmsg != NULL;
}
static PyObject*
cli_parse_result_as_python(CLISpec *spec) {
if (PyErr_Occurred()) return NULL;
if (spec->err[0]) {
PyErr_SetString(PyExc_ValueError, spec->err); return NULL;
if (spec->errmsg) {
PyErr_SetString(PyExc_ValueError, spec->errmsg); return NULL;
}
RAII_PyObject(ans, PyDict_New()); if (!ans) return NULL;
flag_map_for_loop(&spec->flag_map) {
const FlagSpec *flag = itr.data->val;
const FlagSpec *flag = &itr.data->val;
cli_hash_itr i = vt_get(&spec->value_map, flag->dest);
PyObject *is_seen = vt_is_end(i) ? Py_False : Py_True;
const CLIValue *v = is_seen == Py_True ? &i.data->val : &flag->defval;
@@ -317,50 +356,47 @@ parse_cli_from_python_spec(PyObject *self, PyObject *args) {
argv[i] = strdup(PyUnicode_AsUTF8(PyList_GET_ITEM(pyargs, i)));
if (!argv[i]) return PyErr_NoMemory();
}
RAII_ALLOC(FlagSpec, flags, calloc(PyDict_GET_SIZE(names_map), sizeof(FlagSpec))); if (!flags) return PyErr_NoMemory();
RAII_CLISpec(spec);
PyObject *key = NULL, *opt = NULL;
Py_ssize_t pos = 0, flag_num = 0;
Py_ssize_t pos = 0;
while (PyDict_Next(names_map, &pos, &key, &opt)) {
FlagSpec *flag = &flags[flag_num++];
flag->dest = PyUnicode_AsUTF8(key);
FlagSpec flag = {.dest=PyUnicode_AsUTF8(key)};
PyObject *pytype = PyDict_GetItemString(opt, "type");
const char *type = pytype ? PyUnicode_AsUTF8(pytype) : "";
PyObject *defval = PyDict_GetItemWithError(defval_map, key); if (!defval && PyErr_Occurred()) return NULL;
PyObject *pyaliases = PyDict_GetItemString(opt, "aliases");
for (int a = 0; a < PyTuple_GET_SIZE(pyaliases); a++) {
const char *alias = PyUnicode_AsUTF8(PyTuple_GET_ITEM(pyaliases, a));
if (vt_is_end(vt_insert(&spec.alias_map, alias, flag->dest))) return PyErr_NoMemory();
if (vt_is_end(vt_insert(&spec.alias_map, alias, flag.dest))) return PyErr_NoMemory();
}
if (strstr(type, "bool-") == type) {
flag->defval.type = CLI_VALUE_BOOL;
flag->defval.boolval = PyObject_IsTrue(defval);
flag.defval.type = CLI_VALUE_BOOL;
flag.defval.boolval = PyObject_IsTrue(defval);
} else if (strcmp(type, "int") == 0) {
flag->defval.type = CLI_VALUE_INT;
flag->defval.intval = PyLong_AsLongLong(defval);
flag.defval.type = CLI_VALUE_INT;
flag.defval.intval = PyLong_AsLongLong(defval);
} else if (strcmp(type, "float") == 0) {
flag->defval.type = CLI_VALUE_FLOAT;
flag->defval.floatval = PyFloat_AsDouble(defval);
flag.defval.type = CLI_VALUE_FLOAT;
flag.defval.floatval = PyFloat_AsDouble(defval);
} else if (strcmp(type, "list") == 0) {
flag->defval.type = CLI_VALUE_LIST;
flag.defval.type = CLI_VALUE_LIST;
} else if (strcmp(type, "choices") == 0) {
flag->defval.type = CLI_VALUE_CHOICE;
flag->defval.strval = PyUnicode_AsUTF8(defval);
flag.defval.type = CLI_VALUE_CHOICE;
flag.defval.strval = PyUnicode_AsUTF8(defval);
PyObject *pyc = PyDict_GetItemString(opt, "choices");
flag->defval.listval.items = malloc(PyTuple_GET_SIZE(pyc) * sizeof(char*));
if (!flag->defval.listval.items) return PyErr_NoMemory();
flag->defval.listval.count = PyTuple_GET_SIZE(pyc);
flag->defval.listval.needs_free = true;
flag->defval.listval.capacity = PyTuple_GET_SIZE(pyc);
for (size_t n = 0; n < flag->defval.listval.count; n++) {
flag->defval.listval.items[n] = PyUnicode_AsUTF8(PyTuple_GET_ITEM(pyc, n));
if (!flag->defval.listval.items[n]) return NULL;
flag.defval.listval.items = alloc_for_cli(&spec, PyTuple_GET_SIZE(pyc) * sizeof(char*));
if (!flag.defval.listval.items) return PyErr_NoMemory();
flag.defval.listval.count = PyTuple_GET_SIZE(pyc);
flag.defval.listval.capacity = PyTuple_GET_SIZE(pyc);
for (size_t n = 0; n < flag.defval.listval.count; n++) {
flag.defval.listval.items[n] = PyUnicode_AsUTF8(PyTuple_GET_ITEM(pyc, n));
if (!flag.defval.listval.items[n]) return NULL;
}
} else {
flag->defval.type = CLI_VALUE_STRING;
flag->defval.strval = PyUnicode_Check(defval) ? PyUnicode_AsUTF8(defval) : NULL;
flag.defval.type = CLI_VALUE_STRING;
flag.defval.strval = PyUnicode_Check(defval) ? PyUnicode_AsUTF8(defval) : NULL;
}
if (vt_is_end(vt_insert(&spec.flag_map, flag->dest, flag))) return PyErr_NoMemory();
if (vt_is_end(vt_insert(&spec.flag_map, flag.dest, flag))) return PyErr_NoMemory();
}
if (PyErr_Occurred()) return NULL;
parse_cli_loop(&spec, argc, argv);