quinn-os/tcp.c
2017-10-08 16:49:34 +10:00

229 lines
6.7 KiB
C

#include "ether.h"
#include "tcp.h"
#include "socket.h"
#include "ipv4.h"
#include "memory.h"
#include "console.h"
#include "string.h"
static inline unsigned short sum_to_checksum(unsigned int sum) {
while (sum >> 16) {
sum = (sum >> 16) + (sum & 0xFFFF);
}
return ~sum;
}
void tcp_sock_add_packet(struct socket_t *sock, struct tcp_data_t *packet) {
if (sock->packet_count == 0) {
sock->packets.tcp = (struct tcp_data_t **)malloc(sizeof(struct tcp_data_t *));
} else {
sock->packets.tcp = (struct tcp_data_t **)realloc(sock->packets.tcp, sizeof(struct tcp_data_t *) * (sock->packet_count + 1));
}
sock->packets.tcp[sock->packet_count] = packet;
sock->packet_count++;
}
void tcp_send(struct ether_t *ether, struct socket_t *sock, unsigned short flags, unsigned char *packet, unsigned int len) {
struct tcp_header_t *header = (struct tcp_header_t *)malloc(sizeof(struct tcp_header_t) + len);
header->source_port = htons(sock->port_recv);
header->dest_port = htons(sock->port_dest);
header->seq_number = htonl(sock->tcp_sock.seq_number);
header->ack_number = flags & (TCP_FLAGS_ACK) ? htonl(sock->tcp_sock.ack_number) : 0;
header->flags = htons(0x5000 ^ (flags & 0xFF));
header->window = htons(1800);
header->checksum = 0; // Fill in later
header->urgent = 0;
if (flags & TCP_FLAGS_SYN) {
sock->tcp_sock.seq_number += 1;
} else {
sock->tcp_sock.seq_number += len;
}
if (packet != (void *)0 && len != 0) {
memcpy(header->payload, packet, len);
}
//data
unsigned int sum = 0;
unsigned int l = len + sizeof(struct tcp_header_t);
unsigned short *ptr = (unsigned short *)header;
for (; l > 1; l -= 2) {
sum += *ptr++;
if (sum & 0x80000000)
sum = (sum & 0xffff) + (sum >> 16);
}
if(l & 1) {
sum += *((unsigned char *) ptr);
}
//pseudo header
sum += htons((ether->ipv4 >> 16) & 0xffff);
sum += htons(ether->ipv4 & 0xffff);
sum += htons((sock->addr >> 16) & 0xffff);
sum += htons(sock->addr & 0xffff);
sum += htons(0x06);
sum += htons(sizeof(struct tcp_header_t) + len);
header->checksum = sum_to_checksum(sum);
ipv4_send(ether, 0x06, htonl(sock->addr), header, len + sizeof(struct tcp_header_t));
}
void tcp_process_packet(struct ether_t *ether, unsigned int dest, unsigned int src, struct tcp_header_t *packet, unsigned int len) {
// find socket
int i;
struct tcp_data_t *data;
unsigned int dlen = len - ((htons(packet->flags) >> 12) * 4);
struct socket_t *sock = socket_find(htons(packet->dest_port), htons(packet->source_port), htonl(src), htonl(packet->seq_number));
if (sock != (void *)0) {
if (((struct task_t *)sock->data)->state == TASK_SLEEPING && ((struct task_t *)sock->data)->sleep_reason == SLEEP_TCP_READ) {
((struct task_t *)sock->data)->state = TASK_RUNNING;
}
if ((htons(packet->flags) & TCP_FLAGS_FIN) && (htons(packet->flags) & TCP_FLAGS_ACK)) {
sock->tcp_sock.ack_number = htonl(packet->seq_number) + 1;
tcp_send(ether, sock, TCP_FLAGS_ACK, (void *)0, 0);
socket_doclose(sock);
return;
}
if (sock->tcp_sock.seq_number != htonl(packet->ack_number)) {
//kprintf("Warning, dropping packet, Wrong Ack Expecting %d Got %d\n", sock->tcp_sock.seq_number, htonl(packet->ack_number));
return;
}
if ((htons(packet->flags) & TCP_FLAGS_SYN) && (htons(packet->flags) & TCP_FLAGS_ACK)) {
sock->tcp_sock.ack_number = htonl(packet->seq_number) + dlen + 1;
tcp_send(ether, sock, TCP_FLAGS_ACK, (void *)0, 0);
sock->status = 2;
if (sock->out_packet_count > 0) {
struct tcp_data_t *payload = sock->out_packets.tcp[0];
for (i=1;i<sock->out_packet_count;i++) {
sock->out_packets.tcp[i-1] = sock->out_packets.tcp[i];
}
sock->out_packet_count--;
if (payload->flags & TCP_FLAGS_RES) {
sock->status = 1;
}
tcp_send(sock->ether, sock, payload->flags, payload->data, payload->len);
if (payload->flags == (TCP_FLAGS_FIN | TCP_FLAGS_ACK)) {
sock->status = 1;
socket_doclose((unsigned int)sock);
if (payload->len > 0) {
free(payload->data);
}
free(payload);
return;
}
if (payload->len > 0) {
free(payload->data);
}
free(payload);
} else {
sock->send_payload = 1;
}
} else if (htons(packet->flags) & TCP_FLAGS_RES) {
socket_doclose(sock);
return;
} else {
if (dlen == 0) {
if (htons(packet->flags) & TCP_FLAGS_ACK) {
if (sock->out_packet_count > 0) {
struct tcp_data_t *payload = sock->out_packets.tcp[0];
for (i=1;i<sock->out_packet_count;i++) {
sock->out_packets.tcp[i-1] = sock->out_packets.tcp[i];
}
sock->out_packet_count--;
if (sock->out_packet_count == 0) {
free(sock->out_packets.tcp);
} else {
sock->out_packets.tcp = (struct tcp_data_t **)realloc(sock->out_packets.tcp, sizeof(struct tcp_data_t *) * sock->out_packet_count);
}
if (payload->flags & TCP_FLAGS_RES) {
sock->status = 1;
}
tcp_send(sock->ether, sock, payload->flags, payload->data, payload->len);
if (payload->flags == (TCP_FLAGS_FIN | TCP_FLAGS_ACK)) {
if (payload->len > 0) {
free(payload->data);
}
free(payload);
return;
}
if (payload->len > 0) {
free(payload->data);
}
free(payload);
} else {
sock->send_payload = 1;
}
} else {
kprintf("Got something else?\n");
}
return;
} else {
data = (struct tcp_data_t *)malloc(sizeof(struct tcp_data_t));
data->len = dlen;
}
if (data->len > 0) {
data->data = (char *)malloc(data->len);
memcpy(data->data, packet->payload, data->len);
}
sock->tcp_sock.ack_number = htonl(packet->seq_number) + dlen;
if ((htons(packet->flags) & TCP_FLAGS_SYN) && (htons(packet->flags) & TCP_FLAGS_ACK) && dlen == 0) {
sock->tcp_sock.ack_number += 1;
}
sock->tcp_sock.ack_number = htonl(packet->seq_number) + data->len;
tcp_send(ether, sock, TCP_FLAGS_ACK, (void *)0, 0);
tcp_sock_add_packet(sock, data);
if (sock->out_packet_count > 0) {
struct tcp_data_t *payload = sock->out_packets.tcp[0];
for (i=1;i<sock->out_packet_count;i++) {
sock->out_packets.tcp[i-1] = sock->out_packets.tcp[i];
}
sock->out_packet_count--;
if (payload->flags & TCP_FLAGS_RES) {
sock->status = 1;
}
tcp_send(sock->ether, sock, payload->flags, payload->data, payload->len);
if (payload->flags == (TCP_FLAGS_FIN | TCP_FLAGS_ACK)) {
sock->status = 1;
socket_doclose((unsigned int)sock);
if (payload->len > 0) {
free(payload->data);
}
free(payload);
return;
}
if (payload->len > 0) {
free(payload->data);
}
free(payload);
} else {
sock->send_payload = 1;
}
}
} else {
// got data for non existant socket.
}
}