quinn-os/socket.c
2017-10-08 11:10:14 +10:00

505 lines
12 KiB
C

#include "ipv4.h"
#include "tcp.h"
#include "udp.h"
#include "socket.h"
#include "memory.h"
#include "string.h"
#include "console.h"
struct socket_t **sockets;
unsigned int socket_count;
extern void tcp_send(struct ether_t *ether, struct socket_t *sock, unsigned short flags, unsigned char *packet, unsigned int len);
int socket_queue(struct socket_t* sock, unsigned char* payload, unsigned int len, unsigned short flags) {
int i;
if (sock->socket_type == 1) {
if (sock->out_packet_count == 0) {
sock->out_packets.tcp = (struct tcp_data_t **)malloc(sizeof(struct tcp_data_t *));
} else {
sock->out_packets.tcp = (struct tcp_data_t **)realloc(sock->out_packets.tcp, sizeof(struct tcp_data_t *) * (sock->out_packet_count + 1));
}
sock->out_packets.tcp[sock->out_packet_count] = (struct tcp_data_t *)malloc(sizeof(struct tcp_data_t));
sock->out_packets.tcp[sock->out_packet_count]->flags = flags;
if (len > 0) {
sock->out_packets.tcp[sock->out_packet_count]->data = (char *)malloc(len);
}
memcpy(sock->out_packets.tcp[sock->out_packet_count]->data, payload, len);
sock->out_packets.tcp[sock->out_packet_count]->len = len;
sock->out_packet_count++;
if (sock->send_payload == 1) {
struct tcp_data_t *payload_start = sock->out_packets.tcp[0];
for (i=1;i<sock->packet_count;i++) {
sock->out_packets.tcp[i-1] = sock->out_packets.tcp[i];
}
sock->out_packet_count--;
if (sock->out_packet_count == 0) {
free(sock->out_packets.tcp);
} else {
sock->out_packets.tcp = (struct tcp_data_t **)realloc(sock->out_packets.tcp, sizeof(struct tcp_data_t *) * sock->out_packet_count);
}
if (payload_start->flags & TCP_FLAGS_RES) {
sock->status = 1;
}
tcp_send(sock->ether, sock, payload_start->flags, payload_start->data, payload_start->len);
if (payload_start->flags == TCP_FLAGS_FIN) {
sock->status = 1;
socket_close((unsigned int)sock);
}
if (len > 0) {
free(payload_start->data);
}
free(payload_start);
sock->send_payload = 0;
}
}
return 0;
}
struct socket_t *socket_find(unsigned int dport, unsigned int sport, unsigned int src_ip, unsigned int seq) {
int i, j;
struct socket_t *new_socket;
for (i=0;i<socket_count;i++) {
if (sockets[i]->port_recv == dport && sockets[i]->port_dest == sport && sockets[i]->addr == src_ip) {
if (sockets[i]->status == 0 || sockets[i]->status == 2 || sockets[i]->status == 5) {
return sockets[i];
} else {
return (void *)0;
}
}
}
for (i=0;i<socket_count;i++) {
if (sockets[i]->port_recv == dport && sockets[i]->addr == 0) {
if (sockets[i]->status == 3) {
struct task_t *task = (struct task_t *)sockets[i]->data;
// for (j=0;j<task->waiting_socket_count;j++) {
// if (task->waiting_sockets[j]->port_recv == dport && task->waiting_sockets[j]->port_dest == sport && task->waiting_sockets[j]->addr == src_ip) {
// return (void *)0;
// }
// }
if (task->waiting_socket_count < 25) {
new_socket = (struct socket_t *)malloc(sizeof(struct socket_t));
if (new_socket) {
memset(new_socket, 0, sizeof(struct socket_t));
new_socket->socket_type = 1;
new_socket->port_recv = dport;
new_socket->tcp_sock.seq_number = 0;
new_socket->tcp_sock.ack_number = seq + 1;
new_socket->packets.tcp = (void *)0;
new_socket->packet_count = 0;
new_socket->addr = src_ip;
new_socket->bytes_available = 0;
new_socket->bytes_read = 0;
new_socket->port_dest = sport;
new_socket->offset = 0;
new_socket->ether = sockets[i]->ether;
new_socket->status = 0;
task->waiting_sockets[task->waiting_socket_count++] = new_socket;
if (task->state == TASK_SLEEPING && task->sleep_reason == SLEEP_TCP_ACCEPT) {
task->state = TASK_RUNNING;
}
}
}
} else {
return (void *)0;
}
}
}
return (void *)0;
}
void *socket_get_packet(struct socket_t *sock) {
int i;
if (sock->socket_type == 1) {
struct tcp_data_t *packet;
if (sock->packet_count == 0) {
return (void *)0;
}
packet = sock->packets.tcp[0];
for (i=0;i<sock->packet_count-1;i++) {
sock->packets.tcp[i] = sock->packets.tcp[i+1];
}
sock->packet_count--;
if (sock->packet_count == 0) {
free(sock->packets.tcp);
sock->packets.tcp = (void *)0;
} else {
sock->packets.tcp = (struct tcp_data_t **)realloc(sock->packets.tcp, sizeof(struct tcp_data_t *) * sock->packet_count);
}
return packet;
}
return (void *)0;
}
int socket_status(struct socket_t *sock) {
return sock->status;
}
int socket_recv_from(struct socket_t *sock, char *buffer, int len, unsigned int *addr) {
int i;
if (sock->socket_type == 17) {
if (sock->packet_count == 0) {
return -1;
}
struct udp_data_t *data = sock->packets.udp[0];
for (i=0;i<sock->packet_count-1;i++) {
sock->packets.udp[i] = sock->packets.udp[i+1];
}
if (sock->packet_count == 0) {
free(sock->packets.udp);
sock->packets.udp = (void *)0;
} else {
sock->packets.udp = (struct udp_data_t **)realloc(sock->packets.udp, sizeof(struct udp_data_t *) * sock->packet_count);
}
if (addr != (void *)0) {
*addr = sock->addr;
}
if (len > data->len) {
len = data->len;
}
memcpy(buffer, data->data, len);
return len;
}
return 0;
}
int socket_read(struct socket_t *sock, char *buffer, int len) {
if (sock->socket_type == 1) {
struct tcp_data_t *data = (void *)0;
int size_to_read = 0;
int offset = sock->offset;
do {
if (sock->bytes_available) {
data = sock->curr_packet.tcp;
} else {
if (sock->packet_count > 0) {
data = (struct tcp_data_t *)socket_get_packet(sock);
} else {
if (sock->status == 1) {
return 0;
}
sock->offset = offset;
((struct task_t *)sock->data)->sleep_reason = SLEEP_TCP_READ;
((struct task_t *)sock->data)->state = TASK_SLEEPING;
return -1; // return -1 to signal retry.
}
sock->bytes_available = data->len;
sock->bytes_read = 0;
}
if (len < offset + sock->bytes_available) {
size_to_read = len;
} else {
size_to_read = offset + sock->bytes_available;
}
if (data->len > 0) {
memcpy(buffer + offset, data->data + sock->bytes_read, size_to_read);
}
offset += size_to_read;
if (size_to_read < sock->bytes_available) {
sock->bytes_available = sock->bytes_available - size_to_read;
sock->bytes_read += size_to_read;
sock->curr_packet.tcp = data;
} else {
sock->bytes_available = 0;
sock->curr_packet.tcp = (void *)0;
free(data);
}
} while (!size_to_read);
sock->offset = 0;
return size_to_read;
} else {
return 0;
}
}
unsigned socket_get_port(void) {
static unsigned short next = 49152;
unsigned short out = next;
next++;
return out;
}
int socket_bind(struct socket_t* sock, unsigned int dest_ip, unsigned short dest_port) {
if (sock->socket_type != 17) {
return -1;
}
sock->port_recv = socket_get_port();
sock->addr = dest_ip;
sock->port_dest = dest_port;
sock->ether = ether_find_from_ipv4(dest_ip);
if (!sock->ether) {
return -1;
}
return 0;
}
int socket_connect(struct socket_t* sock, unsigned int dest_ip, unsigned short dest_port) {
if (sock->socket_type != 1) {
return -1;
}
sock->port_recv = socket_get_port();
sock->tcp_sock.seq_number = 0;
sock->tcp_sock.ack_number = 0;
sock->packets.tcp = (void *)0;
sock->packet_count = 0;
sock->addr = dest_ip;
sock->port_dest = dest_port;
sock->ether = ether_find_from_ipv4(dest_ip);
if (!sock->ether) {
return -1;
}
tcp_send(sock->ether, sock, (1<<1), (void *)0, 0);
return 0;
}
void socket_close(struct socket_t *sock) {
int i;
int j;
if (sock->socket_type == 1) {
if (sock->status != 1) {
sock->status = 5;
} else {
for (i=0;i<socket_count;i++) {
if (sockets[i] == sock) {
for (j=0;j<sock->packet_count;j++) {
free(sock->packets.tcp[j]->data);
free(sock->packets.tcp[j]);
}
free(sock->packets.tcp);
free(sock);
for (j=i+1;j<socket_count;j++) {
sockets[j-1] = sockets[j];
}
socket_count--;
}
}
}
} else {
for (i=0;i<socket_count;i++) {
if (sockets[i] == sock) {
for (j=0;j<sock->packet_count;j++) {
free(sock->packets.udp[j]->data);
free(sock->packets.udp[j]);
}
free(sock->packets.udp);
free(sock);
for (j=i+1;j<socket_count;j++) {
sockets[j-1] = sockets[j];
}
socket_count--;
}
}
}
}
void socket_doclose(struct socket_t *sock) {
int i;
int j;
if (sock->socket_type == 1) {
if (sock->status != 5) {
sock->status = 1;
} else {
for (i=0;i<socket_count;i++) {
if (sockets[i] == sock) {
for (j=0;j<sock->packet_count;j++) {
free(sock->packets.tcp[j]->data);
free(sock->packets.tcp[j]);
}
free(sock->packets.tcp);
free(sock);
for (j=i+1;j<socket_count;j++) {
sockets[j-1] = sockets[j];
}
socket_count--;
}
}
}
}
}
struct socket_t *socket_open(unsigned char type) {
struct socket_t *sock;
if (type == 1) {
sock = (struct socket_t *)malloc(sizeof(struct socket_t));
memset(sock, 0, sizeof(struct socket_t));
if (socket_count == 0) {
sockets = (struct socket_t **)malloc(sizeof(struct socket_t *));
} else {
sockets = (struct socket_t **)realloc(sockets, sizeof(struct socket_t *) * (socket_count + 1));
}
sock->status = 0;
sock->socket_type = type;
sock->bytes_available = 0;
sock->bytes_read = 0;
sockets[socket_count] = sock;
socket_count++;
} else if (type == 17) {
sock = (struct socket_t *)malloc(sizeof(struct socket_t));
memset(sock, 0, sizeof(struct socket_t));
if (socket_count == 0) {
sockets = (struct socket_t **)malloc(sizeof(struct socket_t *));
} else {
sockets = (struct socket_t **)realloc(sockets, sizeof(struct socket_t *) * (socket_count + 1));
}
sock->status = 0;
sock->socket_type = type;
sock->bytes_available = 0;
sock->bytes_read = 0;
sockets[socket_count] = sock;
socket_count++;
} else {
kprintf("Unsupported Socket type!\n");
return (void *)0;
}
return sock;
}
extern void *current_task;
int socket_listen(struct socket_t *sock, unsigned int listenip, unsigned short port) {
if (sock->socket_type != 1) {
kprintf("Wrong socket Type\n");
return -1;
}
sock->port_recv = port;
sock->tcp_sock.seq_number = 0;
sock->tcp_sock.ack_number = 0;
sock->packets.tcp = (void *)0;
sock->packet_count = 0;
sock->addr = 0;
sock->port_dest = 0;
sock->ether = ether_find_from_ipv4(listenip);
sock->status = 3;
sock->data = (void *)current_task;
if (!sock->ether) {
kprintf("Cant find ether\n");
return -1;
}
return 0;
}
struct socket_t *socket_accept(struct socket_t *socket, struct inet_addr *client_addr) {
struct task_t *task = (struct task_t *)socket->data;
struct socket_t *new_socket;
int i;
if (task->waiting_socket_count > 0) {
new_socket = task->waiting_sockets[0];
for (i=1;i<task->waiting_socket_count;i++) {
task->waiting_sockets[i-1] = task->waiting_sockets[i];
}
task->waiting_socket_count--;
if (socket_count == 0) {
sockets = (struct socket_t **)malloc(sizeof(struct socket_t *));
} else {
sockets = (struct socket_t **)realloc(sockets, sizeof(struct socket_t *) * (socket_count + 1));
}
sockets[socket_count] = new_socket;
socket_count++;
// send SYN ACK
if (new_socket->socket_type == 1) {
tcp_send(new_socket->ether, new_socket, TCP_FLAGS_SYN | TCP_FLAGS_ACK, (void *)0, 0);
}
if (client_addr != (void *)0) {
client_addr->type = new_socket->socket_type;
client_addr->addr = new_socket->addr;
}
return new_socket;
}
task->state = TASK_SLEEPING;
task->sleep_reason = SLEEP_TCP_ACCEPT;
return (void *)0;
}
int socket_write(struct socket_t* sock, unsigned char* payload, unsigned int len) {
if (sock->socket_type == 1) {
return socket_queue(sock, payload, len, TCP_FLAGS_PSH | TCP_FLAGS_ACK);
} else if (sock->socket_type == 17) {
udp_send(sock->ether, sock, payload, len);
return len;
}
return 0;
}
void init_sockets() {
socket_count = 0;
}