#include <linux/init.h>
#include <linux/module.h>
#include <linux/netfilter.h>
#include <linux/netfilter_bridge.h>
#include <linux/list.h>
#include <linux/if_ether.h>
#include <linux/etherdevice.h>
#include <net/netlink.h>
#include <linux/string.h>
#include <linux/timer.h>
#include <linux/smp.h>
#include <linux/spinlock.h>
#include <linux/kobject.h>
#include <linux/version.h>

MODULE_LICENSE("Dual BSD/GPL");

#define HASH_TABLE_SIZE 256
#define MAC_ADDR_HASH(addr) (addr[0]^addr[1]^addr[2]^addr[3]^addr[4]^addr[5])
#define MAC_ADDR_HASH_INDEX(addr) (MAC_ADDR_HASH(addr) & (HASH_TABLE_SIZE - 1))

struct map_nt_entry {
	struct list_head list;
	unsigned char mac[6];
	unsigned char dev_addr[6];
	unsigned long updated;
	char intf[IFNAMSIZ + 1];
};

#define MAP_NETLINK 25
#define MAX_ENTRY_CNT 256
#define MAX_MSGSIZE 1024
#define NT_TIMEOUT (600 * HZ) /*10mins*/
#define TIMER_EXCUTE_PERIOD (5*HZ)

#define CMDU_HLEN	8
#define TOPOLOGY_DISCOVERY	0x0000
#define END_OF_TLV_TYPE 0

struct list_head nt[HASH_TABLE_SIZE];
spinlock_t nt_lock;
struct sock *nl_sk;
struct timer_list nt_timer;
int nt_cnt;
unsigned char almac[ETH_ALEN];
unsigned char mtk_oui[3] = {0x00, 0x0C, 0xE7};

void hex_dump_all(char *str, unsigned char *pSrcBufVA, unsigned int SrcBufLen)
{
	unsigned char *pt;
	int x;

	pt = pSrcBufVA;
	pr_info("%s: %p, len = %d\n",str,  pSrcBufVA, SrcBufLen);
	for (x=0; x<SrcBufLen; x++)
	{
		if (x % 16 == 0)
			pr_info("0x%04x : ", x);
		pr_info("%02x ", ((unsigned char)pt[x]));
		if (x%16 == 15) pr_info("\n");
	}
    pr_info("\n");
}


unsigned short get_1905_message_type(unsigned char *msg)
{
	unsigned char *pos = NULL;
	unsigned short mtype = 0;

	pos = msg;
	/*jump to message_type*/
	pos += 2;
	memcpy(&mtype, pos, 2);
	mtype = htons(mtype);

	return mtype;
}

int get_cmdu_tlv_length(unsigned char *buf)
{
    unsigned char *temp_buf = buf;
    unsigned short length = 0;

    temp_buf += 1;

    length = (*temp_buf);
    length = (length << 8) & 0xFF00;
    length = length |(*(temp_buf+1));

    return (length+3);
}

int add_discovery_inf_vstlv(struct sk_buff *skb, unsigned char *rmac)
{
	unsigned char *temp_buf = skb->data;
	int length =0, left_len = 0;
	/*vslen = tlvType(1 octet) + tlvLength(2 octets) + OUI(3 octets) +
	* subType(1 octet) + subLength(2 octets) + mac(6 octets)
	*/
	unsigned char vsinfo[15] = {0};

	/*build inf vs tlv*/
	vsinfo[0] = 11; /*Vendor-specific TLV value*/
	vsinfo[1] = 0x00;
	vsinfo[2] = 0x0c;
	memcpy(&vsinfo[3], mtk_oui, 3);
	vsinfo[6] = 0x00; /*subTyle inf mac*/
	vsinfo[7] = 0x00;
	vsinfo[8] = 0x06;
	memcpy(&vsinfo[9], rmac, 6);

	/*1905 cmdu header*/
	temp_buf += CMDU_HLEN;
	left_len = (int)skb->len;
	left_len -= CMDU_HLEN;
	while (left_len > 0) {
        if (*temp_buf == END_OF_TLV_TYPE) {
            break;
        } else {
            /*ignore extra tlv*/
            //pr_info("%s ignore extra tlv in topology discovery message\n", __func__);
            length = get_cmdu_tlv_length(temp_buf);
            temp_buf += length;
			left_len -= length;
        }
    }

	//pr_info("%s skb headroom(%d) tailroom(%d) left_len(%d)\n",
	//	__func__, skb_headroom(skb), skb_tailroom(skb), left_len);

	if (left_len < 0) {
		pr_info("%s error discovery msg\n", __func__);
		return -1;
	}


	/*vsinfo(15 octets) + End of message TLV(3 octets)*/
	if ((left_len + skb_tailroom(skb)) < 18) {
		/*do skb expand*/
		pr_info("%s enough room for vstlv & end tlv\n", __func__);
	} else {
		//pr_info("%s enough room for vstlv & end tlv\n", __func__);
		if (left_len >= 18) {
			memcpy(temp_buf, vsinfo, 15);
			temp_buf += 15;
			memset(temp_buf, 0, 3);
		} else {
			pr_info("%s need move skb->tail\n", __func__);
		}
	}

	return 0;
}

