/*
 * Copyright (c) 2005 Anton Antonov <aga@pikenet.ru>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $Id: compare_rules.c,v 1.10 2006/08/11 11:01:03 aga Exp $
 */

#include <sys/param.h>
#include <sys/socket.h>
#include <sys/sysctl.h>
#include <sys/time.h> 

#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <net/if.h>
#include <netinet/in.h>
#include <netinet/ip_fw2.h>
#include <arpa/inet.h>

#include <syslog.h>
#include <fcntl.h>
#include "multilar.h"
#include "babolo/BLINflag.h"

#include <sys/types.h>
#include <sys/event.h>

#define NEXT(r)	((struct ip_fw *)((char *)r + RULESIZE(r)))

char errMessage[256];
int kq, event, nbytes=0, HZ;
struct kevent ke;
struct timespec timeout; /* timeout */
int s = -1;  /* the socket */
void *kernel_rules; 

void 
writeLogs(int priority, const char *message);

static int
do_cmd(int optname, void *optval, uintptr_t optlen) {
	int i;
	 
	/* fill out the kevent struct */
    	EV_SET(&ke, s, EVFILT_WRITE, EV_ADD, 0, 0, NULL);
    	/* set the event */
    	event = kevent(kq, &ke, 1, NULL, 0, NULL);
    	if (event == -1) {
        	snprintf(errMessage,256,"kevent: set event(do_cmd): %s", strerror(errno) );
        	writeLogs(LOG_ALERT, errMessage);
        	return -1;
    	}

    	memset(&ke, 0, sizeof(ke));
    	/* receive an event */
    	event = kevent(kq, NULL, 0, &ke, 1, &timeout);
    	if (event < 0) {
        	snprintf(errMessage,256,"kevent: receive event(do_cmd): %s", strerror(errno) );
        	writeLogs(LOG_ALERT, errMessage);
        	return -1;
    	}
    	if (ke.ident == s) {
        	if (optname == IP_FW_ADD) i = getsockopt(s, IPPROTO_IP, optname, optval, (socklen_t *)optlen);
		else i = setsockopt(s, IPPROTO_IP, optname, optval, optlen);
        	/* Delete event from the kqueue */
        	/* fill out the kevent struct */
        	EV_SET(&ke, s, EVFILT_WRITE, EV_DELETE, 0, 0, NULL);
        	/* set the event */
        	event = kevent(kq, &ke, 1, NULL, 0, NULL);
        	if (event == -1) {
            		snprintf(errMessage,256,"kevent: delete event(do_cmd): %s", strerror(errno) );
            		writeLogs(LOG_ALERT, errMessage);
            		return -1;
        	}
        	return i;
    	} else {
        	snprintf(errMessage, 256, "kevent: bad event(do_cmd)");
        	writeLogs(LOG_ALERT, errMessage);
        	return -1;
    	}
}

int 
get_rules(void) {
	long tick1, tick2, time = 0;
    	int len = sizeof(long);
    	int counter = 0, i;
	int nbytes;

	int nalloc = 1024;	/* start somewhere... */
	
	/* fill out the kevent struct */
    	EV_SET(&ke, s, EVFILT_WRITE, EV_ADD, 0, 0, NULL);
    	/* set the event */
    	event = kevent(kq, &ke, 1, NULL, 0, NULL);
    	if (event == -1) {
        	snprintf(errMessage, 256, "kevent: set event(get_rules): %s", strerror(errno) );
        	writeLogs(LOG_ALERT, errMessage);
        	return -1;
    	}
	/* get rules, resizing array as necessary */
	nbytes = nalloc;

	while (nbytes >= nalloc) {
		nalloc = nalloc * 2 + 200;
		nbytes = nalloc;
		
		if ((kernel_rules = realloc(kernel_rules, nbytes)) == NULL) {
			snprintf(errMessage, 256, "get_rules: realloc: %s", strerror(errno) );
        		writeLogs(LOG_EMERG, errMessage);
        		return -2;
		}
		
		memset(&ke, 0, sizeof(ke));
        	counter++;
        	if (sysctlbyname("net.inet.ip.dummynet.curr_time", &tick1, &len, 0, 0) == -1) {
            		snprintf(errMessage, 256, "sysctlbyname: net.inet.ip.dummynet.curr_time: %s", strerror(errno) );
            		writeLogs(LOG_EMERG, errMessage);
            		return -1;
        	}
		i = getsockopt(s, IPPROTO_IP, IP_FW_GET, kernel_rules, (socklen_t *)&nbytes);
		if (sysctlbyname("net.inet.ip.dummynet.curr_time", &tick2, &len, 0, 0) == -1) {
            		snprintf(errMessage, 256, "sysctlbyname: net.inet.ip.dummynet.curr_time: %s", strerror(errno) );
            		writeLogs(LOG_EMERG, errMessage);
            		return -1;
        	}
        	time += abs(tick2 - tick1) * HZ;
		if (i < 0) {
            		snprintf(errMessage, 256, "getsockopt(IP_FW_GET): %s", strerror(errno) );
            		writeLogs(LOG_ALERT, errMessage);
            		return -1;
        	}
	}
	/* (Average time)*100 */
    	time = 100 * time / counter;
    	if (time < 100 * HZ) time = 100 * HZ;
    	if (time >= 1000000) {
        	timeout.tv_sec = (time_t)(time / 1000000);
        	timeout.tv_nsec = (time % 1000000) * 1000;
    	} else timeout.tv_nsec = time * 1000;
	
	return 0;
}

