quinn-os/ipv4.c

327 lines
10 KiB
C

#include "ether.h"
#include "ipv4.h"
#include "arp.h"
#include "memory.h"
#include "hashmap.h"
static inline unsigned short sum_to_checksum(unsigned int sum) {
while (sum >> 16) {
sum = (sum >> 16) + (sum & 0xFFFF);
}
return ~sum;
}
struct outbound_packet_t {
unsigned int dest_addr;
char *packet;
int len;
struct ether_t *ether;
struct outbound_packet_t *next_packet;
};
struct ipv4_header_t {
unsigned char vhl;
unsigned char tos;
unsigned short len;
unsigned short ipid;
unsigned short ipoffset;
unsigned char ttl;
unsigned char protocol;
unsigned short chksum;
unsigned int src_addr;
unsigned int dest_addr;
unsigned char payload[];
}__attribute__((packed));
struct ipv4_frag_t {
struct ipv4_header_t *first_packet;
struct ipv4_header_t **middle_packets;
int middle_packet_count;
struct ipv4_header_t *last_packet;
};
map_t ipv4_frag_map;
struct outbound_packet_t *first_packet;
void ipv4_process_outbound() {
struct outbound_packet_t *packet = first_packet;
struct outbound_packet_t *prev_packet = (void *)0;
char *mac;
while (packet != (void *)0) {
mac = arp_req_ipv4(packet->ether, htonl(packet->dest_addr));
if (mac != (void *)0) {
struct ether_packet_t *etherp = (struct ether_packet_t *)malloc(packet->len + sizeof(struct ether_packet_t));
etherp->destmac[0] = mac[0];
etherp->destmac[1] = mac[1];
etherp->destmac[2] = mac[2];
etherp->destmac[3] = mac[3];
etherp->destmac[4] = mac[4];
etherp->destmac[5] = mac[5];
etherp->sourcemac[0] = packet->ether->mac[0];
etherp->sourcemac[1] = packet->ether->mac[1];
etherp->sourcemac[2] = packet->ether->mac[2];
etherp->sourcemac[3] = packet->ether->mac[3];
etherp->sourcemac[4] = packet->ether->mac[4];
etherp->sourcemac[5] = packet->ether->mac[5];
etherp->ethertype = htons(0x0800);
memcpy(etherp->payload, packet->packet, packet->len);
free(packet->packet);
ether_send(packet->ether, etherp, packet->len + sizeof(struct ether_packet_t));
free(etherp);
if (prev_packet == (void *)0) {
first_packet = packet->next_packet;
free(packet);
packet = first_packet;
} else {
prev_packet->next_packet = packet->next_packet;
free(packet);
packet = prev_packet->next_packet;
}
} else {
prev_packet = packet;
packet = packet->next_packet;
}
}
}
void ipv4_queue_outbound(struct ether_t *ether, unsigned int dest, char *packet, int len) {
struct outbound_packet_t *newpacket = (struct outbound_packet_t *)malloc(sizeof(struct outbound_packet_t));
newpacket->dest_addr = dest;
newpacket->packet = packet;
newpacket->len = len;
newpacket->next_packet = first_packet;
newpacket->ether = ether;
first_packet = newpacket;
}
void ipv4_send(struct ether_t *ether, int type, unsigned int dest, char *packet, int len) {
struct ipv4_header_t *iph = (struct ipv4_header_t *)malloc(sizeof(struct ipv4_header_t) + len);
iph->vhl = (4 << 4) | (sizeof(struct ipv4_header_t) / 4);
iph->tos = 0;
iph->len = htons(sizeof(struct ipv4_header_t) + len);
iph->ipid = 0;
iph->ipoffset = (1 << 6); // Dont fragment
iph->ttl = 0x40;
iph->protocol = type;
iph->chksum = 0;
iph->src_addr = htonl(ether->ipv4);
iph->dest_addr = dest;
unsigned int sum = 0;
for (unsigned int l = 0; l < (sizeof(struct ipv4_header_t) / 2); l++) {
sum += ((unsigned short *) iph)[l];
}
iph->chksum = sum_to_checksum(sum);
memcpy(iph->payload, packet, len);
free(packet);
ipv4_queue_outbound(ether, dest, iph, sizeof(struct ipv4_header_t) + len);
}
void ipv4_reassemble_packet(struct ether_t *ether, struct ipv4_header_t *iph) {
unsigned short key = htons(iph->ipid);
char shortkey[6];
char *ptr;
struct ipv4_frag_t *frag;
int error;
int i;
shortkey[5] == '\0';
ptr = &shortkey[5];
while (key > 0) {
ptr--;
*ptr = (key % 10) + '0';
key = key / 10;
}
error = hashmap_get(ipv4_frag_map, shortkey, (void **)&frag);
if (error == MAP_OK) {
if (!(htons(iph->ipoffset) & 0x8000)) {
// last fragment
if (frag->last_packet != (void *)0) {
kprintf("Got two last fragments!");
}
frag->last_packet = (struct ipv4_header_t *)malloc(htons(iph->len));
memcpy(frag->last_packet, iph, iph->len);
} else if (!(htons(iph->ipoffset) & 0x1fff)) {
// first fragment
if (frag->first_packet != (void *)0) {
kprintf("Got two last fragments!");
}
frag->first_packet = (struct ipv4_header_t *)malloc(htons(iph->len));
memcpy(frag->first_packet, iph, iph->len);
} else {
// middle fragment
if (frag->middle_packet_count == 0) {
frag->middle_packets = (struct ipv4_header_t **)malloc(sizeof(struct ipv4_header_t *));
} else {
frag->middle_packets = (struct ipv4_header_t **)realloc(frag->middle_packets, sizeof(struct ipv4_header_t *) * (frag->middle_packet_count + 1));
}
frag->middle_packets[frag->middle_packet_count] = (struct ipv4_header_t *)malloc(htons(iph->len));
memcpy(frag->middle_packets[frag->middle_packet_count], iph, iph->len);
frag->middle_packet_count++;
}
} else if (error == MAP_MISSING) {
frag = (struct ipv4_frag_t *)malloc(sizeof(struct ipv4_frag_t));
frag->first_packet = (void *)0;
frag->last_packet = (void *)0;
frag->middle_packet_count = 0;
if (!(htons(iph->ipoffset) & 0x8000)) {
// last fragment
frag->last_packet = (struct ipv4_header_t *)malloc(htons(iph->len));
memcpy(frag->last_packet, iph, iph->len);
} else if (!(htons(iph->ipoffset) & 0x1fff)) {
// first fragment
frag->first_packet = (struct ipv4_header_t *)malloc(htons(iph->len));
memcpy(frag->first_packet, iph, iph->len);
} else {
// middle fragment
if (frag->middle_packet_count == 0) {
frag->middle_packets = (struct ipv4_header_t **)malloc(sizeof(struct ipv4_header_t *));
} else {
frag->middle_packets = (struct ipv4_header_t **)realloc(frag->middle_packets, sizeof(struct ipv4_header_t *) * (frag->middle_packet_count + 1));
}
frag->middle_packets[frag->middle_packet_count] = (struct ipv4_header_t *)malloc(htons(iph->len));
memcpy(frag->middle_packets[frag->middle_packet_count], iph, iph->len);
frag->middle_packet_count++;
}
hashmap_put(ipv4_frag_map, shortkey, frag);
}
// check if fragment is complete...
if (frag->first_packet == (void *)0 || frag->last_packet == (void *)0) {
return;
}
char *data;
unsigned int dlen;
if (frag->middle_packet_count == 0) {
// check if we're 2 packets only frag
if (htons((frag->last_packet->ipoffset) & 0x1fff) * 8 == htons(frag->first_packet->len) - sizeof(struct ipv4_header_t)) {
// it's complete
dlen = (htons(frag->first_packet->len) - sizeof(struct ipv4_header_t)) + htons(frag->last_packet->len) - sizeof(struct ipv4_header_t);
data = (char *)malloc(dlen);
memcpy(data, frag->first_packet->payload, htons(frag->first_packet->len) - sizeof(struct ipv4_header_t));
memcpy(&data[htons(frag->first_packet->len) - sizeof(struct ipv4_header_t)], frag->last_packet->payload, htons(frag->last_packet->len) - sizeof(struct ipv4_header_t));
} else {
// it's not complete
return;
}
} else {
// sort the middle packets
int c, d;
struct ipv4_header_t *swap;
for (c = 0 ; c < ( frag->middle_packet_count - 1 ); c++) {
for (d = 0 ; d < frag->middle_packet_count - c - 1; d++) {
if (htons(frag->middle_packets[d]->ipoffset) & 0x1fff > htons(frag->middle_packets[d+1]->ipoffset) & 0x1fff) {
swap = frag->middle_packets[d];
frag->middle_packets[d] = frag->middle_packets[d+1];
frag->middle_packets[d+1] = swap;
}
}
}
// check for holes
int len = htons(frag->first_packet->len) - sizeof(struct ipv4_header_t);
for (i=0;i<frag->middle_packet_count;i++) {
if (htons(frag->middle_packets[i]->ipoffset) & 0x1fff * 8 == len) {
len += htons(frag->middle_packets[i]->len) - sizeof(struct ipv4_header_t);
} else {
return;
}
}
if (htons(frag->last_packet->ipoffset & 0x1fff) * 8 == len) {
len += htons(frag->last_packet->len) - sizeof(struct ipv4_header_t);
dlen = len;
data = malloc(dlen);
memcpy(data, frag->first_packet->payload, htons(frag->first_packet->len) - sizeof(struct ipv4_header_t));
len = htons(frag->first_packet->len) - sizeof(struct ipv4_header_t);
for (i=0;i<frag->middle_packet_count;i++) {
memcpy(&data[len], frag->middle_packets[i]->payload, htons(frag->middle_packets[i]->len) - sizeof(struct ipv4_header_t));
len += htons(frag->middle_packets[i]->len) - sizeof(struct ipv4_header_t);
}
memcpy(&data[len], frag->last_packet->payload, htons(frag->last_packet->len) - sizeof(struct ipv4_header_t));
} else {
return;
}
}
if (((frag->first_packet->vhl & 0xF0) >> 4) == 0x4) {
switch(frag->first_packet->protocol) {
case 1:
icmp_process_packet(ether, frag->first_packet->src_addr, data, dlen);
break;
case 6:
tcp_process_packet(ether, frag->first_packet->src_addr, frag->first_packet->dest_addr, data, dlen);
break;
}
}
// free up stuff
free(data);
hashmap_remove(ipv4_frag_map, shortkey);
free(frag->first_packet);
free(frag->last_packet);
for (i=0;i<frag->middle_packet_count;i++) {
free(frag->middle_packets[i]);
}
free(frag->middle_packets);
free(frag);
}
void ipv4_process_packet(struct ether_t *ether, char *packet, int len) {
struct ipv4_header_t *iph = (struct ipv4_header_t *)packet;
if (htons(iph->ipoffset) & 0x8000 || htons(iph->ipoffset) & 0x1fff) {
// part of a fragmented packet
ipv4_reassemble_packet(ether, iph);
return;
}
if (((iph->vhl & 0xF0) >> 4) == 0x4) {
switch (iph->protocol) {
case 1: //ICMP
icmp_process_packet(ether, iph->src_addr, iph->payload, htons(iph->len) - sizeof(struct ipv4_header_t));
break;
case 6: //TCP
tcp_process_packet(ether, iph->dest_addr, iph->src_addr, iph->payload, htons(iph->len) - sizeof(struct ipv4_header_t));
break;
case 17: //UDP
break;
default:
kprintf("Unsupported IP Protocol...\n");
break;
}
} else {
kprintf("Unsupported IP version...\n");
}
}
void init_ipv4() {
first_packet = (void *)0;
ipv4_frag_map = hashmap_new();
init_sockets();
}