#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <sys/types.h>

#ifdef _WIN32
#include <basetsd.h>
typedef SSIZE_T ssize_t;
#else
#include <sys/time.h>
#include <unistd.h>
#endif

#include "discovery.h"
#include "socket.h"
#include "select_helper.h"
#include "portable_endian.h"
#include "list.h"
#include "timestamp.h"


typedef struct BroadcastSocket {

    socket_t sock;
    struct sockaddr_in addr;

} BroadcastSocket;


static const char HELLO[] = {
    REQUEST_HELLO,
    (unsigned char)(TENKINET_MAGIC_NUMBER >> 8),
    (unsigned char)(TENKINET_MAGIC_NUMBER & 0xff),
    TENKINET_PROTOCOL,
    TENKINET_PROTOCOL
};

static socket_t sock;

static char buffer[TENKINET_DISCOVERY_RESPONSE_LEN_MAX];

static LIST(socket_list);
static select_helper_data seldata;
static tenkinet_discovery_callback_fn callback = NULL;
static TenkinetServerInfo *results = NULL;

static int init_sockets(uint16_t port);
static int find_server_info(struct sockaddr_in *addr);
static int handle_response(BroadcastSocket *bs);


#ifdef _WIN32

/*** BEGIN WIN32 CODE ***/

static int init_sockets(uint16_t port) {

    // Initialize Winsock

    WORD winsock_version = MAKEWORD(2, 2);
    WSADATA wsa_data;

    int winsock_error = WSAStartup(winsock_version, &wsa_data);
    if (winsock_error) {
        return winsock_error;
    }

    // Open socket
    sock = socket_open_nonblocking(AF_INET, SOCK_DGRAM, 0);
    if (sock == -1) {
        return -1;
    }

    // Enable broadcast on socket
    if (setsockopt(sock, SOL_SOCKET, SO_BROADCAST, &(char){1}, sizeof(char)) < 0) {
        return -1;
    }

    // Enable SO_REUSEADDR on socket (probably not needed, but why not)
    if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &(char){1}, sizeof(char)) < 0) {
        // ignore error, since this isn't mandatory
    }

    FD_SET(sock, &seldata.readfds);
    select_helper_nfds_update(sock, &seldata);

    // Get network interfaces (IPv4 only for now)

    MIB_IPADDRTABLE *table;
    DWORD size = 0;

    table = malloc(sizeof(MIB_IPADDRTABLE));

    int error = GetIpAddrTable(table, &size, 0);

    if (error == ERROR_INSUFFICIENT_BUFFER) {
        table = realloc(table, size);
        error = GetIpAddrTable(table, &size, 0);
    }

    if (error) {
        free(table);
        WSACleanup();
        return error;
    }

    // Iterate over network interfaces

    for (int i = 0; i < table->dwNumEntries; i++) {

        MIB_IPADDRROW *row = &(table->table[i]);

        // Skip loopback interface
        if (row->dwAddr == htobe32(INADDR_LOOPBACK)) {
            continue;
        }

        BroadcastSocket *bs = malloc(sizeof(BroadcastSocket));

        bs->sock = sock;
        bs->addr.sin_family = AF_INET;
        bs->addr.sin_port = htobe16(port);

        // row->dwBCastAddr is actually not set by GetIpAddrTable() !!
        // Let's compute it ourselves
        int broadcast_mask = ~(row->dwMask);
        bs->addr.sin_addr.s_addr = row->dwAddr | broadcast_mask;

        list_add(&socket_list, bs);

    }

    free(table);

    if (socket_list.size == 0) {
        return -1;  // TODO error code
    }

    return 0;

}

/*** END WIN32 CODE ***/

#else

/*** BEGIN LINUX/UNIX CODE ***/

