#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <zstd.h>
#include "com.h"
#include "error.h"
#include "packetio.h"

struct Worker {
    Com_t *com;
    PacketIO_t *packet;
    int init_frame;
    int error;
    ZSTD_DStream *zds;
    ZSTD_outBuffer out_buf;
    uint8_t *tmp_buf;
    uint64_t tmp_buf_pos;
    size_t tmp_buf_size;
};

void receive_init(size_t max_len, void* user_data) {
    struct Worker *worker = (struct Worker*)user_data;
    worker->init_frame = 0;
    worker->error = 0;
    worker->out_buf.pos = 0;
    worker->tmp_buf_pos = 0;
    ZSTD_initDStream(worker->zds);
}

void receive_callback(uint8_t *fragment, size_t len, int last, void* user_data) {
    ZSTD_inBuffer in_buf;
    struct Worker *worker = (struct Worker*)user_data;
    if (worker->error) {
        return;
    }
    if (worker->init_frame == 0) {
        worker->init_frame = 1;
        uint64_t uncompressed_len = ZSTD_getFrameContentSize(fragment, len);
        if (ZSTD_isError(uncompressed_len)) {
            worker->error = 1;
            return;
        }
        packetio_fragment_init(uncompressed_len, worker->packet);
    }

    if (worker->tmp_buf_pos + len > worker->tmp_buf_size) {
        worker->error = 1;
        return;
    }
    memcpy(worker->tmp_buf + worker->tmp_buf_pos, fragment, len);
    worker->tmp_buf_pos += len;
    in_buf.src = worker->tmp_buf;
    in_buf.pos = 0;
    in_buf.size = worker->tmp_buf_pos;
    uint64_t ret = ZSTD_decompressStream(worker->zds, &worker->out_buf, &in_buf);
    if (ZSTD_isError(ret)) {
        worker->error = 1;
        return;
    }
    packetio_fragment_send(worker->out_buf.dst, worker->out_buf.pos, worker->packet);
    if (ret == 0) {
        packetio_fragment_send(NULL, 0, worker->packet);
        worker->error = 1;
        return;
    }
    if (in_buf.pos != in_buf.size) {
        memmove(worker->tmp_buf, in_buf.src + in_buf.pos, in_buf.size - in_buf.pos);
    }
    worker->tmp_buf_pos = in_buf.size - in_buf.pos;
}

void send_callback(uint8_t *frame, size_t len, void* user_data) {
    struct Worker *worker = (struct Worker*)user_data;
    com_send(frame, len, worker->com);
}

int com_receive(uint8_t *buf, size_t n, void *user_data) {
    struct Worker *worker = (struct Worker*)user_data;
    packetio_frame_receive(buf, n, worker->packet);
    return 0;
}

int main(int argc, char *argv[]) {
    struct Worker worker = { 0 };
    worker.com = com_init(STDIN_FILENO, STDOUT_FILENO, com_receive, &worker);
    if (worker.com == NULL) {
        error("worker: error initializing communication\n");
    }
    worker.packet = packetio_init_streaming(MAX_FRAME_LEN, receive_init, receive_callback, &worker, send_callback, &worker);
    if (worker.packet == NULL) {
        error("worker: error initializing packet layer\n");
    }
    worker.zds = ZSTD_createDStream();
    if (worker.zds == NULL) {
        error("worker: error initializing decompression stream\n");
    }
    worker.tmp_buf_size = ZSTD_DStreamInSize() + MAX_FRAME_LEN;
    worker.tmp_buf = malloc(worker.tmp_buf_size);
    if (worker.tmp_buf == NULL) {
        error("worker: error allocating input buffer\n");
    }
    worker.out_buf.size = ZSTD_DStreamOutSize();
    worker.out_buf.dst = malloc(worker.out_buf.size);
    if (worker.out_buf.dst == NULL) {
        error("worker: error allocating output buffer\n");
    }
    com_receive_loop(worker.com);
}