void nt_timeout_func(unsigned long arg)
{
	int i;
	struct map_nt_entry *pos, *n;

	spin_lock_bh(&nt_lock);
	for (i = 0; i < HASH_TABLE_SIZE; i++) {
		list_for_each_entry_safe(pos, n, &nt[i], list) {
			if (time_after(jiffies, pos->updated+NT_TIMEOUT)) {
				pr_info("timeout! free %02x:%02x:%02x:%02x:%02x:%02x\n",
					pos->mac[0], pos->mac[1], pos->mac[2],
					pos->mac[3], pos->mac[4], pos->mac[5]);
				list_del(&pos->list);
				nt_cnt--;
				kfree(pos);
			}
		}
	}
	spin_unlock_bh(&nt_lock);
	mod_timer(&nt_timer, jiffies + TIMER_EXCUTE_PERIOD);

	return;
}

int nt_get_intf(char *addr)
{
	int hash_idx = MAC_ADDR_HASH_INDEX(addr);
	struct map_nt_entry *pos;
	spin_lock_bh(&nt_lock);
	list_for_each_entry(pos, &nt[hash_idx], list) {
		if (ether_addr_equal(pos->mac, addr)) {
			memcpy(addr, pos->dev_addr, ETH_ALEN);
			spin_unlock_bh(&nt_lock);
			return 0;
		}
	}
	spin_unlock_bh(&nt_lock);
	memset(addr, 0 , ETH_ALEN);
	return -1;
}

#if (LINUX_VERSION_CODE >= KERNEL_VERSION(4,1,0))
unsigned int map_hook_fn_pre_rout(void *priv,
			       struct sk_buff *skb,
			       const struct nf_hook_state *state)
#else
unsigned int map_hook_fn_pre_rout(unsigned int hooknum,
	struct sk_buff *skb,
	const struct net_device *in,
	const struct net_device *out,
	int (*okfn)(struct sk_buff *))
#endif
{
	struct net_device *indev = NULL;
	struct ethhdr *hdr = eth_hdr(skb);
	int hash_idx = MAC_ADDR_HASH_INDEX(hdr->h_source);
	struct map_nt_entry *pos;
	unsigned short mtype = 0;
	int ret = 0;
	
	/*step 1: if it's not a 1905 packet, ignore it.*/
	if (likely(skb->protocol != htons(0x893A)))
		return NF_ACCEPT;

#if (LINUX_VERSION_CODE >= KERNEL_VERSION(4,1,0))
	indev = state->in;
#else
	indev = in;
#endif

	/*only handle skb without nonlinear memory*/
	if (likely(skb->data_len == 0)) {
		/*step 1.1: check if the pkts is a discovery message
		* need add vendor specific tlv specify receiving net device mac
		*/
		mtype = get_1905_message_type(skb->data);
		if (mtype == TOPOLOGY_DISCOVERY) {
			//pr_info("%s receive TOPOLOGY_DISCOVERY msg\n", __func__);
			//hex_dump_all("discovery",skb->data,skb->len);
			ret = add_discovery_inf_vstlv(skb, indev->dev_addr);
			if (ret < 0) {
				pr_info("%s drop error msg\n", __func__);
				return NF_DROP;
			}
		}
	}

	spin_lock_bh(&nt_lock);
	/*step 2: if the source address was add to the table, update it.*/
	list_for_each_entry(pos, &nt[hash_idx], list) {
		if (ether_addr_equal(pos->mac, hdr->h_source)) {
			strncpy(pos->intf, indev->name, IFNAMSIZ);
			memcpy(pos->dev_addr, indev->dev_addr, 6);
			pos->updated = jiffies;
			break;
		}
	}

	/*step 3: if the source address has not been added to the table.*/
	if (&pos->list == &nt[hash_idx] && nt_cnt <= MAX_ENTRY_CNT) {
		/*add a new entry to table.*/
		pos = kmalloc(sizeof(struct map_nt_entry), GFP_ATOMIC);
		if (pos == NULL)
			goto out;

		memset(pos, 0, sizeof(struct map_nt_entry));
		memcpy(pos->mac, hdr->h_source, 6);
		memcpy(pos->dev_addr, indev->dev_addr, 6);
		strncpy(pos->intf, indev->name, IFNAMSIZ);
		pos->updated = jiffies;

		pr_info("alloc new entry for %02x:%02x:%02x:%02x:%02x:%02x, interface:%s\n",
			hdr->h_source[0], hdr->h_source[1], hdr->h_source[2],
			hdr->h_source[3], hdr->h_source[4], hdr->h_source[5],
			indev->name);
		pr_info("recv intf mac %02x:%02x:%02x:%02x:%02x:%02x\n",
			indev->dev_addr[0], indev->dev_addr[1], indev->dev_addr[2],
			indev->dev_addr[3], indev->dev_addr[4], indev->dev_addr[5]);
		list_add_tail(&pos->list, &nt[hash_idx]);
		nt_cnt++;

	}
out:
	spin_unlock_bh(&nt_lock);
	return NF_ACCEPT;
}

