#include <atomic>
#include <chrono>
#include <csignal>
#include <iostream>
#include <numeric>
#include <stdbool.h>
#include <thread>
#include <vector>

#ifdef __linux__
#	include <arpa/inet.h>
#	include <fcntl.h>
#	include <filesystem>
#	include <libevdev/libevdev.h>
#	include <linux/uinput.h>
#	include <memory>
#	include <netdb.h>
#	include <netinet/in.h>
#	include <netinet/tcp.h>
#	include <stddef.h>
#	include <stdio.h>
#	include <string.h>
#	include <sys/socket.h>
#	include <termios.h>
#	include <unistd.h>
#else
#	include <winsock2.h> // "Please include winsock2.h before windows.h" ???

#	include <windows.h>
#	include <ws2tcpip.h>

#	pragma comment (lib, "Ws2_32.lib")
#	define close(x) closesocket(x)
#endif

// https://github.com/torvalds/linux/blob/master/drivers/tty/vt/defkeymap.map
#define KEY 31 // s
// https://learn.microsoft.com/en-us/windows/win32/inputdev/virtual-key-codes
#define KEY_WIN 'S'

#ifdef __linux__
struct device {
	int fd;
	struct libevdev* dev;
};
#endif

std::unique_ptr<std::thread> dev_thread, recv_thread;

