]> www.pilppa.org Git - linux-2.6-omap-h63xx.git/blobdiff - net/ipv4/udp.c
udp: multicast packets need to check namespace
[linux-2.6-omap-h63xx.git] / net / ipv4 / udp.c
index 57e26fa66185affd152d8613800ee5a9e6230991..cf02701ced48091d9eecb6765fe80934a3dc73c8 100644 (file)
@@ -8,7 +8,7 @@
  * Authors:    Ross Biro
  *             Fred N. van Kempen, <waltje@uWalt.NL.Mugnet.ORG>
  *             Arnt Gulbrandsen, <agulbra@nvg.unit.no>
- *             Alan Cox, <Alan.Cox@linux.org>
+ *             Alan Cox, <alan@lxorguk.ukuu.org.uk>
  *             Hirokazu Takahashi, <taka@valinux.co.jp>
  *
  * Fixes:
  *     Snmp MIB for the UDP layer
  */
 
-DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly;
-EXPORT_SYMBOL(udp_stats_in6);
-
 struct hlist_head udp_hash[UDP_HTABLE_SIZE];
 DEFINE_RWLOCK(udp_hash_lock);
 
@@ -125,14 +122,23 @@ EXPORT_SYMBOL(sysctl_udp_wmem_min);
 atomic_t udp_memory_allocated;
 EXPORT_SYMBOL(udp_memory_allocated);
 
-static inline int __udp_lib_lport_inuse(struct net *net, __u16 num,
-                                       const struct hlist_head udptable[])
+static int udp_lib_lport_inuse(struct net *net, __u16 num,
+                              const struct hlist_head udptable[],
+                              struct sock *sk,
+                              int (*saddr_comp)(const struct sock *sk1,
+                                                const struct sock *sk2))
 {
-       struct sock *sk;
+       struct sock *sk2;
        struct hlist_node *node;
 
-       sk_for_each(sk, node, &udptable[udp_hashfn(net, num)])
-               if (net_eq(sock_net(sk), net) && sk->sk_hash == num)
+       sk_for_each(sk2, node, &udptable[udp_hashfn(net, num)])
+               if (net_eq(sock_net(sk2), net)                  &&
+                   sk2 != sk                                   &&
+                   sk2->sk_hash == num                         &&
+                   (!sk2->sk_reuse || !sk->sk_reuse)           &&
+                   (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
+                       || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
+                   (*saddr_comp)(sk, sk2))
                        return 1;
        return 0;
 }
@@ -149,83 +155,37 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
                                         const struct sock *sk2 )    )
 {
        struct hlist_head *udptable = sk->sk_prot->h.udp_hash;
-       struct hlist_node *node;
-       struct hlist_head *head;
-       struct sock *sk2;
        int    error = 1;
        struct net *net = sock_net(sk);
 
        write_lock_bh(&udp_hash_lock);
 
        if (!snum) {
-               int i, low, high, remaining;
-               unsigned rover, best, best_size_so_far;
+               int low, high, remaining;
+               unsigned rand;
+               unsigned short first;
 
                inet_get_local_port_range(&low, &high);
                remaining = (high - low) + 1;
 
-               best_size_so_far = UINT_MAX;
-               best = rover = net_random() % remaining + low;
-
-               /* 1st pass: look for empty (or shortest) hash chain */
-               for (i = 0; i < UDP_HTABLE_SIZE; i++) {
-                       int size = 0;
-
-                       head = &udptable[udp_hashfn(net, rover)];
-                       if (hlist_empty(head))
-                               goto gotit;
-
-                       sk_for_each(sk2, node, head) {
-                               if (++size >= best_size_so_far)
-                                       goto next;
-                       }
-                       best_size_so_far = size;
-                       best = rover;
-               next:
-                       /* fold back if end of range */
-                       if (++rover > high)
-                               rover = low + ((rover - low)
-                                              & (UDP_HTABLE_SIZE - 1));
-
-
-               }
-
-               /* 2nd pass: find hole in shortest hash chain */
-               rover = best;
-               for (i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; i++) {
-                       if (! __udp_lib_lport_inuse(net, rover, udptable))
-                               goto gotit;
-                       rover += UDP_HTABLE_SIZE;
-                       if (rover > high)
-                               rover = low + ((rover - low)
-                                              & (UDP_HTABLE_SIZE - 1));
+               rand = net_random();
+               snum = first = rand % remaining + low;
+               rand |= 1;
+               while (udp_lib_lport_inuse(net, snum, udptable, sk,
+                                          saddr_comp)) {
+                       do {
+                               snum = snum + rand;
+                       } while (snum < low || snum > high);
+                       if (snum == first)
+                               goto fail;
                }
-
-
-               /* All ports in use! */
+       } else if (udp_lib_lport_inuse(net, snum, udptable, sk, saddr_comp))
                goto fail;
 
-gotit:
-               snum = rover;
-       } else {
-               head = &udptable[udp_hashfn(net, snum)];
-
-               sk_for_each(sk2, node, head)
-                       if (sk2->sk_hash == snum                             &&
-                           sk2 != sk                                        &&
-                           net_eq(sock_net(sk2), net)                       &&
-                           (!sk2->sk_reuse        || !sk->sk_reuse)         &&
-                           (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
-                            || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
-                           (*saddr_comp)(sk, sk2)                             )
-                               goto fail;
-       }
-
        inet_sk(sk)->num = snum;
        sk->sk_hash = snum;
        if (sk_unhashed(sk)) {
-               head = &udptable[udp_hashfn(net, snum)];
-               sk_add_node(sk, head);
+               sk_add_node(sk, &udptable[udp_hashfn(net, snum)]);
                sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
        }
        error = 0;
@@ -302,7 +262,29 @@ static struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
        return result;
 }
 
-static inline struct sock *udp_v4_mcast_next(struct sock *sk,
+static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
+                                                __be16 sport, __be16 dport,
+                                                struct hlist_head udptable[])
+{
+       struct sock *sk;
+       const struct iphdr *iph = ip_hdr(skb);
+
+       if (unlikely(sk = skb_steal_sock(skb)))
+               return sk;
+       else
+               return __udp4_lib_lookup(dev_net(skb->dst->dev), iph->saddr, sport,
+                                        iph->daddr, dport, inet_iif(skb),
+                                        udptable);
+}
+
+struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
+                            __be32 daddr, __be16 dport, int dif)
+{
+       return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif, udp_hash);
+}
+EXPORT_SYMBOL_GPL(udp4_lib_lookup);
+
+static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
                                             __be16 loc_port, __be32 loc_addr,
                                             __be16 rmt_port, __be32 rmt_addr,
                                             int dif)