#if (LINUX_VERSION_CODE >= KERNEL_VERSION(4,1,0))
unsigned int map_hook_fn_foward(void *priv,
			       struct sk_buff *skb,
			       const struct nf_hook_state *state)
#else
unsigned int map_hook_fn_foward(unsigned int hooknum,
		struct sk_buff *skb,
		const struct net_device *in,
		const struct net_device *out,
		int (*okfn)(struct sk_buff *))
#endif		
{
	struct ethhdr *hdr = eth_hdr(skb);
#if 0
	unsigned char *pos = NULL;
	unsigned short mid = 0, mtype = 0, protocol = 0;
#endif

	if ((skb->protocol == htons(0x893A)) && (hdr->h_dest[0]&1)) {
#if 0
		pr_info("Drop 1905 MC frame when foward src.%02x:%02x:%02x:%02x:%02x:%02x\n",
			hdr->h_source[0], hdr->h_source[1], hdr->h_source[2],
			hdr->h_source[3], hdr->h_source[4], hdr->h_source[5]);
		pr_info("Drop 1905 MC frame when foward. dest%02x:%02x:%02x:%02x:%02x:%02x\n",
			hdr->h_dest[0], hdr->h_dest[1], hdr->h_dest[2],
			hdr->h_dest[3], hdr->h_dest[4], hdr->h_dest[5]);
#endif
		return NF_DROP;
	} else if (skb->protocol == htons(0x893A) && !memcmp(hdr->h_dest, almac, ETH_ALEN)) {
#if 0
		pr_info("Drop 1905 dest frame when foward.%02x:%02x:%02x:%02x:%02x:%02x\n",
			hdr->h_dest[0], hdr->h_dest[1], hdr->h_dest[2],
			hdr->h_dest[3], hdr->h_dest[4], hdr->h_dest[5]);
#endif
			return NF_DROP;
	} else{
#if 0
		if (skb->protocol == htons(0x893A)) {
			pos = (unsigned char *)hdr;
			//pr_info("foward 1905 frame when foward.%02x:%02x:%02x:%02x:%02x:%02x\n",
			//hdr->h_dest[0], hdr->h_dest[1], hdr->h_dest[2],
			//hdr->h_dest[3], hdr->h_dest[4], hdr->h_dest[5]);
			pr_info("foward 1905 frame when foward.%02x:%02x:%02x:%02x:%02x:%02x\n",
				pos[0], pos[1], pos[2], pos[3], pos[4], pos[5]);
			pos += 6;
			pr_info("foward 1905 frame when foward src.%02x:%02x:%02x:%02x:%02x:%02x\n",
				pos[0], pos[1], pos[2], pos[3], pos[4], pos[5]);
			pos += 6;
			memcpy(&protocol, pos, 2);
			protocol = htons(protocol);
			pos += 4;
			memcpy(&mtype, pos, 2);
			mtype = htons(mtype);
			pos += 2;
			memcpy(&mid, pos, 2);
			mid = htons(mid);
			pr_info("inf(%s) protocol(%04x) msgtype(%04x) msgid(%04x)\n",out->name, protocol, mtype, mid);
		}
#endif
		return NF_ACCEPT;
	}
}

