#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#ifdef _WIN32
#include <basetsd.h>
typedef SSIZE_T ssize_t;
#else
#include <unistd.h>
#endif


#include "ntp.h"
#include "socket.h"
#include "list.h"
#include "timestamp.h"
#include "portable_endian.h"
#include "str_helper.h"


////////////////////////////////////////////////////////////////
// Types
////////////////////////////////////////////////////////////////

/**
    NTP Packet Header (without optional sections)

    Reference: RFC5905 - Network Time Protocol (Version 4)
    https://datatracker.ietf.org/doc/html/rfc5905

    0                   1                   2                   3
    0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |LI | VN  |Mode |    Stratum     |     Poll      |  Precision   |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |                         Root Delay                            |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |                         Root Dispersion                       |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |                          Reference ID                         |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |                                                               |
    +                     Reference Timestamp (64)                  +
    |                                                               |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |                                                               |
    +                      Origin Timestamp (64)                    +
    |                                                               |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |                                                               |
    +                      Receive Timestamp (64)                   +
    |                                                               |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |                                                               |
    +                      Transmit Timestamp (64)                  +
    |                                                               |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
typedef struct Packet {

    uint8_t li_vn_mode;
    uint8_t stratum;
    uint8_t poll;
    uint8_t precision;
    uint32_t root_delay;
    uint32_t root_dispersion;
    uint32_t ref_id;
    uint64_t ref_time;
    uint64_t orig_time;
    uint64_t rx_time;
    uint64_t tx_time;

} Packet;

#define PACKET_LI(packet) ((packet).li_vn_mode >> 6)
#define PACKET_VN(packet) (((packet).li_vn_mode >> 3) & 0b111)
#define PACKET_MODE(packet) ((packet).li_vn_mode & 0b111)

#define PACKET_SET_LI_VN_MODE(packet, li, vn, mode) \
    (packet).li_vn_mode = (((li) & 0b11) << 6) | (((vn) & 0b111) << 3) | ((mode) & 0b111)


typedef struct Context {

    struct sockaddr_in addr;
    socklen_t addrlen;

    ntp_callback_fn callback;
    void *user_data;

    int64_t target_timestamp;
    int64_t interval;
    uint32_t count;

} Context;


////////////////////////////////////////////////////////////////
// Constants
////////////////////////////////////////////////////////////////

#define VERSION 4
#define MODE_CLIENT 3
#define MODE_SERVER 4

#define KISS(a,b,c,d) ( \
    ((((uint32_t)(a)) & 0xFF)      ) | \
    ((((uint32_t)(b)) & 0xFF) <<  8) | \
    ((((uint32_t)(c)) & 0xFF) << 16) | \
    ((((uint32_t)(d)) & 0xFF) << 24)   \
)

#define KISS_DENY KISS('D', 'E', 'N', 'Y')
#define KISS_RSTR KISS('R', 'S', 'T', 'R')
#define KISS_RATE KISS('R', 'A', 'T', 'E')

#if __SIZEOF_LONG__ == 8
#define NTP_EPOCH_OFFSET 2208988800ul
#else
#define NTP_EPOCH_OFFSET 2208988800ull
#endif

#define ONE_BILLION 1000000000

const char NTP_PORT_STR[] = "123";


////////////////////////////////////////////////////////////////
// Static variables
////////////////////////////////////////////////////////////////

static socket_t sock;
static select_helper_data *seldata;
static List idle_contexts;
static List active_contexts;

static Packet packet;
static NTP_Result result;


////////////////////////////////////////////////////////////////
// Static functions
////////////////////////////////////////////////////////////////

static void ntp2tp(uint64_t ntp_time, struct timespec *tp, uint64_t epoch_offset) {

    // High 32 bits: Integer part: seconds since NTP epoch
    tp->tv_sec = (ntp_time >> 32) - epoch_offset;
    
    // Low 32 bits: Fractional part: multiple of 1/(2^32) of a second
    uint64_t fraction = ntp_time & 0xFFFFFFFF;
    tp->tv_nsec = (fraction * ONE_BILLION) >> 32;

}

static uint64_t tp2ntp(struct timespec *tp, uint64_t epoch_offset) {

    // Integer part
    uint64_t sec = tp->tv_sec;  // seconds since UNIX Epoch (1970-01-01)
    sec += epoch_offset;        // adjust Epoch
    sec &= 0xFFFFFFFF;          // roll over every 2^32 seconds

    // Fractional part: multiple of 1/(2^32) of a second
    // Max value = 0xFFFFFFFB (guaranteed to fit in 32 bits)
    uint64_t fraction = (((uint64_t)tp->tv_nsec) << 32) / ONE_BILLION;

    // Integer part in the high 32 bits, fractional part in the low 32 bits
    uint64_t ntp_time = (sec << 32) | fraction;   
    return ntp_time;

}

#define ntp_timestamp_to_timespec(ntp_timestamp, tp) ntp2tp(ntp_timestamp, tp, NTP_EPOCH_OFFSET)

static void ntp_interval_to_timespec(int64_t ntp_interval, struct timespec *tp) {

    if (ntp_interval < 0) {
        ntp2tp(-ntp_interval, tp, 0);
        if (tp->tv_nsec > 0) {
            tp->tv_sec = -(tp->tv_sec + 1);
            tp->tv_nsec = ONE_BILLION - tp->tv_nsec;
        }
        else {
            tp->tv_sec = -(tp->tv_sec);
        }
    }
    else {
        ntp2tp(ntp_interval, tp, 0);
    }

}

#define timespec_to_ntp_timestamp(tp) tp2ntp(tp, NTP_EPOCH_OFFSET)

#define timespec_to_ntp_interval(tp) tp2ntp(tp, 0)

static Context *find_context(List *list, const struct sockaddr_in *addr, char remove) {

    LIST_FOR(list) {
        Context *context = LIST_CUR(Context);
        if (addr->sin_addr.s_addr == context->addr.sin_addr.s_addr && addr->sin_port == context->addr.sin_port) {
            if (remove) {
                LIST_CUR_REMOVE(list);
            }
            return context;
        }
    }

    return NULL;

}

static void ntp_error(Context *context, int32_t error) {

    memset(&result, 0, sizeof(NTP_Result));
    result.status = error;

    context->callback((NTP_Client)context, &result, context->user_data);

}

static void handle_packet(Context *context) {

    struct timespec tp;

    uint64_t t1 = be64toh(packet.orig_time);
    uint64_t t2 = be64toh(packet.rx_time);
    uint64_t t3 = be64toh(packet.tx_time);
    uint64_t t4 = timespec_to_ntp_timestamp(&seldata->tp);

    int64_t delay = (t4 - t1) - (t3 - t2);
    int64_t a = t2 - t1;
    int64_t b = t3 - t4;
    int64_t offset = (a/2) + (b/2);

    ntp_interval_to_timespec(delay, &tp);
    result.delay_us = timestamp_from_timespec(&tp);

    ntp_interval_to_timespec(offset, &tp);
    result.offset_us = timestamp_from_timespec(&tp);

    result.status = NTP_CLIENT_SUCCESS;

    context->callback((NTP_Client)context, &result, context->user_data);

}

static void ntp_receive() {

    struct sockaddr_in addr;
    socklen_t addr_len = sizeof(addr);

    ssize_t n = recvfrom(sock, (void *)&packet, sizeof(Packet), 0,
        (struct sockaddr *)(&addr), &addr_len);

    if (n < 0) {
        return;
    }

    Context *context = find_context(&active_contexts, &addr, 1);
    if (!context) {
        // Packet came from an unknown address... Nothing to do here.
        return;
    }

    list_add(&idle_contexts, context);

    #if 0
        printf("Read %ld bytes\n", n);
        printf("li = %u, vn = %u, mode = %u, stratum = %u\n",
            PACKET_LI(*packet),
            PACKET_VN(*packet),
            PACKET_MODE(*packet),
            packet->stratum
        );
    #endif

    // Validate packet

    int error =
        (n < sizeof(Packet)) ||
        (PACKET_VN(packet) != VERSION) ||
        (PACKET_MODE(packet) != MODE_SERVER) ||
        (
            packet.stratum == 0 &&
            (
                packet.ref_id == KISS_DENY ||
                packet.ref_id == KISS_RSTR ||
                packet.ref_id == KISS_RATE
            )
        );

    if (error) {
        ntp_error(context, NTP_CLIENT_ERROR);
    }
    else {
        handle_packet(context);
    }

}

static void ntp_send(Context *context) {

    // Prepare packet
    memset(&packet, 0, sizeof(Packet));
    PACKET_SET_LI_VN_MODE(packet, 0, VERSION, MODE_CLIENT);

    // Invoke clock immediately before sending for the best possible precision
    struct timespec tp;
    clock_gettime(CLOCK_REALTIME, &tp);
    uint64_t timestamp = timespec_to_ntp_timestamp(&tp);
    packet.tx_time = htobe64(timestamp);

    sendto(sock, (void*)&packet, sizeof(Packet), 0,
        (struct sockaddr *)(&context->addr), sizeof(struct sockaddr_in));

    context->target_timestamp = timestamp_from_timespec(&tp) + context->interval;
    context->count--;

}


////////////////////////////////////////////////////////////////
// Public function definitions
////////////////////////////////////////////////////////////////

int ntp_init(select_helper_data *_seldata) {

    #ifdef _WIN32
        WORD winsock_version = MAKEWORD(2, 2);
        WSADATA wsa_data;
        int winsock_error = WSAStartup(winsock_version, &wsa_data);
        if (winsock_error) {
            return winsock_error;
        }
    #endif

    sock = socket_open_nonblocking(AF_INET, SOCK_DGRAM, 0);
    if (sock == -1) {
        return -1;
    }

    seldata = _seldata;
    select_helper_nfds_update(sock, _seldata);

    return 0;

}

void ntp_exit(NTP_Client client) {

    FD_CLR(sock, &seldata->readfds);
    socket_close(sock);

    #ifdef _WIN32
        WSACleanup();
    #endif

    LIST_FOR(&idle_contexts) {
        Context *context = LIST_CUR(Context);
        free(context);
    }

    list_clear(&idle_contexts);

    LIST_FOR(&active_contexts) {
        Context *context = LIST_CUR(Context);
        free(context);
    }

    list_clear(&active_contexts);

}

void ntp_process() {

    if (FD_ISSET(sock, &seldata->readfds_ready)) {
        ntp_receive();
    }

    LIST_FOR(&active_contexts) {

        Context *context = LIST_CUR(Context);

        if (seldata->timestamp < context->target_timestamp) {
            continue;
        }

        if (context->count == 0) {
            ntp_error(context, NTP_CLIENT_TIMEOUT);
            LIST_CUR_REMOVE(&active_contexts);
            list_add(&idle_contexts, context);
            continue;
        }

        ntp_send(context);

    }

    if (active_contexts.size == 0) {
        FD_CLR(sock, &seldata->readfds);
    }

}

NTP_Client ntp_client_new(const char *host, const char *port, ntp_callback_fn cb, void *user_data) {

    // Resolve host name

    struct addrinfo hints;
    struct addrinfo *info;

    memset(&hints, 0, sizeof(hints));
    hints.ai_family = AF_INET;
    hints.ai_socktype = SOCK_DGRAM;

    // TODO error codes
    if (getaddrinfo(host, port, &hints, &info)) {
        return NULL;
    }

    Context *context = malloc(sizeof(Context));

    struct sockaddr_in *addr = (struct sockaddr_in *)(info->ai_addr);
    socklen_t addrlen = info->ai_addrlen;

    memcpy(&context->addr, addr, addrlen);
    context->addrlen = addrlen;

    freeaddrinfo(info);

    context->callback = cb;
    context->user_data = user_data;

    list_add(&idle_contexts, context);

    return (NTP_Client)context;

}

void ntp_client_delete(NTP_Client client) {

    Context *context = (Context*)client;

    list_remove(&idle_contexts, context);
    list_remove(&active_contexts, context);

    free(context);

}

void ntp_query(NTP_Client client, uint32_t interval_ms, uint32_t count) {

    Context *context = (Context*)client;

    if (!list_remove(&idle_contexts, context)) {
        return;
    }

    list_add(&active_contexts, context);

    context->interval = ((int64_t)interval_ms) * 1000;
    context->count = count;

    ntp_send(context);
    FD_SET(sock, &seldata->readfds);

}