static int sockfd = -1;
static std::atomic<bool> stopping = false;
#ifdef __linux__
static void signal(int) {
	std::signal(SIGINT, NULL);
	std::signal(SIGTERM, NULL);
	// unfortunately closing evdev device FDs or freeing the device objects does not interrupt next_event
	std::cout << "\r\033[KSubmit an input on the device that you used to exit" << std::endl;
	std::cout << "INT again if this does not exit after that" << std::endl;
#else
BOOL WINAPI ConsoleHandler(DWORD) {
	SetConsoleCtrlHandler(ConsoleHandler, FALSE);
#endif
	stopping = true;
#ifdef __linux__
	close(STDIN_FILENO);
#endif
	// apparently this doesn't interrupt waiting readers :(
	// workaroundable but I'm lazy, just INT twice
	if(sockfd != -1) { close(sockfd); sockfd = -1; }
	if(recv_thread) recv_thread->join();
#ifdef __linux__
	if(dev_thread)   dev_thread->join();
#endif
	exit(0);
}

#ifdef __linux__
static void select_device(device* dev) {
	bool refresh;
	do {
		refresh = false;
		size_t dev_count = (size_t)std::distance(std::filesystem::directory_iterator{"/dev/input/"}, std::filesystem::directory_iterator{});
		for(size_t dev_id = 0; dev_id < dev_count; dev_id++) {
			if((dev->fd = open(("/dev/input/event" + std::to_string(dev_id)).c_str(), O_RDONLY)) == -1) continue;
			if(libevdev_new_from_fd(dev->fd, &dev->dev) == 0) {
				std::cout << dev_id << ": " << libevdev_get_name(dev->dev) << std::endl;
				libevdev_free(dev->dev);
			}
			close(dev->fd);
		}
		dev->fd = -1;
		std::cout << "Press R to refresh" << std::endl;
		while(dev->fd == -1) {
			std::cout << "Device [0-" << dev_count << "]: ";
			std::string dev_id;
			int c;
			while((c = getchar()) != EOF) {
				if(c == '\n') {
					if(dev_id.size() > 0) break;
					std::cout << "\033[A\r\033[KDevice [0-" << dev_count << "]: ";
					continue;
				}
				else if(c == 'r' || c == 'R') {
					dev_id.clear();
					refresh = true;
					std::cout << std::endl;
					break;
				}
				else if(c == '\177') {
					std::cout << "\b\b  \b\b";
					if(!dev_id.empty()) {
						std::cout << "\b \b";
						dev_id.erase(dev_id.size()-1);
					}
					continue;
				}
				else if(c < '0' || c > '9') {
					if(c >= ' ') std::cout << "\b \b";
					else std::cout << "\r\033[KDevice [0-" << dev_count << "]: " << dev_id;
					continue;
				}
				dev_id += c;
			}
			if(c == EOF) exit(1);
			else if(refresh) break;
			if((dev->fd = open(("/dev/input/event" + dev_id).c_str(), O_RDONLY)) == -1) {
				std::cerr << "\033[A\r\033[KFailed to open device " << dev_id << "!" << std::endl;
				continue; }
			//if(ioctl(dev->fd, EVIOCGRAB, 1) < 0) {
			//	std::cerr << "\033[A\r\033[KFailed to get device " << dev_id << " exclusively!" << std::endl;
			//	close(dev->fd); dev->fd = -1; continue; }
			if(libevdev_new_from_fd(dev->fd, &dev->dev) != 0) {
				std::cerr << "\033[A\r\033[KFailed to load device " << dev_id << " through evdev!" << std::endl;
				ioctl(dev->fd, EVIOCGRAB, 0);
				close(dev->fd);
				continue; }
		}
	} while(refresh);
}
#endif

// events are the current second % 10 followed by ms 0-999
// I should hope noone is trying to play with a >=9s delay
static inline void send_event(int connfd, bool down) {
	char event[5];
	auto now = std::chrono::high_resolution_clock::now().time_since_epoch();
	event[0] = '0' + (char)(std::chrono::floor<std::chrono::seconds>(now)%std::chrono::seconds(10)).count();
	auto ms = (std::chrono::floor<std::chrono::milliseconds>(now)%std::chrono::milliseconds(1000)).count();
	event[1] = '0' + (char)(ms/100);
	event[2] = '0' + (char)((ms%100)/10);
	event[3] = '0' + (char)(ms%10);
	event[4] = (down) ? '1' : '0';
	if(send(connfd, &event[0], 5, 0) != 5) stopping = true;
}

#ifndef __linux__
static int hook_connfd = -1;
unsigned int round_trip = 0;
LRESULT CALLBACK LowLevelKeyboardProc(int nCode, WPARAM wParam, LPARAM lParam) {
	static bool other_down = false;
	if(((KBDLLHOOKSTRUCT*)lParam)->vkCode == KEY_WIN) other_down = wParam == WM_KEYDOWN;
	if(((KBDLLHOOKSTRUCT*)lParam)->vkCode == VK_SPACE) {
		static bool last = false;
		if(last != (wParam == WM_KEYDOWN)) {
			last = !last;
			send_event(hook_connfd, wParam == WM_KEYDOWN);
			if(other_down && wParam == WM_KEYDOWN) std::this_thread::sleep_for(std::chrono::milliseconds(round_trip/2));
		}
	}
	return CallNextHookEx( NULL, nCode, wParam, lParam);
}
#endif
static void device_thread(
#ifdef __linux__
	device* dev,
#endif
int connfd) {
#ifdef __linux__
	while(true) {
		struct input_event ev;
		int rc = libevdev_next_event(dev->dev, LIBEVDEV_READ_FLAG_NORMAL, &ev);
		if(stopping) break;
		if(rc == -EAGAIN) continue;
		else if(rc != 0) {
			std::cerr << "Error reading from device (" << rc << ")!" << std::endl;
			break;
		}
		else {
			switch(ev.type) {
				case EV_KEY:
					if(ev.code == KEY_SPACE && ev.value != 2) {
						static bool last = false;
						if(last != (ev.value == 1)) {
							last = !last;
							send_event(connfd, ev.value == 1);
						}
					}
					break;
				//case EV_REL:
				case EV_ABS:
					break;
				case EV_SYN:
					break;
			}
		}
	}
	libevdev_free(dev->dev);
	ioctl(dev->fd, EVIOCGRAB, 0);
	close(dev->fd);
#else
	hook_connfd = connfd;
	SetWindowsHookEx(WH_KEYBOARD_LL, LowLevelKeyboardProc, GetModuleHandle(0), 0);
	MSG msg; while(GetMessage(&msg, NULL, 0, 0)) continue;
#endif
}

#ifdef __linux__
static inline void emit(int fd, int type, int code, int val) {
	struct input_event ie;
	ie.type = type;
	ie.code = code;
	ie.value = val;
	ie.time.tv_sec = 0;
	ie.time.tv_usec = 0;
	write(fd, &ie, sizeof(ie));
}
#endif
static void receive_thread(int latency, int connfd, int
#ifdef __linux__
	kb
#endif
) {
	char event[5];
	std::vector<int> delays_ms;
	send_event(connfd, false);
	while(recv(connfd, &event[0], 5, 0) == 5) {
		auto now_ts = std::chrono::high_resolution_clock::now();
		auto now = now_ts.time_since_epoch();
		auto sec = (std::chrono::floor<std::chrono::seconds>(now)%std::chrono::seconds(10)).count();
		int ms = (std::chrono::floor<std::chrono::milliseconds>(now)%std::chrono::milliseconds(1000)).count();
		decltype(sec) e_sec = event[0]-'0';
		int e_ms = 100*(event[1]-'0')+10*(event[2]-'0')+(event[1]-'0');
		int e_delay_ms = 1000*(10-e_sec+sec)+(ms-e_ms);
		if(e_delay_ms >= 5000) e_delay_ms -= 10000;
		if(e_delay_ms >= 5000) e_delay_ms -= 10000;
		if(delays_ms.size() < 5) {
			delays_ms.push_back(e_delay_ms);
			if(delays_ms.size() < 5) {
				send_event(connfd, false);
#ifndef __linux__
				static decltype(now_ts) round_trip_start;
				if(delays_ms.size() == 1) round_trip_start = now_ts;
				else if(delays_ms.size() == 2) round_trip = std::chrono::duration_cast<std::chrono::milliseconds>(now_ts-round_trip_start).count();
#endif
			}
			else {
				delays_ms[0] = std::accumulate(delays_ms.begin(), delays_ms.end(), 0)/5;
				if(delays_ms[0] < 0 && delays_ms[0] > -100) std::cerr << "(Don't worry about this being negative)" << std::endl;
				if(latency == 0) { latency = delays_ms[0]+10; std::cout << "Measured " << delays_ms[0] << "ms of latency, using " << (delays_ms[0]+10) << "ms" << std::endl; }
				else std::cout << "Measured " << delays_ms[0] << "ms of latency and would use " << (delays_ms[0]+10) << "ms, but using " << latency << "ms as requested" << std::endl;
			}
			continue;
		}
		if(latency-e_delay_ms > 2) std::this_thread::sleep_for(std::chrono::milliseconds(latency-e_delay_ms));
#ifdef __linux__
		emit(kb, EV_KEY, KEY, (event[4] == '1') ? 1 : 0);
		emit(kb, EV_SYN, SYN_REPORT, 0);
#else
		INPUT ip;
		ip.type = INPUT_KEYBOARD;
		static UINT scancode = 0;
		if(scancode == 0) scancode = MapVirtualKeyA(KEY_WIN, MAPVK_VK_TO_VSC);
		ip.ki.wScan = scancode;
		ip.ki.time = 0; // Windows fills in the timestamp
		ip.ki.dwExtraInfo = 0;
		ip.ki.wVk = KEY_WIN;
		ip.ki.dwFlags = KEYEVENTF_SCANCODE;
		if(event[4] == '0') ip.ki.dwFlags = KEYEVENTF_KEYUP;
		SendInput(1, &ip, sizeof(INPUT));
#endif
	}
	stopping = true;
}

static void usage() {
	fprintf(stderr, "Usage:\n");
	fprintf(stderr, "\tOne of (required):\n");
	fprintf(stderr, "\t\t-s     \thost server\n");
	fprintf(stderr, "\t\t-c [IP]\tconnect to IP\n");
	fprintf(stderr, "\t-p [PORT]   \t(required) port number to bind or connect to\n");
	fprintf(stderr, "\t-l [LATENCY]\tlatency in ms to use instead of tested value\n");
}

#define usage() { usage(); return 1; }
int main(int argc, char** argv) {
	int mode = -1;
	int address = 0;
	int port = 0;
	int latency = 0;

	for(int i = 1; i < argc; i++) {
		if(argv[i][0] != '-' || argv[i][1] == '\0' || argv[i][2] != '\0') usage();
		int value = 0;
		switch(argv[i][1]) {
			case 's': if(mode != -1) usage(); mode = 0; break;
			case 'c': if(mode != -1) usage(); mode = 1; if((address = ++i) == argc) usage(); break;
			case 'p':
			case 'l':
				if(++i == argc) usage();
				for(int j = 0; argv[i][j] != '\0'; j++) {
					if(argv[i][j] < '0' || argv[i][j] > '9') { std::cerr << argv[i] << " is not a positive number!" << std::endl; return 1; }
					value *= 10;
					value += argv[i][j]-'0';
				}
				switch(argv[i-1][1]) {
					case 'p':    port = value; break;
					case 'l': latency = value; break;
				}
				break;
			default: usage();
		}
	}
	if(port == 0) usage();

#ifdef __linux__
	struct termios oldt, newt;
	tcgetattr(STDIN_FILENO, &oldt);
	newt = oldt;
	newt.c_lflag &= ~(ICANON);
	tcsetattr(STDIN_FILENO, TCSANOW, &newt);
	setbuf(stdout, NULL);
	std::cout.setf(std::ios::unitbuf);
	std::signal(SIGINT, signal);
	std::signal(SIGTERM, signal);

	struct device dev;
	select_device(&dev);
	if(stopping) {
		libevdev_free(dev.dev);
		ioctl(dev.fd, EVIOCGRAB, 0);
		close(dev.fd);
		return 0;
	}
#else
	SetConsoleCtrlHandler(ConsoleHandler, TRUE);
#endif

#ifndef __linux__
	char one = 1;
	WSAData data; if(WSAStartup(MAKEWORD(2, 2), &data) != 0) { std::cerr << "WSAStartup() failed!" << std::endl; return 1; }
#endif
	struct sockaddr_in addr, client;
	sockfd = socket(AF_INET, SOCK_STREAM, 0);
	if(sockfd == -1) { std::cerr << "socket() failed!" << std::endl; return 1; }
	memset(&addr, 0, sizeof(addr));
	addr.sin_family = AF_INET;
	addr.sin_port = htons(port);
	int connfd;
	if(mode == 0) {
		addr.sin_addr.s_addr = htonl(INADDR_ANY);
		if((bind(sockfd, (sockaddr*)&addr, sizeof(addr))) == -1) { std::cerr << "bind() failed!" << std::endl; close(sockfd); return 1; }
		if((listen(sockfd, 5)) == -1) { std::cerr << "listen() failed!" << std::endl; close(sockfd); return 1; }
#ifdef __linux__
		unsigned
#endif
		int len = sizeof(client);
		std::cout << "Waiting for connections on port " << port << std::endl;
		connfd = accept(sockfd, (sockaddr*)&client, &len);
		if(connfd == -1) { std::cerr << "accept() failed!" << std::endl; close(connfd); return 1; }
#ifdef __linux__
		int one = 1;
#else
		char one = 1;
#endif
		if(setsockopt(sockfd,
#ifdef __linux__
			SOL_TCP,
#else
			IPPROTO_TCP,
#endif
			TCP_NODELAY, &one, sizeof(one)) == -1
		) std::cerr << "Failed to set TCP_NODELAY, but continuing" << std::endl;
		std::cout << "Client connected! Running. . ." << std::endl;
	}
	else {
		if(inet_pton(AF_INET, argv[address], &addr.sin_addr) <= 0) { std::cerr << "inet_pton() failed!" << std::endl; close(sockfd); return 1; }
		std::cout << "Connecting. . ." << std::endl;
		if(connect(sockfd, (sockaddr*)&addr, sizeof(addr)) == -1) { std::cerr << "connect() failed!" << std::endl; close(sockfd); return 1; }
		std::cout << "Connected! Running. . ." << std::endl;
		connfd = sockfd;
	}

#ifdef __linux__
	struct uinput_setup usetup;
	int kb = open("/dev/uinput", O_WRONLY | O_NONBLOCK);
	if(kb != -1) {
		ioctl(kb, UI_SET_EVBIT, EV_KEY);
		ioctl(kb, UI_SET_KEYBIT, KEY);
		memset(&usetup, 0, sizeof(usetup));
		usetup.id.bustype = BUS_USB;
		usetup.id.vendor = 0x1234;
		usetup.id.product = 0x5678;
		strcpy(usetup.name, "RhythmDoctorKeyboard");
		ioctl(kb, UI_DEV_SETUP, &usetup);
		ioctl(kb, UI_DEV_CREATE);
#else
	int kb = 0;
#endif

		dev_thread = std::make_unique<std::thread>(device_thread,
#ifdef __linux__
			&dev,
#endif
			connfd
		);
		recv_thread = std::make_unique<std::thread>(receive_thread, latency, connfd, kb);
		dev_thread->join();
		recv_thread->join();
#ifdef __linux__
	}
	else std::cerr << "Failed to create input device!" << std::endl;

	close(connfd);
	if(mode == 0 && sockfd != -1) close(sockfd);
#endif
	return 0;
}