void send_msg(char *msg, int pid)
{
	struct sk_buff *skb;
	struct nlmsghdr *nlh;
	int len = NLMSG_SPACE(MAX_MSGSIZE);

	if (!msg || !nl_sk)
		return;

	skb = alloc_skb(len, GFP_KERNEL);
	if (!skb) {
		pr_info("send_msg:alloc_skb error\n");
		return;
	}
	nlh = nlmsg_put(skb, 0, 0, 0, MAX_MSGSIZE, 0);
	NETLINK_CB(skb).portid = 0;
	NETLINK_CB(skb).dst_group = 0;
	memcpy(NLMSG_DATA(nlh), msg, 6);
	netlink_unicast(nl_sk, skb, pid, MSG_DONTWAIT);
}

void recv_nlmsg(struct sk_buff *skb)
{
	int pid;
	struct nlmsghdr *nlh = nlmsg_hdr(skb);
	char *msg = NULL;

	if (nlh->nlmsg_len < NLMSG_HDRLEN || skb->len < nlh->nlmsg_len)
		return;

	msg = (char *)NLMSG_DATA(nlh);
	pr_info("address %02x:%02x:%02x:%02x:%02x:%02x is at ",
		msg[0], msg[1], msg[2], msg[3], msg[4], msg[5]);
	pid = nlh->nlmsg_pid;

	nt_get_intf(msg);
	pr_info("%02x:%02x:%02x:%02x:%02x:%02x\n",
		msg[0], msg[1], msg[2], msg[3], msg[4], msg[5]);
	send_msg(msg, pid);
}

struct netlink_kernel_cfg nl_kernel_cfg = {
	.groups = 0,
	.flags = 0,
	.input = recv_nlmsg,
	.cb_mutex = NULL,
	.bind = NULL,
};

static struct nf_hook_ops map_ops[] = {
	{
		/*hook for 1905 daemon to get which intface
		  does one frame come from.*/
		.hook		= map_hook_fn_pre_rout,
		.pf		= NFPROTO_BRIDGE,
		.hooknum	= NF_BR_PRE_ROUTING,
		.priority	= NF_BR_PRI_BRNF,
		//.owner		= THIS_MODULE,
	},
	{
		/*hook for NOT to foward 1905 MC frame.*/
		.hook		= map_hook_fn_foward,
		.pf		= NFPROTO_BRIDGE,
		.hooknum	= NF_BR_FORWARD,
		.priority	= NF_BR_PRI_BRNF,
		//.owner		= THIS_MODULE,
	}
};

#define MAP_SO_BASE 1905
#define MAP_GET_MAC_BY_SRC MAP_SO_BASE
#define MAP_SET_ALMAC MAP_SO_BASE
#define MAP_SO_MAX (MAP_SO_BASE + 1)

static int do_map_set_ctl(struct sock *sk, int cmd,
		void __user *user, unsigned int len)
{
	int ret = 0;

	pr_info("do_map_set_ctl==>cmd(%d)\n", cmd);
	switch (cmd) {
	case MAP_SET_ALMAC:
		if (copy_from_user(almac, user, ETH_ALEN) != 0) {
			ret = -EFAULT;
			pr_info("do_map_set_ctl==>copy_from_user fail\n");
			break;
		}
		pr_info("do_map_set_ctl==>almac(%02x:%02x:%02x:%02x:%02x:%02x)\n",
			almac[0],almac[1],almac[2],almac[3],almac[4],almac[5]);
		break;
	default:
		ret = -EINVAL;
		break;
	}

	return ret;
}

static int do_map_get_ctl(struct sock *sk, int cmd, void __user *user, int *len)
{
	int ret = 0;

	switch (cmd) {
	case MAP_GET_MAC_BY_SRC:
	{
		char addr[ETH_ALEN] = {0};

		if (*len < ETH_ALEN) {
			ret = -EINVAL;
			break;
		}

		if (copy_from_user(addr, user, ETH_ALEN) != 0) {
			ret = EFAULT;
			break;
		}

		nt_get_intf(addr);
		if (copy_to_user(user, addr, ETH_ALEN) != 0)
			ret = -EFAULT;
		break;
	}
	default:
		ret = -EINVAL;
	}

	return ret;
}
static struct nf_sockopt_ops map_sockopts = {
	.pf		= PF_INET,
	.set_optmin	= MAP_SO_BASE,
	.set_optmax	= MAP_SO_MAX,
	.set		= do_map_set_ctl,
#ifdef CONFIG_COMPAT
	.compat_set	= NULL,
#endif
	.get_optmin	= MAP_SO_BASE,
	.get_optmax	= MAP_SO_MAX,
	.get		= do_map_get_ctl,
#ifdef CONFIG_COMPAT
	.compat_get	= NULL,
#endif
	.owner		= THIS_MODULE,
};