static int init_sockets(uint16_t port) {

    // Open socket
    sock = socket_open_nonblocking(AF_INET, SOCK_DGRAM, 0);
    if (sock == -1) {
        return -1;
    }

    // Enable broadcast on socket
    if (setsockopt(sock, SOL_SOCKET, SO_BROADCAST, &(int){1}, sizeof(int)) < 0) {
        return -1;
    }

    // Enable SO_REUSEADDR on socket (probably not needed, but why not)
    if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) < 0) {
        // ignore error, since this isn't mandatory
    }

    FD_SET(sock, &seldata.readfds);
    select_helper_nfds_update(sock, &seldata);

    // Get the broadcast address of each network interface

    struct ifaddrs *ifa;
    if (getifaddrs(&ifa)) {
        return -1;  // TODO error code
    }

    for (struct ifaddrs *p = ifa; p != NULL; p = p->ifa_next) {

        // Skip non-IPv4 interface
        if (p->ifa_addr == NULL || p->ifa_addr->sa_family != AF_INET) {
            continue;
        }

        struct sockaddr_in *addr = (struct sockaddr_in *)(p->ifa_addr);

        // Skip loopback interface
        if (addr->sin_addr.s_addr == htobe32(INADDR_LOOPBACK)) {
            continue;
        }

        BroadcastSocket *bs = malloc(sizeof(BroadcastSocket));
        bs->sock = sock;
        #ifdef __APPLE__
            bs->addr = *(struct sockaddr_in *)(p->ifa_broadaddr);
        #else
            bs->addr = *(struct sockaddr_in *)(p->ifa_ifu.ifu_broadaddr);
        #endif
        bs->addr.sin_port = htobe16(port);

        list_add(&socket_list, bs);

        #if 0

            // Express broadcast address as a human-readable string

            int error = getnameinfo(
                p->ifa_ifu.ifu_broadaddr,
                sizeof(struct sockaddr_in),
                bcast_str,
                sizeof(bcast_str),
                NULL,
                0,
                NI_NUMERICHOST
            );

            if (error) {
                continue;
            }

        #endif

    }

    freeifaddrs(ifa);

    if (socket_list.size == 0) {
        return -1;  // TODO error code
    }

    return 0;

}

/*** END LINUX/UNIX CODE ***/

#endif

int tenkinet_discovery_init(uint16_t port, tenkinet_discovery_callback_fn cb) {

    if (socket_list.size) {
        // TODO error code
        return -1;  // discovery already underway
    }

    callback = cb;

    return init_sockets(port);

}

void tenkinet_discovery_exit() {

    LIST_FOR(&socket_list) {
        BroadcastSocket *bs = LIST_CUR(BroadcastSocket);
        socket_close(bs->sock);
        free(bs);
    }

    list_clear(&socket_list);
    memset(&seldata, 0, sizeof(seldata));
    callback = NULL;

    tenkinet_discovery_results_free();

    #ifdef _WIN32
        WSACleanup();
    #endif

}

static void tenkinet_discovery_broadcast() {

    LIST_FOR(&socket_list) {

        BroadcastSocket *bs = LIST_CUR(BroadcastSocket);

        // Send datagram and ignore errors. Datagrams are unreliable anyway.
        sendto(bs->sock, HELLO, sizeof(HELLO), 0,
            (struct sockaddr *)(&bs->addr), sizeof(struct sockaddr_in));

    }

}

void tenkinet_discovery_loop(uint32_t interval_ms, uint32_t count) {

    if (count == 0 || socket_list.size == 0) {
        return;
    }

    const int64_t interval_us = interval_ms * 1000;

    int64_t timeout = interval_us;
    int64_t end_timestamp = timestamp_now() + timeout;

    tenkinet_discovery_broadcast();
    count--;

    while (1) {

        struct timeval tv;
        timestamp_to_timeval(timeout, &tv);

        if (select_helper(&seldata, &tv) < 0) {
            break;
        }

        LIST_FOR(&socket_list) {
            BroadcastSocket *bs = LIST_CUR(BroadcastSocket);
            if (FD_ISSET(bs->sock, &seldata.readfds_ready) && handle_response(bs)) {
                // Ignore errors
            }
        }

        timeout = end_timestamp - seldata.timestamp;

        if (timeout <= 0) {

            if (count == 0) {
                break;
            }

            timeout = interval_us;
            end_timestamp = seldata.timestamp + timeout;

            tenkinet_discovery_broadcast();
            count--;

        }

    }

}

