1 #define _LARGEFILE64_SOURCE
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <unistd.h>
5 #include <stdint.h>
6 #include <sys/types.h>
7 #include <sys/stat.h>
8 #include <sys/socket.h>
9 #include <netinet/in.h>
10 #include <arpa/inet.h>
11 #include <fcntl.h>
12 #include <dirent.h>
13 #include <string.h>
14 #include <time.h>
15 #include <pthread.h>
16 #include <signal.h>
17 #include <linux/if_ether.h>
18 #include <linux/tcp.h>
19 #include <mos_api.h>
20 #include <ctype.h>
21 #include "cpu.h"
22 #include "http_parsing.h"
23 #include "debug.h"
24 #include "applib.h"
25 /*----------------------------------------------------------------------------*/
26 /* default configuration file path */
27 #define MOS_CONFIG_FILE "config/mos.conf"
28 /* max length per line in firewall config file */
29 #define CONF_MAX_LINE_LEN 1024
30 /* number of array elements */
31 #define NELEMS(x) (sizeof(x) / sizeof(x[0]))
32 /* macro to skip spaces */
33 #define SKIP_SPACES(x) while (*x && isspace((int)*x)) x++;
34 /* macro to skip characters */
35 #define SKIP_CHAR(x) while((*x) && !isspace(*x)) x++;
36 /* macro to skip digit characters */
37 #define SKIP_DIGIT(x) while((*x) && isdigit(*x)) x++;
38 /* macro to do netmasking with ip address */
39 #define IP_NETMASK(x, y) x & (0xFFFFFFFF >> (32 - y));
40 /* macro to dump error and exit */
41 #define EXIT_WITH_ERROR(f, m...) { \
42 fprintf(stderr, "[%10s:%4d] errno: %u" f, __FUNCTION__, __LINE__, errno, ##m); \
43 exit(EXIT_FAILURE); \
44 }
45 /* boolean for function return value */
46 #define SUCCESS 1
47 #define FAILURE 0
48 /*----------------------------------------------------------------------------*/
49 /* firewall rule action */
50 typedef enum {FRA_INVALID, FRA_ACCEPT, FRA_DROP} FRAction;
51 #define FR_ACCEPT "ACCEPT"
52 #define FR_DROP "DROP"
53 #define FR_SPORT "sport:"
54 #define FR_DPORT "dport:"
55 /* firewall rule structure */
56 #define MAX_RULES 1024
57 #define MAX_IP_ADDR_LEN 19 /* in CIDR format */
58 /* all fields are in network byte order */
59 typedef struct FirewallRule {
60 in_addr_t fr_srcIP; /* source IP */
61 int fr_srcIPmask; /* source IP netmask */
62 in_addr_t fr_dstIP; /* destination IP */
63 int fr_dstIPmask; /* destination IP netmask */
64 in_port_t fr_srcPort; /* source port */
65 in_port_t fr_dstPort; /* destination port */
66 FRAction fr_action; /* action */
67 uint32_t fr_count; /* packet count */
68 } FirewallRule;
69 static FirewallRule g_FWRules[MAX_RULES];
70 /*----------------------------------------------------------------------------*/
71 struct thread_context
72 {
73 mctx_t mctx; /* per-thread mos context */
74 int mon_listener; /* listening socket for flow monitoring */
75 };
76 /*----------------------------------------------------------------------------*/
77 /* Print the entire firewall rule and status table */
78 static void
DumpFWRuleTable(mctx_t mctx,int sock,int side,uint64_t events,filter_arg_t * arg)79 DumpFWRuleTable(mctx_t mctx, int sock, int side,
80 uint64_t events, filter_arg_t *arg)
81 {
82 int i;
83 FirewallRule *fwr;
84 char cip_str[MAX_IP_ADDR_LEN];
85 char sip_str[MAX_IP_ADDR_LEN];
86 struct timeval tv_1sec = { /* 1 second */
87 .tv_sec = 1,
88 .tv_usec = 0
89 };
90
91 printf("-----------------------------------------------------------------------\n");
92 printf("Firewall rule table\n");
93 printf("idx flows target client server port\n");
94
95 for (i = 0; i < MAX_RULES; i++) {
96 fwr = &g_FWRules[i];
97
98 /* we've searched till the end */
99 if (fwr->fr_action == FRA_INVALID)
100 break;
101
102 /* print out each rule */
103 if (!inet_ntop(AF_INET, &(fwr->fr_srcIP), cip_str, INET_ADDRSTRLEN) ||
104 !inet_ntop(AF_INET, &(fwr->fr_dstIP), sip_str, INET_ADDRSTRLEN))
105 EXIT_WITH_ERROR("inet_ntop() error\n");
106
107 if (fwr->fr_srcIPmask != 32)
108 sprintf(cip_str, "%s/%d", cip_str, fwr->fr_srcIPmask);
109 if (fwr->fr_dstIPmask != 32)
110 sprintf(sip_str, "%s/%d", sip_str, fwr->fr_dstIPmask);
111 printf("%-6u%-8u%-9s%-19s%-19s",
112 (i + 1), fwr->fr_count,
113 (fwr->fr_action == FRA_DROP)? FR_DROP : FR_ACCEPT,
114 cip_str, sip_str);
115 if (fwr->fr_srcPort)
116 printf("sport:%-6d", ntohs(fwr->fr_srcPort));
117 if (fwr->fr_dstPort)
118 printf("dport:%-6d", ntohs(fwr->fr_dstPort));
119 printf("\n");
120 }
121 printf("-----------------------------------------------------------------------\n");
122
123 /* Set a timer for next printing */
124 if (mtcp_settimer(mctx, sock, &tv_1sec, DumpFWRuleTable))
125 EXIT_WITH_ERROR("mtcp_settimer() error\n");
126 }
127 /*----------------------------------------------------------------------------*/
128 static inline char*
ExtractPort(char * buf,in_port_t * sport,in_port_t * dport)129 ExtractPort(char* buf, in_port_t* sport, in_port_t* dport)
130 {
131 in_port_t* p = NULL;
132 char* temp = (char*)buf;
133 char* check;
134 int port;
135 char s = 0; /* swap character */
136
137 SKIP_CHAR(temp); /* skip characters */
138 s = *temp; *temp = 0; /* replace the end character with null */
139
140 /* check if the port format is correct */
141 if (!strncmp(buf, FR_SPORT, sizeof(FR_SPORT) - 1)) {
142 p = sport;
143 buf += (sizeof(FR_SPORT) - 1);
144 }
145 else if (!strncmp(buf, FR_DPORT, sizeof(FR_DPORT) - 1)) {
146 p = dport;
147 buf += (sizeof(FR_DPORT) - 1);
148 }
149 else
150 EXIT_WITH_ERROR("Invalid rule in port setup [%s]\n", buf);
151
152 check = buf;
153 SKIP_DIGIT(check);
154 if (check != temp)
155 EXIT_WITH_ERROR("Invalid port format [%s]\n", buf);
156
157 /* convert to port number */
158 port = atoi(buf);
159 if (port < 0 || port > 65536)
160 EXIT_WITH_ERROR("Invalid port [%d]\n", port);
161 (*p) = htons(port);
162
163 (*temp) = s; /* recover the original character */
164 buf = temp; /* move buf pointer to next string */
165 SKIP_SPACES(buf);
166
167 return buf;
168 }
169 /*----------------------------------------------------------------------------*/
170 static inline char*
ExtractIPAddress(char * buf,in_addr_t * addr,int * addrmask)171 ExtractIPAddress(char* buf, in_addr_t* addr, int* addrmask)
172 {
173 struct in_addr addr_conv;
174 char* temp = (char*)buf;
175 char* check;
176 int netmask = 32;
177 char s = 0; /* swap character */
178
179 /* skip characters which are not '/' */
180 while ((*temp) && !isspace(*temp) && (*temp) != '/') temp++;
181
182 s = *temp; *temp = 0;
183 if (inet_aton(buf, &addr_conv) == 0)
184 EXIT_WITH_ERROR("Invalid IP address [%s]\n", buf);
185 (*addr) = addr_conv.s_addr;
186 (*temp) = s;
187
188 /* if the rule contains netmask */
189 if ((*temp) == '/') {
190 buf = temp + 1;
191 SKIP_CHAR(temp);
192 s = *temp; *temp = 0;
193
194 /* check if the format is correct */
195 check = buf;
196 SKIP_DIGIT(check);
197 if (check != temp)
198 EXIT_WITH_ERROR("Invalid netmask format [%s]\n", buf);
199
200 /* convert to netmask number */
201 netmask = atoi(buf);
202 if (netmask < 0 || netmask > 32)
203 EXIT_WITH_ERROR("Invalid netmask [%s]\n", buf);
204 (*addr) = IP_NETMASK((*addr), netmask);
205 (*temp) = s;
206 }
207
208 /* move buf pointer to next string */
209 buf = temp;
210 SKIP_SPACES(buf);
211
212 (*addrmask) = netmask;
213
214 return buf;
215 }
216 /*----------------------------------------------------------------------------*/
217 static void
ParseConfigFile(char * configF)218 ParseConfigFile(char* configF)
219 {
220 FirewallRule *fwr;
221 FILE *fp;
222 char line_buf[CONF_MAX_LINE_LEN] = {0};
223 char *line, *p;
224 int i = 0;
225
226 /* config file path should not be null */
227 assert(configF != NULL);
228
229 /* open firewall rule file */
230 if ((fp = fopen(configF, "r")) == NULL)
231 EXIT_WITH_ERROR("Firewall rule file %s is not found.\n", configF);
232
233 /* read each line */
234 while ((line = fgets(line_buf, CONF_MAX_LINE_LEN, fp)) != NULL) {
235
236 /* each line represents a rule */
237 fwr = &g_FWRules[i];
238 if (line[CONF_MAX_LINE_LEN - 1])
239 EXIT_WITH_ERROR("%s has a line longer than %d\n",
240 configF, CONF_MAX_LINE_LEN);
241
242 SKIP_SPACES(line); /* remove spaces */
243 if (*line == '\0' || *line == '#')
244 continue;
245 if ((p = strchr(line, '#'))) /* skip comments in the line */
246 *p = '\0';
247 while (isspace(line[strlen(line) - 1])) /* remove spaces */
248 line[strlen(line) - 1] = '\0';
249
250 /* read firewall rule action */
251 p = line;
252 if (!strncmp(p, FR_ACCEPT, sizeof(FR_ACCEPT) - 1)) {
253 fwr->fr_action = FRA_ACCEPT;
254 p += (sizeof(FR_ACCEPT) - 1);
255 }
256 else if (!strncmp(p, FR_DROP, sizeof(FR_DROP) - 1)) {
257 fwr->fr_action = FRA_DROP;
258 p += (sizeof(FR_DROP) - 1);
259 }
260 else
261 EXIT_WITH_ERROR("Unknown rule action [%s].\n", line);
262
263 if (!isspace(*p)) /* invalid if no space exists after action */
264 EXIT_WITH_ERROR("Invalid format [%s].\n", line);
265 SKIP_SPACES(p);
266
267 /* read client ip address */
268 if (*p)
269 p = ExtractIPAddress(p, &fwr->fr_srcIP, &(fwr->fr_srcIPmask));
270 else
271 EXIT_WITH_ERROR("Invalid format [%s].\n", line);
272
273 /* read server ip address */
274 if (*p)
275 p = ExtractIPAddress(p, &fwr->fr_dstIP, &(fwr->fr_dstIPmask));
276 else
277 EXIT_WITH_ERROR("Invalid format [%s].\n", line);
278
279 /* read port filter information */
280 while (*p)
281 p = ExtractPort(p, &(fwr->fr_srcPort), &(fwr->fr_dstPort));
282
283 fwr->fr_count = 0;
284 if ((i++) >= MAX_RULES)
285 EXIT_WITH_ERROR("Exceeded max number of rules (%d)\n", MAX_RULES);
286 }
287
288 fclose(fp);
289 }
290 /*----------------------------------------------------------------------------*/
291 static inline int
MatchAddr(in_addr_t ip,in_addr_t fw_ip,int netmask)292 MatchAddr(in_addr_t ip, in_addr_t fw_ip, int netmask)
293 {
294 ip = IP_NETMASK(ip, netmask);
295
296 /* 0 means '*' */
297 return (fw_ip == 0 || ip == fw_ip);
298 }
299 /*----------------------------------------------------------------------------*/
300 static inline int
MatchPort(in_port_t port,in_port_t fw_port)301 MatchPort(in_port_t port, in_port_t fw_port)
302 {
303 /* 0 means '*' */
304 return (fw_port == 0 || port == fw_port);
305 }
306 /*----------------------------------------------------------------------------*/
307 static int
FWRLookup(in_addr_t sip,in_addr_t dip,in_port_t sp,in_port_t dp)308 FWRLookup(in_addr_t sip, in_addr_t dip, in_port_t sp, in_port_t dp)
309 {
310 int i;
311 FirewallRule *p = g_FWRules;
312
313 for (i = 0; i < MAX_RULES; i++) {
314 if (p[i].fr_action == FRA_INVALID) {
315 /* We've searched till the end. By default, allow any flow */
316 return (FRA_ACCEPT);
317 }
318
319 if (MatchAddr(sip, p[i].fr_srcIP, p[i].fr_srcIPmask) &&
320 MatchAddr(dip, p[i].fr_dstIP, p[i].fr_dstIPmask) &&
321 MatchPort(sp, p[i].fr_srcPort) &&
322 MatchPort(dp, p[i].fr_dstPort)) {
323 p[i].fr_count++;
324 return p[i].fr_action;
325 }
326 }
327
328 assert(0); /* can't reach here */
329 return (FRA_ACCEPT);
330 }
331 /*----------------------------------------------------------------------------*/
332 static void
ApplyActionPerFlow(mctx_t mctx,int msock,int side,uint64_t events,filter_arg_t * arg)333 ApplyActionPerFlow(mctx_t mctx, int msock, int side,
334 uint64_t events, filter_arg_t *arg)
335
336 {
337 /* this function is called at the first SYN */
338 struct pkt_info p;
339 int opt;
340 FRAction action;
341
342 if (mtcp_getlastpkt(mctx, msock, side, &p) < 0)
343 EXIT_WITH_ERROR("Failed to get packet context!\n");
344
345 /* look up the firewall rules */
346 action = FWRLookup(p.iph->saddr, p.iph->daddr,
347 p.tcph->source, p.tcph->dest);
348
349 if (action == FRA_DROP) {
350 mtcp_setlastpkt(mctx, msock, side, 0, NULL, 0, MOS_DROP);
351 } else {
352 assert(action == FRA_ACCEPT);
353 /* no need to monitor this flow any more */
354 opt = MOS_SIDE_BOTH;
355 if (mtcp_setsockopt(mctx, msock, SOL_MONSOCKET,
356 MOS_STOP_MON, &opt, sizeof(opt)) < 0)
357 EXIT_WITH_ERROR("Failed to stop monitoring conn with sockid: %d\n",
358 msock);
359 }
360 }
361 /*----------------------------------------------------------------------------*/
362 static bool
CatchInitSYN(mctx_t mctx,int sockid,int side,uint64_t events,filter_arg_t * arg)363 CatchInitSYN(mctx_t mctx, int sockid,
364 int side, uint64_t events, filter_arg_t *arg)
365 {
366 struct pkt_info p;
367
368 if (mtcp_getlastpkt(mctx, sockid, side, &p) < 0)
369 EXIT_WITH_ERROR("Failed to get packet context!!!\n");
370
371 return (p.tcph->syn && !p.tcph->ack);
372 }
373 /*----------------------------------------------------------------------------*/
374 static void
CreateAndInitThreadContext(struct thread_context * ctx,int core,event_t udeForSYN)375 CreateAndInitThreadContext(struct thread_context* ctx,
376 int core, event_t udeForSYN)
377 {
378 struct timeval tv_1sec = { /* 1 second */
379 .tv_sec = 1,
380 .tv_usec = 0
381 };
382
383 ctx->mctx = mtcp_create_context(core);
384
385 /* create socket */
386 ctx->mon_listener = mtcp_socket(ctx->mctx, AF_INET,
387 MOS_SOCK_MONITOR_STREAM, 0);
388 if (ctx->mon_listener < 0)
389 EXIT_WITH_ERROR("Failed to create monitor listening socket!\n");
390
391 /* register callback */
392 if (mtcp_register_callback(ctx->mctx, ctx->mon_listener,
393 udeForSYN,
394 MOS_HK_SND,
395 ApplyActionPerFlow) == -1)
396 EXIT_WITH_ERROR("Failed to register callback func!\n");
397
398 /* CPU 0 is in charge of printing stats */
399 if (ctx->mctx->cpu == 0 &&
400 mtcp_settimer(ctx->mctx, ctx->mon_listener,
401 &tv_1sec, DumpFWRuleTable))
402 EXIT_WITH_ERROR("Failed to register timer callback func!\n");
403
404 }
405 /*----------------------------------------------------------------------------*/
406 static void
WaitAndCleanupThreadContext(struct thread_context * ctx)407 WaitAndCleanupThreadContext(struct thread_context* ctx)
408 {
409 /* wait for the TCP thread to finish */
410 mtcp_app_join(ctx->mctx);
411
412 /* close the monitoring socket */
413 mtcp_close(ctx->mctx, ctx->mon_listener);
414
415 /* tear down */
416 mtcp_destroy_context(ctx->mctx);
417 }
418 /*----------------------------------------------------------------------------*/
419 int
main(int argc,char ** argv)420 main(int argc, char **argv)
421 {
422 int ret, i;
423 char *fname = MOS_CONFIG_FILE; /* path to the default mos config file */
424 struct mtcp_conf mcfg;
425 char simple_firewall_file[1024] = "config/simple_firewall.conf";
426 struct thread_context ctx[MAX_CPUS] = {{0}}; /* init all fields to 0 */
427 event_t initSYNEvent;
428 int num_cpus;
429 int opt, rc;
430
431 /* get the total # of cpu cores */
432 num_cpus = GetNumCPUs();
433
434 while ((opt = getopt(argc, argv, "c:f:n:")) != -1) {
435 switch (opt) {
436 case 'c':
437 fname = optarg;
438 break;
439 case 'f':
440 strcpy(simple_firewall_file, optarg);
441 break;
442 case 'n':
443 if ((rc=atoi(optarg)) > num_cpus) {
444 EXIT_WITH_ERROR("Available number of CPU cores is %d "
445 "while requested cores is %d\n",
446 num_cpus, rc);
447 }
448 num_cpus = rc;
449 break;
450 default:
451 printf("Usage: %s [-c mos_config_file] "
452 "[-f simple_firewall_config_file]\n",
453 argv[0]);
454 return 0;
455 }
456 }
457
458 /* parse mos configuration file */
459 ret = mtcp_init(fname);
460 if (ret)
461 EXIT_WITH_ERROR("Failed to initialize mtcp.\n");
462
463 /* set the core limit */
464 mtcp_getconf(&mcfg);
465 mcfg.num_cores = num_cpus;
466 mtcp_setconf(&mcfg);
467
468 /* parse simple firewall-specfic startup file */
469 ParseConfigFile(simple_firewall_file);
470
471 /* populate local mos-specific mcfg struct for later usage */
472 mtcp_getconf(&mcfg);
473
474 /* event for the initial SYN packet */
475 initSYNEvent = mtcp_define_event(MOS_ON_PKT_IN, CatchInitSYN, NULL);
476 if (initSYNEvent == MOS_NULL_EVENT)
477 EXIT_WITH_ERROR("mtcp_define_event() failed!");
478
479 /* initialize monitor threads */
480 for (i = 0; i < mcfg.num_cores; i++)
481 CreateAndInitThreadContext(&ctx[i], i, initSYNEvent);
482
483 /* wait until all threads finish */
484 for (i = 0; i < mcfg.num_cores; i++) {
485 WaitAndCleanupThreadContext(&ctx[i]);
486 TRACE_INFO("Message test thread %d joined.\n", i);
487 }
488
489 mtcp_destroy();
490
491 return EXIT_SUCCESS;
492 }
493 /*----------------------------------------------------------------------------*/
494