diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
| -rw-r--r-- | net/netlink/af_netlink.c | 178 | 
1 files changed, 168 insertions, 10 deletions
| diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 57ee84d21470..0c61b59175dc 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -57,6 +57,7 @@  #include <linux/audit.h>  #include <linux/mutex.h>  #include <linux/vmalloc.h> +#include <linux/if_arp.h>  #include <asm/cacheflush.h>  #include <net/net_namespace.h> @@ -101,6 +102,9 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);  static ATOMIC_NOTIFIER_HEAD(netlink_chain); +static DEFINE_SPINLOCK(netlink_tap_lock); +static struct list_head netlink_tap_all __read_mostly; +  static inline u32 netlink_group_mask(u32 group)  {  	return group ? 1 << (group - 1) : 0; @@ -111,6 +115,100 @@ static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u  	return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask];  } +int netlink_add_tap(struct netlink_tap *nt) +{ +	if (unlikely(nt->dev->type != ARPHRD_NETLINK)) +		return -EINVAL; + +	spin_lock(&netlink_tap_lock); +	list_add_rcu(&nt->list, &netlink_tap_all); +	spin_unlock(&netlink_tap_lock); + +	if (nt->module) +		__module_get(nt->module); + +	return 0; +} +EXPORT_SYMBOL_GPL(netlink_add_tap); + +int __netlink_remove_tap(struct netlink_tap *nt) +{ +	bool found = false; +	struct netlink_tap *tmp; + +	spin_lock(&netlink_tap_lock); + +	list_for_each_entry(tmp, &netlink_tap_all, list) { +		if (nt == tmp) { +			list_del_rcu(&nt->list); +			found = true; +			goto out; +		} +	} + +	pr_warn("__netlink_remove_tap: %p not found\n", nt); +out: +	spin_unlock(&netlink_tap_lock); + +	if (found && nt->module) +		module_put(nt->module); + +	return found ? 0 : -ENODEV; +} +EXPORT_SYMBOL_GPL(__netlink_remove_tap); + +int netlink_remove_tap(struct netlink_tap *nt) +{ +	int ret; + +	ret = __netlink_remove_tap(nt); +	synchronize_net(); + +	return ret; +} +EXPORT_SYMBOL_GPL(netlink_remove_tap); + +static int __netlink_deliver_tap_skb(struct sk_buff *skb, +				     struct net_device *dev) +{ +	struct sk_buff *nskb; +	int ret = -ENOMEM; + +	dev_hold(dev); +	nskb = skb_clone(skb, GFP_ATOMIC); +	if (nskb) { +		nskb->dev = dev; +		ret = dev_queue_xmit(nskb); +		if (unlikely(ret > 0)) +			ret = net_xmit_errno(ret); +	} + +	dev_put(dev); +	return ret; +} + +static void __netlink_deliver_tap(struct sk_buff *skb) +{ +	int ret; +	struct netlink_tap *tmp; + +	list_for_each_entry_rcu(tmp, &netlink_tap_all, list) { +		ret = __netlink_deliver_tap_skb(skb, tmp->dev); +		if (unlikely(ret)) +			break; +	} +} + +static void netlink_deliver_tap(struct sk_buff *skb) +{ +	rcu_read_lock(); + +	if (unlikely(!list_empty(&netlink_tap_all))) +		__netlink_deliver_tap(skb); + +	rcu_read_unlock(); +} +  static void netlink_overrun(struct sock *sk)  {  	struct netlink_sock *nlk = nlk_sk(sk); @@ -750,6 +848,13 @@ static void netlink_skb_destructor(struct sk_buff *skb)  		skb->head = NULL;  	}  #endif +	if (is_vmalloc_addr(skb->head)) { +		if (!skb->cloned || +		    !atomic_dec_return(&(skb_shinfo(skb)->dataref))) +			vfree(skb->head); + +		skb->head = NULL; +	}  	if (skb->sk != NULL)  		sock_rfree(skb);  } @@ -854,16 +959,23 @@ netlink_unlock_table(void)  		wake_up(&nl_table_wait);  } +static bool netlink_compare(struct net *net, struct sock *sk) +{ +	return net_eq(sock_net(sk), net); +} +  static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)  { -	struct nl_portid_hash *hash = &nl_table[protocol].hash; +	struct netlink_table *table = &nl_table[protocol]; +	struct nl_portid_hash *hash = &table->hash;  	struct hlist_head *head;  	struct sock *sk;  	read_lock(&nl_table_lock);  	head = nl_portid_hashfn(hash, portid);  	sk_for_each(sk, head) { -		if (net_eq(sock_net(sk), net) && (nlk_sk(sk)->portid == portid)) { +		if (table->compare(net, sk) && +		    (nlk_sk(sk)->portid == portid)) {  			sock_hold(sk);  			goto found;  		} @@ -976,7 +1088,8 @@ netlink_update_listeners(struct sock *sk)  static int netlink_insert(struct sock *sk, struct net *net, u32 portid)  { -	struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash; +	struct netlink_table *table = &nl_table[sk->sk_protocol]; +	struct nl_portid_hash *hash = &table->hash;  	struct hlist_head *head;  	int err = -EADDRINUSE;  	struct sock *osk; @@ -986,7 +1099,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)  	head = nl_portid_hashfn(hash, portid);  	len = 0;  	sk_for_each(osk, head) { -		if (net_eq(sock_net(osk), net) && (nlk_sk(osk)->portid == portid)) +		if (table->compare(net, osk) && +		    (nlk_sk(osk)->portid == portid))  			break;  		len++;  	} @@ -1183,7 +1297,8 @@ static int netlink_autobind(struct socket *sock)  {  	struct sock *sk = sock->sk;  	struct net *net = sock_net(sk); -	struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash; +	struct netlink_table *table = &nl_table[sk->sk_protocol]; +	struct nl_portid_hash *hash = &table->hash;  	struct hlist_head *head;  	struct sock *osk;  	s32 portid = task_tgid_vnr(current); @@ -1195,7 +1310,7 @@ retry:  	netlink_table_grab();  	head = nl_portid_hashfn(hash, portid);  	sk_for_each(osk, head) { -		if (!net_eq(sock_net(osk), net)) +		if (!table->compare(net, osk))  			continue;  		if (nlk_sk(osk)->portid == portid) {  			/* Bind collision, search negative portid values. */ @@ -1420,6 +1535,33 @@ struct sock *netlink_getsockbyfilp(struct file *filp)  	return sock;  } +static struct sk_buff *netlink_alloc_large_skb(unsigned int size, +					       int broadcast) +{ +	struct sk_buff *skb; +	void *data; + +	if (size <= NLMSG_GOODSIZE || broadcast) +		return alloc_skb(size, GFP_KERNEL); + +	size = SKB_DATA_ALIGN(size) + +	       SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + +	data = vmalloc(size); +	if (data == NULL) +		return NULL; + +	skb = build_skb(data, size); +	if (skb == NULL) +		vfree(data); +	else { +		skb->head_frag = 0; +		skb->destructor = netlink_skb_destructor; +	} + +	return skb; +} +  /*   * Attach a skb to a netlink socket.   * The caller must hold a reference to the destination socket. On error, the @@ -1475,6 +1617,8 @@ static int __netlink_sendskb(struct sock *sk, struct sk_buff *skb)  {  	int len = skb->len; +	netlink_deliver_tap(skb); +  #ifdef CONFIG_NETLINK_MMAP  	if (netlink_skb_is_mmaped(skb))  		netlink_queue_mmaped_skb(sk, skb); @@ -1510,7 +1654,7 @@ static struct sk_buff *netlink_trim(struct sk_buff *skb, gfp_t allocation)  		return skb;  	delta = skb->end - skb->tail; -	if (delta * 2 < skb->truesize) +	if (is_vmalloc_addr(skb->head) || delta * 2 < skb->truesize)  		return skb;  	if (skb_shared(skb)) { @@ -1535,6 +1679,11 @@ static int netlink_unicast_kernel(struct sock *sk, struct sk_buff *skb,  	ret = -ECONNREFUSED;  	if (nlk->netlink_rcv != NULL) { +		/* We could do a netlink_deliver_tap(skb) here as well +		 * but since this is intended for the kernel only, we +		 * should rather let it stay under the hood. +		 */ +  		ret = skb->len;  		netlink_skb_set_owner_r(skb, sk);  		NETLINK_CB(skb).sk = ssk; @@ -2096,7 +2245,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,  	if (len > sk->sk_sndbuf - 32)  		goto out;  	err = -ENOBUFS; -	skb = alloc_skb(len, GFP_KERNEL); +	skb = netlink_alloc_large_skb(len, dst_group);  	if (skb == NULL)  		goto out; @@ -2285,6 +2434,8 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,  		if (cfg) {  			nl_table[unit].bind = cfg->bind;  			nl_table[unit].flags = cfg->flags; +			if (cfg->compare) +				nl_table[unit].compare = cfg->compare;  		}  		nl_table[unit].registered = 1;  	} else { @@ -2707,6 +2858,7 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)  {  	struct sock *s;  	struct nl_seq_iter *iter; +	struct net *net;  	int i, j;  	++*pos; @@ -2714,11 +2866,12 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)  	if (v == SEQ_START_TOKEN)  		return netlink_seq_socket_idx(seq, 0); +	net = seq_file_net(seq);  	iter = seq->private;  	s = v;  	do {  		s = sk_next(s); -	} while (s && sock_net(s) != seq_file_net(seq)); +	} while (s && !nl_table[s->sk_protocol].compare(net, s));  	if (s)  		return s; @@ -2730,7 +2883,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)  		for (; j <= hash->mask; j++) {  			s = sk_head(&hash->table[j]); -			while (s && sock_net(s) != seq_file_net(seq)) + +			while (s && !nl_table[s->sk_protocol].compare(net, s))  				s = sk_next(s);  			if (s) {  				iter->link = i; @@ -2923,8 +3077,12 @@ static int __init netlink_proto_init(void)  		hash->shift = 0;  		hash->mask = 0;  		hash->rehash_time = jiffies; + +		nl_table[i].compare = netlink_compare;  	} +	INIT_LIST_HEAD(&netlink_tap_all); +  	netlink_add_usersock_entry();  	sock_register(&netlink_family_ops); | 
