#define pr_fmt(fmt) "[alloc]: " fmt
//#define DEBUG
#include "alloc.h"
#include "errno.h"
#include "irq.h"
#include "kernel.h"
#include "klibc.h"
#include "list.h"
#include "math.h"
#include "mem.h"
#include "paging.h"

#define IS_SELF_CONTAINED(desc) ((vaddr_t)((desc)->page) == (vaddr_t)(desc))
// Slab will contains object from sizeof(void *) to PAGE_SIZE/2 by pow2
static struct slabDesc *slub;

int allocSlab(struct slabDesc **desc, size_t sizeEl, size_t sizeSlab, int self_containing);
int allocSlabEntry(struct slabEntry **desc, size_t sizeEl, size_t sizeSlab, int selfContained);
static int formatPage(struct slabEntry *desc, size_t size, size_t sizeSlab, int selfContained);

static struct {
    size_t elementSize;
    size_t slabSize;
    unsigned char isSelf;
} initSlab[] = {{sizeof(struct slabDesc), PAGE_SIZE, 1},
                {sizeof(struct slabEntry), PAGE_SIZE, 1},
                {4, PAGE_SIZE, 0},
                {8, PAGE_SIZE, 0},
                {16, PAGE_SIZE, 0},
                {32, PAGE_SIZE, 0},
                {64, PAGE_SIZE, 0},
                {128, PAGE_SIZE, 0},
                {256, 2 * PAGE_SIZE, 0},
                {1024, 2 * PAGE_SIZE, 0},
                {2048, 3 * PAGE_SIZE, 0},
                {4096, 4 * PAGE_SIZE, 0},
                {8192, 8 * PAGE_SIZE, 0},
                {16384, 12 * PAGE_SIZE, 0},
                {0, 0, 0}};

int allocSetup(void)
{
    list_init(slub);

    for (uint i = 0; initSlab[i].elementSize != 0; i++) {
        int ret;
        if ((ret = allocBookSlab(initSlab[i].elementSize, initSlab[i].slabSize,
                                 initSlab[i].isSelf))) {
            if (ret == -EEXIST)
                continue;
            pr_devel("Fail to allocBookSlab %d for %d \n", ret, (1U << i));
            return ret;
        }
    }
    return 0;
}

int allocBookSlab(size_t sizeEl, size_t sizeSlab, int selfContained)
{
    pr_devel("%s for element of size %d is self %d\n", __func__, sizeEl, selfContained);
    struct slabDesc *slab    = NULL;
    struct slabDesc *newSlab = NULL;
    int slabIdx;
    int ret;
    int flags;

    disable_IRQs(flags);
    list_foreach(slub, slab, slabIdx)
    {
        if (slab->size == sizeEl) {
            restore_IRQs(flags);
            return -EEXIST;
        }
        if (slab->size > sizeEl) {
            break;
        }
    }

    if ((ret = allocSlab(&newSlab, sizeEl, sizeSlab, selfContained))) {
        restore_IRQs(flags);
        return ret;
    }

    if (list_foreach_early_break(slub, slab, slabIdx)) {
        list_insert_before(slub, slab, newSlab);
    } else {
        list_add_tail(slub, newSlab);
    }

    restore_IRQs(flags);
    return 0;
}

int allocSlab(struct slabDesc **desc, size_t size, size_t sizeSlab, int selfContained)
{
    uint nbPage, i;

    pr_devel("%s for size %d is self %d\n", __func__, size, selfContained);
    sizeSlab = MAX(sizeSlab, PAGE_SIZE);
    if (size > sizeSlab) {
        pr_devel("%s size of element %d are bigger than slab size %d\n", size, sizeSlab);
        return -ENOENT;
    }

    nbPage        = DIV_ROUND_UP(sizeSlab, PAGE_SIZE);
    paddr_t alloc = allocPhyPage(nbPage);
    if (alloc == (paddr_t)NULL)
        return -ENOMEM;
    for (i = 0; i < nbPage; i++) {
        if (pageMap((vaddr_t)alloc + i * PAGE_SIZE, alloc + i * PAGE_SIZE, PAGING_MEM_WRITE))
            goto free_page;
    }

    if (selfContained) {
        *desc                  = (struct slabDesc *)alloc;
        ((*desc)->slab).freeEl = (char *)(*desc) + sizeof(struct slabDesc);
    } else {
        *desc                = malloc(sizeof(struct slabDesc));
        (*desc)->slab.freeEl = (void *)alloc;
    }
    struct slabEntry *slab = &(*desc)->slab;
    list_singleton(slab, slab);
    slab->page    = (vaddr_t)alloc;
    slab->full    = 0;
    slab->size    = sizeSlab;
    (*desc)->size = size;
    // pr_devel("got page %d for size %d first %d", alloc, size, (*desc)->slab.freeEl);
    return formatPage(&(*desc)->slab, size, sizeSlab, selfContained);

free_page:
    for (uint j = 0; j < i; j++) {
        pageUnmap((vaddr_t)alloc + i * PAGE_SIZE);
    }
    return -ENOMEM;
}

