diff options
| author | Raghuram Subramani <raghus2247@gmail.com> | 2024-10-17 17:33:46 +0530 | 
|---|---|---|
| committer | Raghuram Subramani <raghus2247@gmail.com> | 2024-10-17 17:33:46 +0530 | 
| commit | 321337c9e82f016a0cd64f81573c18b5731ffa8d (patch) | |
| tree | e9874bb042e851fec1e19bb8dfca694ef885456a /drivers/net/wireguard | |
| parent | cc57cb4ee3b7918b74d30604735d353b9a5fa23b (diff) | |
Merge remote-tracking branch 'msm8998/lineage-20' into lineage-20
Change-Id: I126075a330f305c85f8fe1b8c9d408f368be95d1
Diffstat (limited to 'drivers/net/wireguard')
19 files changed, 710 insertions, 407 deletions
| diff --git a/drivers/net/wireguard/compat/Makefile.include b/drivers/net/wireguard/compat/Makefile.include index 513dba444a37..df7670ae8d6c 100644 --- a/drivers/net/wireguard/compat/Makefile.include +++ b/drivers/net/wireguard/compat/Makefile.include @@ -6,11 +6,16 @@ kbuild-dir := $(if $(filter /%,$(src)),$(src),$(srctree)/$(src))  ccflags-y += -include $(kbuild-dir)/compat/compat.h  asflags-y += -include $(kbuild-dir)/compat/compat-asm.h +LINUXINCLUDE := -DCOMPAT_VERSION=$(VERSION) -DCOMPAT_PATCHLEVEL=$(PATCHLEVEL) -DCOMPAT_SUBLEVEL=$(SUBLEVEL) -I$(kbuild-dir)/compat/version $(LINUXINCLUDE)  ifeq ($(wildcard $(srctree)/include/linux/ptr_ring.h),)  ccflags-y += -I$(kbuild-dir)/compat/ptr_ring/include  endif +ifeq ($(wildcard $(srctree)/include/linux/skb_array.h),) +ccflags-y += -I$(kbuild-dir)/compat/skb_array/include +endif +  ifeq ($(wildcard $(srctree)/include/linux/siphash.h),)  ccflags-y += -I$(kbuild-dir)/compat/siphash/include  wireguard-y += compat/siphash/siphash.o @@ -64,6 +69,10 @@ ifeq ($(wildcard $(srctree)/arch/arm64/include/asm/neon.h)$(CONFIG_ARM64),y)  ccflags-y += -I$(kbuild-dir)/compat/neon-arm/include  endif +ifeq ($(wildcard $(srctree)/include/net/dst_metadata.h),) +ccflags-y += -I$(kbuild-dir)/compat/dstmetadata/include +endif +  ifeq ($(CONFIG_X86_64),y)  	ifeq ($(ssse3_instr),)  		ssse3_instr := $(call as-instr,pshufb %xmm0$(comma)%xmm0,-DCONFIG_AS_SSSE3=1) diff --git a/drivers/net/wireguard/compat/compat-asm.h b/drivers/net/wireguard/compat/compat-asm.h index fde21dabba4f..951fc1094470 100644 --- a/drivers/net/wireguard/compat/compat-asm.h +++ b/drivers/net/wireguard/compat/compat-asm.h @@ -15,14 +15,14 @@  #define ISRHEL7  #elif RHEL_MAJOR == 8  #define ISRHEL8 -#if RHEL_MINOR >= 4 +#if RHEL_MINOR >= 6  #define ISCENTOS8S  #endif  #endif  #endif  /* PaX compatibility */ -#if defined(RAP_PLUGIN) +#if defined(RAP_PLUGIN) && defined(RAP_ENTRY)  #undef ENTRY  #define ENTRY RAP_ENTRY  #endif @@ -51,7 +51,7 @@  #undef pull  #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 76) && !defined(ISCENTOS8S) +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 76) && !defined(ISRHEL8) && !defined(SYM_FUNC_START)  #define SYM_FUNC_START ENTRY  #define SYM_FUNC_END ENDPROC  #endif diff --git a/drivers/net/wireguard/compat/compat.h b/drivers/net/wireguard/compat/compat.h index 91d4388824ea..d166ac235a99 100644 --- a/drivers/net/wireguard/compat/compat.h +++ b/drivers/net/wireguard/compat/compat.h @@ -16,15 +16,13 @@  #define ISRHEL7  #elif RHEL_MAJOR == 8  #define ISRHEL8 -#if RHEL_MINOR >= 4 +#if RHEL_MINOR >= 6  #define ISCENTOS8S  #endif  #endif  #endif  #ifdef UTS_UBUNTU_RELEASE_ABI -#if LINUX_VERSION_CODE == KERNEL_VERSION(3, 13, 11) -#define ISUBUNTU1404 -#elif LINUX_VERSION_CODE < KERNEL_VERSION(4, 5, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 4, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 5, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 4, 0)  #define ISUBUNTU1604  #elif LINUX_VERSION_CODE < KERNEL_VERSION(4, 16, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0)  #define ISUBUNTU1804 @@ -219,7 +217,7 @@ static inline void skb_scrub_packet(struct sk_buff *skb, bool xnet)  #define skb_scrub_packet(a, b) skb_scrub_packet(a)  #endif -#if ((LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 63) || defined(ISUBUNTU1404)) && !defined(ISRHEL7) +#if ((LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 63)) && !defined(ISRHEL7)  #include <linux/random.h>  static inline u32 __compat_prandom_u32_max(u32 ep_ro)  { @@ -268,7 +266,7 @@ static inline u32 __compat_prandom_u32_max(u32 ep_ro)  #endif  #endif -#if (LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 3) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 16, 35) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 15, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 24) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0) && !defined(ISUBUNTU1404)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 33) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 11, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 10, 60) && !defined(ISRHEL7)) +#if (LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 3) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 16, 35) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 15, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 24) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 33) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 11, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 10, 60) && !defined(ISRHEL7))  static inline void memzero_explicit(void *s, size_t count)  {  	memset(s, 0, count); @@ -281,7 +279,7 @@ static const struct in6_addr __compat_in6addr_any = IN6ADDR_ANY_INIT;  #define in6addr_any __compat_in6addr_any  #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 13, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 13, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 320))  #include <linux/completion.h>  #include <linux/random.h>  #include <linux/errno.h> @@ -325,7 +323,7 @@ static inline int wait_for_random_bytes(void)  }  #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 19, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0) && !defined(ISRHEL8) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 19, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 14, 285)) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 320)) && !defined(ISRHEL8)  #include <linux/random.h>  #include <linux/slab.h>  struct rng_is_initialized_callback { @@ -377,7 +375,7 @@ static inline bool rng_is_initialized(void)  }  #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 13, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 13, 0) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 320))  static inline int get_random_bytes_wait(void *buf, int nbytes)  {  	int ret = wait_for_random_bytes(); @@ -502,7 +500,7 @@ static inline void *__compat_kvzalloc(size_t size, gfp_t flags)  #define kvzalloc __compat_kvzalloc  #endif -#if ((LINUX_VERSION_CODE < KERNEL_VERSION(3, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 41)) && !defined(ISUBUNTU1404) +#if ((LINUX_VERSION_CODE < KERNEL_VERSION(3, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 41))  #include <linux/vmalloc.h>  #include <linux/mm.h>  static inline void __compat_kvfree(const void *addr) @@ -515,6 +513,28 @@ static inline void __compat_kvfree(const void *addr)  #define kvfree __compat_kvfree  #endif +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 12, 0) +#include <linux/vmalloc.h> +#include <linux/mm.h> +static inline void *__compat_kvmalloc_array(size_t n, size_t size, gfp_t flags) +{ +	if (n != 0 && SIZE_MAX / n < size) +		return NULL; +	return kvmalloc(n * size, flags); +} +#define kvmalloc_array __compat_kvmalloc_array +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 18, 0) +#include <linux/vmalloc.h> +#include <linux/mm.h> +static inline void *__compat_kvcalloc(size_t n, size_t size, gfp_t flags) +{ +        return kvmalloc_array(n, size, flags | __GFP_ZERO); +} +#define kvcalloc __compat_kvcalloc +#endif +  #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 11, 9)  #include <linux/netdevice.h>  #define priv_destructor destructor @@ -704,7 +724,7 @@ static inline void *skb_put_data(struct sk_buff *skb, const void *data, unsigned  #endif  #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 17, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 17, 0) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 14, 285)) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 320))  static inline void le32_to_cpu_array(u32 *buf, unsigned int words)  {  	while (words--) { @@ -757,7 +777,7 @@ static inline void crypto_xor_cpy(u8 *dst, const u8 *src1, const u8 *src2,  #define hlist_add_behind(a, b) hlist_add_after(b, a)  #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 0, 0) && !defined(ISCENTOS8S) +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 0, 0) && !defined(ISRHEL8)  #define totalram_pages() totalram_pages  #endif @@ -831,10 +851,16 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb)  #endif  #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0) && !defined(ISRHEL8) +#include <net/netlink.h> +#ifndef NLA_POLICY_EXACT_LEN  #define NLA_POLICY_EXACT_LEN(_len) { .type = NLA_UNSPEC, .len = _len }  #endif +#endif  #if LINUX_VERSION_CODE < KERNEL_VERSION(5, 2, 0) && !defined(ISRHEL8) +#include <net/netlink.h> +#ifndef NLA_POLICY_MIN_LEN  #define NLA_POLICY_MIN_LEN(_len) { .type = NLA_UNSPEC, .len = _len } +#endif  #define COMPAT_CANNOT_INDIVIDUAL_NETLINK_OPS_POLICY  #endif @@ -849,7 +875,7 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb)  #endif  #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 0) && !defined(ISRHEL8)  #define genl_dumpit_info(cb) ({ \  	struct { struct nlattr **attrs; } *a = (void *)((u8 *)cb->args + offsetofend(struct dump_ctx, next_allowedip)); \  	BUILD_BUG_ON(sizeof(cb->args) < offsetofend(struct dump_ctx, next_allowedip) + sizeof(*a)); \ @@ -869,11 +895,13 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb)  #endif  #endif -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0) +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 200) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 249)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 14, 285)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 10, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 9, 320))  #define blake2s_init zinc_blake2s_init  #define blake2s_init_key zinc_blake2s_init_key  #define blake2s_update zinc_blake2s_update  #define blake2s_final zinc_blake2s_final +#endif +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0)  #define blake2s_hmac zinc_blake2s_hmac  #define chacha20 zinc_chacha20  #define hchacha20 zinc_hchacha20 @@ -1096,6 +1124,37 @@ static const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tun  #endif  #endif +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 16, 0) +#include <net/dst_cache.h> +struct dst_cache_pcpu { +	unsigned long refresh_ts; +	struct dst_entry *dst; +	u32 cookie; +	union { +		struct in_addr in_saddr; +		struct in6_addr in6_saddr; +	}; +}; +#define COMPAT_HAS_DEFINED_DST_CACHE_PCPU +static inline void dst_cache_reset_now(struct dst_cache *dst_cache) +{ +	int i; + +	if (!dst_cache->cache) +		return; + +	dst_cache->reset_ts = jiffies; +	for_each_possible_cpu(i) { +		struct dst_cache_pcpu *idst = per_cpu_ptr(dst_cache->cache, i); +		struct dst_entry *dst = idst->dst; + +		idst->cookie = 0; +		idst->dst = NULL; +		dst_release(dst); +	} +} +#endif +  #if defined(ISUBUNTU1604) || defined(ISRHEL7)  #include <linux/siphash.h>  #ifndef _WG_LINUX_SIPHASH_H @@ -1127,7 +1186,7 @@ static const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tun  #undef __read_mostly  #define __read_mostly  #endif -#if (defined(RAP_PLUGIN) || defined(CONFIG_CFI_CLANG)) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) +#if (defined(CONFIG_PAX) || defined(CONFIG_CFI_CLANG)) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0)  #include <linux/timer.h>  #define wg_expired_retransmit_handshake(a) wg_expired_retransmit_handshake(unsigned long timer)  #define wg_expired_send_keepalive(a) wg_expired_send_keepalive(unsigned long timer) diff --git a/drivers/net/wireguard/compat/dst_cache/dst_cache.c b/drivers/net/wireguard/compat/dst_cache/dst_cache.c index 7ec22f768a8f..f74c43c550eb 100644 --- a/drivers/net/wireguard/compat/dst_cache/dst_cache.c +++ b/drivers/net/wireguard/compat/dst_cache/dst_cache.c @@ -27,6 +27,7 @@ static inline u32 rt6_get_cookie(const struct rt6_info *rt)  #endif  #include <uapi/linux/in.h> +#ifndef COMPAT_HAS_DEFINED_DST_CACHE_PCPU  struct dst_cache_pcpu {  	unsigned long refresh_ts;  	struct dst_entry *dst; @@ -36,6 +37,7 @@ struct dst_cache_pcpu {  		struct in6_addr in6_saddr;  	};  }; +#endif  static void dst_cache_per_cpu_dst_set(struct dst_cache_pcpu *dst_cache,  				      struct dst_entry *dst, u32 cookie) diff --git a/drivers/net/wireguard/compat/dstmetadata/include/net/dst_metadata.h b/drivers/net/wireguard/compat/dstmetadata/include/net/dst_metadata.h new file mode 100644 index 000000000000..995094d4f099 --- /dev/null +++ b/drivers/net/wireguard/compat/dstmetadata/include/net/dst_metadata.h @@ -0,0 +1,3 @@ +#ifndef skb_valid_dst +#define skb_valid_dst(skb) (!!skb_dst(skb)) +#endif diff --git a/drivers/net/wireguard/compat/siphash/include/linux/siphash.h b/drivers/net/wireguard/compat/siphash/include/linux/siphash.h index 1e5e337d15bf..3b30b3c47778 100644 --- a/drivers/net/wireguard/compat/siphash/include/linux/siphash.h +++ b/drivers/net/wireguard/compat/siphash/include/linux/siphash.h @@ -22,9 +22,7 @@ typedef struct {  } siphash_key_t;  u64 __siphash_aligned(const void *data, size_t len, const siphash_key_t *key); -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS  u64 __siphash_unaligned(const void *data, size_t len, const siphash_key_t *key); -#endif  u64 siphash_1u64(const u64 a, const siphash_key_t *key);  u64 siphash_2u64(const u64 a, const u64 b, const siphash_key_t *key); @@ -77,10 +75,9 @@ static inline u64 ___siphash_aligned(const __le64 *data, size_t len,  static inline u64 siphash(const void *data, size_t len,  			  const siphash_key_t *key)  { -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS -	if (!IS_ALIGNED((unsigned long)data, SIPHASH_ALIGNMENT)) +	if (IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) || +	    !IS_ALIGNED((unsigned long)data, SIPHASH_ALIGNMENT))  		return __siphash_unaligned(data, len, key); -#endif  	return ___siphash_aligned(data, len, key);  } @@ -91,10 +88,8 @@ typedef struct {  u32 __hsiphash_aligned(const void *data, size_t len,  		       const hsiphash_key_t *key); -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS  u32 __hsiphash_unaligned(const void *data, size_t len,  			 const hsiphash_key_t *key); -#endif  u32 hsiphash_1u32(const u32 a, const hsiphash_key_t *key);  u32 hsiphash_2u32(const u32 a, const u32 b, const hsiphash_key_t *key); @@ -130,10 +125,9 @@ static inline u32 ___hsiphash_aligned(const __le32 *data, size_t len,  static inline u32 hsiphash(const void *data, size_t len,  			   const hsiphash_key_t *key)  { -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS -	if (!IS_ALIGNED((unsigned long)data, HSIPHASH_ALIGNMENT)) +	if (IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) || +	    !IS_ALIGNED((unsigned long)data, HSIPHASH_ALIGNMENT))  		return __hsiphash_unaligned(data, len, key); -#endif  	return ___hsiphash_aligned(data, len, key);  } diff --git a/drivers/net/wireguard/compat/siphash/siphash.c b/drivers/net/wireguard/compat/siphash/siphash.c index 58855328e6e0..7dc72cb4a710 100644 --- a/drivers/net/wireguard/compat/siphash/siphash.c +++ b/drivers/net/wireguard/compat/siphash/siphash.c @@ -57,6 +57,7 @@  	SIPROUND; \  	return (v0 ^ v1) ^ (v2 ^ v3); +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS  u64 __siphash_aligned(const void *data, size_t len, const siphash_key_t *key)  {  	const u8 *end = data + len - (len % sizeof(u64)); @@ -76,19 +77,19 @@ u64 __siphash_aligned(const void *data, size_t len, const siphash_key_t *key)  						  bytemask_from_count(left)));  #else  	switch (left) { -	case 7: b |= ((u64)end[6]) << 48; -	case 6: b |= ((u64)end[5]) << 40; -	case 5: b |= ((u64)end[4]) << 32; +	case 7: b |= ((u64)end[6]) << 48; fallthrough; +	case 6: b |= ((u64)end[5]) << 40; fallthrough; +	case 5: b |= ((u64)end[4]) << 32; fallthrough;  	case 4: b |= le32_to_cpup(data); break; -	case 3: b |= ((u64)end[2]) << 16; +	case 3: b |= ((u64)end[2]) << 16; fallthrough;  	case 2: b |= le16_to_cpup(data); break;  	case 1: b |= end[0];  	}  #endif  	POSTAMBLE  } +#endif -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS  u64 __siphash_unaligned(const void *data, size_t len, const siphash_key_t *key)  {  	const u8 *end = data + len - (len % sizeof(u64)); @@ -108,18 +109,17 @@ u64 __siphash_unaligned(const void *data, size_t len, const siphash_key_t *key)  						  bytemask_from_count(left)));  #else  	switch (left) { -	case 7: b |= ((u64)end[6]) << 48; -	case 6: b |= ((u64)end[5]) << 40; -	case 5: b |= ((u64)end[4]) << 32; +	case 7: b |= ((u64)end[6]) << 48; fallthrough; +	case 6: b |= ((u64)end[5]) << 40; fallthrough; +	case 5: b |= ((u64)end[4]) << 32; fallthrough;  	case 4: b |= get_unaligned_le32(end); break; -	case 3: b |= ((u64)end[2]) << 16; +	case 3: b |= ((u64)end[2]) << 16; fallthrough;  	case 2: b |= get_unaligned_le16(end); break;  	case 1: b |= end[0];  	}  #endif  	POSTAMBLE  } -#endif  /**   * siphash_1u64 - compute 64-bit siphash PRF value of a u64 @@ -250,6 +250,7 @@ u64 siphash_3u32(const u32 first, const u32 second, const u32 third,  	HSIPROUND; \  	return (v0 ^ v1) ^ (v2 ^ v3); +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS  u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key)  {  	const u8 *end = data + len - (len % sizeof(u64)); @@ -268,19 +269,19 @@ u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key)  						  bytemask_from_count(left)));  #else  	switch (left) { -	case 7: b |= ((u64)end[6]) << 48; -	case 6: b |= ((u64)end[5]) << 40; -	case 5: b |= ((u64)end[4]) << 32; +	case 7: b |= ((u64)end[6]) << 48; fallthrough; +	case 6: b |= ((u64)end[5]) << 40; fallthrough; +	case 5: b |= ((u64)end[4]) << 32; fallthrough;  	case 4: b |= le32_to_cpup(data); break; -	case 3: b |= ((u64)end[2]) << 16; +	case 3: b |= ((u64)end[2]) << 16; fallthrough;  	case 2: b |= le16_to_cpup(data); break;  	case 1: b |= end[0];  	}  #endif  	HPOSTAMBLE  } +#endif -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS  u32 __hsiphash_unaligned(const void *data, size_t len,  			 const hsiphash_key_t *key)  { @@ -300,18 +301,17 @@ u32 __hsiphash_unaligned(const void *data, size_t len,  						  bytemask_from_count(left)));  #else  	switch (left) { -	case 7: b |= ((u64)end[6]) << 48; -	case 6: b |= ((u64)end[5]) << 40; -	case 5: b |= ((u64)end[4]) << 32; +	case 7: b |= ((u64)end[6]) << 48; fallthrough; +	case 6: b |= ((u64)end[5]) << 40; fallthrough; +	case 5: b |= ((u64)end[4]) << 32; fallthrough;  	case 4: b |= get_unaligned_le32(end); break; -	case 3: b |= ((u64)end[2]) << 16; +	case 3: b |= ((u64)end[2]) << 16; fallthrough;  	case 2: b |= get_unaligned_le16(end); break;  	case 1: b |= end[0];  	}  #endif  	HPOSTAMBLE  } -#endif  /**   * hsiphash_1u32 - compute 64-bit hsiphash PRF value of a u32 @@ -412,6 +412,7 @@ u32 hsiphash_4u32(const u32 first, const u32 second, const u32 third,  	HSIPROUND; \  	return v1 ^ v3; +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS  u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key)  {  	const u8 *end = data + len - (len % sizeof(u32)); @@ -425,14 +426,14 @@ u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key)  		v0 ^= m;  	}  	switch (left) { -	case 3: b |= ((u32)end[2]) << 16; +	case 3: b |= ((u32)end[2]) << 16; fallthrough;  	case 2: b |= le16_to_cpup(data); break;  	case 1: b |= end[0];  	}  	HPOSTAMBLE  } +#endif -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS  u32 __hsiphash_unaligned(const void *data, size_t len,  			 const hsiphash_key_t *key)  { @@ -447,13 +448,12 @@ u32 __hsiphash_unaligned(const void *data, size_t len,  		v0 ^= m;  	}  	switch (left) { -	case 3: b |= ((u32)end[2]) << 16; +	case 3: b |= ((u32)end[2]) << 16; fallthrough;  	case 2: b |= get_unaligned_le16(end); break;  	case 1: b |= end[0];  	}  	HPOSTAMBLE  } -#endif  /**   * hsiphash_1u32 - compute 32-bit hsiphash PRF value of a u32 diff --git a/drivers/net/wireguard/compat/skb_array/include/linux/skb_array.h b/drivers/net/wireguard/compat/skb_array/include/linux/skb_array.h new file mode 100644 index 000000000000..c91fedcdbfc6 --- /dev/null +++ b/drivers/net/wireguard/compat/skb_array/include/linux/skb_array.h @@ -0,0 +1,11 @@ +#ifndef _WG_SKB_ARRAY_H +#define _WG_SKB_ARRAY_H + +#include <linux/skbuff.h> + +static void __skb_array_destroy_skb(void *ptr) +{ +	kfree_skb(ptr); +} + +#endif diff --git a/drivers/net/wireguard/compat/udp_tunnel/udp_tunnel.c b/drivers/net/wireguard/compat/udp_tunnel/udp_tunnel.c index 9b8770ae7b3f..d287b917be84 100644 --- a/drivers/net/wireguard/compat/udp_tunnel/udp_tunnel.c +++ b/drivers/net/wireguard/compat/udp_tunnel/udp_tunnel.c @@ -38,9 +38,10 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg,  	struct socket *sock = NULL;  	struct sockaddr_in udp_addr; -	err = __sock_create(net, AF_INET, SOCK_DGRAM, 0, &sock, 1); +	err = sock_create_kern(AF_INET, SOCK_DGRAM, 0, &sock);  	if (err < 0)  		goto error; +	sk_change_net(sock->sk, net);  	udp_addr.sin_family = AF_INET;  	udp_addr.sin_addr = cfg->local_ip; @@ -72,7 +73,7 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg,  error:  	if (sock) {  		kernel_sock_shutdown(sock, SHUT_RDWR); -		sock_release(sock); +		sk_release_kernel(sock->sk);  	}  	*sockp = NULL;  	return err; @@ -229,7 +230,7 @@ void udp_tunnel_sock_release(struct socket *sock)  {  	rcu_assign_sk_user_data(sock->sk, NULL);  	kernel_sock_shutdown(sock, SHUT_RDWR); -	sock_release(sock); +	sk_release_kernel(sock->sk);  }  #if IS_ENABLED(CONFIG_IPV6) @@ -254,9 +255,10 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg,  	int err;  	struct socket *sock = NULL; -	err = __sock_create(net, AF_INET6, SOCK_DGRAM, 0, &sock, 1); +	err = sock_create_kern(AF_INET6, SOCK_DGRAM, 0, &sock);  	if (err < 0)  		goto error; +	sk_change_net(sock->sk, net);  	if (cfg->ipv6_v6only) {  		int val = 1; @@ -301,7 +303,7 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg,  error:  	if (sock) {  		kernel_sock_shutdown(sock, SHUT_RDWR); -		sock_release(sock); +		sk_release_kernel(sock->sk);  	}  	*sockp = NULL;  	return err; diff --git a/drivers/net/wireguard/crypto/zinc/curve25519/curve25519-x86_64.c b/drivers/net/wireguard/crypto/zinc/curve25519/curve25519-x86_64.c index 79716c425b0c..8b6872a2f0d0 100644 --- a/drivers/net/wireguard/crypto/zinc/curve25519/curve25519-x86_64.c +++ b/drivers/net/wireguard/crypto/zinc/curve25519/curve25519-x86_64.c @@ -34,11 +34,11 @@ static inline u64 add_scalar(u64 *out, const u64 *f1, u64 f2)  	asm volatile(  		/* Clear registers to propagate the carry bit */ -		"  xor %%r8, %%r8;" -		"  xor %%r9, %%r9;" -		"  xor %%r10, %%r10;" -		"  xor %%r11, %%r11;" -		"  xor %1, %1;" +		"  xor %%r8d, %%r8d;" +		"  xor %%r9d, %%r9d;" +		"  xor %%r10d, %%r10d;" +		"  xor %%r11d, %%r11d;" +		"  xor %k1, %k1;"  		/* Begin addition chain */  		"  addq 0(%3), %0;" @@ -52,10 +52,9 @@ static inline u64 add_scalar(u64 *out, const u64 *f1, u64 f2)  		/* Return the carry bit in a register */  		"  adcx %%r11, %1;" -	: "+&r" (f2), "=&r" (carry_r) -	: "r" (out), "r" (f1) -	: "%r8", "%r9", "%r10", "%r11", "memory", "cc" -	); +		: "+&r"(f2), "=&r"(carry_r) +		: "r"(out), "r"(f1) +		: "%r8", "%r9", "%r10", "%r11", "memory", "cc");  	return carry_r;  } @@ -82,7 +81,7 @@ static inline void fadd(u64 *out, const u64 *f1, const u64 *f2)  		"  cmovc %0, %%rax;"  		/* Step 2: Add carry*38 to the original sum */ -		"  xor %%rcx, %%rcx;" +		"  xor %%ecx, %%ecx;"  		"  add %%rax, %%r8;"  		"  adcx %%rcx, %%r9;"  		"  movq %%r9, 8(%1);" @@ -96,17 +95,16 @@ static inline void fadd(u64 *out, const u64 *f1, const u64 *f2)  		"  cmovc %0, %%rax;"  		"  add %%rax, %%r8;"  		"  movq %%r8, 0(%1);" -	: "+&r" (f2) -	: "r" (out), "r" (f1) -	: "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "memory", "cc" -	); +		: "+&r"(f2) +		: "r"(out), "r"(f1) +		: "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "memory", "cc");  } -/* Computes the field substraction of two field elements */ +/* Computes the field subtraction of two field elements */  static inline void fsub(u64 *out, const u64 *f1, const u64 *f2)  {  	asm volatile( -		/* Compute the raw substraction of f1-f2 */ +		/* Compute the raw subtraction of f1-f2 */  		"  movq 0(%1), %%r8;"  		"  subq 0(%2), %%r8;"  		"  movq 8(%1), %%r9;" @@ -123,7 +121,7 @@ static inline void fsub(u64 *out, const u64 *f1, const u64 *f2)  		"  mov $38, %%rcx;"  		"  cmovc %%rcx, %%rax;" -		/* Step 2: Substract carry*38 from the original difference */ +		/* Step 2: Subtract carry*38 from the original difference */  		"  sub %%rax, %%r8;"  		"  sbb $0, %%r9;"  		"  sbb $0, %%r10;" @@ -139,10 +137,9 @@ static inline void fsub(u64 *out, const u64 *f1, const u64 *f2)  		"  movq %%r9, 8(%0);"  		"  movq %%r10, 16(%0);"  		"  movq %%r11, 24(%0);" -	: -	: "r" (out), "r" (f1), "r" (f2) -	: "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "memory", "cc" -	); +		: +		: "r"(out), "r"(f1), "r"(f2) +		: "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "memory", "cc");  }  /* Computes a field multiplication: out <- f1 * f2 @@ -150,239 +147,400 @@ static inline void fsub(u64 *out, const u64 *f1, const u64 *f2)  static inline void fmul(u64 *out, const u64 *f1, const u64 *f2, u64 *tmp)  {  	asm volatile( +  		/* Compute the raw multiplication: tmp <- src1 * src2 */  		/* Compute src1[0] * src2 */ -		"  movq 0(%1), %%rdx;" -		"  mulxq 0(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"     "  movq %%r8, 0(%0);" -		"  mulxq 8(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  movq %%r10, 8(%0);" -		"  mulxq 16(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;" -		"  mulxq 24(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;" +		"  movq 0(%0), %%rdx;" +		"  mulxq 0(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  movq %%r8, 0(%2);" +		"  mulxq 8(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  movq %%r10, 8(%2);" +		"  mulxq 16(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  mulxq 24(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +  		/* Compute src1[1] * src2 */ -		"  movq 8(%1), %%rdx;" -		"  mulxq 0(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"     "  adcxq 8(%0), %%r8;"    "  movq %%r8, 8(%0);" -		"  mulxq 8(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 16(%0);" -		"  mulxq 16(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  mov $0, %%r8;" -		"  mulxq 24(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;" +		"  movq 8(%0), %%rdx;" +		"  mulxq 0(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 8(%2), %%r8;" +		"  movq %%r8, 8(%2);" +		"  mulxq 8(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 16(%2);" +		"  mulxq 16(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  mov $0, %%r8;" +		"  mulxq 24(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +  		/* Compute src1[2] * src2 */ -		"  movq 16(%1), %%rdx;" -		"  mulxq 0(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"    "  adcxq 16(%0), %%r8;"    "  movq %%r8, 16(%0);" -		"  mulxq 8(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 24(%0);" -		"  mulxq 16(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  mov $0, %%r8;" -		"  mulxq 24(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;" +		"  movq 16(%0), %%rdx;" +		"  mulxq 0(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 16(%2), %%r8;" +		"  movq %%r8, 16(%2);" +		"  mulxq 8(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 24(%2);" +		"  mulxq 16(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  mov $0, %%r8;" +		"  mulxq 24(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +  		/* Compute src1[3] * src2 */ -		"  movq 24(%1), %%rdx;" -		"  mulxq 0(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"    "  adcxq 24(%0), %%r8;"    "  movq %%r8, 24(%0);" -		"  mulxq 8(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 32(%0);" -		"  mulxq 16(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  movq %%rbx, 40(%0);"    "  mov $0, %%r8;" -		"  mulxq 24(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  movq %%r14, 48(%0);"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;"     "  movq %%rax, 56(%0);" +		"  movq 24(%0), %%rdx;" +		"  mulxq 0(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 24(%2), %%r8;" +		"  movq %%r8, 24(%2);" +		"  mulxq 8(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 32(%2);" +		"  mulxq 16(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  movq %%rbx, 40(%2);" +		"  mov $0, %%r8;" +		"  mulxq 24(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  movq %%r14, 48(%2);" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +		"  movq %%rax, 56(%2);" +  		/* Line up pointers */ -		"  mov %0, %1;"  		"  mov %2, %0;" +		"  mov %3, %2;"  		/* Wrap the result back into the field */  		/* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */  		"  mov $38, %%rdx;" -		"  mulxq 32(%1), %%r8, %%r13;" -		"  xor %3, %3;" -		"  adoxq 0(%1), %%r8;" -		"  mulxq 40(%1), %%r9, %%rbx;" +		"  mulxq 32(%0), %%r8, %%r13;" +		"  xor %k1, %k1;" +		"  adoxq 0(%0), %%r8;" +		"  mulxq 40(%0), %%r9, %%rbx;"  		"  adcx %%r13, %%r9;" -		"  adoxq 8(%1), %%r9;" -		"  mulxq 48(%1), %%r10, %%r13;" +		"  adoxq 8(%0), %%r9;" +		"  mulxq 48(%0), %%r10, %%r13;"  		"  adcx %%rbx, %%r10;" -		"  adoxq 16(%1), %%r10;" -		"  mulxq 56(%1), %%r11, %%rax;" +		"  adoxq 16(%0), %%r10;" +		"  mulxq 56(%0), %%r11, %%rax;"  		"  adcx %%r13, %%r11;" -		"  adoxq 24(%1), %%r11;" -		"  adcx %3, %%rax;" -		"  adox %3, %%rax;" +		"  adoxq 24(%0), %%r11;" +		"  adcx %1, %%rax;" +		"  adox %1, %%rax;"  		"  imul %%rdx, %%rax;"  		/* Step 2: Fold the carry back into dst */  		"  add %%rax, %%r8;" -		"  adcx %3, %%r9;" -		"  movq %%r9, 8(%0);" -		"  adcx %3, %%r10;" -		"  movq %%r10, 16(%0);" -		"  adcx %3, %%r11;" -		"  movq %%r11, 24(%0);" +		"  adcx %1, %%r9;" +		"  movq %%r9, 8(%2);" +		"  adcx %1, %%r10;" +		"  movq %%r10, 16(%2);" +		"  adcx %1, %%r11;" +		"  movq %%r11, 24(%2);"  		/* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */  		"  mov $0, %%rax;"  		"  cmovc %%rdx, %%rax;"  		"  add %%rax, %%r8;" -		"  movq %%r8, 0(%0);" -	: "+&r" (tmp), "+&r" (f1), "+&r" (out), "+&r" (f2) -	: -	: "%rax", "%rdx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "%r14", "memory", "cc" -	); +		"  movq %%r8, 0(%2);" +		: "+&r"(f1), "+&r"(f2), "+&r"(tmp) +		: "r"(out) +		: "%rax", "%rbx", "%rdx", "%r8", "%r9", "%r10", "%r11", "%r13", +		  "%r14", "memory", "cc");  }  /* Computes two field multiplications: - * out[0] <- f1[0] * f2[0] - * out[1] <- f1[1] * f2[1] - * Uses the 16-element buffer tmp for intermediate results. */ + *   out[0] <- f1[0] * f2[0] + *   out[1] <- f1[1] * f2[1] + * Uses the 16-element buffer tmp for intermediate results: */  static inline void fmul2(u64 *out, const u64 *f1, const u64 *f2, u64 *tmp)  {  	asm volatile( +  		/* Compute the raw multiplication tmp[0] <- f1[0] * f2[0] */  		/* Compute src1[0] * src2 */ -		"  movq 0(%1), %%rdx;" -		"  mulxq 0(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"     "  movq %%r8, 0(%0);" -		"  mulxq 8(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  movq %%r10, 8(%0);" -		"  mulxq 16(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;" -		"  mulxq 24(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;" +		"  movq 0(%0), %%rdx;" +		"  mulxq 0(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  movq %%r8, 0(%2);" +		"  mulxq 8(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  movq %%r10, 8(%2);" +		"  mulxq 16(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  mulxq 24(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +  		/* Compute src1[1] * src2 */ -		"  movq 8(%1), %%rdx;" -		"  mulxq 0(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"     "  adcxq 8(%0), %%r8;"    "  movq %%r8, 8(%0);" -		"  mulxq 8(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 16(%0);" -		"  mulxq 16(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  mov $0, %%r8;" -		"  mulxq 24(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;" +		"  movq 8(%0), %%rdx;" +		"  mulxq 0(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 8(%2), %%r8;" +		"  movq %%r8, 8(%2);" +		"  mulxq 8(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 16(%2);" +		"  mulxq 16(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  mov $0, %%r8;" +		"  mulxq 24(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +  		/* Compute src1[2] * src2 */ -		"  movq 16(%1), %%rdx;" -		"  mulxq 0(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"    "  adcxq 16(%0), %%r8;"    "  movq %%r8, 16(%0);" -		"  mulxq 8(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 24(%0);" -		"  mulxq 16(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  mov $0, %%r8;" -		"  mulxq 24(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;" +		"  movq 16(%0), %%rdx;" +		"  mulxq 0(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 16(%2), %%r8;" +		"  movq %%r8, 16(%2);" +		"  mulxq 8(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 24(%2);" +		"  mulxq 16(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  mov $0, %%r8;" +		"  mulxq 24(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +  		/* Compute src1[3] * src2 */ -		"  movq 24(%1), %%rdx;" -		"  mulxq 0(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"    "  adcxq 24(%0), %%r8;"    "  movq %%r8, 24(%0);" -		"  mulxq 8(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 32(%0);" -		"  mulxq 16(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  movq %%rbx, 40(%0);"    "  mov $0, %%r8;" -		"  mulxq 24(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  movq %%r14, 48(%0);"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;"     "  movq %%rax, 56(%0);" +		"  movq 24(%0), %%rdx;" +		"  mulxq 0(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 24(%2), %%r8;" +		"  movq %%r8, 24(%2);" +		"  mulxq 8(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 32(%2);" +		"  mulxq 16(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  movq %%rbx, 40(%2);" +		"  mov $0, %%r8;" +		"  mulxq 24(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  movq %%r14, 48(%2);" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +		"  movq %%rax, 56(%2);"  		/* Compute the raw multiplication tmp[1] <- f1[1] * f2[1] */  		/* Compute src1[0] * src2 */ -		"  movq 32(%1), %%rdx;" -		"  mulxq 32(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"     "  movq %%r8, 64(%0);" -		"  mulxq 40(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  movq %%r10, 72(%0);" -		"  mulxq 48(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;" -		"  mulxq 56(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;" +		"  movq 32(%0), %%rdx;" +		"  mulxq 32(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  movq %%r8, 64(%2);" +		"  mulxq 40(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  movq %%r10, 72(%2);" +		"  mulxq 48(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  mulxq 56(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +  		/* Compute src1[1] * src2 */ -		"  movq 40(%1), %%rdx;" -		"  mulxq 32(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"     "  adcxq 72(%0), %%r8;"    "  movq %%r8, 72(%0);" -		"  mulxq 40(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 80(%0);" -		"  mulxq 48(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  mov $0, %%r8;" -		"  mulxq 56(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;" +		"  movq 40(%0), %%rdx;" +		"  mulxq 32(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 72(%2), %%r8;" +		"  movq %%r8, 72(%2);" +		"  mulxq 40(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 80(%2);" +		"  mulxq 48(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  mov $0, %%r8;" +		"  mulxq 56(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +  		/* Compute src1[2] * src2 */ -		"  movq 48(%1), %%rdx;" -		"  mulxq 32(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"    "  adcxq 80(%0), %%r8;"    "  movq %%r8, 80(%0);" -		"  mulxq 40(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 88(%0);" -		"  mulxq 48(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  mov $0, %%r8;" -		"  mulxq 56(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;" +		"  movq 48(%0), %%rdx;" +		"  mulxq 32(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 80(%2), %%r8;" +		"  movq %%r8, 80(%2);" +		"  mulxq 40(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 88(%2);" +		"  mulxq 48(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  mov $0, %%r8;" +		"  mulxq 56(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +  		/* Compute src1[3] * src2 */ -		"  movq 56(%1), %%rdx;" -		"  mulxq 32(%3), %%r8, %%r9;"       "  xor %%r10, %%r10;"    "  adcxq 88(%0), %%r8;"    "  movq %%r8, 88(%0);" -		"  mulxq 40(%3), %%r10, %%r11;"     "  adox %%r9, %%r10;"     "  adcx %%rbx, %%r10;"    "  movq %%r10, 96(%0);" -		"  mulxq 48(%3), %%rbx, %%r13;"    "  adox %%r11, %%rbx;"    "  adcx %%r14, %%rbx;"    "  movq %%rbx, 104(%0);"    "  mov $0, %%r8;" -		"  mulxq 56(%3), %%r14, %%rdx;"    "  adox %%r13, %%r14;"    "  adcx %%rax, %%r14;"    "  movq %%r14, 112(%0);"    "  mov $0, %%rax;" -		                                   "  adox %%rdx, %%rax;"    "  adcx %%r8, %%rax;"     "  movq %%rax, 120(%0);" +		"  movq 56(%0), %%rdx;" +		"  mulxq 32(%1), %%r8, %%r9;" +		"  xor %%r10d, %%r10d;" +		"  adcxq 88(%2), %%r8;" +		"  movq %%r8, 88(%2);" +		"  mulxq 40(%1), %%r10, %%r11;" +		"  adox %%r9, %%r10;" +		"  adcx %%rbx, %%r10;" +		"  movq %%r10, 96(%2);" +		"  mulxq 48(%1), %%rbx, %%r13;" +		"  adox %%r11, %%rbx;" +		"  adcx %%r14, %%rbx;" +		"  movq %%rbx, 104(%2);" +		"  mov $0, %%r8;" +		"  mulxq 56(%1), %%r14, %%rdx;" +		"  adox %%r13, %%r14;" +		"  adcx %%rax, %%r14;" +		"  movq %%r14, 112(%2);" +		"  mov $0, %%rax;" +		"  adox %%rdx, %%rax;" +		"  adcx %%r8, %%rax;" +		"  movq %%rax, 120(%2);" +  		/* Line up pointers */ -		"  mov %0, %1;"  		"  mov %2, %0;" +		"  mov %3, %2;"  		/* Wrap the results back into the field */  		/* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */  		"  mov $38, %%rdx;" -		"  mulxq 32(%1), %%r8, %%r13;" -		"  xor %3, %3;" -		"  adoxq 0(%1), %%r8;" -		"  mulxq 40(%1), %%r9, %%rbx;" +		"  mulxq 32(%0), %%r8, %%r13;" +		"  xor %k1, %k1;" +		"  adoxq 0(%0), %%r8;" +		"  mulxq 40(%0), %%r9, %%rbx;"  		"  adcx %%r13, %%r9;" -		"  adoxq 8(%1), %%r9;" -		"  mulxq 48(%1), %%r10, %%r13;" +		"  adoxq 8(%0), %%r9;" +		"  mulxq 48(%0), %%r10, %%r13;"  		"  adcx %%rbx, %%r10;" -		"  adoxq 16(%1), %%r10;" -		"  mulxq 56(%1), %%r11, %%rax;" +		"  adoxq 16(%0), %%r10;" +		"  mulxq 56(%0), %%r11, %%rax;"  		"  adcx %%r13, %%r11;" -		"  adoxq 24(%1), %%r11;" -		"  adcx %3, %%rax;" -		"  adox %3, %%rax;" +		"  adoxq 24(%0), %%r11;" +		"  adcx %1, %%rax;" +		"  adox %1, %%rax;"  		"  imul %%rdx, %%rax;"  		/* Step 2: Fold the carry back into dst */  		"  add %%rax, %%r8;" -		"  adcx %3, %%r9;" -		"  movq %%r9, 8(%0);" -		"  adcx %3, %%r10;" -		"  movq %%r10, 16(%0);" -		"  adcx %3, %%r11;" -		"  movq %%r11, 24(%0);" +		"  adcx %1, %%r9;" +		"  movq %%r9, 8(%2);" +		"  adcx %1, %%r10;" +		"  movq %%r10, 16(%2);" +		"  adcx %1, %%r11;" +		"  movq %%r11, 24(%2);"  		/* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */  		"  mov $0, %%rax;"  		"  cmovc %%rdx, %%rax;"  		"  add %%rax, %%r8;" -		"  movq %%r8, 0(%0);" +		"  movq %%r8, 0(%2);"  		/* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */  		"  mov $38, %%rdx;" -		"  mulxq 96(%1), %%r8, %%r13;" -		"  xor %3, %3;" -		"  adoxq 64(%1), %%r8;" -		"  mulxq 104(%1), %%r9, %%rbx;" +		"  mulxq 96(%0), %%r8, %%r13;" +		"  xor %k1, %k1;" +		"  adoxq 64(%0), %%r8;" +		"  mulxq 104(%0), %%r9, %%rbx;"  		"  adcx %%r13, %%r9;" -		"  adoxq 72(%1), %%r9;" -		"  mulxq 112(%1), %%r10, %%r13;" +		"  adoxq 72(%0), %%r9;" +		"  mulxq 112(%0), %%r10, %%r13;"  		"  adcx %%rbx, %%r10;" -		"  adoxq 80(%1), %%r10;" -		"  mulxq 120(%1), %%r11, %%rax;" +		"  adoxq 80(%0), %%r10;" +		"  mulxq 120(%0), %%r11, %%rax;"  		"  adcx %%r13, %%r11;" -		"  adoxq 88(%1), %%r11;" -		"  adcx %3, %%rax;" -		"  adox %3, %%rax;" +		"  adoxq 88(%0), %%r11;" +		"  adcx %1, %%rax;" +		"  adox %1, %%rax;"  		"  imul %%rdx, %%rax;"  		/* Step 2: Fold the carry back into dst */  		"  add %%rax, %%r8;" -		"  adcx %3, %%r9;" -		"  movq %%r9, 40(%0);" -		"  adcx %3, %%r10;" -		"  movq %%r10, 48(%0);" -		"  adcx %3, %%r11;" -		"  movq %%r11, 56(%0);" +		"  adcx %1, %%r9;" +		"  movq %%r9, 40(%2);" +		"  adcx %1, %%r10;" +		"  movq %%r10, 48(%2);" +		"  adcx %1, %%r11;" +		"  movq %%r11, 56(%2);"  		/* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */  		"  mov $0, %%rax;"  		"  cmovc %%rdx, %%rax;"  		"  add %%rax, %%r8;" -		"  movq %%r8, 32(%0);" -	: "+&r" (tmp), "+&r" (f1), "+&r" (out), "+&r" (f2) -	: -	: "%rax", "%rdx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "%r14", "memory", "cc" -	); +		"  movq %%r8, 32(%2);" +		: "+&r"(f1), "+&r"(f2), "+&r"(tmp) +		: "r"(out) +		: "%rax", "%rbx", "%rdx", "%r8", "%r9", "%r10", "%r11", "%r13", +		  "%r14", "memory", "cc");  } -/* Computes the field multiplication of four-element f1 with value in f2 */ +/* Computes the field multiplication of four-element f1 with value in f2 + * Requires f2 to be smaller than 2^17 */  static inline void fmul_scalar(u64 *out, const u64 *f1, u64 f2)  {  	register u64 f2_r asm("rdx") = f2;  	asm volatile(  		/* Compute the raw multiplication of f1*f2 */ -		"  mulxq 0(%2), %%r8, %%rcx;"      /* f1[0]*f2 */ -		"  mulxq 8(%2), %%r9, %%rbx;"      /* f1[1]*f2 */ +		"  mulxq 0(%2), %%r8, %%rcx;" /* f1[0]*f2 */ +		"  mulxq 8(%2), %%r9, %%rbx;" /* f1[1]*f2 */  		"  add %%rcx, %%r9;"  		"  mov $0, %%rcx;" -		"  mulxq 16(%2), %%r10, %%r13;"    /* f1[2]*f2 */ +		"  mulxq 16(%2), %%r10, %%r13;" /* f1[2]*f2 */  		"  adcx %%rbx, %%r10;" -		"  mulxq 24(%2), %%r11, %%rax;"    /* f1[3]*f2 */ +		"  mulxq 24(%2), %%r11, %%rax;" /* f1[3]*f2 */  		"  adcx %%r13, %%r11;"  		"  adcx %%rcx, %%rax;" @@ -406,17 +564,17 @@ static inline void fmul_scalar(u64 *out, const u64 *f1, u64 f2)  		"  cmovc %%rdx, %%rax;"  		"  add %%rax, %%r8;"  		"  movq %%r8, 0(%1);" -	: "+&r" (f2_r) -	: "r" (out), "r" (f1) -	: "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "memory", "cc" -	); +		: "+&r"(f2_r) +		: "r"(out), "r"(f1) +		: "%rax", "%rbx", "%rcx", "%r8", "%r9", "%r10", "%r11", "%r13", +		  "memory", "cc");  }  /* Computes p1 <- bit ? p2 : p1 in constant time */  static inline void cswap2(u64 bit, const u64 *p1, const u64 *p2)  {  	asm volatile( -		/* Invert the polarity of bit to match cmov expectations */ +		/* Transfer bit into CF flag */  		"  add $18446744073709551615, %0;"  		/* cswap p1[0], p2[0] */ @@ -490,10 +648,9 @@ static inline void cswap2(u64 bit, const u64 *p1, const u64 *p2)  		"  cmovc %%r10, %%r9;"  		"  movq %%r8, 56(%1);"  		"  movq %%r9, 56(%2);" -	: "+&r" (bit) -	: "r" (p1), "r" (p2) -	: "%r8", "%r9", "%r10", "memory", "cc" -	); +		: "+&r"(bit) +		: "r"(p1), "r"(p2) +		: "%r8", "%r9", "%r10", "memory", "cc");  }  /* Computes the square of a field element: out <- f * f @@ -504,18 +661,25 @@ static inline void fsqr(u64 *out, const u64 *f, u64 *tmp)  		/* Compute the raw multiplication: tmp <- f * f */  		/* Step 1: Compute all partial products */ -		"  movq 0(%1), %%rdx;"                                       /* f[0] */ -		"  mulxq 8(%1), %%r8, %%r14;"      "  xor %%r15, %%r15;"     /* f[1]*f[0] */ -		"  mulxq 16(%1), %%r9, %%r10;"     "  adcx %%r14, %%r9;"     /* f[2]*f[0] */ -		"  mulxq 24(%1), %%rax, %%rcx;"    "  adcx %%rax, %%r10;"    /* f[3]*f[0] */ -		"  movq 24(%1), %%rdx;"                                      /* f[3] */ -		"  mulxq 8(%1), %%r11, %%rbx;"     "  adcx %%rcx, %%r11;"    /* f[1]*f[3] */ -		"  mulxq 16(%1), %%rax, %%r13;"    "  adcx %%rax, %%rbx;"    /* f[2]*f[3] */ -		"  movq 8(%1), %%rdx;"             "  adcx %%r15, %%r13;"    /* f1 */ -		"  mulxq 16(%1), %%rax, %%rcx;"    "  mov $0, %%r14;"        /* f[2]*f[1] */ +		"  movq 0(%0), %%rdx;" /* f[0] */ +		"  mulxq 8(%0), %%r8, %%r14;" +		"  xor %%r15d, %%r15d;" /* f[1]*f[0] */ +		"  mulxq 16(%0), %%r9, %%r10;" +		"  adcx %%r14, %%r9;" /* f[2]*f[0] */ +		"  mulxq 24(%0), %%rax, %%rcx;" +		"  adcx %%rax, %%r10;" /* f[3]*f[0] */ +		"  movq 24(%0), %%rdx;" /* f[3] */ +		"  mulxq 8(%0), %%r11, %%rbx;" +		"  adcx %%rcx, %%r11;" /* f[1]*f[3] */ +		"  mulxq 16(%0), %%rax, %%r13;" +		"  adcx %%rax, %%rbx;" /* f[2]*f[3] */ +		"  movq 8(%0), %%rdx;" +		"  adcx %%r15, %%r13;" /* f1 */ +		"  mulxq 16(%0), %%rax, %%rcx;" +		"  mov $0, %%r14;" /* f[2]*f[1] */  		/* Step 2: Compute two parallel carry chains */ -		"  xor %%r15, %%r15;" +		"  xor %%r15d, %%r15d;"  		"  adox %%rax, %%r10;"  		"  adcx %%r8, %%r8;"  		"  adox %%rcx, %%r11;" @@ -530,39 +694,50 @@ static inline void fsqr(u64 *out, const u64 *f, u64 *tmp)  		"  adcx %%r14, %%r14;"  		/* Step 3: Compute intermediate squares */ -		"  movq 0(%1), %%rdx;"     "  mulx %%rdx, %%rax, %%rcx;"    /* f[0]^2 */ -		                           "  movq %%rax, 0(%0);" -		"  add %%rcx, %%r8;"       "  movq %%r8, 8(%0);" -		"  movq 8(%1), %%rdx;"     "  mulx %%rdx, %%rax, %%rcx;"    /* f[1]^2 */ -		"  adcx %%rax, %%r9;"      "  movq %%r9, 16(%0);" -		"  adcx %%rcx, %%r10;"     "  movq %%r10, 24(%0);" -		"  movq 16(%1), %%rdx;"    "  mulx %%rdx, %%rax, %%rcx;"    /* f[2]^2 */ -		"  adcx %%rax, %%r11;"     "  movq %%r11, 32(%0);" -		"  adcx %%rcx, %%rbx;"     "  movq %%rbx, 40(%0);" -		"  movq 24(%1), %%rdx;"    "  mulx %%rdx, %%rax, %%rcx;"    /* f[3]^2 */ -		"  adcx %%rax, %%r13;"     "  movq %%r13, 48(%0);" -		"  adcx %%rcx, %%r14;"     "  movq %%r14, 56(%0);" +		"  movq 0(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ +		"  movq %%rax, 0(%1);" +		"  add %%rcx, %%r8;" +		"  movq %%r8, 8(%1);" +		"  movq 8(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ +		"  adcx %%rax, %%r9;" +		"  movq %%r9, 16(%1);" +		"  adcx %%rcx, %%r10;" +		"  movq %%r10, 24(%1);" +		"  movq 16(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ +		"  adcx %%rax, %%r11;" +		"  movq %%r11, 32(%1);" +		"  adcx %%rcx, %%rbx;" +		"  movq %%rbx, 40(%1);" +		"  movq 24(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ +		"  adcx %%rax, %%r13;" +		"  movq %%r13, 48(%1);" +		"  adcx %%rcx, %%r14;" +		"  movq %%r14, 56(%1);"  		/* Line up pointers */ -		"  mov %0, %1;" -		"  mov %2, %0;" +		"  mov %1, %0;" +		"  mov %2, %1;"  		/* Wrap the result back into the field */  		/* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */  		"  mov $38, %%rdx;" -		"  mulxq 32(%1), %%r8, %%r13;" -		"  xor %%rcx, %%rcx;" -		"  adoxq 0(%1), %%r8;" -		"  mulxq 40(%1), %%r9, %%rbx;" +		"  mulxq 32(%0), %%r8, %%r13;" +		"  xor %%ecx, %%ecx;" +		"  adoxq 0(%0), %%r8;" +		"  mulxq 40(%0), %%r9, %%rbx;"  		"  adcx %%r13, %%r9;" -		"  adoxq 8(%1), %%r9;" -		"  mulxq 48(%1), %%r10, %%r13;" +		"  adoxq 8(%0), %%r9;" +		"  mulxq 48(%0), %%r10, %%r13;"  		"  adcx %%rbx, %%r10;" -		"  adoxq 16(%1), %%r10;" -		"  mulxq 56(%1), %%r11, %%rax;" +		"  adoxq 16(%0), %%r10;" +		"  mulxq 56(%0), %%r11, %%rax;"  		"  adcx %%r13, %%r11;" -		"  adoxq 24(%1), %%r11;" +		"  adoxq 24(%0), %%r11;"  		"  adcx %%rcx, %%rax;"  		"  adox %%rcx, %%rax;"  		"  imul %%rdx, %%rax;" @@ -570,43 +745,50 @@ static inline void fsqr(u64 *out, const u64 *f, u64 *tmp)  		/* Step 2: Fold the carry back into dst */  		"  add %%rax, %%r8;"  		"  adcx %%rcx, %%r9;" -		"  movq %%r9, 8(%0);" +		"  movq %%r9, 8(%1);"  		"  adcx %%rcx, %%r10;" -		"  movq %%r10, 16(%0);" +		"  movq %%r10, 16(%1);"  		"  adcx %%rcx, %%r11;" -		"  movq %%r11, 24(%0);" +		"  movq %%r11, 24(%1);"  		/* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */  		"  mov $0, %%rax;"  		"  cmovc %%rdx, %%rax;"  		"  add %%rax, %%r8;" -		"  movq %%r8, 0(%0);" -	: "+&r" (tmp), "+&r" (f), "+&r" (out) -	: -	: "%rax", "%rcx", "%rdx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "%r14", "%r15", "memory", "cc" -	); +		"  movq %%r8, 0(%1);" +		: "+&r,&r"(f), "+&r,&r"(tmp) +		: "r,m"(out) +		: "%rax", "%rbx", "%rcx", "%rdx", "%r8", "%r9", "%r10", "%r11", +		  "%r13", "%r14", "%r15", "memory", "cc");  }  /* Computes two field squarings: - * out[0] <- f[0] * f[0] - * out[1] <- f[1] * f[1] + *   out[0] <- f[0] * f[0] + *   out[1] <- f[1] * f[1]   * Uses the 16-element buffer tmp for intermediate results */  static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp)  {  	asm volatile(  		/* Step 1: Compute all partial products */ -		"  movq 0(%1), %%rdx;"                                       /* f[0] */ -		"  mulxq 8(%1), %%r8, %%r14;"      "  xor %%r15, %%r15;"     /* f[1]*f[0] */ -		"  mulxq 16(%1), %%r9, %%r10;"     "  adcx %%r14, %%r9;"     /* f[2]*f[0] */ -		"  mulxq 24(%1), %%rax, %%rcx;"    "  adcx %%rax, %%r10;"    /* f[3]*f[0] */ -		"  movq 24(%1), %%rdx;"                                      /* f[3] */ -		"  mulxq 8(%1), %%r11, %%rbx;"     "  adcx %%rcx, %%r11;"    /* f[1]*f[3] */ -		"  mulxq 16(%1), %%rax, %%r13;"    "  adcx %%rax, %%rbx;"    /* f[2]*f[3] */ -		"  movq 8(%1), %%rdx;"             "  adcx %%r15, %%r13;"    /* f1 */ -		"  mulxq 16(%1), %%rax, %%rcx;"    "  mov $0, %%r14;"        /* f[2]*f[1] */ +		"  movq 0(%0), %%rdx;" /* f[0] */ +		"  mulxq 8(%0), %%r8, %%r14;" +		"  xor %%r15d, %%r15d;" /* f[1]*f[0] */ +		"  mulxq 16(%0), %%r9, %%r10;" +		"  adcx %%r14, %%r9;" /* f[2]*f[0] */ +		"  mulxq 24(%0), %%rax, %%rcx;" +		"  adcx %%rax, %%r10;" /* f[3]*f[0] */ +		"  movq 24(%0), %%rdx;" /* f[3] */ +		"  mulxq 8(%0), %%r11, %%rbx;" +		"  adcx %%rcx, %%r11;" /* f[1]*f[3] */ +		"  mulxq 16(%0), %%rax, %%r13;" +		"  adcx %%rax, %%rbx;" /* f[2]*f[3] */ +		"  movq 8(%0), %%rdx;" +		"  adcx %%r15, %%r13;" /* f1 */ +		"  mulxq 16(%0), %%rax, %%rcx;" +		"  mov $0, %%r14;" /* f[2]*f[1] */  		/* Step 2: Compute two parallel carry chains */ -		"  xor %%r15, %%r15;" +		"  xor %%r15d, %%r15d;"  		"  adox %%rax, %%r10;"  		"  adcx %%r8, %%r8;"  		"  adox %%rcx, %%r11;" @@ -621,32 +803,50 @@ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp)  		"  adcx %%r14, %%r14;"  		/* Step 3: Compute intermediate squares */ -		"  movq 0(%1), %%rdx;"     "  mulx %%rdx, %%rax, %%rcx;"    /* f[0]^2 */ -		                           "  movq %%rax, 0(%0);" -		"  add %%rcx, %%r8;"       "  movq %%r8, 8(%0);" -		"  movq 8(%1), %%rdx;"     "  mulx %%rdx, %%rax, %%rcx;"    /* f[1]^2 */ -		"  adcx %%rax, %%r9;"      "  movq %%r9, 16(%0);" -		"  adcx %%rcx, %%r10;"     "  movq %%r10, 24(%0);" -		"  movq 16(%1), %%rdx;"    "  mulx %%rdx, %%rax, %%rcx;"    /* f[2]^2 */ -		"  adcx %%rax, %%r11;"     "  movq %%r11, 32(%0);" -		"  adcx %%rcx, %%rbx;"     "  movq %%rbx, 40(%0);" -		"  movq 24(%1), %%rdx;"    "  mulx %%rdx, %%rax, %%rcx;"    /* f[3]^2 */ -		"  adcx %%rax, %%r13;"     "  movq %%r13, 48(%0);" -		"  adcx %%rcx, %%r14;"     "  movq %%r14, 56(%0);" +		"  movq 0(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ +		"  movq %%rax, 0(%1);" +		"  add %%rcx, %%r8;" +		"  movq %%r8, 8(%1);" +		"  movq 8(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ +		"  adcx %%rax, %%r9;" +		"  movq %%r9, 16(%1);" +		"  adcx %%rcx, %%r10;" +		"  movq %%r10, 24(%1);" +		"  movq 16(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ +		"  adcx %%rax, %%r11;" +		"  movq %%r11, 32(%1);" +		"  adcx %%rcx, %%rbx;" +		"  movq %%rbx, 40(%1);" +		"  movq 24(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ +		"  adcx %%rax, %%r13;" +		"  movq %%r13, 48(%1);" +		"  adcx %%rcx, %%r14;" +		"  movq %%r14, 56(%1);"  		/* Step 1: Compute all partial products */ -		"  movq 32(%1), %%rdx;"                                       /* f[0] */ -		"  mulxq 40(%1), %%r8, %%r14;"      "  xor %%r15, %%r15;"     /* f[1]*f[0] */ -		"  mulxq 48(%1), %%r9, %%r10;"     "  adcx %%r14, %%r9;"     /* f[2]*f[0] */ -		"  mulxq 56(%1), %%rax, %%rcx;"    "  adcx %%rax, %%r10;"    /* f[3]*f[0] */ -		"  movq 56(%1), %%rdx;"                                      /* f[3] */ -		"  mulxq 40(%1), %%r11, %%rbx;"     "  adcx %%rcx, %%r11;"    /* f[1]*f[3] */ -		"  mulxq 48(%1), %%rax, %%r13;"    "  adcx %%rax, %%rbx;"    /* f[2]*f[3] */ -		"  movq 40(%1), %%rdx;"             "  adcx %%r15, %%r13;"    /* f1 */ -		"  mulxq 48(%1), %%rax, %%rcx;"    "  mov $0, %%r14;"        /* f[2]*f[1] */ +		"  movq 32(%0), %%rdx;" /* f[0] */ +		"  mulxq 40(%0), %%r8, %%r14;" +		"  xor %%r15d, %%r15d;" /* f[1]*f[0] */ +		"  mulxq 48(%0), %%r9, %%r10;" +		"  adcx %%r14, %%r9;" /* f[2]*f[0] */ +		"  mulxq 56(%0), %%rax, %%rcx;" +		"  adcx %%rax, %%r10;" /* f[3]*f[0] */ +		"  movq 56(%0), %%rdx;" /* f[3] */ +		"  mulxq 40(%0), %%r11, %%rbx;" +		"  adcx %%rcx, %%r11;" /* f[1]*f[3] */ +		"  mulxq 48(%0), %%rax, %%r13;" +		"  adcx %%rax, %%rbx;" /* f[2]*f[3] */ +		"  movq 40(%0), %%rdx;" +		"  adcx %%r15, %%r13;" /* f1 */ +		"  mulxq 48(%0), %%rax, %%rcx;" +		"  mov $0, %%r14;" /* f[2]*f[1] */  		/* Step 2: Compute two parallel carry chains */ -		"  xor %%r15, %%r15;" +		"  xor %%r15d, %%r15d;"  		"  adox %%rax, %%r10;"  		"  adcx %%r8, %%r8;"  		"  adox %%rcx, %%r11;" @@ -661,37 +861,48 @@ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp)  		"  adcx %%r14, %%r14;"  		/* Step 3: Compute intermediate squares */ -		"  movq 32(%1), %%rdx;"     "  mulx %%rdx, %%rax, %%rcx;"    /* f[0]^2 */ -		                           "  movq %%rax, 64(%0);" -		"  add %%rcx, %%r8;"       "  movq %%r8, 72(%0);" -		"  movq 40(%1), %%rdx;"     "  mulx %%rdx, %%rax, %%rcx;"    /* f[1]^2 */ -		"  adcx %%rax, %%r9;"      "  movq %%r9, 80(%0);" -		"  adcx %%rcx, %%r10;"     "  movq %%r10, 88(%0);" -		"  movq 48(%1), %%rdx;"    "  mulx %%rdx, %%rax, %%rcx;"    /* f[2]^2 */ -		"  adcx %%rax, %%r11;"     "  movq %%r11, 96(%0);" -		"  adcx %%rcx, %%rbx;"     "  movq %%rbx, 104(%0);" -		"  movq 56(%1), %%rdx;"    "  mulx %%rdx, %%rax, %%rcx;"    /* f[3]^2 */ -		"  adcx %%rax, %%r13;"     "  movq %%r13, 112(%0);" -		"  adcx %%rcx, %%r14;"     "  movq %%r14, 120(%0);" +		"  movq 32(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ +		"  movq %%rax, 64(%1);" +		"  add %%rcx, %%r8;" +		"  movq %%r8, 72(%1);" +		"  movq 40(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ +		"  adcx %%rax, %%r9;" +		"  movq %%r9, 80(%1);" +		"  adcx %%rcx, %%r10;" +		"  movq %%r10, 88(%1);" +		"  movq 48(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ +		"  adcx %%rax, %%r11;" +		"  movq %%r11, 96(%1);" +		"  adcx %%rcx, %%rbx;" +		"  movq %%rbx, 104(%1);" +		"  movq 56(%0), %%rdx;" +		"  mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ +		"  adcx %%rax, %%r13;" +		"  movq %%r13, 112(%1);" +		"  adcx %%rcx, %%r14;" +		"  movq %%r14, 120(%1);"  		/* Line up pointers */ -		"  mov %0, %1;" -		"  mov %2, %0;" +		"  mov %1, %0;" +		"  mov %2, %1;"  		/* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */  		"  mov $38, %%rdx;" -		"  mulxq 32(%1), %%r8, %%r13;" -		"  xor %%rcx, %%rcx;" -		"  adoxq 0(%1), %%r8;" -		"  mulxq 40(%1), %%r9, %%rbx;" +		"  mulxq 32(%0), %%r8, %%r13;" +		"  xor %%ecx, %%ecx;" +		"  adoxq 0(%0), %%r8;" +		"  mulxq 40(%0), %%r9, %%rbx;"  		"  adcx %%r13, %%r9;" -		"  adoxq 8(%1), %%r9;" -		"  mulxq 48(%1), %%r10, %%r13;" +		"  adoxq 8(%0), %%r9;" +		"  mulxq 48(%0), %%r10, %%r13;"  		"  adcx %%rbx, %%r10;" -		"  adoxq 16(%1), %%r10;" -		"  mulxq 56(%1), %%r11, %%rax;" +		"  adoxq 16(%0), %%r10;" +		"  mulxq 56(%0), %%r11, %%rax;"  		"  adcx %%r13, %%r11;" -		"  adoxq 24(%1), %%r11;" +		"  adoxq 24(%0), %%r11;"  		"  adcx %%rcx, %%rax;"  		"  adox %%rcx, %%rax;"  		"  imul %%rdx, %%rax;" @@ -699,32 +910,32 @@ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp)  		/* Step 2: Fold the carry back into dst */  		"  add %%rax, %%r8;"  		"  adcx %%rcx, %%r9;" -		"  movq %%r9, 8(%0);" +		"  movq %%r9, 8(%1);"  		"  adcx %%rcx, %%r10;" -		"  movq %%r10, 16(%0);" +		"  movq %%r10, 16(%1);"  		"  adcx %%rcx, %%r11;" -		"  movq %%r11, 24(%0);" +		"  movq %%r11, 24(%1);"  		/* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */  		"  mov $0, %%rax;"  		"  cmovc %%rdx, %%rax;"  		"  add %%rax, %%r8;" -		"  movq %%r8, 0(%0);" +		"  movq %%r8, 0(%1);"  		/* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */  		"  mov $38, %%rdx;" -		"  mulxq 96(%1), %%r8, %%r13;" -		"  xor %%rcx, %%rcx;" -		"  adoxq 64(%1), %%r8;" -		"  mulxq 104(%1), %%r9, %%rbx;" +		"  mulxq 96(%0), %%r8, %%r13;" +		"  xor %%ecx, %%ecx;" +		"  adoxq 64(%0), %%r8;" +		"  mulxq 104(%0), %%r9, %%rbx;"  		"  adcx %%r13, %%r9;" -		"  adoxq 72(%1), %%r9;" -		"  mulxq 112(%1), %%r10, %%r13;" +		"  adoxq 72(%0), %%r9;" +		"  mulxq 112(%0), %%r10, %%r13;"  		"  adcx %%rbx, %%r10;" -		"  adoxq 80(%1), %%r10;" -		"  mulxq 120(%1), %%r11, %%rax;" +		"  adoxq 80(%0), %%r10;" +		"  mulxq 120(%0), %%r11, %%rax;"  		"  adcx %%r13, %%r11;" -		"  adoxq 88(%1), %%r11;" +		"  adoxq 88(%0), %%r11;"  		"  adcx %%rcx, %%rax;"  		"  adox %%rcx, %%rax;"  		"  imul %%rdx, %%rax;" @@ -732,21 +943,21 @@ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp)  		/* Step 2: Fold the carry back into dst */  		"  add %%rax, %%r8;"  		"  adcx %%rcx, %%r9;" -		"  movq %%r9, 40(%0);" +		"  movq %%r9, 40(%1);"  		"  adcx %%rcx, %%r10;" -		"  movq %%r10, 48(%0);" +		"  movq %%r10, 48(%1);"  		"  adcx %%rcx, %%r11;" -		"  movq %%r11, 56(%0);" +		"  movq %%r11, 56(%1);"  		/* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */  		"  mov $0, %%rax;"  		"  cmovc %%rdx, %%rax;"  		"  add %%rax, %%r8;" -		"  movq %%r8, 32(%0);" -	: "+&r" (tmp), "+&r" (f), "+&r" (out) -	: -	: "%rax", "%rcx", "%rdx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "%r14", "%r15", "memory", "cc" -	); +		"  movq %%r8, 32(%1);" +		: "+&r,&r"(f), "+&r,&r"(tmp) +		: "r,m"(out) +		: "%rax", "%rbx", "%rcx", "%rdx", "%r8", "%r9", "%r10", "%r11", +		  "%r13", "%r14", "%r15", "memory", "cc");  }  static void point_add_and_double(u64 *q, u64 *p01_tmp1, u64 *tmp2) diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c index b8c2390b0a35..062490f1b8a7 100644 --- a/drivers/net/wireguard/device.c +++ b/drivers/net/wireguard/device.c @@ -19,6 +19,7 @@  #include <linux/if_arp.h>  #include <linux/icmp.h>  #include <linux/suspend.h> +#include <net/dst_metadata.h>  #include <net/icmp.h>  #include <net/rtnetlink.h>  #include <net/ip_tunnels.h> @@ -106,6 +107,7 @@ static int wg_stop(struct net_device *dev)  {  	struct wg_device *wg = netdev_priv(dev);  	struct wg_peer *peer; +	struct sk_buff *skb;  	mutex_lock(&wg->device_update_lock);  	list_for_each_entry(peer, &wg->peer_list, peer_list) { @@ -116,7 +118,9 @@ static int wg_stop(struct net_device *dev)  		wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);  	}  	mutex_unlock(&wg->device_update_lock); -	skb_queue_purge(&wg->incoming_handshakes); +	while ((skb = ptr_ring_consume(&wg->handshake_queue.ring)) != NULL) +		kfree_skb(skb); +	atomic_set(&wg->handshake_queue_len, 0);  	wg_socket_reinit(wg, NULL, NULL);  	return 0;  } @@ -157,7 +161,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev)  		goto err_peer;  	} -	mtu = skb_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu; +	mtu = skb_valid_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu;  	__skb_queue_head_init(&packets);  	if (!skb_is_gso(skb)) { @@ -243,14 +247,13 @@ static void wg_destruct(struct net_device *dev)  	destroy_workqueue(wg->handshake_receive_wq);  	destroy_workqueue(wg->handshake_send_wq);  	destroy_workqueue(wg->packet_crypt_wq); -	wg_packet_queue_free(&wg->decrypt_queue); -	wg_packet_queue_free(&wg->encrypt_queue); +	wg_packet_queue_free(&wg->handshake_queue, true); +	wg_packet_queue_free(&wg->decrypt_queue, false); +	wg_packet_queue_free(&wg->encrypt_queue, false);  	rcu_barrier(); /* Wait for all the peers to be actually freed. */  	wg_ratelimiter_uninit();  	memzero_explicit(&wg->static_identity, sizeof(wg->static_identity)); -	skb_queue_purge(&wg->incoming_handshakes);  	free_percpu(dev->tstats); -	free_percpu(wg->incoming_handshakes_worker);  	kvfree(wg->index_hashtable);  	kvfree(wg->peer_hashtable);  	mutex_unlock(&wg->device_update_lock); @@ -312,7 +315,6 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,  	init_rwsem(&wg->static_identity.lock);  	mutex_init(&wg->socket_update_lock);  	mutex_init(&wg->device_update_lock); -	skb_queue_head_init(&wg->incoming_handshakes);  	wg_allowedips_init(&wg->peer_allowedips);  	wg_cookie_checker_init(&wg->cookie_checker, wg);  	INIT_LIST_HEAD(&wg->peer_list); @@ -330,16 +332,10 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,  	if (!dev->tstats)  		goto err_free_index_hashtable; -	wg->incoming_handshakes_worker = -		wg_packet_percpu_multicore_worker_alloc( -				wg_packet_handshake_receive_worker, wg); -	if (!wg->incoming_handshakes_worker) -		goto err_free_tstats; -  	wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s",  			WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name);  	if (!wg->handshake_receive_wq) -		goto err_free_incoming_handshakes; +		goto err_free_tstats;  	wg->handshake_send_wq = alloc_workqueue("wg-kex-%s",  			WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); @@ -361,10 +357,15 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,  	if (ret < 0)  		goto err_free_encrypt_queue; -	ret = wg_ratelimiter_init(); +	ret = wg_packet_queue_init(&wg->handshake_queue, wg_packet_handshake_receive_worker, +				   MAX_QUEUED_INCOMING_HANDSHAKES);  	if (ret < 0)  		goto err_free_decrypt_queue; +	ret = wg_ratelimiter_init(); +	if (ret < 0) +		goto err_free_handshake_queue; +  	ret = register_netdevice(dev);  	if (ret < 0)  		goto err_uninit_ratelimiter; @@ -381,18 +382,18 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,  err_uninit_ratelimiter:  	wg_ratelimiter_uninit(); +err_free_handshake_queue: +	wg_packet_queue_free(&wg->handshake_queue, false);  err_free_decrypt_queue: -	wg_packet_queue_free(&wg->decrypt_queue); +	wg_packet_queue_free(&wg->decrypt_queue, false);  err_free_encrypt_queue: -	wg_packet_queue_free(&wg->encrypt_queue); +	wg_packet_queue_free(&wg->encrypt_queue, false);  err_destroy_packet_crypt:  	destroy_workqueue(wg->packet_crypt_wq);  err_destroy_handshake_send:  	destroy_workqueue(wg->handshake_send_wq);  err_destroy_handshake_receive:  	destroy_workqueue(wg->handshake_receive_wq); -err_free_incoming_handshakes: -	free_percpu(wg->incoming_handshakes_worker);  err_free_tstats:  	free_percpu(dev->tstats);  err_free_index_hashtable: @@ -412,6 +413,7 @@ static struct rtnl_link_ops link_ops __read_mostly = {  static void wg_netns_pre_exit(struct net *net)  {  	struct wg_device *wg; +	struct wg_peer *peer;  	rtnl_lock();  	list_for_each_entry(wg, &device_list, device_list) { @@ -421,6 +423,8 @@ static void wg_netns_pre_exit(struct net *net)  			mutex_lock(&wg->device_update_lock);  			rcu_assign_pointer(wg->creating_net, NULL);  			wg_socket_reinit(wg, NULL, NULL); +			list_for_each_entry(peer, &wg->peer_list, peer_list) +				wg_socket_clear_peer_endpoint_src(peer);  			mutex_unlock(&wg->device_update_lock);  		}  	} diff --git a/drivers/net/wireguard/device.h b/drivers/net/wireguard/device.h index 854bc3d97150..43c7cebbf50b 100644 --- a/drivers/net/wireguard/device.h +++ b/drivers/net/wireguard/device.h @@ -39,21 +39,18 @@ struct prev_queue {  struct wg_device {  	struct net_device *dev; -	struct crypt_queue encrypt_queue, decrypt_queue; +	struct crypt_queue encrypt_queue, decrypt_queue, handshake_queue;  	struct sock __rcu *sock4, *sock6;  	struct net __rcu *creating_net;  	struct noise_static_identity static_identity; -	struct workqueue_struct *handshake_receive_wq, *handshake_send_wq; -	struct workqueue_struct *packet_crypt_wq; -	struct sk_buff_head incoming_handshakes; -	int incoming_handshake_cpu; -	struct multicore_worker __percpu *incoming_handshakes_worker; +	struct workqueue_struct *packet_crypt_wq,*handshake_receive_wq, *handshake_send_wq;  	struct cookie_checker cookie_checker;  	struct pubkey_hashtable *peer_hashtable;  	struct index_hashtable *index_hashtable;  	struct allowedips peer_allowedips;  	struct mutex device_update_lock, socket_update_lock;  	struct list_head device_list, peer_list; +	atomic_t handshake_queue_len;  	unsigned int num_peers, device_update_gen;  	u32 fwmark;  	u16 incoming_port; diff --git a/drivers/net/wireguard/main.c b/drivers/net/wireguard/main.c index 9b8bbe27999e..a6714ce13a65 100644 --- a/drivers/net/wireguard/main.c +++ b/drivers/net/wireguard/main.c @@ -18,7 +18,7 @@  #include <linux/genetlink.h>  #include <net/rtnetlink.h> -static int __init mod_init(void) +static int __init wg_mod_init(void)  {  	int ret; @@ -66,7 +66,7 @@ err_allowedips:  	return ret;  } -static void __exit mod_exit(void) +static void __exit wg_mod_exit(void)  {  	wg_genetlink_uninit();  	wg_device_uninit(); @@ -74,8 +74,8 @@ static void __exit mod_exit(void)  	wg_allowedips_slab_uninit();  } -module_init(mod_init); -module_exit(mod_exit); +module_init(wg_mod_init); +module_exit(wg_mod_exit);  MODULE_LICENSE("GPL v2");  MODULE_DESCRIPTION("WireGuard secure network tunnel");  MODULE_AUTHOR("Jason A. Donenfeld <Jason@zx2c4.com>"); diff --git a/drivers/net/wireguard/queueing.c b/drivers/net/wireguard/queueing.c index 48e7b982a307..8084e7408c0a 100644 --- a/drivers/net/wireguard/queueing.c +++ b/drivers/net/wireguard/queueing.c @@ -4,6 +4,7 @@   */  #include "queueing.h" +#include <linux/skb_array.h>  struct multicore_worker __percpu *  wg_packet_percpu_multicore_worker_alloc(work_func_t function, void *ptr) @@ -38,11 +39,11 @@ int wg_packet_queue_init(struct crypt_queue *queue, work_func_t function,  	return 0;  } -void wg_packet_queue_free(struct crypt_queue *queue) +void wg_packet_queue_free(struct crypt_queue *queue, bool purge)  {  	free_percpu(queue->worker); -	WARN_ON(!__ptr_ring_empty(&queue->ring)); -	ptr_ring_cleanup(&queue->ring, NULL); +	WARN_ON(!purge && !__ptr_ring_empty(&queue->ring)); +	ptr_ring_cleanup(&queue->ring, purge ? __skb_array_destroy_skb : NULL);  }  #define NEXT(skb) ((skb)->prev) diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h index b6ccf650c738..03850c43ebaf 100644 --- a/drivers/net/wireguard/queueing.h +++ b/drivers/net/wireguard/queueing.h @@ -23,7 +23,7 @@ struct sk_buff;  /* queueing.c APIs: */  int wg_packet_queue_init(struct crypt_queue *queue, work_func_t function,  			 unsigned int len); -void wg_packet_queue_free(struct crypt_queue *queue); +void wg_packet_queue_free(struct crypt_queue *queue, bool purge);  struct multicore_worker __percpu *  wg_packet_percpu_multicore_worker_alloc(work_func_t function, void *ptr); diff --git a/drivers/net/wireguard/ratelimiter.c b/drivers/net/wireguard/ratelimiter.c index e33ec72a9642..ecee41f528a5 100644 --- a/drivers/net/wireguard/ratelimiter.c +++ b/drivers/net/wireguard/ratelimiter.c @@ -188,12 +188,12 @@ int wg_ratelimiter_init(void)  			(1U << 14) / sizeof(struct hlist_head)));  	max_entries = table_size * 8; -	table_v4 = kvzalloc(table_size * sizeof(*table_v4), GFP_KERNEL); +	table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);  	if (unlikely(!table_v4))  		goto err_kmemcache;  #if IS_ENABLED(CONFIG_IPV6) -	table_v6 = kvzalloc(table_size * sizeof(*table_v6), GFP_KERNEL); +	table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);  	if (unlikely(!table_v6)) {  		kvfree(table_v4);  		goto err_kmemcache; diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c index 07147ff0522d..214889edb48e 100644 --- a/drivers/net/wireguard/receive.c +++ b/drivers/net/wireguard/receive.c @@ -117,8 +117,8 @@ static void wg_receive_handshake_packet(struct wg_device *wg,  		return;  	} -	under_load = skb_queue_len(&wg->incoming_handshakes) >= -		     MAX_QUEUED_INCOMING_HANDSHAKES / 8; +	under_load = atomic_read(&wg->handshake_queue_len) >= +			MAX_QUEUED_INCOMING_HANDSHAKES / 8;  	if (under_load) {  		last_under_load = ktime_get_coarse_boottime_ns();  	} else if (last_under_load) { @@ -213,13 +213,14 @@ static void wg_receive_handshake_packet(struct wg_device *wg,  void wg_packet_handshake_receive_worker(struct work_struct *work)  { -	struct wg_device *wg = container_of(work, struct multicore_worker, -					    work)->ptr; +	struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; +	struct wg_device *wg = container_of(queue, struct wg_device, handshake_queue);  	struct sk_buff *skb; -	while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) { +	while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {  		wg_receive_handshake_packet(wg, skb);  		dev_kfree_skb(skb); +		atomic_dec(&wg->handshake_queue_len);  		cond_resched();  	}  } @@ -562,22 +563,28 @@ void wg_packet_receive(struct wg_device *wg, struct sk_buff *skb)  	case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION):  	case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE):  	case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): { -		int cpu; - -		if (skb_queue_len(&wg->incoming_handshakes) > -			    MAX_QUEUED_INCOMING_HANDSHAKES || -		    unlikely(!rng_is_initialized())) { +		int cpu, ret = -EBUSY; + +		if (unlikely(!rng_is_initialized())) +			goto drop; +		if (atomic_read(&wg->handshake_queue_len) > MAX_QUEUED_INCOMING_HANDSHAKES / 2) { +			if (spin_trylock_bh(&wg->handshake_queue.ring.producer_lock)) { +				ret = __ptr_ring_produce(&wg->handshake_queue.ring, skb); +				spin_unlock_bh(&wg->handshake_queue.ring.producer_lock); +			} +		} else +			ret = ptr_ring_produce_bh(&wg->handshake_queue.ring, skb); +		if (ret) { +	drop:  			net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n",  						wg->dev->name, skb);  			goto err;  		} -		skb_queue_tail(&wg->incoming_handshakes, skb); -		/* Queues up a call to packet_process_queued_handshake_ -		 * packets(skb): -		 */ -		cpu = wg_cpumask_next_online(&wg->incoming_handshake_cpu); +		atomic_inc(&wg->handshake_queue_len); +		cpu = wg_cpumask_next_online(&wg->handshake_queue.last_cpu); +		/* Queues up a call to packet_process_queued_handshake_packets(skb): */  		queue_work_on(cpu, wg->handshake_receive_wq, -			&per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work); +			      &per_cpu_ptr(wg->handshake_queue.worker, cpu)->work);  		break;  	}  	case cpu_to_le32(MESSAGE_DATA): diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c index 04739763e303..0414d7a6ce74 100644 --- a/drivers/net/wireguard/socket.c +++ b/drivers/net/wireguard/socket.c @@ -49,7 +49,7 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,  		rt = dst_cache_get_ip4(cache, &fl.saddr);  	if (!rt) { -		security_sk_classify_flow(sock, flowi4_to_flowi(&fl)); +		security_sk_classify_flow(sock, flowi4_to_flowi_common(&fl));  		if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0,  						fl.saddr, RT_SCOPE_HOST))) {  			endpoint->src4.s_addr = 0; @@ -129,7 +129,7 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,  		dst = dst_cache_get_ip6(cache, &fl.saddr);  	if (!dst) { -		security_sk_classify_flow(sock, flowi6_to_flowi(&fl)); +		security_sk_classify_flow(sock, flowi6_to_flowi_common(&fl));  		if (unlikely(!ipv6_addr_any(&fl.saddr) &&  			     !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) {  			endpoint->src6 = fl.saddr = in6addr_any; @@ -160,6 +160,7 @@ out:  	rcu_read_unlock_bh();  	return ret;  #else +	kfree_skb(skb);  	return -EAFNOSUPPORT;  #endif  } @@ -241,7 +242,7 @@ int wg_socket_endpoint_from_skb(struct endpoint *endpoint,  		endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr;  		endpoint->src4.s_addr = ip_hdr(skb)->daddr;  		endpoint->src_if4 = skb->skb_iif; -	} else if (skb->protocol == htons(ETH_P_IPV6)) { +	} else if (IS_ENABLED(CONFIG_IPV6) && skb->protocol == htons(ETH_P_IPV6)) {  		endpoint->addr6.sin6_family = AF_INET6;  		endpoint->addr6.sin6_port = udp_hdr(skb)->source;  		endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr; @@ -284,7 +285,7 @@ void wg_socket_set_peer_endpoint(struct wg_peer *peer,  		peer->endpoint.addr4 = endpoint->addr4;  		peer->endpoint.src4 = endpoint->src4;  		peer->endpoint.src_if4 = endpoint->src_if4; -	} else if (endpoint->addr.sa_family == AF_INET6) { +	} else if (IS_ENABLED(CONFIG_IPV6) && endpoint->addr.sa_family == AF_INET6) {  		peer->endpoint.addr6 = endpoint->addr6;  		peer->endpoint.src6 = endpoint->src6;  	} else { @@ -308,7 +309,7 @@ void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer)  {  	write_lock_bh(&peer->endpoint_lock);  	memset(&peer->endpoint.src6, 0, sizeof(peer->endpoint.src6)); -	dst_cache_reset(&peer->endpoint_cache); +	dst_cache_reset_now(&peer->endpoint_cache);  	write_unlock_bh(&peer->endpoint_lock);  } diff --git a/drivers/net/wireguard/version.h b/drivers/net/wireguard/version.h index 35ef576765c9..c7f9028f0177 100644 --- a/drivers/net/wireguard/version.h +++ b/drivers/net/wireguard/version.h @@ -1 +1,3 @@ -#define WIREGUARD_VERSION "1.0.20210606" +#ifndef WIREGUARD_VERSION +#define WIREGUARD_VERSION "1.0.20220627" +#endif | 
