#include <sys/types.h>
#include <sys/socket.h>

#include <assert.h>
#include <endian.h>
#include <errno.h>
#include <inttypes.h>
#include <poll.h>
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
#include <unistd.h>

#include "mu.h"
#include "list.h"


#define BACKLOG 100
#define REQUEST_SIZE 8
#define RESPONSE_SIZE 4
#define MAX_MESSAGE_SIZE 8


struct server {
    struct list_head clients;
    size_t num_clients;

    int sk;

    struct pollfd *pollfds;
};


struct client {
    struct list_head list;
    int sk;
    char peer_str[MU_LIMITS_MAX_INET_STR_SIZE];

    uint8_t buf[MAX_MESSAGE_SIZE];
    size_t n; /* number of bytes in buf */
    size_t pos; /* buffer pos (for sending) */

    short events;
};


static struct client *
client_new(int sk)
{
    MU_NEW(client, client);
    struct sockaddr_in addr;
    socklen_t addr_size = sizeof(addr);

    if (getpeername(sk, (struct sockaddr *)&addr, &addr_size) == -1)
        mu_die_errno(errno, "getpeername");

    mu_sockaddr_in_to_str(&addr, client->peer_str, sizeof(client->peer_str));

    client->sk = sk;
    mu_set_nonblocking(client->sk);

    return client;
}


static void
client_free(struct client *client)
{
    free(client);
}


static void
client_handle_request(struct client *client)
{
    uint32_t x, y, res;

    /* deserialize request */
    memcpy(&x, client->buf, sizeof(x));
    x = be32toh(x);
    memcpy(&y, client->buf + sizeof(x), sizeof(y));
    y = be32toh(y);

    /* process request */
    res = x + y;

    mu_pr_debug("%s: %" PRIu32 " + %" PRIu32 " = %" PRIu32,
            client->peer_str, x, y, res);

    /* serialize response */
    res = htobe32(res);
    memcpy(client->buf, &res, sizeof(res));
    client->n = sizeof(res);
    client->pos = 0; 
}


static int
tcp_server_create(const char *ip, const char *port)
{
    struct sockaddr_in sa;
    int sk;
    int err;

    sk = socket(AF_INET, SOCK_STREAM, 0);
    if (sk == -1)
        mu_die_errno(errno, "socket");

    mu_reuseaddr(sk);

    mu_init_sockaddr_in(&sa, ip, port);
    err = bind(sk, (struct sockaddr *)&sa, sizeof(sa));
    if (err == -1)
        mu_die_errno(errno, "bind");

    err = listen(sk, BACKLOG);
    if (err == -1)
        mu_die_errno(errno, "listen");

    return sk;
}


static void
server_reset_pollfds(struct server *server)
{
    struct client *client;
    size_t i = 1;

    free(server->pollfds);
    server->pollfds = mu_calloc(server->num_clients + 1, sizeof(struct pollfd));

    server->pollfds[0].fd = server->sk;
    server->pollfds[0].events = POLLIN;

    list_for_each_entry(client, &server->clients, list) {
        server->pollfds[i].fd = client->sk;
        server->pollfds[i].events = client->events;
    }
}


static struct server *
server_new(const char *ip, const char *port)
{
    MU_NEW(server, server);

    server->sk = tcp_server_create(ip, port);
    mu_set_nonblocking(server->sk);

    INIT_LIST_HEAD(&server->clients);

    return server;
}


static void
server_add_connection(struct server *server, int sk)
{
    struct client *client;

    client = client_new(sk);
    client->events = POLLIN;

    list_add_tail(&client->list, &server->clients);
    server->num_clients++;
                    
    mu_pr_debug("%s: connected (sk=%d)", client->peer_str, client->sk);
}


static void
server_close_client(struct server *server, struct client *client)
{
    mu_pr_debug("%s: closing (sk=%d)", client->peer_str, client->sk);

    close(client->sk);
    list_del(&client->list);
    client_free(client);

    server->num_clients--;
}


static struct client *
server_get_client_by_sk(const struct server *server, int sk)
{
    struct client *client;

    list_for_each_entry(client, &server->clients, list) {
        if (client->sk == sk)
            return client;
    }

    assert(false);
    return NULL;
}


static void
server_serve_forever(struct server *server)
{
    int n;
    nfds_t i, nfds;
    struct pollfd *pollfd;
    int conn;
    struct client *client;
    ssize_t got, put;

    while (1) {
        server_reset_pollfds(server);
        nfds = server->num_clients + 1;
        n = poll(server->pollfds, nfds, -1);
        if (n == -1)
            mu_die_errno(errno, "poll");

        for (i = 0; i < nfds; i++) {
            pollfd = &server->pollfds[i];
            assert(!(pollfd->revents & POLLNVAL));

            if (pollfd->revents & (POLLERR | POLLHUP)) {
                server_close_client(server, client);
                continue;
            }

            if (pollfd->revents & POLLIN) {
                if (pollfd->fd == server->sk) {
                    /* new connction */
                    conn = accept(server->sk, NULL, 0);
                    if (conn == -1)
                        mu_die_errno(errno, "accept");
                    server_add_connection(server, conn); 
                } else {
                    /* recv data from existing client */
                    client = server_get_client_by_sk(server, pollfd->fd);
                    assert(client->n < REQUEST_SIZE);
                    got = recv(client->sk, client->buf + client->n,
                            REQUEST_SIZE - client->n, 0);
                    if (got == -1) {
                        if (errno == EAGAIN || errno == EWOULDBLOCK) {
                            continue;
                        } else {
                            mu_stderr_errno(errno, "%s: error handling TCP request", client->peer_str);
                            server_close_client(server, client);
                        }
                    } else {
                        client->n += got;
                        if (client->n == REQUEST_SIZE) {
                            client_handle_request(client);
                            client->events = POLLOUT;
                        }
                    }
                }
            } 
            
            if (pollfd->revents & POLLOUT) {
                /* send data to client */
                client = server_get_client_by_sk(server, pollfd->fd);
                assert(client->pos < RESPONSE_SIZE);
                put = send(client->sk, client->buf + client->pos,
                            RESPONSE_SIZE - client->pos, 0);
                if (put == -1) {
                    if (errno == EAGAIN || errno == EWOULDBLOCK) {
                        continue;
                    } else {
                        mu_stderr_errno(errno, "%s: error handling TCP request", client->peer_str);
                        server_close_client(server, client);
                    }
                } else {
                    client->pos += put;
                    if (client->pos == RESPONSE_SIZE) {
                        server_close_client(server, client);
                    }
                }

            }
        }
    }
}


int
main(int argc, char *argv[])
{
    struct server *server;

    if (argc != 2)
        mu_die("Usage: %s PORT", argv[0]);

    server = server_new("0.0.0.0", argv[1]);
    server_serve_forever(server);

    return 0;
}