int allocSlabEntry(struct slabEntry **desc, size_t size, size_t sizeSlab, int selfContained)
{
    uint nbPage, i;

    pr_devel("%s for size %d is self %d\n", __func__, size, selfContained);
    sizeSlab = MAX(sizeSlab, PAGE_SIZE);
    if (size > sizeSlab) {
        pr_devel("%s size of element %d are bigger than slab size %d\n", size, sizeSlab);
        return -ENOENT;
    }

    nbPage        = DIV_ROUND_UP(sizeSlab, PAGE_SIZE);
    paddr_t alloc = allocPhyPage(nbPage);
    if (alloc == (paddr_t)NULL)
        return -ENOMEM;
    for (i = 0; i < nbPage; i++) {
        if (pageMap((vaddr_t)alloc + i * PAGE_SIZE, alloc + i * PAGE_SIZE, PAGING_MEM_WRITE))
            goto free_page;
    }

    if (selfContained) {
        *desc           = (struct slabEntry *)alloc;
        (*desc)->freeEl = (char *)(*desc) + sizeof(struct slabEntry);
    } else {
        *desc           = malloc(sizeof(struct slabEntry));
        (*desc)->freeEl = (void *)alloc;
    }
    list_singleton(*desc, *desc);
    (*desc)->page = (vaddr_t)alloc;
    (*desc)->full = 0;
    (*desc)->size = sizeSlab;
    // pr_devel("got page %d for size %d first %d", alloc, size, (*desc)->freeEl);
    return formatPage((*desc), size, sizeSlab, selfContained);

free_page:
    for (uint j = 0; j < i; j++) {
        pageUnmap((vaddr_t)alloc + i * PAGE_SIZE);
    }
    return -ENOMEM;
}

static int formatPage(struct slabEntry *desc, size_t size, size_t sizeSlab, int selfContained)
{
    char *cur  = desc->freeEl;
    ulong nbEl = sizeSlab / size - 1;
    if (selfContained)
        nbEl = (sizeSlab - sizeof(struct slabDesc)) / size - 1;
    ulong i;
    for (i = 0; i < nbEl; i++) {
        *((vaddr_t *)cur) = (vaddr_t)cur + size;
        cur += size;
    }
    *((vaddr_t *)cur) = (vaddr_t)NULL;
    // pr_devel("last at %d allocated %d\n", cur, i + 1);
    return 0;
}

static void *allocFromSlab(struct slabEntry *slab)
{
    vaddr_t *next = slab->freeEl;
    if (*next == (vaddr_t)NULL) {
        pr_devel("Slab @%d is now full\n", slab);
        slab->full = 1;
    } else {
        slab->freeEl = (void *)(*next);
    }
    return (void *)next;
}

void *malloc(size_t size)
{
    struct slabDesc *slab = NULL;
    uint slubIdx;
    void *ret;
    int flags;

    disable_IRQs(flags);

    list_foreach(slub, slab, slubIdx)
    {
        if (size <= slab->size)
            break;
    }

    if (!list_foreach_early_break(slub, slab, slubIdx)) {
        pr_devel("No slab found for %d\n", size);
        return NULL;
    }

    struct slabEntry *slabEntry;
    int slabIdx;
    list_foreach(&slab->slab, slabEntry, slabIdx)
    {
        if (!slabEntry->full) {
            // pr_devel("found place in slub %d at idx %d for size %d\n", slubIdx,
            //         slabIdx, size);
            ret = allocFromSlab(slabEntry);
            restore_IRQs(flags);
            return ret;
        }
    }

    // No room found
    struct slabEntry *newSlabEntry;
    struct slabEntry *slabList = &slab->slab;
    size_t slabSize            = MAX(PAGE_SIZE, size);
    int retSlab;
    if ((retSlab = allocSlabEntry(&newSlabEntry, slab->size, slabSize,
                                  IS_SELF_CONTAINED(&slab->slab)))) {
        pr_devel("Fail to allocSlabEntry %d\n", retSlab);
        restore_IRQs(flags);
        return NULL;
    }
    pr_devel("Allocate new slab for object of size %d\n", slab->size);
    list_add_tail(slabList, newSlabEntry);
    ret = allocFromSlab(newSlabEntry);
    restore_IRQs(flags);
    return ret;
}

int freeFromSlab(void *ptr, struct slabEntry *slab)
{
    struct slabEntry *slabEntry;
    int slabIdx;
    list_foreach(slab, slabEntry, slabIdx)
    {
        if ((slabEntry->page <= (vaddr_t)ptr) &&
            ((vaddr_t)ptr < (slabEntry->page + slabEntry->size))) {
            // pr_devel("free place! was %d is now %d\n", slabEntry->freeEl, ptr);
            if (slabEntry->full) {
                *((vaddr_t *)ptr) = (vaddr_t)NULL;
            } else {
                *((vaddr_t *)ptr) = (vaddr_t)slabEntry->freeEl;
            }
            slabEntry->freeEl = ptr;
            slabEntry->full   = 0;
            return 1;
        }
    }
    return 0;
}
void free(void *ptr)
{
    if (!ptr)
        return;

    struct slabDesc *slab;
    int slabIdx;
    int flags;

    disable_IRQs(flags);
    list_foreach(slub, slab, slabIdx)
    {
        struct slabEntry *slabEntry;
        int entryIdx;
        list_foreach(&slab->slab, slabEntry, entryIdx)
        {
            if (freeFromSlab(ptr, slabEntry)) {
                restore_IRQs(flags);
                return;
            }
        }
    }
    restore_IRQs(flags);
    pr_devel("free: slab not found\n");
}