quinn-os/ipv4.c
2022-07-20 21:00:19 +10:00

398 lines
13 KiB
C

#include "inttypes.h"
#include "ether.h"
#include "ipv4.h"
#include "socket.h"
#include "udp.h"
#include "tcp.h"
#include "icmp.h"
#include "arp.h"
#include "memory.h"
#include "hashmap.h"
#include "console.h"
#include "string.h"
extern volatile uint32_t timer_ticks;
static inline uint16_t sum_to_checksum(uint32_t sum) {
while (sum >> 16) {
sum = (sum >> 16) + (sum & 0xFFFF);
}
return ~sum;
}
struct outbound_packet_t {
uint32_t dest_addr;
char *packet;
int len;
struct ether_t *ether;
struct outbound_packet_t *next_packet;
};
struct ipv4_header_t {
uint8_t vhl;
uint8_t tos;
uint16_t len;
uint16_t ipid;
uint16_t ipoffset;
uint8_t ttl;
uint8_t protocol;
uint16_t chksum;
uint32_t src_addr;
uint32_t dest_addr;
uint8_t payload[];
} __attribute__((packed));
struct ipv4_frag_t {
char key[6];
struct ipv4_header_t *first_packet;
struct ipv4_header_t **middle_packets;
int middle_packet_count;
struct ipv4_header_t *last_packet;
uint32_t time_created;
};
map_t ipv4_frag_map;
struct outbound_packet_t *first_packet;
int ipv4_fragment_trim_iter(any_t item, any_t data) {
struct ptr_vector *pv = (struct ptr_vector *)item;
struct ipv4_frag_t *f = (struct ipv4_frag_t *)data;
if (f->time_created + 180000 >= timer_ticks) {
ptr_vector_append(pv, f);
}
return MAP_OK;
}
void ipv4_fragment_trim() {
struct ptr_vector totrim;
init_ptr_vector(&totrim);
hashmap_iterate(ipv4_frag_map, ipv4_fragment_trim_iter, &totrim);
for (size_t i = 0; i < ptr_vector_len(&totrim); i++) {
struct ipv4_frag_t *frag = (struct ipv4_frag_t *)ptr_vector_del(&totrim, 0);
hashmap_remove(ipv4_frag_map, frag->key);
if (frag->first_packet != NULL) {
dbfree(frag->first_packet, "ipv4_fragment_trim 1");
}
if (frag->last_packet != NULL) {
dbfree(frag->last_packet, "ipv4_fragment_trim 2");
}
for (size_t j = 0; j < frag->middle_packet_count; j++ ) {
dbfree(frag->middle_packets[j], "ipv4_fragment_trim 3");
}
if (frag->middle_packet_count > 0) {
dbfree(frag->middle_packets, "ipv4_fragment_trim 4");
}
dbfree(frag, "ipv4_fragment_trim 5");
}
destroy_ptr_vector(&totrim);
}
void ipv4_process_outbound() {
struct outbound_packet_t *packet = first_packet;
struct outbound_packet_t *prev_packet = NULL;
char *mac;
while (packet != NULL) {
if ((htonl(packet->dest_addr) & packet->ether->mask) != (packet->ether->ipv4 & packet->ether->mask)) {
if (packet->ether->default_gw != 0) {
mac = arp_req_ipv4(packet->ether, packet->ether->default_gw);
} else {
// drop packet.
if (prev_packet == NULL) {
first_packet = packet->next_packet;
free(packet);
packet = first_packet;
} else {
prev_packet->next_packet = packet->next_packet;
free(packet);
packet = prev_packet->next_packet;
}
continue;
}
} else {
mac = arp_req_ipv4(packet->ether, htonl(packet->dest_addr));
}
if (mac != NULL) {
struct ether_packet_t *etherp = (struct ether_packet_t *)dbmalloc(packet->len + sizeof(struct ether_packet_t), "ipv4_process_outbound 1");
etherp->destmac[0] = mac[0];
etherp->destmac[1] = mac[1];
etherp->destmac[2] = mac[2];
etherp->destmac[3] = mac[3];
etherp->destmac[4] = mac[4];
etherp->destmac[5] = mac[5];
etherp->sourcemac[0] = packet->ether->mac[0];
etherp->sourcemac[1] = packet->ether->mac[1];
etherp->sourcemac[2] = packet->ether->mac[2];
etherp->sourcemac[3] = packet->ether->mac[3];
etherp->sourcemac[4] = packet->ether->mac[4];
etherp->sourcemac[5] = packet->ether->mac[5];
etherp->ethertype = htons(0x0800);
memcpy(etherp->payload, packet->packet, packet->len);
dbfree(packet->packet, "ipv4_process_outbound 2");
ether_send(packet->ether, (char *)etherp, packet->len + sizeof(struct ether_packet_t));
dbfree(etherp, "ipv4_process_outbound 3");
if (prev_packet == NULL) {
first_packet = packet->next_packet;
dbfree(packet, "ipv4_process_outbound 4");
packet = first_packet;
} else {
prev_packet->next_packet = packet->next_packet;
dbfree(packet, "ipv4_process_outbound 5");
packet = prev_packet->next_packet;
}
} else {
prev_packet = packet;
packet = packet->next_packet;
}
}
}
void ipv4_queue_outbound(struct ether_t *ether, uint32_t dest, char *packet, int len) {
struct outbound_packet_t *newpacket = (struct outbound_packet_t *)dbmalloc(sizeof(struct outbound_packet_t), "ipv4 queue outbound");
newpacket->dest_addr = dest;
newpacket->packet = packet;
newpacket->len = len;
newpacket->next_packet = first_packet;
newpacket->ether = ether;
first_packet = newpacket;
}
void ipv4_send(struct ether_t *ether, int type, uint32_t dest, char *packet, int len) {
struct ipv4_header_t *iph = (struct ipv4_header_t *)dbmalloc(sizeof(struct ipv4_header_t) + len, "ipv4 send 1");
iph->vhl = (4 << 4) | (sizeof(struct ipv4_header_t) / 4);
iph->tos = 0;
iph->len = htons(sizeof(struct ipv4_header_t) + len);
iph->ipid = 0;
iph->ipoffset = (1 << 6); // Dont fragment
iph->ttl = 0x40;
iph->protocol = type;
iph->chksum = 0;
iph->src_addr = htonl(ether->ipv4);
iph->dest_addr = dest;
uint32_t sum = 0;
for (uint32_t l = 0; l < (sizeof(struct ipv4_header_t) / 2); l++) {
sum += ((uint16_t *)iph)[l];
}
iph->chksum = sum_to_checksum(sum);
memcpy(iph->payload, packet, len);
dbfree(packet, "ipv4 send 2");
ipv4_queue_outbound(ether, dest, (char *)iph, sizeof(struct ipv4_header_t) + len);
}
void ipv4_reassemble_packet(struct ether_t *ether, struct ipv4_header_t *iph) {
uint16_t key = htons(iph->ipid);
char shortkey[6];
char *ptr;
struct ipv4_frag_t *frag;
int error;
int i;
shortkey[5] = '\0';
ptr = &shortkey[5];
while (key > 0) {
ptr--;
*ptr = (key % 10) + '0';
key = key / 10;
}
error = hashmap_get(ipv4_frag_map, shortkey, (void **)&frag);
if (error == MAP_OK) {
if (!(htons(iph->ipoffset) & 0x8000)) {
// last fragment
if (frag->last_packet != NULL) {
kprintf("Got two last fragments!");
} else {
frag->last_packet = (struct ipv4_header_t *)dbmalloc(htons(iph->len), "ipv4 reassemble packet 1");
memcpy(frag->last_packet, iph, htons(iph->len));
}
} else if (!(htons(iph->ipoffset) & 0x1fff)) {
// first fragment
if (frag->first_packet != NULL) {
kprintf("Got two first fragments!");
} else {
frag->first_packet = (struct ipv4_header_t *)dbmalloc(htons(iph->len), "ipv4 reassemble packet 2");
memcpy(frag->first_packet, iph, htons(iph->len));
}
} else {
// middle fragment
if (frag->middle_packet_count == 0) {
frag->middle_packets = (struct ipv4_header_t **)dbmalloc(sizeof(struct ipv4_header_t *), "ipv4 reassemble packet 3");
} else {
frag->middle_packets = (struct ipv4_header_t **)dbrealloc(frag->middle_packets, sizeof(struct ipv4_header_t *) * (frag->middle_packet_count + 1), "ipv4 reassemble packet 4");
}
frag->middle_packets[frag->middle_packet_count] = (struct ipv4_header_t *)dbmalloc(htons(iph->len), "ipv4 reassemble packet 5");
memcpy(frag->middle_packets[frag->middle_packet_count], iph, htons(iph->len));
frag->middle_packet_count++;
}
} else if (error == MAP_MISSING) {
frag = (struct ipv4_frag_t *)dbmalloc(sizeof(struct ipv4_frag_t), "ipv4 reassemble packet 6");
frag->first_packet = NULL;
frag->last_packet = NULL;
frag->middle_packets = NULL;
frag->middle_packet_count = 0;
frag->time_created = timer_ticks;
memcpy(frag->key, shortkey, 6);
if (!(htons(iph->ipoffset) & 0x8000)) {
// last fragment
frag->last_packet = (struct ipv4_header_t *)dbmalloc(htons(iph->len), "ipv4 reassemble packet 7");
memcpy(frag->last_packet, iph, htons(iph->len));
} else if (!(htons(iph->ipoffset) & 0x1fff)) {
// first fragment
frag->first_packet = (struct ipv4_header_t *)dbmalloc(htons(iph->len), "ipv4 reassemble packet 8");
memcpy(frag->first_packet, iph, htons(iph->len));
} else {
// middle fragment
if (frag->middle_packet_count == 0) {
frag->middle_packets = (struct ipv4_header_t **)dbmalloc(sizeof(struct ipv4_header_t *), "ipv4 reassemble packet 9");
} else {
frag->middle_packets = (struct ipv4_header_t **)dbrealloc(frag->middle_packets, sizeof(struct ipv4_header_t *) * (frag->middle_packet_count + 1), "ipv4 reassemble packet 10");
}
frag->middle_packets[frag->middle_packet_count] = (struct ipv4_header_t *)dbmalloc(htons(iph->len), "ipv4 reassemble packet 11");
memcpy(frag->middle_packets[frag->middle_packet_count], iph, htons(iph->len));
frag->middle_packet_count++;
}
hashmap_put(ipv4_frag_map, shortkey, frag);
}
// check if fragment is complete...
if (frag->first_packet == NULL || frag->last_packet == NULL) {
return;
}
char *data;
uint32_t dlen;
if (frag->middle_packet_count == 0) {
// check if we're 2 packets only frag
if (htons((frag->last_packet->ipoffset) & 0x1fff) * 8 == htons(frag->first_packet->len) - sizeof(struct ipv4_header_t)) {
// it's complete
dlen = (htons(frag->first_packet->len) - sizeof(struct ipv4_header_t)) + htons(frag->last_packet->len) - sizeof(struct ipv4_header_t);
data = (char *)dbmalloc(dlen, "ipv4 reassemble packet 12");
memcpy(data, frag->first_packet->payload, htons(frag->first_packet->len) - sizeof(struct ipv4_header_t));
memcpy(&data[htons(frag->first_packet->len) - sizeof(struct ipv4_header_t)], frag->last_packet->payload,
htons(frag->last_packet->len) - sizeof(struct ipv4_header_t));
} else {
// it's not complete
return;
}
} else {
// sort the middle packets
int c, d;
struct ipv4_header_t *swap;
for (c = 0; c < (frag->middle_packet_count - 1); c++) {
for (d = 0; d < frag->middle_packet_count - c - 1; d++) {
if ((htons(frag->middle_packets[d]->ipoffset) & 0x1fff) > (htons(frag->middle_packets[d + 1]->ipoffset) & 0x1fff)) {
swap = frag->middle_packets[d];
frag->middle_packets[d] = frag->middle_packets[d + 1];
frag->middle_packets[d + 1] = swap;
}
}
}
// check for holes
int len = htons(frag->first_packet->len) - sizeof(struct ipv4_header_t);
for (i = 0; i < frag->middle_packet_count; i++) {
if ((htons(frag->middle_packets[i]->ipoffset) & 0x1fff) * 8 == len) {
len += htons(frag->middle_packets[i]->len) - sizeof(struct ipv4_header_t);
} else {
return;
}
}
if (htons(frag->last_packet->ipoffset & 0x1fff) * 8 == len) {
len += htons(frag->last_packet->len) - sizeof(struct ipv4_header_t);
dlen = len;
data = (char *)dbmalloc(dlen, "ipv4 reassemble packet 13");
memcpy(data, frag->first_packet->payload, htons(frag->first_packet->len) - sizeof(struct ipv4_header_t));
len = htons(frag->first_packet->len) - sizeof(struct ipv4_header_t);
for (i = 0; i < frag->middle_packet_count; i++) {
memcpy(&data[len], frag->middle_packets[i]->payload, htons(frag->middle_packets[i]->len) - sizeof(struct ipv4_header_t));
len += htons(frag->middle_packets[i]->len) - sizeof(struct ipv4_header_t);
}
memcpy(&data[len], frag->last_packet->payload, htons(frag->last_packet->len) - sizeof(struct ipv4_header_t));
} else {
return;
}
}
if (((frag->first_packet->vhl & 0xF0) >> 4) == 0x4) {
switch (frag->first_packet->protocol) {
case 1:
icmp_process_packet(ether, frag->first_packet->src_addr, data, dlen);
break;
case 6:
tcp_process_packet(ether, frag->first_packet->src_addr, frag->first_packet->dest_addr, (struct tcp_header_t *)data, dlen);
break;
}
}
// free up stuff
dbfree(data, "ipv4 reassemble packet 14");
hashmap_remove(ipv4_frag_map, shortkey);
dbfree(frag->first_packet, "ipv4 reassemble packet 15");
dbfree(frag->last_packet, "ipv4 reassemble packet 16");
for (i = 0; i < frag->middle_packet_count; i++) {
dbfree(frag->middle_packets[i], "ipv4 reassemble packet 17");
}
dbfree(frag->middle_packets, "ipv4 reassemble packet 18");
dbfree(frag, "ipv4 reassemble packet 19");
}
void ipv4_process_packet(struct ether_t *ether, char *packet, int len) {
struct ipv4_header_t *iph = (struct ipv4_header_t *)packet;
if (htons(iph->ipoffset) & 0x8000 || htons(iph->ipoffset) & 0x1fff) {
// part of a fragmented packet
ipv4_reassemble_packet(ether, iph);
return;
}
if (((iph->vhl & 0xF0) >> 4) == 0x4) {
switch (iph->protocol) {
case 1: // ICMP
icmp_process_packet(ether, iph->src_addr, (char *)iph->payload, htons(iph->len) - sizeof(struct ipv4_header_t));
break;
case 6: // TCP
tcp_process_packet(ether, iph->dest_addr, iph->src_addr, (struct tcp_header_t *)iph->payload, htons(iph->len) - sizeof(struct ipv4_header_t));
break;
case 17: // UDP
udp_process_packet(ether, iph->dest_addr, iph->src_addr, (struct udp_header_t *)iph->payload, htons(iph->len) - sizeof(struct ipv4_header_t));
break;
default:
kprintf("Unsupported IP Protocol (%d)...\n", iph->protocol);
break;
}
} else {
kprintf("Unsupported IP version...\n");
}
}
void init_ipv4() {
first_packet = NULL;
ipv4_frag_map = hashmap_new();
init_sockets();
}