int 
compare_r(struct ip_fw *rule, struct ip_fw *kern_rule) {

	int r = 1; /*   */
	ipfw_insn *cmd_kern, *cmd_rule;
	int len_rule, len_kern, k;
	uint16_t *port_rule, *port_kern;
	uint32_t *rul_par, *kern_par, *d_rule, *d_kern;

	if (rule->set == kern_rule->set && rule->cmd_len == kern_rule->cmd_len &&
	    rule->act_ofs == kern_rule->act_ofs && rule->cmd_len == kern_rule->cmd_len) { 
		r = 0; /*   */
		//printf("rule %d\n", rule->rulenum);
		cmd_rule = rule->cmd;
		cmd_kern = kern_rule->cmd;
		len_rule = rule->cmd_len;
		len_kern = kern_rule->cmd_len;
		while (len_rule > 0 && r == 0) {
			if(cmd_rule->opcode == cmd_kern->opcode && cmd_rule->arg1 == cmd_kern->arg1 &&
			   cmd_rule->len == cmd_kern->len && (len_rule -= F_LEN(cmd_rule)) == (len_kern -= F_LEN(cmd_kern))) {
				switch(cmd_rule->opcode) {
				case O_ACCEPT: 
				case O_CHECK_STATE:
				case O_PROBE_STATE:
				case O_COUNT:
				case O_DENY:
				case O_IP_SRC_ME:
				case O_IP_DST_ME:
				case O_LAYER2:
				case O_ESTAB:
				case O_FRAG:
				case O_IN:
				case O_KEEP_STATE:
				case O_VERREVPATH:
				case O_DIVERT:
				case O_TEE:
				case O_PIPE:
				case O_QUEUE:
				case O_SKIPTO:
				case O_REJECT:
				case O_PROTO: 
				case O_IPVER:
				case O_TCPWIN:
				case O_IPPRECEDENCE:
				case O_TCPFLAGS: 
				case O_TCPOPTS:
				case O_IPTOS:
				case O_IPOPT:
				break;
				case O_LOG: 
					if (((ipfw_insn_log *)cmd_rule)->max_log != ((ipfw_insn_log *)cmd_kern)->max_log)
						r = 1;
				break;
				case O_TCPSEQ:
	     			case O_TCPACK:
				case O_UID:
				case O_GID:
				case O_PROB:
				case O_ICMPTYPE:
				case O_IP_SRC_LOOKUP:
				case O_IP_DST_LOOKUP:
				      if (((ipfw_insn_u32 *)cmd_rule)->d[0] != ((ipfw_insn_u32 *)cmd_kern)->d[0]) 
					      r = 1;
				break;
				case O_RECV:
				case O_XMIT:
				case O_VIA: 
					if (strcmp(((ipfw_insn_if *)cmd_rule)->name, ((ipfw_insn_if *)cmd_kern)->name) != 0 ||
					    ((ipfw_insn_if *)cmd_rule)->p.unit != ((ipfw_insn_if *)cmd_kern)->p.unit ||
					    strcmp(inet_ntoa(((ipfw_insn_if *)cmd_rule)->p.ip), inet_ntoa(((ipfw_insn_if *)cmd_kern)->p.ip)) != 0)
					    	r = 1;
				break;
				case O_FORWARD_IP: 
					if (((ipfw_insn_sa *)cmd_rule)->sa.sin_len != ((ipfw_insn_sa *)cmd_kern)->sa.sin_len ||
					    ((ipfw_insn_sa *)cmd_rule)->sa.sin_family != ((ipfw_insn_sa *)cmd_kern)->sa.sin_family ||
					    ((ipfw_insn_sa *)cmd_rule)->sa.sin_port != ((ipfw_insn_sa *)cmd_kern)->sa.sin_port ||
					    strcmp(inet_ntoa(((ipfw_insn_sa *)cmd_rule)->sa.sin_addr), inet_ntoa(((ipfw_insn_sa *)cmd_kern)->sa.sin_addr)))
					    	r = 1;
				break;
				case O_IP_SRCPORT:
	    			case O_IP_DSTPORT:
	     			case O_IPID:
	     			case O_IPLEN:
	     			case O_IPTTL:
					port_rule = ((ipfw_insn_u16 *)cmd_rule)->ports;
					port_kern = ((ipfw_insn_u16 *)cmd_kern)->ports;
					for (k = F_LEN(cmd_rule) - 1; k > 0; k--) {
						if (port_rule[0] != port_kern [0] || port_rule[1] != port_kern[1]) 
							{r = 1; break;}
						port_rule += 2;
						port_kern += 2;
					}
				break;
				case O_LIMIT:
					if (((ipfw_insn_limit *)cmd_rule)->conn_limit != ((ipfw_insn_limit *)cmd_kern)->conn_limit ||
					    ((ipfw_insn_limit *)cmd_rule)->limit_mask != ((ipfw_insn_limit *)cmd_kern)->limit_mask ||
					    ((ipfw_insn_limit *)cmd_rule)->_pad != ((ipfw_insn_limit *)cmd_kern)->_pad) 
					    	r = 1;
				break;
				case O_IP_SRC_SET: 
				case O_IP_DST_SET:  
					rul_par = (uint32_t *)&((ipfw_insn_ip *)cmd_rule)->addr;
					kern_par = (uint32_t *)&((ipfw_insn_ip *)cmd_kern)->addr;
					if (strcmp(inet_ntoa(*(struct in_addr *)&rul_par), inet_ntoa(*(struct in_addr *)&kern_par)) != 0)
						{r = 1; break;}
					rul_par = (uint32_t *)&((ipfw_insn_ip *)cmd_rule)->mask;
					kern_par = (uint32_t *)&((ipfw_insn_ip *)cmd_kern)->mask;
					for (k = 0; k < cmd_kern->arg1; k++) 
						if (rul_par[k/32] != kern_par[k/32])
							{r = 1; break;}	
				case O_IP_SRC:
				case O_IP_DST:
				case O_IP_SRC_MASK:
				case O_IP_DST_MASK: 
					d_rule = ((ipfw_insn_u32 *)cmd_rule)->d;
					d_kern = ((ipfw_insn_u32 *)cmd_kern)->d;
					for (k = cmd_rule->len; k > 1; k -=2) {
						if (d_rule[0] != d_kern[0] || d_rule[1] != d_kern[1]) 
							{r = 1; break;}
						d_rule += 2;
						d_kern += 2;	
					}
				break;
				default: r = 1;
				} /* end of switch */
			} else r = 1;
			cmd_rule += F_LEN(cmd_rule);
			cmd_kern += F_LEN(cmd_kern);
		} /* end of while */
	}
	return r;
}

