#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <netdb.h>
#include <pthread.h>
#include <unistd.h>
#include <netinet/ip.h>
#include <sys/socket.h>
#include <sys/types.h>
#ifdef __gnu_linux__
#define __USE_GNU
#include <fcntl.h>
#endif

static void * receiver(void *arg) {
	int sock = (intptr_t) arg;
#ifdef SPLICE_F_MOVE
	size_t splice_len = INTPTR_MAX & ~(getpagesize() - 1);
	if (splice(sock, NULL, 1, NULL, splice_len, SPLICE_F_MOVE | SPLICE_F_MORE) >= 0) {
		while (splice(sock, NULL, 1, NULL, splice_len, SPLICE_F_MOVE | SPLICE_F_MORE) > 0);
	}
	else
#endif
	{
		ssize_t n, o;
		char buf[2048];
		while ((n = recv(sock, buf, sizeof buf, 0)) > 0) {
			while ((o = write(1, buf, (size_t) n)) > 0) {
				if ((n -= o) <= 0) {
					goto next_recv;
				}
			}
			if (o < 0) {
				perror("write");
				exit(-1);
			}
			else if (o == 0) {
				break;
			}
		next_recv:;
		}
		if (n < 0) {
			perror("recv");
			exit(-1);
		}
	}
	shutdown(sock, SHUT_RD);
	close(1);
	return NULL;
}

int main(int argc, char **argv) {
	unsigned long port;
	int sock;
	if (argc < 2 || argc > 3) {
		fprintf(stderr, "usage: %s [<host>] <port>\n", argc > 0 ? argv[0] : "netpipe");
		return -1;
	}
	port = strtoul(argv[argc - 1], NULL, 0);
	if (port == 0 || port > 0xFFFF) {
		fputs("invalid port\n", stderr);
		return -1;
	}
	if (argc == 2) {
		struct sockaddr_in sai = { .sin_family = AF_INET, .sin_port = htons(port), .sin_addr = { .s_addr = htonl(INADDR_ANY) } };
		socklen_t sai_len = sizeof sai;
		int lsock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
		if (lsock < 0) {
			perror("socket");
			return -1;
		}
		if (bind(lsock, (struct sockaddr *) &sai, sai_len) != 0) {
			perror("bind");
			return -1;
		}
		if (listen(lsock, 1) != 0) {
			perror("listen");
			return -1;
		}
		if ((sock = accept(lsock, (struct sockaddr *) &sai, &sai_len)) < 0) {
			perror("accept");
			return -1;
		}
		if (close(lsock) != 0) {
			perror("close");
			return -1;
		}
		{
			uint32_t host = ntohl(sai.sin_addr.s_addr);
			uint16_t port = ntohs(sai.sin_port);
			fprintf(stderr, "accepted connection from %u.%u.%u.%u:%u\n", host >> 24, host >> 16 & 0xFF, host >> 8 & 0xFF, host & 0xFF, port);
		}
	}
	else {
		struct addrinfo hints = { .ai_family = AF_INET, .ai_socktype = SOCK_STREAM, .ai_protocol = IPPROTO_TCP }, *info;
		int error = getaddrinfo(argv[1], NULL, &hints, &info);
		if (error != 0) {
			fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(error));
			return -1;
		}
		for (struct addrinfo *p = info; p != NULL; p = p->ai_next) {
			sock = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
			if (sock < 0) {
				perror("socket");
				return -1;
			}
			((struct sockaddr_in *) p->ai_addr)->sin_port = htons(port);
			if (connect(sock, p->ai_addr, p->ai_addrlen) == 0) {
				goto connected;
			}
			if (close(sock) != 0) {
				perror("close");
				return -1;
			}
		}
		perror("connect");
		return -1;
	connected:
		freeaddrinfo(info);
		{
			struct sockaddr_in sai;
			socklen_t sai_len = sizeof sai;
			getpeername(sock, (struct sockaddr *) &sai, &sai_len);
			uint32_t host = ntohl(sai.sin_addr.s_addr);
			uint16_t port = ntohs(sai.sin_port);
			fprintf(stderr, "connected to %u.%u.%u.%u:%u\n", host >> 24, host >> 16 & 0xFF, host >> 8 & 0xFF, host & 0xFF, port);
		}
	}
	{
		pthread_attr_t attr;
		pthread_t thread;
		if ((errno = pthread_attr_init(&attr)) != 0) {
			perror("pthread_attr_init");
			return -1;
		}
		if ((errno = pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED)) != 0) {
			perror("pthread_attr_setdetachstate");
			return -1;
		}
		if ((errno = pthread_create(&thread, &attr, receiver, (void *) (intptr_t) sock)) != 0) {
			perror("pthread_create");
			return -1;
		}
		if ((errno = pthread_attr_destroy(&attr)) != 0) {
			perror("pthread_attr_destroy");
			return -1;
		}
	}
	{
#ifdef SPLICE_F_MOVE
		size_t splice_len = INTPTR_MAX & ~(getpagesize() - 1);
		if (splice(0, NULL, sock, NULL, splice_len, SPLICE_F_MOVE | SPLICE_F_MORE) >= 0) {
			while (splice(0, NULL, sock, NULL, splice_len, SPLICE_F_MOVE | SPLICE_F_MORE) > 0);
		}
		else
#endif
		{
			ssize_t n, o;
			char buf[2048];
			while ((n = read(0, buf, sizeof buf)) > 0) {
				while ((o = send(sock, buf, (size_t) n, 0)) > 0) {
					if ((n -= o) <= 0) {
						goto next_read;
					}
				}
				if (o < 0) {
					perror("send");
					return -1;
				}
				else if (o == 0) {
					break;
				}
			next_read:;
			}
			if (n < 0) {
				perror("read");
				return -1;
			}
		}
		close(0);
		shutdown(sock, SHUT_WR);
		pthread_exit(NULL);
	}
	return 0;
}
