quinn-os/socket.c
2021-12-24 20:42:51 +10:00

600 lines
14 KiB
C

#include "pvec.h"
#include "ether.h"
#include "ipv4.h"
#include "socket.h"
#include "schedule.h"
#include "tcp.h"
#include "udp.h"
#include "memory.h"
#include "string.h"
#include "console.h"
#include "rtc.h"
extern struct task_t *current_task;
static unsigned int socket_serial = 0;
struct ptr_vector sockets;
extern struct ptr_vector ether_devs;
static unsigned char port_bitmap[8192];
extern void tcp_send(struct ether_t *ether, struct socket_t *sock, unsigned short flags, unsigned char *packet, unsigned int len);
struct socket_t *socket_get_from_serial(unsigned int serial) {
struct socket_t *sock = NULL;
for (int i = 0; i < ptr_vector_len(&sockets); i++) {
sock = ptr_vector_get(&sockets, i);
if (sock->serial == serial) {
return sock;
}
}
return NULL;
}
int socket_queue(struct socket_t* sock, unsigned char* payload, unsigned int len, unsigned short flags) {
if (sock->socket_type == 1) {
struct tcp_data_t *outpacket = (struct tcp_data_t *)malloc(sizeof(struct tcp_data_t));
outpacket->flags = flags;
if (len > 0) {
outpacket->data = (unsigned char *)malloc(len);
}
memcpy(outpacket->data, payload, len);
outpacket->len = len;
if (sock->send_payload == 1) {
tcp_send(sock->ether, sock, outpacket->flags, outpacket->data, outpacket->len);
if (len > 0) {
dbfree(outpacket->data, "socket_queue 1");
}
dbfree(outpacket, "socket_queue 2");
sock->send_payload = 0;
} else {
ptr_vector_append(&sock->out_packets, outpacket);
}
}
return 0;
}
struct socket_t *socket_find(unsigned short dport, unsigned short sport, unsigned int src_ip, unsigned int seq) {
int i;
struct socket_t *new_socket;
for (i=0;i<ptr_vector_len(&sockets);i++) {
struct socket_t *s = ptr_vector_get(&sockets, i);
if (s->port_recv == dport && s->port_dest == sport && (s->addr == 0xffffffff || s->addr == src_ip)) {
if (s->status == 0 || s->status == 2 || s->status == 5) {
return s;
} else {
return (void *)0;
}
}
}
for (i=0;i<ptr_vector_len(&sockets);i++) {
struct socket_t *s = ptr_vector_get(&sockets, i);
if (s->port_recv == dport && s->addr == 0) {
if (s->status == 3) {
struct task_t *task = (struct task_t *)s->data;
if (task->waiting_socket_count < 25) {
new_socket = (struct socket_t *)malloc(sizeof(struct socket_t));
if (new_socket) {
memset(new_socket, 0, sizeof(struct socket_t));
new_socket->socket_type = 1;
new_socket->port_recv = dport;
new_socket->tcp_sock.seq_number = 0;
new_socket->tcp_sock.ack_number = seq + 1;
init_ptr_vector(&new_socket->out_packets);
init_ptr_vector(&new_socket->packets);
new_socket->addr = src_ip;
new_socket->bytes_available = 0;
new_socket->bytes_read = 0;
new_socket->port_dest = sport;
new_socket->offset = 0;
new_socket->ether = s->ether;
new_socket->ref = 1;
new_socket->serial = ++socket_serial;
new_socket->status = 0;
task->waiting_sockets[task->waiting_socket_count++] = new_socket;
if (task->state == TASK_SLEEPING && task->sleep_reason == SLEEP_TCP_ACCEPT) {
task->state = TASK_RUNNING;
}
}
} else {
kprintf("Task waiting count >= 25");
}
} else {
return (void *)0;
}
}
}
return (void *)0;
}
unsigned int socket_ioctl(int serial, int flag) {
struct socket_t *sock = socket_get_from_serial(serial);
sock->flags |= flag;
return sock->flags;
}
void *socket_get_packet(struct socket_t *sock) {
if (sock->socket_type == 1) {
struct tcp_data_t *packet;
if (ptr_vector_len(&sock->packets) == 0) {
return (void *)0;
}
packet = ptr_vector_del(&sock->packets, 0);
return packet;
}
return (void *)0;
}
int socket_status(unsigned int serial) {
struct socket_t *sock = socket_get_from_serial(serial);
return sock->status;
}
int socket_recv_from(unsigned int serial, char *buffer, int len, unsigned int *addr) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock->socket_type == 17) {
if (ptr_vector_len(&sock->packets) == 0) {
return -1;
}
struct udp_data_t *data = ptr_vector_del(&sock->packets, 0);
if (addr != (void *)0) {
*addr = sock->addr;
}
if (len > data->len) {
len = data->len;
}
memcpy(buffer, data->data, len);
return len;
}
return 0;
}
int socket_read(unsigned int serial, char *buffer, int len) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL) {
return 0;
}
if (sock->status == 5 || sock->status == 1) {
return 0;
}
if (sock->socket_type == 1) {
struct tcp_data_t *data = (void *)0;
int size_to_read = 0;
int offset = sock->offset;
do {
if (sock->bytes_available) {
data = sock->curr_packet.tcp;
} else {
if (ptr_vector_len(&sock->packets) > 0) {
data = (struct tcp_data_t *)socket_get_packet(sock);
} else {
sock->offset = offset;
if (sock->flags & SOCKET_FLAG_NOBLOCK) {
return -2;
} else {
//((struct task_t *)sock->data)
current_task->sleep_reason = SLEEP_TCP_READ;
//((struct task_t *)sock->data)
current_task->state = TASK_SLEEPING;
return -1;
}
}
sock->bytes_available = data->len;
sock->bytes_read = 0;
}
if (len < sock->bytes_available) {
size_to_read = len;
} else {
size_to_read = sock->bytes_available;
}
if (data->len > 0) {
memcpy(buffer + offset, data->data + sock->bytes_read, size_to_read);
}
offset += size_to_read;
if (size_to_read < sock->bytes_available) {
sock->bytes_available = sock->bytes_available - size_to_read;
sock->bytes_read += size_to_read;
sock->curr_packet.tcp = data;
} else {
sock->bytes_available = 0;
sock->curr_packet.tcp = (void *)0;
dbfree(data->data, "socket_read");
dbfree(data, "socket_read 2");
}
} while (!size_to_read);
sock->offset = 0;
return size_to_read;
} else {
return 0;
}
}
void socket_clear_port_bitmap(unsigned short port) {
unsigned int index = port / 8;
unsigned int offset = port % 8;
port_bitmap[index] &= ~(1 << offset);
}
void socket_mark_port_bitmap(unsigned short port) {
unsigned int index = port / 8;
unsigned int offset = port % 8;
port_bitmap[index] |= (1 << offset);
}
int socket_check_port_bitmap(unsigned short port) {
unsigned int index = port / 8;
unsigned int offset = port % 8;
return (port_bitmap[index] & (1 << offset));
}
unsigned socket_get_port(void) {
unsigned short next = random() % 16384 + 49152;
unsigned short out = next;
unsigned short start = next;
while (socket_check_port_bitmap(out)) {
next++;
if (next >= 65535) {
next = 49152;
}
if (next == start) {
// run out of sockets!
return 0;
}
out = next;
}
next++;
if (next == 65535) {
next = 49152;
}
return out;
}
int socket_bind(unsigned int serial, unsigned int dest_ip, unsigned short dest_port, unsigned short src_port) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock->socket_type != 17) {
return -1;
}
if (src_port == 0) {
sock->port_recv = socket_get_port();
if (sock->port_recv == 0) {
return -1;
}
socket_mark_port_bitmap(sock->port_recv);
} else {
if (!socket_check_port_bitmap(src_port)) {
sock->port_recv = src_port;
socket_mark_port_bitmap(src_port);
} else {
return -1;
}
}
sock->addr = dest_ip;
sock->port_dest = dest_port;
sock->ether = ether_find_from_ipv4(dest_ip);
if (!sock->ether) {
if (ptr_vector_len(&ether_devs) > 0) {
sock->ether = ptr_vector_get(&ether_devs, 0);
} else {
return -1;
}
}
return 0;
}
int socket_connect(unsigned int serial, unsigned int dest_ip, unsigned short dest_port) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL) {
return -1;
}
if (sock->socket_type != 1) {
return -1;
}
sock->port_recv = socket_get_port();
if (sock->port_recv == 0) {
return -1;
}
socket_mark_port_bitmap(sock->port_recv);
sock->tcp_sock.seq_number = 0;
sock->tcp_sock.ack_number = 0;
sock->addr = dest_ip;
sock->port_dest = dest_port;
sock->ether = ether_find_from_ipv4(dest_ip);
if (!sock->ether) {
return -1;
}
tcp_send(sock->ether, sock, (1<<1), (void *)0, 0);
return 0;
}
void socket_close(unsigned int serial) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock == NULL) {
return; // Already closed
}
int i;
int j;
// remove from task struct
sched_remove_socket(sock);
sock->ref--;
if (sock->ref == 0) {
if (sock->socket_type == 1) {
if (sock->status != 1) {
sock->status = 5;
} else {
for (i=0;i<ptr_vector_len(&sockets);i++) {
struct socket_t *s = ptr_vector_get(&sockets, i);
if (s == sock) {
socket_clear_port_bitmap(s->port_recv);
for (j=0;j<ptr_vector_len(&sock->packets);j++) {
struct tcp_data_t *d = ptr_vector_get(&sock->packets, j);
if (d->len > 0) {
dbfree(d->data, "socket close 1");
}
}
ptr_vector_apply(&sock->packets, free);
destroy_ptr_vector(&sock->packets);
for (j=0;j<ptr_vector_len(&sock->out_packets);j++) {
struct tcp_data_t *d = ptr_vector_get(&sock->out_packets, j);
if (d->len > 0) {
dbfree(d->data, "socket close 2");
}
}
ptr_vector_apply(&sock->out_packets, free);
destroy_ptr_vector(&sock->out_packets);
dbfree(sock, "socket close 3");
ptr_vector_del(&sockets, i);
break;
}
}
}
} else {
for (i=0;i<ptr_vector_len(&sockets);i++) {
struct socket_t *s = ptr_vector_get(&sockets, i);
if (s == sock) {
socket_clear_port_bitmap(sock->port_recv);
for (j=0;j<ptr_vector_len(&sock->packets);j++) {
struct udp_data_t *d = ptr_vector_get(&sock->packets, j);
free(d);
}
ptr_vector_apply(&sock->packets, free);
destroy_ptr_vector(&sock->packets);
for (j=0;j<ptr_vector_len(&sock->out_packets);j++) {
struct udp_data_t *d = ptr_vector_get(&sock->out_packets, j);
free(d);
}
ptr_vector_apply(&sock->out_packets, free);
destroy_ptr_vector(&sock->out_packets);
free(sock);
ptr_vector_del(&sockets, i);
break;
}
}
}
}
}
void socket_doclose(struct socket_t *sock) {
int i;
int j;
if (sock->socket_type == 1) {
if (sock->status != 5) {
sock->status = 1;
} else {
if (sock->ref == 0) {
for (i=0;i<ptr_vector_len(&sockets);i++) {
struct socket_t *s = ptr_vector_get(&sockets, i);
if (s == sock) {
sched_remove_socket_from_all_tasks(s);
socket_clear_port_bitmap(sock->port_recv);
for (j=0;j<ptr_vector_len(&sock->packets);j++) {
struct tcp_data_t *d = ptr_vector_get(&sock->packets, j);
free(d);
}
ptr_vector_apply(&sock->packets, free);
destroy_ptr_vector(&sock->packets);
for (j=0;j<ptr_vector_len(&sock->out_packets);j++) {
struct tcp_data_t *d = ptr_vector_get(&sock->out_packets, j);
free(d);
}
ptr_vector_apply(&sock->out_packets, free);
destroy_ptr_vector(&sock->out_packets);
dbfree(sock, "socket_doclose 4");
ptr_vector_del(&sockets, i);
break;
}
}
}
}
}
}
unsigned int socket_open(unsigned char type) {
struct socket_t *sock;
if (type == 1) {
sock = (struct socket_t *)malloc(sizeof(struct socket_t));
if (!sock){
return 0;
}
memset(sock, 0, sizeof(struct socket_t));
sock->status = 0;
sock->socket_type = type;
sock->bytes_available = 0;
sock->bytes_read = 0;
sock->ref = 1;
sock->serial = ++socket_serial;
ptr_vector_append(&sockets, sock);
ptr_vector_append(&current_task->sockets, sock);
} else if (type == 17) {
sock = (struct socket_t *)malloc(sizeof(struct socket_t));
if (!sock){
return 0;
}
memset(sock, 0, sizeof(struct socket_t));
sock->status = 0;
sock->socket_type = type;
sock->bytes_available = 0;
sock->bytes_read = 0;
sock->ref = 1;
sock->serial = ++socket_serial;
ptr_vector_append(&sockets, sock);
ptr_vector_append(&current_task->sockets, sock);
} else {
kprintf("Unsupported Socket type!\n");
return 0;
}
return sock->serial;
}
int socket_listen(unsigned int serial, unsigned int listenip, unsigned short port) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock->socket_type != 1) {
kprintf("Wrong socket Type\n");
return -1;
}
sock->port_recv = port;
sock->tcp_sock.seq_number = 0;
sock->tcp_sock.ack_number = 0;
sock->addr = 0;
sock->port_dest = 0;
sock->ether = ether_find_from_ipv4(listenip);
sock->status = 3;
sock->data = (void *)current_task;
if (!sock->ether) {
kprintf("Cant find ether\n");
return -1;
}
return 0;
}
unsigned int socket_accept(unsigned int serial, struct inet_addr *client_addr) {
struct socket_t *socket = socket_get_from_serial(serial);
struct task_t *task = (struct task_t *)socket->data;
struct socket_t *new_socket;
int i;
if (task->waiting_socket_count > 0) {
new_socket = task->waiting_sockets[0];
for (i=1;i<task->waiting_socket_count;i++) {
task->waiting_sockets[i-1] = task->waiting_sockets[i];
}
task->waiting_socket_count--;
// send SYN ACK
if (new_socket->socket_type == 1) {
tcp_send(new_socket->ether, new_socket, TCP_FLAGS_SYN | TCP_FLAGS_ACK, (void *)0, 0);
new_socket->tcp_sock.seq_number = 1;
}
if (client_addr != (void *)0) {
client_addr->type = new_socket->socket_type;
client_addr->addr = new_socket->addr;
}
ptr_vector_append(&sockets, new_socket);
ptr_vector_append(&task->sockets, new_socket);
return new_socket->serial;
}
task->state = TASK_SLEEPING;
task->sleep_reason = SLEEP_TCP_ACCEPT;
return 0;
}
int socket_write(unsigned int serial, unsigned char* payload, unsigned int len) {
struct socket_t *sock = socket_get_from_serial(serial);
if (sock->socket_type == 1) {
return socket_queue(sock, payload, len, TCP_FLAGS_PSH | TCP_FLAGS_ACK);
} else if (sock->socket_type == 17) {
udp_send(sock->ether, sock, payload, len);
return len;
}
return 0;
}
void init_sockets() {
init_ptr_vector(&sockets);
memset(port_bitmap, 0, 8192);
}