600 lines
14 KiB
C
600 lines
14 KiB
C
#include "pvec.h"
|
|
#include "ether.h"
|
|
#include "ipv4.h"
|
|
#include "socket.h"
|
|
#include "schedule.h"
|
|
#include "tcp.h"
|
|
#include "udp.h"
|
|
#include "memory.h"
|
|
#include "string.h"
|
|
#include "console.h"
|
|
#include "rtc.h"
|
|
|
|
|
|
extern struct task_t *current_task;
|
|
|
|
static unsigned int socket_serial = 0;
|
|
|
|
struct ptr_vector sockets;
|
|
|
|
extern struct ptr_vector ether_devs;
|
|
|
|
static unsigned char port_bitmap[8192];
|
|
|
|
extern void tcp_send(struct ether_t *ether, struct socket_t *sock, unsigned short flags, unsigned char *packet, unsigned int len);
|
|
|
|
struct socket_t *socket_get_from_serial(unsigned int serial) {
|
|
struct socket_t *sock = NULL;
|
|
|
|
for (int i = 0; i < ptr_vector_len(&sockets); i++) {
|
|
sock = ptr_vector_get(&sockets, i);
|
|
if (sock->serial == serial) {
|
|
return sock;
|
|
}
|
|
}
|
|
return NULL;
|
|
}
|
|
|
|
int socket_queue(struct socket_t* sock, unsigned char* payload, unsigned int len, unsigned short flags) {
|
|
if (sock->socket_type == 1) {
|
|
struct tcp_data_t *outpacket = (struct tcp_data_t *)malloc(sizeof(struct tcp_data_t));
|
|
|
|
outpacket->flags = flags;
|
|
if (len > 0) {
|
|
outpacket->data = (unsigned char *)malloc(len);
|
|
}
|
|
memcpy(outpacket->data, payload, len);
|
|
outpacket->len = len;
|
|
|
|
if (sock->send_payload == 1) {
|
|
tcp_send(sock->ether, sock, outpacket->flags, outpacket->data, outpacket->len);
|
|
|
|
if (len > 0) {
|
|
dbfree(outpacket->data, "socket_queue 1");
|
|
}
|
|
dbfree(outpacket, "socket_queue 2");
|
|
sock->send_payload = 0;
|
|
} else {
|
|
ptr_vector_append(&sock->out_packets, outpacket);
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
struct socket_t *socket_find(unsigned short dport, unsigned short sport, unsigned int src_ip, unsigned int seq) {
|
|
int i;
|
|
struct socket_t *new_socket;
|
|
|
|
for (i=0;i<ptr_vector_len(&sockets);i++) {
|
|
struct socket_t *s = ptr_vector_get(&sockets, i);
|
|
if (s->port_recv == dport && s->port_dest == sport && (s->addr == 0xffffffff || s->addr == src_ip)) {
|
|
if (s->status == 0 || s->status == 2 || s->status == 5) {
|
|
return s;
|
|
} else {
|
|
return (void *)0;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (i=0;i<ptr_vector_len(&sockets);i++) {
|
|
struct socket_t *s = ptr_vector_get(&sockets, i);
|
|
if (s->port_recv == dport && s->addr == 0) {
|
|
if (s->status == 3) {
|
|
struct task_t *task = (struct task_t *)s->data;
|
|
|
|
if (task->waiting_socket_count < 25) {
|
|
|
|
new_socket = (struct socket_t *)malloc(sizeof(struct socket_t));
|
|
if (new_socket) {
|
|
memset(new_socket, 0, sizeof(struct socket_t));
|
|
new_socket->socket_type = 1;
|
|
|
|
new_socket->port_recv = dport;
|
|
new_socket->tcp_sock.seq_number = 0;
|
|
new_socket->tcp_sock.ack_number = seq + 1;
|
|
|
|
init_ptr_vector(&new_socket->out_packets);
|
|
init_ptr_vector(&new_socket->packets);
|
|
|
|
new_socket->addr = src_ip;
|
|
new_socket->bytes_available = 0;
|
|
new_socket->bytes_read = 0;
|
|
new_socket->port_dest = sport;
|
|
new_socket->offset = 0;
|
|
new_socket->ether = s->ether;
|
|
new_socket->ref = 1;
|
|
new_socket->serial = ++socket_serial;
|
|
|
|
new_socket->status = 0;
|
|
task->waiting_sockets[task->waiting_socket_count++] = new_socket;
|
|
if (task->state == TASK_SLEEPING && task->sleep_reason == SLEEP_TCP_ACCEPT) {
|
|
task->state = TASK_RUNNING;
|
|
}
|
|
}
|
|
} else {
|
|
kprintf("Task waiting count >= 25");
|
|
}
|
|
} else {
|
|
return (void *)0;
|
|
}
|
|
}
|
|
}
|
|
|
|
return (void *)0;
|
|
}
|
|
|
|
unsigned int socket_ioctl(int serial, int flag) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
sock->flags |= flag;
|
|
|
|
return sock->flags;
|
|
}
|
|
|
|
void *socket_get_packet(struct socket_t *sock) {
|
|
if (sock->socket_type == 1) {
|
|
struct tcp_data_t *packet;
|
|
|
|
if (ptr_vector_len(&sock->packets) == 0) {
|
|
return (void *)0;
|
|
}
|
|
|
|
packet = ptr_vector_del(&sock->packets, 0);
|
|
|
|
return packet;
|
|
}
|
|
return (void *)0;
|
|
}
|
|
|
|
int socket_status(unsigned int serial) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
return sock->status;
|
|
}
|
|
|
|
int socket_recv_from(unsigned int serial, char *buffer, int len, unsigned int *addr) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
if (sock->socket_type == 17) {
|
|
if (ptr_vector_len(&sock->packets) == 0) {
|
|
return -1;
|
|
}
|
|
|
|
struct udp_data_t *data = ptr_vector_del(&sock->packets, 0);
|
|
|
|
if (addr != (void *)0) {
|
|
*addr = sock->addr;
|
|
}
|
|
|
|
if (len > data->len) {
|
|
len = data->len;
|
|
}
|
|
|
|
memcpy(buffer, data->data, len);
|
|
|
|
return len;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int socket_read(unsigned int serial, char *buffer, int len) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
|
|
if (sock == NULL) {
|
|
return 0;
|
|
}
|
|
|
|
if (sock->status == 5 || sock->status == 1) {
|
|
return 0;
|
|
}
|
|
|
|
if (sock->socket_type == 1) {
|
|
|
|
struct tcp_data_t *data = (void *)0;
|
|
|
|
int size_to_read = 0;
|
|
int offset = sock->offset;
|
|
do {
|
|
if (sock->bytes_available) {
|
|
data = sock->curr_packet.tcp;
|
|
} else {
|
|
if (ptr_vector_len(&sock->packets) > 0) {
|
|
data = (struct tcp_data_t *)socket_get_packet(sock);
|
|
} else {
|
|
sock->offset = offset;
|
|
if (sock->flags & SOCKET_FLAG_NOBLOCK) {
|
|
return -2;
|
|
} else {
|
|
//((struct task_t *)sock->data)
|
|
current_task->sleep_reason = SLEEP_TCP_READ;
|
|
//((struct task_t *)sock->data)
|
|
current_task->state = TASK_SLEEPING;
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
sock->bytes_available = data->len;
|
|
sock->bytes_read = 0;
|
|
}
|
|
|
|
if (len < sock->bytes_available) {
|
|
size_to_read = len;
|
|
} else {
|
|
size_to_read = sock->bytes_available;
|
|
}
|
|
|
|
if (data->len > 0) {
|
|
memcpy(buffer + offset, data->data + sock->bytes_read, size_to_read);
|
|
}
|
|
|
|
offset += size_to_read;
|
|
|
|
if (size_to_read < sock->bytes_available) {
|
|
sock->bytes_available = sock->bytes_available - size_to_read;
|
|
sock->bytes_read += size_to_read;
|
|
sock->curr_packet.tcp = data;
|
|
} else {
|
|
sock->bytes_available = 0;
|
|
sock->curr_packet.tcp = (void *)0;
|
|
dbfree(data->data, "socket_read");
|
|
dbfree(data, "socket_read 2");
|
|
}
|
|
} while (!size_to_read);
|
|
|
|
sock->offset = 0;
|
|
return size_to_read;
|
|
} else {
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
void socket_clear_port_bitmap(unsigned short port) {
|
|
unsigned int index = port / 8;
|
|
unsigned int offset = port % 8;
|
|
|
|
port_bitmap[index] &= ~(1 << offset);
|
|
}
|
|
|
|
void socket_mark_port_bitmap(unsigned short port) {
|
|
unsigned int index = port / 8;
|
|
unsigned int offset = port % 8;
|
|
|
|
port_bitmap[index] |= (1 << offset);
|
|
}
|
|
|
|
int socket_check_port_bitmap(unsigned short port) {
|
|
unsigned int index = port / 8;
|
|
unsigned int offset = port % 8;
|
|
|
|
return (port_bitmap[index] & (1 << offset));
|
|
}
|
|
|
|
unsigned socket_get_port(void) {
|
|
unsigned short next = random() % 16384 + 49152;
|
|
|
|
unsigned short out = next;
|
|
unsigned short start = next;
|
|
while (socket_check_port_bitmap(out)) {
|
|
next++;
|
|
|
|
if (next >= 65535) {
|
|
next = 49152;
|
|
}
|
|
|
|
if (next == start) {
|
|
// run out of sockets!
|
|
return 0;
|
|
}
|
|
|
|
out = next;
|
|
}
|
|
|
|
|
|
next++;
|
|
if (next == 65535) {
|
|
next = 49152;
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
|
|
int socket_bind(unsigned int serial, unsigned int dest_ip, unsigned short dest_port, unsigned short src_port) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
if (sock->socket_type != 17) {
|
|
return -1;
|
|
}
|
|
|
|
if (src_port == 0) {
|
|
sock->port_recv = socket_get_port();
|
|
|
|
if (sock->port_recv == 0) {
|
|
return -1;
|
|
}
|
|
socket_mark_port_bitmap(sock->port_recv);
|
|
} else {
|
|
if (!socket_check_port_bitmap(src_port)) {
|
|
sock->port_recv = src_port;
|
|
socket_mark_port_bitmap(src_port);
|
|
} else {
|
|
return -1;
|
|
}
|
|
}
|
|
sock->addr = dest_ip;
|
|
sock->port_dest = dest_port;
|
|
|
|
sock->ether = ether_find_from_ipv4(dest_ip);
|
|
if (!sock->ether) {
|
|
if (ptr_vector_len(ðer_devs) > 0) {
|
|
sock->ether = ptr_vector_get(ðer_devs, 0);
|
|
} else {
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
int socket_connect(unsigned int serial, unsigned int dest_ip, unsigned short dest_port) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
|
|
if (sock == NULL) {
|
|
return -1;
|
|
}
|
|
|
|
if (sock->socket_type != 1) {
|
|
return -1;
|
|
}
|
|
|
|
sock->port_recv = socket_get_port();
|
|
if (sock->port_recv == 0) {
|
|
return -1;
|
|
}
|
|
socket_mark_port_bitmap(sock->port_recv);
|
|
sock->tcp_sock.seq_number = 0;
|
|
sock->tcp_sock.ack_number = 0;
|
|
|
|
sock->addr = dest_ip;
|
|
sock->port_dest = dest_port;
|
|
|
|
sock->ether = ether_find_from_ipv4(dest_ip);
|
|
|
|
if (!sock->ether) {
|
|
return -1;
|
|
}
|
|
|
|
tcp_send(sock->ether, sock, (1<<1), (void *)0, 0);
|
|
|
|
return 0;
|
|
}
|
|
|
|
void socket_close(unsigned int serial) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
|
|
if (sock == NULL) {
|
|
return; // Already closed
|
|
}
|
|
|
|
int i;
|
|
int j;
|
|
// remove from task struct
|
|
sched_remove_socket(sock);
|
|
|
|
sock->ref--;
|
|
|
|
if (sock->ref == 0) {
|
|
if (sock->socket_type == 1) {
|
|
if (sock->status != 1) {
|
|
sock->status = 5;
|
|
} else {
|
|
for (i=0;i<ptr_vector_len(&sockets);i++) {
|
|
struct socket_t *s = ptr_vector_get(&sockets, i);
|
|
if (s == sock) {
|
|
socket_clear_port_bitmap(s->port_recv);
|
|
for (j=0;j<ptr_vector_len(&sock->packets);j++) {
|
|
struct tcp_data_t *d = ptr_vector_get(&sock->packets, j);
|
|
if (d->len > 0) {
|
|
dbfree(d->data, "socket close 1");
|
|
}
|
|
}
|
|
ptr_vector_apply(&sock->packets, free);
|
|
destroy_ptr_vector(&sock->packets);
|
|
for (j=0;j<ptr_vector_len(&sock->out_packets);j++) {
|
|
struct tcp_data_t *d = ptr_vector_get(&sock->out_packets, j);
|
|
if (d->len > 0) {
|
|
dbfree(d->data, "socket close 2");
|
|
}
|
|
}
|
|
ptr_vector_apply(&sock->out_packets, free);
|
|
destroy_ptr_vector(&sock->out_packets);
|
|
|
|
dbfree(sock, "socket close 3");
|
|
ptr_vector_del(&sockets, i);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
for (i=0;i<ptr_vector_len(&sockets);i++) {
|
|
struct socket_t *s = ptr_vector_get(&sockets, i);
|
|
if (s == sock) {
|
|
socket_clear_port_bitmap(sock->port_recv);
|
|
for (j=0;j<ptr_vector_len(&sock->packets);j++) {
|
|
struct udp_data_t *d = ptr_vector_get(&sock->packets, j);
|
|
free(d);
|
|
}
|
|
ptr_vector_apply(&sock->packets, free);
|
|
destroy_ptr_vector(&sock->packets);
|
|
for (j=0;j<ptr_vector_len(&sock->out_packets);j++) {
|
|
struct udp_data_t *d = ptr_vector_get(&sock->out_packets, j);
|
|
free(d);
|
|
}
|
|
ptr_vector_apply(&sock->out_packets, free);
|
|
destroy_ptr_vector(&sock->out_packets);
|
|
|
|
free(sock);
|
|
ptr_vector_del(&sockets, i);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void socket_doclose(struct socket_t *sock) {
|
|
int i;
|
|
int j;
|
|
|
|
if (sock->socket_type == 1) {
|
|
if (sock->status != 5) {
|
|
sock->status = 1;
|
|
} else {
|
|
if (sock->ref == 0) {
|
|
for (i=0;i<ptr_vector_len(&sockets);i++) {
|
|
struct socket_t *s = ptr_vector_get(&sockets, i);
|
|
if (s == sock) {
|
|
sched_remove_socket_from_all_tasks(s);
|
|
socket_clear_port_bitmap(sock->port_recv);
|
|
for (j=0;j<ptr_vector_len(&sock->packets);j++) {
|
|
struct tcp_data_t *d = ptr_vector_get(&sock->packets, j);
|
|
free(d);
|
|
}
|
|
ptr_vector_apply(&sock->packets, free);
|
|
destroy_ptr_vector(&sock->packets);
|
|
|
|
for (j=0;j<ptr_vector_len(&sock->out_packets);j++) {
|
|
struct tcp_data_t *d = ptr_vector_get(&sock->out_packets, j);
|
|
free(d);
|
|
}
|
|
ptr_vector_apply(&sock->out_packets, free);
|
|
destroy_ptr_vector(&sock->out_packets);
|
|
|
|
dbfree(sock, "socket_doclose 4");
|
|
|
|
ptr_vector_del(&sockets, i);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
unsigned int socket_open(unsigned char type) {
|
|
struct socket_t *sock;
|
|
if (type == 1) {
|
|
sock = (struct socket_t *)malloc(sizeof(struct socket_t));
|
|
if (!sock){
|
|
return 0;
|
|
}
|
|
|
|
memset(sock, 0, sizeof(struct socket_t));
|
|
sock->status = 0;
|
|
sock->socket_type = type;
|
|
sock->bytes_available = 0;
|
|
sock->bytes_read = 0;
|
|
sock->ref = 1;
|
|
sock->serial = ++socket_serial;
|
|
ptr_vector_append(&sockets, sock);
|
|
ptr_vector_append(¤t_task->sockets, sock);
|
|
} else if (type == 17) {
|
|
sock = (struct socket_t *)malloc(sizeof(struct socket_t));
|
|
if (!sock){
|
|
return 0;
|
|
}
|
|
|
|
memset(sock, 0, sizeof(struct socket_t));
|
|
|
|
sock->status = 0;
|
|
sock->socket_type = type;
|
|
sock->bytes_available = 0;
|
|
sock->bytes_read = 0;
|
|
sock->ref = 1;
|
|
sock->serial = ++socket_serial;
|
|
ptr_vector_append(&sockets, sock);
|
|
ptr_vector_append(¤t_task->sockets, sock);
|
|
} else {
|
|
kprintf("Unsupported Socket type!\n");
|
|
return 0;
|
|
}
|
|
return sock->serial;
|
|
}
|
|
|
|
int socket_listen(unsigned int serial, unsigned int listenip, unsigned short port) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
if (sock->socket_type != 1) {
|
|
kprintf("Wrong socket Type\n");
|
|
return -1;
|
|
}
|
|
|
|
sock->port_recv = port;
|
|
sock->tcp_sock.seq_number = 0;
|
|
sock->tcp_sock.ack_number = 0;
|
|
|
|
sock->addr = 0;
|
|
sock->port_dest = 0;
|
|
|
|
sock->ether = ether_find_from_ipv4(listenip);
|
|
|
|
sock->status = 3;
|
|
|
|
sock->data = (void *)current_task;
|
|
|
|
if (!sock->ether) {
|
|
kprintf("Cant find ether\n");
|
|
return -1;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
|
|
unsigned int socket_accept(unsigned int serial, struct inet_addr *client_addr) {
|
|
struct socket_t *socket = socket_get_from_serial(serial);
|
|
struct task_t *task = (struct task_t *)socket->data;
|
|
struct socket_t *new_socket;
|
|
int i;
|
|
|
|
if (task->waiting_socket_count > 0) {
|
|
new_socket = task->waiting_sockets[0];
|
|
|
|
for (i=1;i<task->waiting_socket_count;i++) {
|
|
task->waiting_sockets[i-1] = task->waiting_sockets[i];
|
|
}
|
|
task->waiting_socket_count--;
|
|
|
|
// send SYN ACK
|
|
|
|
if (new_socket->socket_type == 1) {
|
|
tcp_send(new_socket->ether, new_socket, TCP_FLAGS_SYN | TCP_FLAGS_ACK, (void *)0, 0);
|
|
new_socket->tcp_sock.seq_number = 1;
|
|
}
|
|
|
|
if (client_addr != (void *)0) {
|
|
client_addr->type = new_socket->socket_type;
|
|
client_addr->addr = new_socket->addr;
|
|
}
|
|
|
|
ptr_vector_append(&sockets, new_socket);
|
|
ptr_vector_append(&task->sockets, new_socket);
|
|
return new_socket->serial;
|
|
}
|
|
task->state = TASK_SLEEPING;
|
|
task->sleep_reason = SLEEP_TCP_ACCEPT;
|
|
|
|
return 0;
|
|
}
|
|
|
|
int socket_write(unsigned int serial, unsigned char* payload, unsigned int len) {
|
|
struct socket_t *sock = socket_get_from_serial(serial);
|
|
if (sock->socket_type == 1) {
|
|
return socket_queue(sock, payload, len, TCP_FLAGS_PSH | TCP_FLAGS_ACK);
|
|
} else if (sock->socket_type == 17) {
|
|
udp_send(sock->ether, sock, payload, len);
|
|
return len;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
void init_sockets() {
|
|
init_ptr_vector(&sockets);
|
|
memset(port_bitmap, 0, 8192);
|
|
}
|