quinn-os/socket.c
2022-07-21 14:52:54 +10:00

757 lines
20 KiB
C

#include "inttypes.h"
#include "pvec.h"
#include "ether.h"
#include "ipv4.h"
#include "socket.h"
#include "schedule.h"
#include "tcp.h"
#include "udp.h"
#include "memory.h"
#include "string.h"
#include "console.h"
#include "rtc.h"
extern struct task_t *current_task;
static uint32_t socket_serial = 0;
struct ptr_vector sockets;
extern struct ptr_vector ether_devs;
static uint8_t port_bitmap[8192];
extern uint32_t timer_ticks;
extern void tcp_send(struct ether_t *ether, struct socket_t *sock, uint16_t flags, uint8_t *packet, uint32_t len, int thrice);
void socket_doclose(struct socket_t *sock);
struct sock_info_t {
uint32_t serialno;
int sock_state;
uint32_t poll_timeout;
uint32_t timeout;
};
int socket_info(uint8_t *buffer, int len, uint32_t last_socket) {
struct sock_info_t *sockinfo = (struct sock_info_t *)buffer;
int ret = 0;
for (size_t i = 0; i < ptr_vector_len(&sockets); i++) {
if (((struct socket_t *)ptr_vector_get(&sockets, i))->serial >= last_socket) {
for (size_t j = i; j < len / sizeof(struct sock_info_t) && j < ptr_vector_len(&sockets); j++) {
struct socket_t *s = (struct socket_t *)ptr_vector_get(&sockets, j);
sockinfo[ret].serialno = s->serial;
sockinfo[ret].sock_state = s->status;
if (s->poll_timeout > 0) {
if (s->poll_timeout < timer_ticks) {
sockinfo[ret].poll_timeout = s->poll_timeout;
} else {
sockinfo[ret].poll_timeout = s->poll_timeout - timer_ticks;
}
} else {
sockinfo[ret].poll_timeout = s->poll_timeout;
}
if (s->timeout > 0) {
if (s->timeout < timer_ticks) {
sockinfo[ret].timeout = s->timeout;
} else {
sockinfo[ret].timeout = s->timeout - timer_ticks;
}
} else {
sockinfo[ret].timeout = s->timeout;
}
last_socket++;
ret++;
}
return ret;
}
}
return ret;
}
void socket_timeout() {
for (size_t i = 0; i < ptr_vector_len(&sockets); i++) {
struct socket_t *s = (struct socket_t *)ptr_vector_get(&sockets, i);
if (!s->ether)
continue;
if (s->socket_type == 1 && (s->status == SOCKET_STATUS_CONNECTED || s->status == SOCKET_STATUS_CLOSEWAIT || s->status == SOCKET_STATUS_OPENED)) {
if (s->status != SOCKET_STATUS_CLOSEWAIT) {
if (s->tcp_keep_alive + 720000 < timer_ticks) {
if (s->tcp_keep_alive_probes == 0) {
tcp_send(s->ether, s, TCP_FLAGS_ACK, NULL, 0, 0);
s->tcp_keep_alive_probes = 1;
s->tcp_keep_alive_interval = timer_ticks;
} else if (s->tcp_keep_alive_probes < 9 && s->tcp_keep_alive_interval + 7500 < timer_ticks) {
tcp_send(s->ether, s, TCP_FLAGS_ACK, NULL, 0, 0);
s->tcp_keep_alive_probes++;
s->tcp_keep_alive_interval = timer_ticks;
} else {
sched_tcp_read_wakeup(s);
s->status = SOCKET_STATUS_CLOSEWAIT;
s->timeout = 12000;
tcp_send(s->ether, s, TCP_FLAGS_RES, NULL, 0, 0);
}
}
}
if (s->poll_timeout != 0 && (s->poll_timeout <= timer_ticks || s->status == SOCKET_STATUS_CLOSEWAIT)) {
sched_tcp_read_wakeup(s);
}
}
if (s->socket_type == 1 && s->timeout != 0 && s->timeout <= timer_ticks) {
s->status = SOCKET_STATUS_CLOSE;
socket_doclose(s);
}
}
}
struct socket_t *socket_get_from_serial(uint32_t serial) {
struct socket_t *sock = NULL;
for (size_t i = 0; i < ptr_vector_len(&sockets); i++) {
sock = (struct socket_t *)ptr_vector_get(&sockets, i);
if (sock->serial == serial) {
return sock;
}
}
return NULL;
}
extern int tcp_process_socket(struct socket_t *sock);
int socket_queue(struct socket_t *sock, uint8_t *payload, uint32_t len, uint16_t flags) {
if (sock->socket_type == 1) {
struct tcp_data_t *outpacket = (struct tcp_data_t *)dbmalloc(sizeof(struct tcp_data_t), "socket queue 1");
outpacket->flags = flags;
if (len > 0) {
outpacket->data = (uint8_t *)dbmalloc(len, "socket queue 2");
if (!outpacket->data)
return -1;
memcpy(outpacket->data, payload, len);
}
outpacket->len = len;
ptr_vector_append(&sock->out_packets, outpacket);
tcp_process_socket(sock);
return len;
}
return 0;
}
struct socket_t *socket_find(uint16_t dport, uint16_t sport, uint32_t src_ip, uint32_t seq, uint16_t flags) {
size_t i;
struct socket_t *new_socket;
for (i = 0; i < ptr_vector_len(&sockets); i++) {
struct socket_t *s = (struct socket_t *)ptr_vector_get(&sockets, i);
if (!s->ether)
continue;
if (s->port_recv == dport && s->port_dest == sport && (s->addr == 0xffffffff || s->addr == src_ip)) {
if (s->status == SOCKET_STATUS_OPENED || s->status == SOCKET_STATUS_CONNECTED || s->status == SOCKET_STATUS_CLOSE2 || s->status == SOCKET_STATUS_FINACK ||
s->status == SOCKET_STATUS_FINACK2) {
s->tcp_keep_alive = timer_ticks;
s->tcp_keep_alive_probes = 0;
s->tcp_keep_alive_interval = 0;
return s;
} else {
return NULL;
}
}
}
if (!(flags & TCP_FLAGS_SYN)) {
return NULL;
}
for (i = 0; i < ptr_vector_len(&sockets); i++) {
struct socket_t *s = (struct socket_t *)ptr_vector_get(&sockets, i);
if (!s->ether)
continue;
if (s->port_recv == dport && s->addr == 0) {
if (s->status == SOCKET_STATUS_LISTEN) {
struct task_t *task = (struct task_t *)s->data;
if (task->waiting_socket_count < 25) {
new_socket = (struct socket_t *)dbmalloc(sizeof(struct socket_t), "socket find 1");
if (new_socket) {
memset(new_socket, 0, sizeof(struct socket_t));
new_socket->socket_type = 1;
new_socket->port_recv = dport;
new_socket->tcp_sock.seq_number = random();
new_socket->tcp_sock.ack_number = seq + 1;
init_ptr_vector(&new_socket->out_packets);
init_ptr_vector(&new_socket->packets);
new_socket->addr = src_ip;
new_socket->bytes_available = 0;
new_socket->bytes_read = 0;
new_socket->port_dest = sport;
new_socket->offset = 0;
new_socket->ether = s->ether;
new_socket->ref = 1;
new_socket->serial = ++socket_serial;
new_socket->tcp_keep_alive = timer_ticks;
new_socket->tcp_keep_alive_probes = 0;
new_socket->tcp_keep_alive_interval = 0;
new_socket->poll_timeout = 0;
new_socket->timeout = 0;
new_socket->status = SOCKET_STATUS_OPENED;
task->waiting_sockets[task->waiting_socket_count++] = new_socket;
if (task->state == TASK_SLEEPING && task->sleep_reason == SLEEP_TCP_ACCEPT) {
task->state = TASK_RUNNING;
}
}
} else {
kprintf("Task waiting count >= 25");
}
} else {
return NULL;
}
}
}
return NULL;
}
uint32_t socket_ioctl(int serial, int flag) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL)
return 0;
sock->flags |= flag;
return sock->flags;
}
void *socket_get_packet(struct socket_t *sock) {
if (sock->socket_type == 1) {
struct tcp_data_t *packet;
if (ptr_vector_len(&sock->packets) == 0) {
return NULL;
}
packet = (struct tcp_data_t *)ptr_vector_del(&sock->packets, 0);
return packet;
}
return NULL;
}
int socket_status(uint32_t serial) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL) {
return -1;
}
return sock->status;
}
int socket_recv_from(uint32_t serial, char *buffer, int len, uint32_t *addr) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL)
return 0;
if (sock->socket_type == 17) {
if (ptr_vector_len(&sock->packets) == 0) {
return -1;
}
struct udp_data_t *data = (struct udp_data_t *)ptr_vector_del(&sock->packets, 0);
if (addr != NULL) {
*addr = sock->addr;
}
if (len > data->len) {
len = data->len;
}
memcpy(buffer, data->data, len);
return len;
}
return 0;
}
int socket_poll(uint32_t serial, uint32_t timeout, int state) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL) {
return -1;
}
if (sock->status == SOCKET_STATUS_CLOSE) {
sock->poll_timeout = 0;
return -1;
}
if (sock->bytes_available || ptr_vector_len(&sock->packets) > 0) {
sock->poll_timeout = 0;
return 1;
}
if (sock->status == SOCKET_STATUS_CLOSEWAIT) {
sock->status = SOCKET_STATUS_CLOSE;
sock->poll_timeout = 0;
socket_doclose(sock);
return -1;
}
if (state == 0) {
sock->poll_timeout = timer_ticks + (timeout * 100);
current_task->sleep_reason = SLEEP_TCP_POLL;
current_task->state = TASK_SLEEPING;
return -2;
} else {
if (sock->poll_timeout > timer_ticks) {
kprintf("PID %d %d\n", current_task->pid, current_task->state);
return -2;
}
}
sock->poll_timeout = 0;
return 0;
}
int socket_read(uint32_t serial, char *buffer, int len) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL) {
return 0;
}
if (sock->status == SOCKET_STATUS_CLOSE) {
return 0;
}
if (sock->socket_type == 1) {
struct tcp_data_t *data = NULL;
uint32_t size_to_read;
int flag_psh = 0;
sock->offset = 0;
do {
if (sock->bytes_available) {
data = sock->curr_packet.tcp;
} else {
if (ptr_vector_len(&sock->packets) > 0) {
data = (struct tcp_data_t *)socket_get_packet(sock);
} else {
if (sock->status == SOCKET_STATUS_CLOSEWAIT) {
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
return 0;
} else if (sock->status == SOCKET_STATUS_CONNECTED || sock->status == SOCKET_STATUS_OPENED) {
if (sock->flags & SOCKET_FLAG_NOBLOCK) {
return -2;
} else {
current_task->sleep_reason = SLEEP_TCP_READ;
current_task->state = TASK_SLEEPING;
return -1;
}
} else {
return 0;
}
}
sock->bytes_available = data->len;
sock->bytes_read = 0;
}
if (len - sock->offset < sock->bytes_available) {
size_to_read = len - sock->offset;
} else {
size_to_read = sock->bytes_available;
}
if (data->len > 0) {
memcpy(buffer + sock->offset, data->data + sock->bytes_read, size_to_read);
}
sock->offset += size_to_read;
if (htons(data->flags) & TCP_FLAGS_PSH) {
flag_psh = 1;
}
if (size_to_read < sock->bytes_available) {
sock->bytes_available -= size_to_read;
sock->bytes_read += size_to_read;
sock->curr_packet.tcp = data;
} else {
sock->bytes_available = 0;
sock->curr_packet.tcp = NULL;
dbfree(data->data, "socket_read");
dbfree(data, "socket_read 2");
}
if (flag_psh) {
break;
}
} while (sock->offset < len);
int ret = sock->offset;
sock->offset = 0;
return ret;
} else {
return 0;
}
}
void socket_clear_port_bitmap(uint16_t port) {
uint32_t index = port / 8;
uint32_t offset = port % 8;
port_bitmap[index] &= ~(1 << offset);
}
void socket_mark_port_bitmap(uint16_t port) {
uint32_t index = port / 8;
uint32_t offset = port % 8;
port_bitmap[index] |= (1 << offset);
}
int socket_check_port_bitmap(uint16_t port) {
uint32_t index = port / 8;
uint32_t offset = port % 8;
return (port_bitmap[index] & (1 << offset));
}
unsigned socket_get_port(void) {
uint16_t next = random() % 16384 + 49152;
uint16_t out = next;
uint16_t start = next;
while (socket_check_port_bitmap(out)) {
next++;
if (next >= 65535) {
next = 49152;
}
if (next == start) {
// run out of sockets!
return 0;
}
out = next;
}
return out;
}
int socket_bind(uint32_t serial, uint32_t dest_ip, uint16_t dest_port, uint16_t src_port) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock->socket_type != 17) {
return -1;
}
if (src_port == 0) {
sock->port_recv = socket_get_port();
if (sock->port_recv == 0) {
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
return -1;
}
socket_mark_port_bitmap(sock->port_recv);
} else {
if (!socket_check_port_bitmap(src_port)) {
sock->port_recv = src_port;
socket_mark_port_bitmap(src_port);
} else {
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
return -1;
}
}
sock->addr = dest_ip;
sock->port_dest = dest_port;
sock->ether = ether_find_from_ipv4(dest_ip);
if (!sock->ether) {
if (ptr_vector_len(&ether_devs) > 0) {
sock->ether = (struct ether_t *)ptr_vector_get(&ether_devs, 0);
} else {
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
return -1;
}
}
return 0;
}
int socket_connect(uint32_t serial, uint32_t dest_ip, uint16_t dest_port) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL) {
return -1;
}
if (sock->socket_type != 1) {
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
return -1;
}
sock->ether = ether_find_from_ipv4(dest_ip);
if (!sock->ether) {
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
return -1;
}
sock->port_recv = socket_get_port();
if (sock->port_recv == 0) {
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
return -1;
}
socket_mark_port_bitmap(sock->port_recv);
sock->tcp_sock.seq_number = random();
sock->tcp_sock.ack_number = 0;
sock->addr = dest_ip;
sock->port_dest = dest_port;
tcp_send(sock->ether, sock, TCP_FLAGS_SYN, NULL, 0, 0);
return 0;
}
void socket_doclose(struct socket_t *sock) {
size_t i, j;
if (sock->socket_type == 1) {
if (sock->status == SOCKET_STATUS_CLOSE) {
sched_tcp_read_wakeup(sock);
sched_remove_socket_from_all_tasks(sock);
for (i = 0; i < ptr_vector_len(&sockets); i++) {
struct socket_t *s = (struct socket_t *)ptr_vector_get(&sockets, i);
if (s == sock) {
ptr_vector_del(&sockets, i);
socket_clear_port_bitmap(sock->port_recv);
for (j = 0; j < ptr_vector_len(&sock->packets); j++) {
struct tcp_data_t *d = (struct tcp_data_t *)ptr_vector_get(&sock->packets, j);
if (d->len > 0) {
dbfree(d->data, "socket close 1");
}
}
ptr_vector_apply(&sock->packets, free);
destroy_ptr_vector(&sock->packets);
for (j = 0; j < ptr_vector_len(&sock->out_packets); j++) {
struct tcp_data_t *d = (struct tcp_data_t *)ptr_vector_get(&sock->out_packets, j);
if (d->len > 0) {
dbfree(d->data, "socket close 2");
}
}
ptr_vector_apply(&sock->out_packets, free);
destroy_ptr_vector(&sock->out_packets);
dbfree(sock, "socket close 3");
break;
}
}
} else {
kprintf("DO CLOSE with SOCK STATUS != CLOSE %d\n", current_task->pid);
}
} else {
sched_remove_socket_from_all_tasks(sock);
for (i = 0; i < ptr_vector_len(&sockets); i++) {
struct socket_t *s = (struct socket_t *)ptr_vector_get(&sockets, i);
if (s == sock) {
ptr_vector_del(&sockets, i);
socket_clear_port_bitmap(sock->port_recv);
for (j = 0; j < ptr_vector_len(&sock->packets); j++) {
struct udp_data_t *d = (struct udp_data_t *)ptr_vector_get(&sock->packets, j);
free(d);
}
ptr_vector_apply(&sock->packets, free);
destroy_ptr_vector(&sock->packets);
for (j = 0; j < ptr_vector_len(&sock->out_packets); j++) {
struct udp_data_t *d = (struct udp_data_t *)ptr_vector_get(&sock->out_packets, j);
free(d);
}
ptr_vector_apply(&sock->out_packets, free);
destroy_ptr_vector(&sock->out_packets);
free(sock);
break;
}
}
}
}
void socket_close(uint32_t serial) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL) {
return; // Already closed
}
// remove from task struct
sched_remove_socket(sock);
sock->ref--;
if (sock->ref == 0) {
if (sock->socket_type == 1 && sock->ether != NULL) {
if (sock->status == SOCKET_STATUS_CONNECTED) {
sock->send_payload = 1;
socket_queue(sock, NULL, 0, TCP_FLAGS_FIN | TCP_FLAGS_ACK);
} else {
sock->send_payload = 1;
socket_queue(sock, NULL, 0, TCP_FLAGS_RES);
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
}
} else {
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
}
}
}
uint32_t socket_open(uint8_t type) {
struct socket_t *sock;
if (type == 1) {
sock = (struct socket_t *)dbmalloc(sizeof(struct socket_t), "socket open 1");
if (!sock) {
return 0;
}
memset(sock, 0, sizeof(struct socket_t));
sock->status = SOCKET_STATUS_OPENED;
sock->socket_type = type;
sock->bytes_available = 0;
sock->bytes_read = 0;
sock->ref = 1;
sock->serial = ++socket_serial;
sock->tcp_sock.seq_number = 0;
sock->tcp_sock.ack_number = 0;
sock->ether = NULL;
ptr_vector_append(&sockets, sock);
ptr_vector_append(&current_task->sockets, sock);
} else if (type == 17) {
sock = (struct socket_t *)dbmalloc(sizeof(struct socket_t), "socket open 2");
if (!sock) {
return 0;
}
memset(sock, 0, sizeof(struct socket_t));
sock->status = SOCKET_STATUS_OPENED;
sock->socket_type = type;
sock->bytes_available = 0;
sock->bytes_read = 0;
sock->ref = 1;
sock->serial = ++socket_serial;
ptr_vector_append(&sockets, sock);
ptr_vector_append(&current_task->sockets, sock);
} else {
kprintf("Unsupported Socket type!\n");
return 0;
}
return sock->serial;
}
int socket_listen(uint32_t serial, uint32_t listenip, uint16_t port) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock->socket_type != 1) {
kprintf("Wrong socket Type\n");
return -1;
}
sock->port_recv = port;
sock->tcp_sock.seq_number = random();
sock->tcp_sock.ack_number = 0;
sock->addr = 0;
sock->port_dest = 0;
sock->ether = ether_find_from_ipv4(listenip);
if (!sock->ether) {
kprintf("Cant find ether\n");
sock->status = SOCKET_STATUS_CLOSE;
socket_doclose(sock);
return -1;
}
sock->status = 3;
sock->data = (void *)current_task;
return 0;
}
uint32_t socket_accept(uint32_t serial, struct inet_addr *client_addr) {
struct socket_t *socket = socket_get_from_serial(serial);
struct task_t *task = (struct task_t *)socket->data;
struct socket_t *new_socket;
int i;
if (task->waiting_socket_count > 0) {
new_socket = task->waiting_sockets[0];
for (i = 1; i < task->waiting_socket_count; i++) {
task->waiting_sockets[i - 1] = task->waiting_sockets[i];
}
task->waiting_socket_count--;
// send SYN ACK
if (new_socket->socket_type == 1) {
tcp_send(new_socket->ether, new_socket, TCP_FLAGS_SYN | TCP_FLAGS_ACK, NULL, 0, 0);
new_socket->tcp_sock.seq_number += 1;
}
if (client_addr != NULL) {
client_addr->type = new_socket->socket_type;
client_addr->addr = new_socket->addr;
}
ptr_vector_append(&sockets, new_socket);
ptr_vector_append(&task->sockets, new_socket);
return new_socket->serial;
}
task->state = TASK_SLEEPING;
task->sleep_reason = SLEEP_TCP_ACCEPT;
return 0;
}
int socket_write(uint32_t serial, uint8_t *payload, uint32_t len) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock->socket_type == 1) {
return socket_queue(sock, payload, len, TCP_FLAGS_PSH | TCP_FLAGS_ACK);
} else if (sock->socket_type == 17) {
udp_send(sock->ether, sock, payload, len);
return len;
}
return 0;
}
void init_sockets() {
init_ptr_vector(&sockets);
memset(port_bitmap, 0, 8192);
}