int 
compare_rules(mular_descriptor **mular, int *counter, void *rules_begin) {
	
	struct ip_fw *rule = NULL; /* rule from new set */
	struct ip_fw *kern_rule;   /* rule from kernel */
	ipfw_insn  *cmd_rule;
	long int *tmp_ptr;
	int i, j, len_rule, flag;
	int par;
	uint32_t del_num;
	
	int len=sizeof(int);
    	struct clockinfo cl_info;

    	/* Socket */
    	if ((s = socket(AF_INET, SOCK_RAW, IPPROTO_RAW)) < 0) {
        	snprintf(errMessage,256,"socket: %s", strerror(errno) );
        	writeLogs(LOG_EMERG, errMessage);
        	return -1;
    	}

    	/* kqueue */
    	if ((kq = kqueue()) == -1) {
        	snprintf(errMessage,256,"kqueue: %s", strerror(errno) );
        	writeLogs(LOG_EMERG, errMessage);
        	return -1;
    	}

    	len = sizeof(struct clockinfo);
    	if (sysctlbyname("kern.clockrate", &cl_info, &len, 0, 0) < 0) {
        	snprintf(errMessage, 256, "sysctlbyname: kern.clockrate: %s", strerror(errno) );
        	writeLogs(LOG_EMERG, errMessage);
        	return -1;
    	} else HZ = cl_info.tick;
	
	if (get_rules() < 0) return -1; /* get rules from kernel */
	
	kern_rule = (struct ip_fw *)kernel_rules;
	
	for (i = 0; i < 65534; i++) {
		if (counter[i] == 0 && kern_rule->rulenum == i + 1) {
			/*     i + 1 */
			del_num = ((i + 1) & 0xffff) | (0 << 24);
			if(do_cmd(IP_FW_DEL, &del_num, sizeof del_num) != 0) {
				snprintf(errMessage, 256, "do_cmd: can not delete rule %d from kernel %s", i + 1, strerror(errno));
				writeLogs(LOG_ALERT, errMessage);
				return -1;
			} //else printf("Delete: %d\n", i + 1);
			/*    kernel_rule    */
			while(kern_rule->rulenum == i + 1) kern_rule = NEXT(kern_rule);
		} else if (counter[i] > 0 && kern_rule->rulenum != i + 1) 
			/*        */
			for(j = 0; j < counter[i]; j++) {
				tmp_ptr = mular_getix(mular[i], j);
				rule = (struct ip_fw *)(rules_begin + (u_int32_t) *tmp_ptr);
				len_rule = rule->cmd_len;
		   		cmd_rule = (ipfw_insn *)rule->cmd;
		   		while(len_rule > 0) {
		   			len_rule -= F_LEN(cmd_rule);
					cmd_rule += F_LEN(cmd_rule);
	     	   		}
				par = (char *)cmd_rule-(char *)rule;
		   		if(do_cmd(IP_FW_ADD, rule, (uintptr_t)&par) !=0) {
		   			snprintf(errMessage, 256, "do_cmd: can not add rule %d to kernel %s", rule->rulenum, strerror(errno));
                			writeLogs(LOG_ALERT, errMessage);
                			return -1;
		   		} //else printf("Add: %d\n", i + 1);
			}
		  else if (counter[i] > 0 && kern_rule->rulenum == i + 1) {
			/*    */
			flag = 0;
			for(j = 0; j < counter[i]; j++) {
				tmp_ptr = mular_getix(mular[i], j);
				rule = (struct ip_fw *)(rules_begin + (u_int32_t) *tmp_ptr);
				if (compare_r(rule, kern_rule) != 0) {
					flag = 1; break;
				}
				kern_rule = NEXT(kern_rule);
				/*    */
				if (j == counter[i] - 1 && kern_rule->rulenum == i + 1) {
					flag = 1; break;
				}
			}
			if (flag == 1) {
				/*     */
				del_num = ((i + 1) & 0xffff) | (0 << 24);
				if(do_cmd(IP_FW_DEL, &del_num, sizeof del_num) != 0) {
					snprintf(errMessage, 256, "do_cmd: can not delete rule %d from kernel %s", i + 1, strerror(errno));
					writeLogs(LOG_ALERT, errMessage);
					return -1;
				} //else printf("Delete: %d\n", i + 1);
				/*    */
				for(j = 0; j < counter[i]; j++) {
					tmp_ptr = mular_getix(mular[i], j);
					rule = (struct ip_fw *)(rules_begin + (u_int32_t) *tmp_ptr);
					len_rule = rule->cmd_len;
		   			cmd_rule = (ipfw_insn *)rule->cmd;
		   			while(len_rule > 0) {
		   				len_rule -= F_LEN(cmd_rule);
						cmd_rule += F_LEN(cmd_rule);
	     	   			}
					par = (char *)cmd_rule-(char *)rule;
		   			if(do_cmd(IP_FW_ADD, rule, (uintptr_t)&par) != 0) {
		   				snprintf(errMessage, 256, "do_cmd: can not add rule %d to kernel %s", rule->rulenum, strerror(errno));
                				writeLogs(LOG_ALERT, errMessage);
                				return -1;
		   			} //else printf("Add: %d\n", i + 1);
				}
				/*    kernel_rule    */
				while(kern_rule->rulenum == i + 1) kern_rule = NEXT(kern_rule);
			}
		}
	}

	return 0;
}	

	
