diff --git a/core/klibc.c b/core/klibc.c index ffd2bc2..ac17d73 100644 --- a/core/klibc.c +++ b/core/klibc.c @@ -1,3 +1,4 @@ +#include "alloc.h" #include "klibc.h" #include "serial.h" #include "vga.h" @@ -211,7 +212,7 @@ int vsnprintf(char *str, size_t size, const char *format, va_list ap) int i = 0; int c = 0; - while (format[i] != '\0' && size) { + while (format[i] != '\0' && (size|| !str)) { switch (format[i]) { case '%': switch (format[i + 1]) { @@ -381,3 +382,47 @@ int vprintf(const char *fmt, va_list ap) } return ret; } + +int asprintf(char **strp, const char *fmt, ...) +{ + int ret; + + va_list ap; + va_start(ap, fmt); + ret = vasprintf(strp, fmt, ap); + va_end(ap); + + return ret; +} + +int vasprintf(char **strp, const char *fmt, va_list ap) +{ + int n = 0; + size_t size = 0; + char *p = NULL; + + /* Determine required size */ + + n = vsnprintf(p, size, fmt, ap); + printf("Found length %d\n", n); + + if (n < 0) + return -1; + + /* One extra byte for '\0' */ + + size = (size_t)n + 1; + p = malloc(size); + if (p == NULL) + return -1; + + n = vsnprintf(p, size, fmt, ap); + + if (n < 0) { + free(p); + return -1; + } + *strp = p; + + return size; +} diff --git a/core/klibc.h b/core/klibc.h index 913ba91..ea1b89e 100644 --- a/core/klibc.h +++ b/core/klibc.h @@ -25,6 +25,10 @@ int vsnprintf(char *str, size_t size, const char *format, va_list ap); int vprintf(const char *format, va_list ap); int printf(const char *format, ...); +// Could be used after malloc is available +int asprintf(char **strp, const char *fmt, ...); +int vasprintf(char **strp, const char *fmt, va_list ap); + /* * Dummy printk for disabled debugging statements to use whilst maintaining * gcc's format checking. diff --git a/tests/test.c b/tests/test.c index 6b4637a..41e80a9 100644 --- a/tests/test.c +++ b/tests/test.c @@ -312,14 +312,23 @@ void testKthread() void run_test(void) { - int test = 1000; - long long int test64 = 0x100000000; - assert(printf("hello") == 5); - assert(printf("hello\n") == 6); - assert(printf("hello %d\n", test) == 11); - assert(printf("hello %llx\n", test64) == 16); - assert(printf("hello %c\n", 'a') == 8); - assert(printf("hello %s\n", "world") == 12); + { + int test = 1000; + long long int test64 = 0x100000000; + assert(printf("hello") == 5); + assert(printf("hello\n") == 6); + assert(printf("hello %d\n", test) == 11); + assert(printf("hello %llx\n", test64) == 16); + assert(printf("hello %c\n", 'a') == 8); + assert(printf("hello %s\n", "world") == 12); + } + { + char *strAlloc; + int ret = asprintf(&strAlloc, "hello %s\n", "world"); + printf("asprint ret %d %s\n", ret, strAlloc); + assert(ret == 13); //include the '\0' + free(strAlloc); + } testPaging(); printf("Testing Serial\n"); serialPutc('h');