TenkinetServerInfo *tenkinet_discovery_results() {

    return results;

}

void tenkinet_discovery_results_free() {

    TenkinetServerInfo *info = results;

    while (info) {
        TenkinetServerInfo *next = info->next;
        free(info);
        info = next;
    }

    results = NULL;

}

static int find_server_info(struct sockaddr_in *addr) {

    for (TenkinetServerInfo *info = results; info != NULL; info = info->next) {
        if (addr->sin_addr.s_addr == info->address.sin_addr.s_addr && addr->sin_port == info->address.sin_port) {
            return 1;
        }
    }

    return 0;

}

static int handle_response(BroadcastSocket *bs) {

    struct sockaddr_in server_addr;
    socklen_t server_addr_len = sizeof(server_addr);

    ssize_t n = recvfrom(bs->sock, buffer, sizeof(buffer), 0,
        (struct sockaddr *)(&server_addr), &server_addr_len);

    if (n < TENKINET_DISCOVERY_RESPONSE_HEADER_LEN) {
        // read error, or datagram is too small
        return -1;  // TODO error code
    }

    uint16_t magic;
    memcpy(&magic, buffer, sizeof(uint16_t));
    magic = be16toh(magic);
    if (magic != TENKINET_MAGIC_NUMBER) {
        return -1;  // TODO error code
    }

    // Check if we've already discovered this server
    if (find_server_info(&server_addr)) {
        return 0;
    }

    TenkinetServerInfo *info = malloc(sizeof(TenkinetServerInfo));

    #ifdef _WIN32
        // !!! WARNING !!!
        // For simplicity, we're using inet_ntoa(), but it only works with IPv4 addresses.
        // Before switching to to WSAAddressToString(), which is more modern and flexible,
        // be aware that it copies BOTH the address AND the port number to the result string!
        // So instead of "127.0.0.1" you would get "127.0.0.1:10395". This behaviour differs
        // from the *nix implementations of similar functions. Keep that in mind!
        char *s = inet_ntoa(server_addr.sin_addr);
        strncpy(info->address_str, s, sizeof(info->address_str));
        info->address_str[sizeof(info->address_str) - 1] = '\0';
    #else
        inet_ntop(
            server_addr.sin_family,
            &(server_addr.sin_addr),
            info->address_str,
            sizeof(info->address_str)
        );
    #endif

    snprintf(info->port_str, sizeof(info->port_str), "%hu", be16toh(server_addr.sin_port));

    info->address = server_addr;
    info->protocol = buffer[TENKINET_DISCOVERY_RESPONSE_OFFSET_PROTOCOL];
    info->version.major = buffer[TENKINET_DISCOVERY_RESPONSE_OFFSET_VERSION];
    info->version.minor = buffer[TENKINET_DISCOVERY_RESPONSE_OFFSET_VERSION + 1];
    info->version.revision = buffer[TENKINET_DISCOVERY_RESPONSE_OFFSET_VERSION + 2];
    info->device_count = buffer[TENKINET_DISCOVERY_RESPONSE_OFFSET_DEVICE_COUNT];

    memcpy(info->serial_number, buffer + TENKINET_DISCOVERY_RESPONSE_OFFSET_SERIAL_NUMBER, TENKINET_SERIAL_NUMBER_LEN);
    info->serial_number[TENKINET_SERIAL_NUMBER_LEN] = '\0';

    uint8_t name_len = buffer[TENKINET_DISCOVERY_RESPONSE_OFFSET_NAME_LEN];
    n -= TENKINET_DISCOVERY_RESPONSE_HEADER_LEN;

    if (name_len > n) {
        name_len = n;
    }

    if (name_len > TENKINET_SERVER_NAME_LEN) {
        name_len = TENKINET_SERVER_NAME_LEN;
    }

    info->name_len = name_len;
    memcpy(info->name_str, buffer + TENKINET_DISCOVERY_RESPONSE_OFFSET_NAME, name_len);
    info->name_str[name_len] = '\0';

    info->next = results;
    results = info;

    if (callback) {
        callback(info);
    }

    return 0;

}