@@ -314,7 +296,8 @@ static inline struct sock *udp_v4_mcast_next(struct sock *sk,
        sk_for_each_from(s, node) {
                struct inet_sock *inet = inet_sk(s);
 
-               if (s->sk_hash != hnum                                  ||
+               if (!net_eq(sock_net(s), net)                           ||
+                   s->sk_hash != hnum                                  ||
                    (inet->daddr && inet->daddr != rmt_addr)            ||
                    (inet->dport != rmt_port && inet->dport)            ||
                    (inet->rcv_saddr && inet->rcv_saddr != loc_addr)    ||
@@ -1097,15 +1080,16 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
        read_lock(&udp_hash_lock);
        sk = sk_head(&udptable[udp_hashfn(net, ntohs(uh->dest))]);
        dif = skb->dev->ifindex;
-       sk = udp_v4_mcast_next(sk, uh->dest, daddr, uh->source, saddr, dif);
+       sk = udp_v4_mcast_next(net, sk, uh->dest, daddr, uh->source, saddr, dif);
        if (sk) {
                struct sock *sknext = NULL;
 
                do {
                        struct sk_buff *skb1 = skb;
 
-                       sknext = udp_v4_mcast_next(sk_next(sk), uh->dest, daddr,
-                                                  uh->source, saddr, dif);
+                       sknext = udp_v4_mcast_next(net, sk_next(sk), uh->dest,
+                                                  daddr, uh->source, saddr,
+                                                  dif);
                        if (sknext)
                                skb1 = skb_clone(skb, GFP_ATOMIC);
 
@@ -1201,8 +1185,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct hlist_head udptable[],
                return __udp4_lib_mcast_deliver(net, skb, uh,
                                saddr, daddr, udptable);
 
-       sk = __udp4_lib_lookup(net, saddr, uh->source, daddr,
-                       uh->dest, inet_iif(skb), udptable);
+       sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
 
        if (sk != NULL) {
                int ret = udp_queue_rcv_skb(sk, skb);