static ssize_t map_nt_show(struct kobject *kobj,
		struct kobj_attribute *attr, char *buf)
{
	int i, cnt = 0;
	struct map_nt_entry *pos;

	spin_lock_bh(&nt_lock);
	for (i = 0; i < HASH_TABLE_SIZE; i++) {
		if (list_empty(&nt[i]))
			continue;
		list_for_each_entry(pos, &nt[i], list) {
			cnt += sprintf(&buf[cnt],
				"idx: %d\t%02x:%02x:%02x:%02x:%02x:%02x is at %s\n",
				i, pos->mac[0], pos->mac[1], pos->mac[2],
				pos->mac[3], pos->mac[4], pos->mac[5],
				pos->intf);
		}
	}
	spin_unlock_bh(&nt_lock);

	return cnt;
}

static ssize_t map_nt_cnt_show(struct kobject *kobj,
		struct kobj_attribute *attr, char *buf)
{
	return sprintf(buf, "%d\n", nt_cnt);
}

static struct kobj_attribute map_sysfs_nt_show =
		__ATTR(nt_show, S_IRUGO, map_nt_show, NULL);
static struct kobj_attribute map_sysfs_nt_cnt_show =
		__ATTR(nt_cnt_show, S_IRUGO, map_nt_cnt_show, NULL);

static struct attribute *map_sysfs[] = {
	&map_sysfs_nt_show.attr,
	&map_sysfs_nt_cnt_show.attr,
	NULL,
};
static struct attribute_group map_attr_group = {
	.attrs = map_sysfs,
};
struct kobject *map_kobj;

static int __init map_init(void)
{
	int ret, i;

	spin_lock_init(&nt_lock);
	for (i = 0; i < HASH_TABLE_SIZE; i++)
		INIT_LIST_HEAD(&nt[i]);

	ret = nf_register_hooks(&map_ops[0], ARRAY_SIZE(map_ops));
	if (ret < 0) {
		pr_info("register nf hook fail, ret = %d\n", ret);
		goto error1;
	}

	nl_sk = netlink_kernel_create(&init_net, MAP_NETLINK, &nl_kernel_cfg);
	if (!nl_sk) {
		pr_info("create netlink socket error.\n");
		ret = -EFAULT;
		goto error2;
	}

	ret = nf_register_sockopt(&map_sockopts);
	if (ret < 0)
		goto error3;

	map_kobj = kobject_create_and_add("mapfilter", NULL);
	if (!map_kobj) {
		ret = -EFAULT;
		goto error4;
	}

	ret = sysfs_create_group(map_kobj, &map_attr_group);
	if (ret)
		goto error5;

	init_timer(&nt_timer);
	nt_timer.function = nt_timeout_func;
	nt_timer.expires = TIMER_EXCUTE_PERIOD;
	add_timer(&nt_timer);
	return ret;
error5:
	kobject_put(map_kobj);
error4:
	nf_unregister_sockopt(&map_sockopts);
error3:
	sock_release(nl_sk->sk_socket);
error2:
	nf_unregister_hooks(&map_ops[0], ARRAY_SIZE(map_ops));
error1:
	return ret;
}

static void __exit map_exit(void)
{
	int i;
	struct map_nt_entry *pos, *n;

	del_timer_sync(&nt_timer);

	kobject_put(map_kobj);
	sysfs_remove_group(map_kobj, &map_attr_group);

	nf_unregister_sockopt(&map_sockopts);

	if (nl_sk != NULL)
		sock_release(nl_sk->sk_socket);

	nf_unregister_hooks(&map_ops[0], ARRAY_SIZE(map_ops));

	for (i = 0; i < HASH_TABLE_SIZE; i++) {
		list_for_each_entry_safe(pos, n, &nt[i], list) {
			pr_info("exit! free %02x:%02x:%02x:%02x:%02x:%02x",
				pos->mac[0], pos->mac[1], pos->mac[2],
				pos->mac[3], pos->mac[4], pos->mac[5]);
			list_del(&pos->list);
			nt_cnt--;
			kfree(pos);
		}
	}
	return;
}

module_init(map_init);
module_exit(map_exit);

