1 #define _LARGEFILE64_SOURCE
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <unistd.h>
5 #include <stdint.h>
6 #include <sys/types.h>
7 #include <sys/stat.h>
8 #include <sys/socket.h>
9 #include <netinet/in.h>
10 #include <arpa/inet.h>
11 #include <fcntl.h>
12 #include <dirent.h>
13 #include <string.h>
14 #include <time.h>
15 #include <pthread.h>
16 #include <signal.h>
17 #include <linux/if_ether.h>
18 #include <linux/tcp.h>
19 #include <mos_api.h>
20 #include <ctype.h>
21 #include "cpu.h"
22 #include "http_parsing.h"
23 #include "debug.h"
24 #include "applib.h"
25 /*----------------------------------------------------------------------------*/
26 /* default configuration file path */
27 #define MOS_CONFIG_FILE		     "config/mos.conf"
28 /* max length per line in firewall config file */
29 #define CONF_MAX_LINE_LEN        1024
30 /* number of array elements */
31 #define NELEMS(x)           (sizeof(x) / sizeof(x[0]))
32 /* macro to skip spaces */
33 #define SKIP_SPACES(x) while (*x && isspace((int)*x)) x++;
34 /* macro to skip characters */
35 #define SKIP_CHAR(x) while((*x) && !isspace(*x)) x++;
36 /* macro to skip digit characters */
37 #define SKIP_DIGIT(x) while((*x) && isdigit(*x)) x++;
38 /* macro to do netmasking with ip address */
39 #define IP_NETMASK(x, y) x & (0xFFFFFFFF >> (32 - y));
40 /* macro to dump error and exit */
41 #define EXIT_WITH_ERROR(f, m...) {                                 \
42 		fprintf(stderr, "[%10s:%4d] errno: %u" f, __FUNCTION__, __LINE__, errno, ##m); \
43 	exit(EXIT_FAILURE);                                            \
44 }
45 /* boolean for function return value */
46 #define SUCCESS 1
47 #define FAILURE 0
48 /*----------------------------------------------------------------------------*/
49 /* firewall rule action */
50 typedef enum {FRA_INVALID, FRA_ACCEPT, FRA_DROP} FRAction;
51 #define FR_ACCEPT "ACCEPT"
52 #define FR_DROP   "DROP"
53 #define FR_SPORT  "sport:"
54 #define FR_DPORT  "dport:"
55 /* firewall rule structure */
56 #define MAX_RULES 1024
57 #define MAX_IP_ADDR_LEN 19      /* in CIDR format */
58 /* all fields are in network byte order */
59 typedef struct FirewallRule {
60 	in_addr_t fr_srcIP;         /* source IP */
61 	int       fr_srcIPmask;     /* source IP netmask */
62 	in_addr_t fr_dstIP;         /* destination IP */
63 	int       fr_dstIPmask;     /* destination IP netmask */
64 	in_port_t fr_srcPort;       /* source port */
65 	in_port_t fr_dstPort;       /* destination port */
66 	FRAction  fr_action;        /* action */
67 	uint32_t  fr_count;         /* packet count */
68 } FirewallRule;
69 static FirewallRule g_FWRules[MAX_RULES];
70 /*----------------------------------------------------------------------------*/
71 struct thread_context
72 {
73 	mctx_t mctx;         /* per-thread mos context */
74 	int mon_listener;    /* listening socket for flow monitoring */
75 };
76 /*----------------------------------------------------------------------------*/
77 /* Print the entire firewall rule and status table */
78 static void
DumpFWRuleTable(mctx_t mctx,int sock,int side,uint64_t events,filter_arg_t * arg)79 DumpFWRuleTable(mctx_t mctx, int sock, int side,
80 		uint64_t events, filter_arg_t *arg)
81 {
82 	int i;
83   	FirewallRule *fwr;
84 	char cip_str[MAX_IP_ADDR_LEN];
85 	char sip_str[MAX_IP_ADDR_LEN];
86 	struct timeval tv_1sec = { /* 1 second */
87 		.tv_sec = 1,
88 		.tv_usec = 0
89 	};
90 
91 	printf("-----------------------------------------------------------------------\n");
92 	printf("Firewall rule table\n");
93 	printf("idx   flows   target   client             server             port\n");
94 
95 	for (i = 0;  i < MAX_RULES; i++) {
96 		fwr = &g_FWRules[i];
97 
98 		/* we've searched till the end */
99 		if (fwr->fr_action == FRA_INVALID)
100 			break;
101 
102 		/* print out each rule */
103 		if (!inet_ntop(AF_INET, &(fwr->fr_srcIP), cip_str, INET_ADDRSTRLEN) ||
104 			!inet_ntop(AF_INET, &(fwr->fr_dstIP), sip_str, INET_ADDRSTRLEN))
105 			EXIT_WITH_ERROR("inet_ntop() error\n");
106 
107 		if (fwr->fr_srcIPmask != 32)
108 			sprintf(cip_str, "%s/%d", cip_str, fwr->fr_srcIPmask);
109 		if (fwr->fr_dstIPmask != 32)
110 			sprintf(sip_str, "%s/%d", sip_str, fwr->fr_dstIPmask);
111 		printf("%-6u%-8u%-9s%-19s%-19s",
112 			   (i + 1), fwr->fr_count,
113 			   (fwr->fr_action == FRA_DROP)? FR_DROP : FR_ACCEPT,
114 			   cip_str, sip_str);
115 		if (fwr->fr_srcPort)
116 			printf("sport:%-6d", ntohs(fwr->fr_srcPort));
117 		if (fwr->fr_dstPort)
118 			printf("dport:%-6d", ntohs(fwr->fr_dstPort));
119 		printf("\n");
120 	}
121 	printf("-----------------------------------------------------------------------\n");
122 
123 	/* Set a timer for next printing */
124 	if (mtcp_settimer(mctx, sock, &tv_1sec, DumpFWRuleTable))
125 		EXIT_WITH_ERROR("mtcp_settimer() error\n");
126 }
127 /*----------------------------------------------------------------------------*/
128 static inline char*
ExtractPort(char * buf,in_port_t * sport,in_port_t * dport)129 ExtractPort(char* buf, in_port_t* sport, in_port_t* dport)
130 {
131 	in_port_t* p = NULL;
132 	char* temp = (char*)buf;
133 	char* check;
134 	int port;
135 	char s = 0;             /* swap character */
136 
137 	SKIP_CHAR(temp);	    /* skip characters */
138 	s = *temp; *temp = 0;	/* replace the end character with null */
139 
140 	/* check if the port format is correct */
141 	if (!strncmp(buf, FR_SPORT, sizeof(FR_SPORT) - 1)) {
142 		p = sport;
143 		buf += (sizeof(FR_SPORT) - 1);
144 	}
145 	else if (!strncmp(buf, FR_DPORT, sizeof(FR_DPORT) - 1)) {
146 		p = dport;
147 		buf += (sizeof(FR_DPORT) - 1);
148 	}
149 	else
150 		EXIT_WITH_ERROR("Invalid rule in port setup [%s]\n", buf);
151 
152 	check = buf;
153 	SKIP_DIGIT(check);
154 	if (check != temp)
155 		EXIT_WITH_ERROR("Invalid port format [%s]\n", buf);
156 
157 	/* convert to port number */
158 	port = atoi(buf);
159 	if (port < 0 || port > 65536)
160 		EXIT_WITH_ERROR("Invalid port [%d]\n", port);
161 	(*p) = htons(port);
162 
163 	(*temp) = s;	/* recover the original character */
164 	buf = temp;	    /* move buf pointer to next string */
165 	SKIP_SPACES(buf);
166 
167 	return buf;
168 }
169 /*----------------------------------------------------------------------------*/
170 static inline char*
ExtractIPAddress(char * buf,in_addr_t * addr,int * addrmask)171 ExtractIPAddress(char* buf, in_addr_t* addr, int* addrmask)
172 {
173 	struct in_addr addr_conv;
174 	char* temp = (char*)buf;
175 	char* check;
176 	int netmask = 32;
177 	char s = 0;        /* swap character */
178 
179 	/* skip characters which are not '/' */
180 	while ((*temp) && !isspace(*temp) && (*temp) != '/') temp++;
181 
182 	s = *temp; *temp = 0;
183 	if (inet_aton(buf, &addr_conv) == 0)
184 		EXIT_WITH_ERROR("Invalid IP address [%s]\n", buf);
185 	(*addr) = addr_conv.s_addr;
186 	(*temp) = s;
187 
188 	/* if the rule contains netmask */
189 	if ((*temp) == '/') {
190 		buf = temp + 1;
191 		SKIP_CHAR(temp);
192 		s = *temp; *temp = 0;
193 
194 		/* check if the format is correct */
195 		check = buf;
196 		SKIP_DIGIT(check);
197 		if (check != temp)
198 			EXIT_WITH_ERROR("Invalid netmask format [%s]\n", buf);
199 
200 		/* convert to netmask number */
201 		netmask = atoi(buf);
202 		if (netmask < 0 || netmask > 32)
203 			EXIT_WITH_ERROR("Invalid netmask [%s]\n", buf);
204 		(*addr) = IP_NETMASK((*addr), netmask);
205 		(*temp) = s;
206 	}
207 
208 	/* move buf pointer to next string */
209 	buf = temp;
210 	SKIP_SPACES(buf);
211 
212 	(*addrmask) = netmask;
213 
214 	return buf;
215 }
216 /*----------------------------------------------------------------------------*/
217 static void
ParseConfigFile(char * configF)218 ParseConfigFile(char* configF)
219 {
220   	FirewallRule *fwr;
221 	FILE *fp;
222 	char line_buf[CONF_MAX_LINE_LEN] = {0};
223 	char *line, *p;
224 	int i = 0;
225 
226 	/* config file path should not be null */
227 	assert(configF != NULL);
228 
229 	/* open firewall rule file */
230 	if ((fp = fopen(configF, "r")) == NULL)
231 		EXIT_WITH_ERROR("Firewall rule file %s is not found.\n", configF);
232 
233 	/* read each line */
234 	while ((line = fgets(line_buf, CONF_MAX_LINE_LEN, fp)) != NULL) {
235 
236 		/* each line represents a rule */
237 		fwr = &g_FWRules[i];
238 		if (line[CONF_MAX_LINE_LEN - 1])
239 			EXIT_WITH_ERROR("%s has a line longer than %d\n",
240 						configF, CONF_MAX_LINE_LEN);
241 
242 		SKIP_SPACES(line); /* remove spaces */
243 		if (*line == '\0' || *line == '#')
244 			continue;
245 		if ((p = strchr(line, '#'))) /* skip comments in the line */
246 			*p = '\0';
247 		while (isspace(line[strlen(line) - 1])) /* remove spaces */
248 			line[strlen(line) - 1] = '\0';
249 
250 		/* read firewall rule action */
251 		p = line;
252 		if (!strncmp(p, FR_ACCEPT, sizeof(FR_ACCEPT) - 1)) {
253 			fwr->fr_action = FRA_ACCEPT;
254 			p += (sizeof(FR_ACCEPT) - 1);
255 		}
256 		else if (!strncmp(p, FR_DROP, sizeof(FR_DROP) - 1)) {
257 			fwr->fr_action = FRA_DROP;
258 			p += (sizeof(FR_DROP) - 1);
259 		}
260 		else
261 			EXIT_WITH_ERROR("Unknown rule action [%s].\n", line);
262 
263 		if (!isspace(*p)) /* invalid if no space exists after action */
264 			EXIT_WITH_ERROR("Invalid format [%s].\n", line);
265 		SKIP_SPACES(p);
266 
267 		/* read client ip address */
268 		if (*p)
269 			p = ExtractIPAddress(p, &fwr->fr_srcIP, &(fwr->fr_srcIPmask));
270 		else
271 			EXIT_WITH_ERROR("Invalid format [%s].\n", line);
272 
273 		/* read server ip address */
274 		if (*p)
275 			p = ExtractIPAddress(p, &fwr->fr_dstIP, &(fwr->fr_dstIPmask));
276 		else
277 			EXIT_WITH_ERROR("Invalid format [%s].\n", line);
278 
279 		/* read port filter information */
280 		while (*p)
281 			p = ExtractPort(p, &(fwr->fr_srcPort), &(fwr->fr_dstPort));
282 
283 		fwr->fr_count = 0;
284 		if ((i++) >= MAX_RULES)
285 			EXIT_WITH_ERROR("Exceeded max number of rules (%d)\n", MAX_RULES);
286 	}
287 
288 	fclose(fp);
289 }
290 /*----------------------------------------------------------------------------*/
291 static inline int
MatchAddr(in_addr_t ip,in_addr_t fw_ip,int netmask)292 MatchAddr(in_addr_t ip, in_addr_t fw_ip, int netmask)
293 {
294 	ip = IP_NETMASK(ip, netmask);
295 
296 	/* 0 means '*' */
297 	return (fw_ip == 0 || ip == fw_ip);
298 }
299 /*----------------------------------------------------------------------------*/
300 static inline int
MatchPort(in_port_t port,in_port_t fw_port)301 MatchPort(in_port_t port, in_port_t fw_port)
302 {
303 	/* 0 means '*' */
304 	return (fw_port == 0 || port == fw_port);
305 }
306 /*----------------------------------------------------------------------------*/
307 static int
FWRLookup(in_addr_t sip,in_addr_t dip,in_port_t sp,in_port_t dp)308 FWRLookup(in_addr_t sip, in_addr_t dip, in_port_t sp, in_port_t dp)
309 {
310 	int i;
311 	FirewallRule *p = g_FWRules;
312 
313 	for (i = 0;  i < MAX_RULES; i++) {
314 		if (p[i].fr_action == FRA_INVALID) {
315 			/* We've searched till the end. By default, allow any flow */
316 			return (FRA_ACCEPT);
317 		}
318 
319 		if (MatchAddr(sip, p[i].fr_srcIP, p[i].fr_srcIPmask) &&
320 			MatchAddr(dip, p[i].fr_dstIP, p[i].fr_dstIPmask) &&
321 			MatchPort(sp, p[i].fr_srcPort) &&
322 			MatchPort(dp, p[i].fr_dstPort)) {
323 			p[i].fr_count++;
324 			return p[i].fr_action;
325 		}
326 	}
327 
328 	assert(0); /* can't reach here */
329 	return  (FRA_ACCEPT);
330 }
331 /*----------------------------------------------------------------------------*/
332 static void
ApplyActionPerFlow(mctx_t mctx,int msock,int side,uint64_t events,filter_arg_t * arg)333 ApplyActionPerFlow(mctx_t mctx, int msock, int side,
334 		   		     uint64_t events, filter_arg_t *arg)
335 
336 {
337 	/* this function is called at the first SYN */
338 	struct pkt_info p;
339 	int opt;
340 	FRAction action;
341 
342 	if (mtcp_getlastpkt(mctx, msock, side, &p) < 0)
343 		EXIT_WITH_ERROR("Failed to get packet context!\n");
344 
345 	/* look up the firewall rules */
346 	action = FWRLookup(p.iph->saddr, p.iph->daddr,
347 					    p.tcph->source, p.tcph->dest);
348 
349 	if (action == FRA_DROP) {
350 		mtcp_setlastpkt(mctx, msock, side, 0, NULL, 0, MOS_DROP);
351 	} else {
352 		assert(action == FRA_ACCEPT);
353 		/* no need to monitor this flow any more */
354 		opt = MOS_SIDE_BOTH;
355 		if (mtcp_setsockopt(mctx, msock, SOL_MONSOCKET,
356 				    MOS_STOP_MON, &opt, sizeof(opt)) < 0)
357 			EXIT_WITH_ERROR("Failed to stop monitoring conn with sockid: %d\n",
358 					msock);
359 	}
360 }
361 /*----------------------------------------------------------------------------*/
362 static bool
CatchInitSYN(mctx_t mctx,int sockid,int side,uint64_t events,filter_arg_t * arg)363 CatchInitSYN(mctx_t mctx, int sockid,
364 			int side, uint64_t events, filter_arg_t *arg)
365 {
366 	struct pkt_info p;
367 
368 	if (mtcp_getlastpkt(mctx, sockid, side, &p) < 0)
369 		EXIT_WITH_ERROR("Failed to get packet context!!!\n");
370 
371 	return (p.tcph->syn && !p.tcph->ack);
372 }
373 /*----------------------------------------------------------------------------*/
374 static void
CreateAndInitThreadContext(struct thread_context * ctx,int core,event_t udeForSYN)375 CreateAndInitThreadContext(struct thread_context* ctx,
376 						      int core, event_t  udeForSYN)
377 {
378 	struct timeval tv_1sec = { /* 1 second */
379 		.tv_sec = 1,
380 		.tv_usec = 0
381 	};
382 
383 	ctx->mctx = mtcp_create_context(core);
384 
385 	/* create socket  */
386 	ctx->mon_listener = mtcp_socket(ctx->mctx, AF_INET,
387 					MOS_SOCK_MONITOR_STREAM, 0);
388 	if (ctx->mon_listener < 0)
389 		EXIT_WITH_ERROR("Failed to create monitor listening socket!\n");
390 
391 	/* register callback */
392 	if (mtcp_register_callback(ctx->mctx, ctx->mon_listener,
393 							   udeForSYN,
394 							   MOS_HK_SND,
395 				   			   ApplyActionPerFlow) == -1)
396 		EXIT_WITH_ERROR("Failed to register callback func!\n");
397 
398 	/* CPU 0 is in charge of printing stats */
399 	if (ctx->mctx->cpu == 0 &&
400 		mtcp_settimer(ctx->mctx, ctx->mon_listener,
401 					  &tv_1sec, DumpFWRuleTable))
402 		EXIT_WITH_ERROR("Failed to register timer callback func!\n");
403 
404 }
405 /*----------------------------------------------------------------------------*/
406 static void
WaitAndCleanupThreadContext(struct thread_context * ctx)407 WaitAndCleanupThreadContext(struct thread_context* ctx)
408 {
409 	/* wait for the TCP thread to finish */
410 	mtcp_app_join(ctx->mctx);
411 
412 	/* close the monitoring socket */
413 	mtcp_close(ctx->mctx, ctx->mon_listener);
414 
415 	/* tear down */
416 	mtcp_destroy_context(ctx->mctx);
417 }
418 /*----------------------------------------------------------------------------*/
419 int
main(int argc,char ** argv)420 main(int argc, char **argv)
421 {
422 	int ret, i;
423 	char *fname = MOS_CONFIG_FILE; /* path to the default mos config file */
424 	struct mtcp_conf mcfg;
425 	char simple_firewall_file[1024] = "config/simple_firewall.conf";
426 	struct thread_context ctx[MAX_CPUS] = {{0}}; /* init all fields to 0 */
427 	event_t initSYNEvent;
428 	int num_cpus;
429 	int opt, rc;
430 
431 	/* get the total # of cpu cores */
432 	num_cpus = GetNumCPUs();
433 
434 	while ((opt = getopt(argc, argv, "c:f:n:")) != -1) {
435 		switch (opt) {
436 			case 'c':
437 				fname = optarg;
438 				break;
439 		        case 'f':
440 				strcpy(simple_firewall_file, optarg);
441 				break;
442 			case 'n':
443 				if ((rc=atoi(optarg)) > num_cpus) {
444 					EXIT_WITH_ERROR("Available number of CPU cores is %d "
445 							"while requested cores is %d\n",
446 							num_cpus, rc);
447 				}
448 				num_cpus = rc;
449 				break;
450 			default:
451 				printf("Usage: %s [-c mos_config_file] "
452 				       "[-f simple_firewall_config_file]\n",
453 				       argv[0]);
454 				return 0;
455 		}
456 	}
457 
458 	/* parse mos configuration file */
459 	ret = mtcp_init(fname);
460 	if (ret)
461 		EXIT_WITH_ERROR("Failed to initialize mtcp.\n");
462 
463 	/* set the core limit */
464 	mtcp_getconf(&mcfg);
465 	mcfg.num_cores = num_cpus;
466 	mtcp_setconf(&mcfg);
467 
468 	/* parse simple firewall-specfic startup file */
469 	ParseConfigFile(simple_firewall_file);
470 
471 	/* populate local mos-specific mcfg struct for later usage */
472 	mtcp_getconf(&mcfg);
473 
474 	/* event for the initial SYN packet */
475 	initSYNEvent = mtcp_define_event(MOS_ON_PKT_IN, CatchInitSYN, NULL);
476 	if (initSYNEvent == MOS_NULL_EVENT)
477 		EXIT_WITH_ERROR("mtcp_define_event() failed!");
478 
479 	/* initialize monitor threads */
480 	for (i = 0; i < mcfg.num_cores; i++)
481 		CreateAndInitThreadContext(&ctx[i], i, initSYNEvent);
482 
483 	/* wait until all threads finish */
484 	for (i = 0; i < mcfg.num_cores; i++) {
485 		WaitAndCleanupThreadContext(&ctx[i]);
486 	  	TRACE_INFO("Message test thread %d joined.\n", i);
487 	}
488 
489 	mtcp_destroy();
490 
491 	return EXIT_SUCCESS;
492 }
493 /*----------------------------------------------------------------------------*/
494