diff options
| author | Raghuram Subramani <raghus2247@gmail.com> | 2024-10-13 13:31:01 +0530 |
|---|---|---|
| committer | Raghuram Subramani <raghus2247@gmail.com> | 2024-10-13 13:31:01 +0530 |
| commit | 444876b4151429efeb64f91bbf5a5d7059484f8a (patch) | |
| tree | 92d81b12bfa532f17954fbb9d4a5c0df50ae7bca | |
| parent | 64568b6872b018e604423d683df261f3caaee3ca (diff) | |
Bpf patch
bpf, arm64: use more scalable stadd over ldxr / stxr loop in xadd
bpf, arm64: remove prefetch insn in xadd mapping
bpf, arm64: use separate register for state in stxr
bpf, arm64: implement jiting of BPF_J{LT, LE, SLT, SLE}
bpf, arm64: implement jiting of BPF_XADD
bpf: add BPF_J{LT,LE,SLT,SLE} instructions
lib/test_bpf: Add tests for unsigned BPF_JGT
arm64: fix endianness annotation for 'struct jit_ctx' and friends
bpf: free up BPF_JMP | BPF_CALL | BPF_X opcode
bpf: remove stubs for cBPF from arch code
bpf: split HAVE_BPF_JIT into cBPF and eBPF variant
{nl,mac}80211: add rssi to mesh candidates
mac80211: mesh: drop new node with weak power
wifi: cfg80211: avoid leaking stack data into trace
UPSTREAM: netpoll: Fix device name check in netpoll_setup()
tracing: Avoid adding tracer option before update_tracer_options
sched_getaffinity: don't assume 'cpumask_size()' is fully initialized
thread_info: Remove superflous struct decls
USB: core: Prevent nested device-reset calls
USB: core: Don't hold device lock while reading the "descriptors" sysfs file
tty: fix deadlock caused by calling printk() under tty_port->lock
flow: fix object-size-mismatch warning in flowi{4,6}_to_flowi_common()
lsm,selinux: pass flowi_common instead of flowi to the LSM hooks
leds: leds-qpnp: Fix uninitialized local variable
qcacld-3.0: Avoid possible array OOB
ASoC: msm-pcm-q6-v2: Add dsp buf check
asoc: Update copy_to_user to requested buffer size
asoc: msm-pcm-q6-v2: Update memset for period size
asoc: Reset the buffer if size is partial or zero
msm: adsprpc: Handle UAF in fastrpc internal munmap
msm: adsprpc: Handle UAF in fastrpc debugfs read
msm: adsprpc: Add missing spin_lock in `fastrpc_debugfs_read`
msm: ADSPRPC: Protect global remote heap maps
msm: adsprpc: Avoid race condition during map creation and free
adsprpc: update mmap list nodes before mmap free
sched: deadline: Add missing WALT code
sched: Reinstantiate EAS check_for_migration() implementation
sched: Remove left-over CPU-query from __migrate_task
BACKPORT: net: ipv6: Fix processing of RAs in presence of VRF
wifi: cfg80211: Fix use after free for wext
cfg80211: allow connect keys only with default (TX) key
nl80211: Update bss channel on channel switch for P2P_CLIENT
ALSA: oss: Fix potential deadlock at unregistration
Revert "ALSA: rawmidi: Fix racy buffer resize under concurrent accesses"
ALSA: rawmidi: Drop register_mutex in snd_rawmidi_free()
ALSA: rawmidi: Avoid OOB access to runtime buffer
HID: check empty report_list in hid_validate_values()
HID: core: Provide new max_buffer_size attribute to over-ride the default
HID: core: fix shift-out-of-bounds in hid_report_raw_event
tty: use new tty_insert_flip_string_and_push_buffer() in pty_write()
tty: extract tty_flip_buffer_commit() from tty_flip_buffer_push()
tracing: Fix infinite loop in tracing_read_pipe on overflowed print_trace_line
tracing: Fix memleak due to race between current_tracer and trace
tracing: Ensure trace buffer is at least 4096 bytes large
tracing: Fix tp_printk option related with tp_printk_stop_on_boot
blktrace: Fix output non-blktrace event when blk_classic option enabled
msm: kgsl: Prevent wrap around during user address mapping
iommu: Fix missing return check of arm_lpae_init_pte
q6asm: validate payload size before access
dsp: afe: Add check for sidetone iir config copy size.
q6core: Avoid OOB access in q6core
q6voice: Add buf size check for cvs cal data.
ASoC: msm-pcm-host-voice: Handle OOB access in hpcm_start.
Asoc: check for invalid voice session id
kconfig: display recursive dependency resolution hint just once
wireguard: version: bump
compat: handle backported rng and blake2s
qemu: set panic_on_warn=1 from cmdline
qemu: use vports on arm
device: check for metadata_dst with skb_valid_dst()
qemu: enable ACPI for SMP
socket: ignore v6 endpoints when ipv6 is disabled
socket: free skb in send6 when ipv6 is disabled
queueing: use CFI-safe ptr_ring cleanup function
crypto: curve25519-x86_64: use in/out register constraints more precisely
compat: drop Ubuntu 14.04
fixup! compat: redefine version constants for sublevel>=256
wireguard: version: bump
Makefile: strip prefixed v from version.h
crypto: curve25519-x86_64: solve register constraints with reserved registers
compat: udp_tunnel: don't take reference to non-init namespace
compat: siphash: use _unaligned version by default
ratelimiter: use kvcalloc() instead of kvzalloc()
receive: drop handshakes if queue lock is contended
receive: use ring buffer for incoming handshakes
device: reset peer src endpoint when netns exits
main: rename 'mod_init' & 'mod_exit' functions to be module-specific
netns: actually test for routing loops
compat: update for RHEL 8.5
compat: account for grsecurity backports and changes
compat: account for latest c8s backports
29 files changed, 878 insertions, 479 deletions
diff --git a/drivers/char/adsprpc.c b/drivers/char/adsprpc.c index 69bfaa0bc6f4..1662707f0d1a 100644 --- a/drivers/char/adsprpc.c +++ b/drivers/char/adsprpc.c @@ -1,5 +1,6 @@ /* * Copyright (c) 2012-2021, The Linux Foundation. All rights reserved. + * Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved. * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 and @@ -298,7 +299,8 @@ struct fastrpc_mmap { int uncached; int secure; uintptr_t attr; - bool is_filemap; /*flag to indicate map used in process init*/ + bool is_filemap; /* flag to indicate map used in process init */ + unsigned int ctx_refs; /* Indicates reference count for context map */ }; struct fastrpc_perf { @@ -334,6 +336,7 @@ struct fastrpc_file { struct fastrpc_perf perf; struct dentry *debugfs_file; struct mutex map_mutex; + struct mutex internal_map_mutex; char *debug_buf; /* Identifies the device (MINOR_NUM_DEV / MINOR_NUM_SECURE_DEV) */ int dev_minor; @@ -473,9 +476,7 @@ static void fastrpc_mmap_add(struct fastrpc_mmap *map) } else { struct fastrpc_file *fl = map->fl; - spin_lock(&fl->hlock); hlist_add_head(&map->hn, &fl->maps); - spin_unlock(&fl->hlock); } } @@ -506,21 +507,17 @@ static int fastrpc_mmap_find(struct fastrpc_file *fl, int fd, uintptr_t va, } spin_unlock(&me->hlock); } else { - spin_lock(&fl->hlock); hlist_for_each_entry_safe(map, n, &fl->maps, hn) { if (va >= map->va && va + len <= map->va + map->len && map->fd == fd) { - if (map->refs + 1 == INT_MAX) { - spin_unlock(&fl->hlock); + if (map->refs + 1 == INT_MAX) return -ETOOMANYREFS; - } map->refs++; match = map; break; } } - spin_unlock(&fl->hlock); } if (match) { *ppmap = match; @@ -560,7 +557,7 @@ static int fastrpc_mmap_remove(struct fastrpc_file *fl, uintptr_t va, hlist_for_each_entry_safe(map, n, &me->maps, hn) { if (map->refs == 1 && map->raddr == va && map->raddr + map->len == va + len && - /*Remove map if not used in process initialization*/ + /* Remove map if not used in process initialization */ !map->is_filemap) { match = map; hlist_del_init(&map->hn); @@ -572,18 +569,17 @@ static int fastrpc_mmap_remove(struct fastrpc_file *fl, uintptr_t va, *ppmap = match; return 0; } - spin_lock(&fl->hlock); hlist_for_each_entry_safe(map, n, &fl->maps, hn) { - if (map->refs == 1 && map->raddr == va && + /* Remove if only one reference map and no context map */ + if (map->refs == 1 && !map->ctx_refs && map->raddr == va && map->raddr + map->len == va + len && - /*Remove map if not used in process initialization*/ + /* Remove map if not used in process initialization */ !map->is_filemap) { match = map; hlist_del_init(&map->hn); break; } } - spin_unlock(&fl->hlock); if (match) { *ppmap = match; return 0; @@ -614,17 +610,13 @@ static void fastrpc_mmap_free(struct fastrpc_mmap *map) } if (map->flags == ADSP_MMAP_HEAP_ADDR || map->flags == ADSP_MMAP_REMOTE_HEAP_ADDR) { - spin_lock(&me->hlock); map->refs--; - if (!map->refs) + if (!map->refs && !map->ctx_refs) hlist_del_init(&map->hn); - spin_unlock(&me->hlock); } else { - spin_lock(&fl->hlock); map->refs--; - if (!map->refs) + if (!map->refs && !map->ctx_refs) hlist_del_init(&map->hn); - spin_unlock(&fl->hlock); } if (map->refs > 0) return; @@ -716,6 +708,7 @@ static int fastrpc_mmap_create(struct fastrpc_file *fl, int fd, unsigned attr, map->fd = fd; map->attr = attr; map->is_filemap = false; + map->ctx_refs = 0; if (mflags == ADSP_MMAP_HEAP_ADDR || mflags == ADSP_MMAP_REMOTE_HEAP_ADDR) { DEFINE_DMA_ATTRS(rh_attrs); @@ -1133,8 +1126,13 @@ static void context_free(struct smq_invoke_ctx *ctx) spin_lock(&ctx->fl->hlock); hlist_del_init(&ctx->hn); spin_unlock(&ctx->fl->hlock); - for (i = 0; i < nbufs; ++i) + mutex_lock(&ctx->fl->map_mutex); + for (i = 0; i < nbufs; ++i) { + if (ctx->maps[i] && ctx->maps[i]->ctx_refs) + ctx->maps[i]->ctx_refs--; fastrpc_mmap_free(ctx->maps[i]); + } + mutex_unlock(&ctx->fl->map_mutex); fastrpc_buf_free(ctx->buf, 1); fastrpc_buf_free(ctx->lbuf, 1); ctx->magic = 0; @@ -1274,10 +1272,14 @@ static int get_args(uint32_t kernel, struct smq_invoke_ctx *ctx) uintptr_t buf = (uintptr_t)lpra[i].buf.pv; size_t len = lpra[i].buf.len; + mutex_lock(&ctx->fl->map_mutex); if (ctx->fds[i] && (ctx->fds[i] != -1)) fastrpc_mmap_create(ctx->fl, ctx->fds[i], ctx->attrs[i], buf, len, mflags, &ctx->maps[i]); + if (ctx->maps[i]) + ctx->maps[i]->ctx_refs++; + mutex_unlock(&ctx->fl->map_mutex); ipage += 1; } metalen = copylen = (size_t)&ipage[0]; @@ -1494,7 +1496,11 @@ static int put_args(uint32_t kernel, struct smq_invoke_ctx *ctx, if (err) goto bail; } else { + mutex_lock(&ctx->fl->map_mutex); + if (ctx->maps[i]->ctx_refs) + ctx->maps[i]->ctx_refs--; fastrpc_mmap_free(ctx->maps[i]); + mutex_unlock(&ctx->fl->map_mutex); ctx->maps[i] = NULL; } } @@ -1903,10 +1909,12 @@ static int fastrpc_init_process(struct fastrpc_file *fl, init->filelen)) goto bail; if (init->filelen) { + mutex_lock(&fl->map_mutex); VERIFY(err, !fastrpc_mmap_create(fl, init->filefd, 0, init->file, init->filelen, mflags, &file)); if (file) file->is_filemap = true; + mutex_unlock(&fl->map_mutex); if (err) goto bail; } @@ -1996,9 +2004,11 @@ static int fastrpc_init_process(struct fastrpc_file *fl, inbuf.pageslen = 0; if (!me->staticpd_flags) { inbuf.pageslen = 1; + mutex_lock(&fl->map_mutex); VERIFY(err, !fastrpc_mmap_create(fl, -1, 0, init->mem, init->memlen, ADSP_MMAP_REMOTE_HEAP_ADDR, &mem)); + mutex_unlock(&fl->map_mutex); if (err) goto bail; phys = mem->phys; @@ -2050,10 +2060,15 @@ bail: if (mem->flags == ADSP_MMAP_REMOTE_HEAP_ADDR) hyp_assign_phys(mem->phys, (uint64_t)mem->size, destVM, 1, srcVM, hlosVMperm, 1); + mutex_lock(&fl->map_mutex); fastrpc_mmap_free(mem); + mutex_unlock(&fl->map_mutex); } - if (file) + if (file) { + mutex_lock(&fl->map_mutex); fastrpc_mmap_free(file); + mutex_unlock(&fl->map_mutex); + } return err; } @@ -2309,7 +2324,7 @@ static int fastrpc_internal_munmap(struct fastrpc_file *fl, struct fastrpc_buf *rbuf = NULL, *free = NULL; struct hlist_node *n; - mutex_lock(&fl->map_mutex); + mutex_lock(&fl->internal_map_mutex); spin_lock(&fl->hlock); hlist_for_each_entry_safe(rbuf, n, &fl->remote_bufs, hn_rem) { if (rbuf->raddr && (rbuf->flags == ADSP_MMAP_ADD_PAGES)) { @@ -2328,11 +2343,13 @@ static int fastrpc_internal_munmap(struct fastrpc_file *fl, if (err) goto bail; fastrpc_buf_free(rbuf, 0); - mutex_unlock(&fl->map_mutex); + mutex_unlock(&fl->internal_map_mutex); return err; } + mutex_lock(&fl->map_mutex); VERIFY(err, !fastrpc_mmap_remove(fl, ud->vaddrout, ud->size, &map)); + mutex_unlock(&fl->map_mutex); if (err) goto bail; if (map) { @@ -2340,12 +2357,17 @@ static int fastrpc_internal_munmap(struct fastrpc_file *fl, map->phys, map->size, map->flags)); if (err) goto bail; + mutex_lock(&fl->map_mutex); fastrpc_mmap_free(map); + mutex_unlock(&fl->map_mutex); } bail: - if (err && map) + if (err && map) { + mutex_lock(&fl->map_mutex); fastrpc_mmap_add(map); - mutex_unlock(&fl->map_mutex); + mutex_unlock(&fl->map_mutex); + } + mutex_unlock(&fl->internal_map_mutex); return err; } @@ -2358,7 +2380,7 @@ static int fastrpc_internal_mmap(struct fastrpc_file *fl, uintptr_t raddr = 0; int err = 0; - mutex_lock(&fl->map_mutex); + mutex_lock(&fl->internal_map_mutex); if (ud->flags == ADSP_MMAP_ADD_PAGES) { DEFINE_DMA_ATTRS(dma_attr); @@ -2385,9 +2407,11 @@ static int fastrpc_internal_mmap(struct fastrpc_file *fl, } else { uintptr_t va_to_dsp; + mutex_lock(&fl->map_mutex); VERIFY(err, !fastrpc_mmap_create(fl, ud->fd, 0, (uintptr_t)ud->vaddrin, ud->size, ud->flags, &map)); + mutex_unlock(&fl->map_mutex); if (err) goto bail; @@ -2404,9 +2428,16 @@ static int fastrpc_internal_mmap(struct fastrpc_file *fl, } ud->vaddrout = raddr; bail: - if (err && map) - fastrpc_mmap_free(map); - mutex_unlock(&fl->map_mutex); + if (err) { + if (map) { + mutex_lock(&fl->map_mutex); + fastrpc_mmap_free(map); + mutex_unlock(&fl->map_mutex); + } + if (!IS_ERR_OR_NULL(rbuf)) + fastrpc_buf_free(rbuf, 0); + } + mutex_unlock(&fl->internal_map_mutex); return err; } @@ -2562,8 +2593,8 @@ static void fastrpc_session_free(struct fastrpc_channel_ctx *chan, static int fastrpc_file_free(struct fastrpc_file *fl) { - struct hlist_node *n; - struct fastrpc_mmap *map = NULL; + struct hlist_node *n = NULL; + struct fastrpc_mmap *map = NULL, *lmap = NULL; int cid; if (!fl) @@ -2587,9 +2618,18 @@ static int fastrpc_file_free(struct fastrpc_file *fl) fastrpc_buf_free(fl->init_mem, 0); fastrpc_context_list_dtor(fl); fastrpc_cached_buf_list_free(fl); - hlist_for_each_entry_safe(map, n, &fl->maps, hn) { - fastrpc_mmap_free(map); - } + mutex_lock(&fl->map_mutex); + do { + lmap = NULL; + hlist_for_each_entry_safe(map, n, &fl->maps, hn) { + hlist_del_init(&map->hn); + lmap = map; + break; + } + fastrpc_mmap_free(lmap); + } while (lmap); + mutex_unlock(&fl->map_mutex); + if (fl->ssrcount == fl->apps->channel[cid].ssrcount) kref_put_mutex(&fl->apps->channel[cid].kref, fastrpc_channel_close, &fl->apps->smd_mutex); @@ -2600,6 +2640,7 @@ static int fastrpc_file_free(struct fastrpc_file *fl) bail: fastrpc_remote_buf_list_free(fl); mutex_destroy(&fl->map_mutex); + mutex_destroy(&fl->internal_map_mutex); kfree(fl); return 0; } @@ -2611,7 +2652,6 @@ static int fastrpc_device_release(struct inode *inode, struct file *file) if (fl) { if (fl->debugfs_file != NULL) debugfs_remove(fl->debugfs_file); - fastrpc_file_free(fl); file->private_data = NULL; } @@ -2914,6 +2954,7 @@ static ssize_t fastrpc_debugfs_read(struct file *filp, char __user *buffer, map->secure, map->attr); } mutex_unlock(&fl->map_mutex); + spin_lock(&fl->hlock); len += scnprintf(fileinfo + len, DEBUGFS_SIZE - len, "\n%s %s %s\n", title, " LIST OF PENDING SMQCONTEXTS ", title); @@ -3025,8 +3066,10 @@ static int fastrpc_channel_open(struct fastrpc_file *fl) } if (cid == 0 && me->channel[cid].ssrcount != me->channel[cid].prevssrcount) { + mutex_lock(&fl->map_mutex); if (fastrpc_mmap_remove_ssr(fl)) pr_err("ADSPRPC: SSR: Failed to unmap remote heap\n"); + mutex_unlock(&fl->map_mutex); me->channel[cid].prevssrcount = me->channel[cid].ssrcount; } @@ -3091,6 +3134,7 @@ static int fastrpc_device_open(struct inode *inode, struct file *filp) fl->debugfs_file = debugfs_file; memset(&fl->perf, 0, sizeof(fl->perf)); filp->private_data = fl; + mutex_init(&fl->internal_map_mutex); mutex_init(&fl->map_mutex); spin_lock(&me->hlock); hlist_add_head(&fl->hn, &me->drivers); diff --git a/drivers/gpu/msm/kgsl_iommu.c b/drivers/gpu/msm/kgsl_iommu.c index 31e8a7ea5f65..1f10cb4c1568 100644 --- a/drivers/gpu/msm/kgsl_iommu.c +++ b/drivers/gpu/msm/kgsl_iommu.c @@ -1,5 +1,5 @@ /* Copyright (c) 2011-2021, The Linux Foundation. All rights reserved. - * Copyright (c) 2022 Qualcomm Innovation Center, Inc. All rights reserved. + * Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved. * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 and @@ -2397,14 +2397,18 @@ static uint64_t kgsl_iommu_find_svm_region(struct kgsl_pagetable *pagetable, static bool iommu_addr_in_svm_ranges(struct kgsl_iommu_pt *pt, u64 gpuaddr, u64 size) { + u64 end = gpuaddr + size; + + /* Make sure size is not zero and we don't wrap around */ + if (end <= gpuaddr) + return false; + if ((gpuaddr >= pt->compat_va_start && gpuaddr < pt->compat_va_end) && - ((gpuaddr + size) > pt->compat_va_start && - (gpuaddr + size) <= pt->compat_va_end)) + (end > pt->compat_va_start && end <= pt->compat_va_end)) return true; if ((gpuaddr >= pt->svm_start && gpuaddr < pt->svm_end) && - ((gpuaddr + size) > pt->svm_start && - (gpuaddr + size) <= pt->svm_end)) + (end > pt->svm_start && end <= pt->svm_end)) return true; return false; diff --git a/drivers/hid/hid-core.c b/drivers/hid/hid-core.c index 9a5be0ca4342..c0f276e27561 100644 --- a/drivers/hid/hid-core.c +++ b/drivers/hid/hid-core.c @@ -246,6 +246,7 @@ static int hid_add_field(struct hid_parser *parser, unsigned report_type, unsign unsigned usages; unsigned offset; unsigned i; + unsigned int max_buffer_size = HID_MAX_BUFFER_SIZE; report = hid_register_report(parser->device, report_type, parser->global.report_id); if (!report) { @@ -269,8 +270,11 @@ static int hid_add_field(struct hid_parser *parser, unsigned report_type, unsign offset = report->size; report->size += parser->global.report_size * parser->global.report_count; + if (parser->device->ll_driver->max_buffer_size) + max_buffer_size = parser->device->ll_driver->max_buffer_size; + /* Total size check: Allow for possible report index byte */ - if (report->size > (HID_MAX_BUFFER_SIZE - 1) << 3) { + if (report->size > (max_buffer_size - 1) << 3) { hid_err(parser->device, "report is too long\n"); return -1; } @@ -964,8 +968,8 @@ struct hid_report *hid_validate_values(struct hid_device *hid, * Validating on id 0 means we should examine the first * report in the list. */ - report = list_entry( - hid->report_enum[type].report_list.next, + report = list_first_entry_or_null( + &hid->report_enum[type].report_list, struct hid_report, list); } else { report = hid->report_enum[type].report_id_hash[id]; @@ -1112,6 +1116,9 @@ static s32 snto32(__u32 value, unsigned n) if (!value || !n) return 0; + if (n > 32) + n = 32; + switch (n) { case 8: return ((__s8)value); case 16: return ((__s16)value); @@ -1506,6 +1513,7 @@ int hid_report_raw_event(struct hid_device *hid, int type, u8 *data, u32 size, struct hid_report_enum *report_enum = hid->report_enum + type; struct hid_report *report; struct hid_driver *hdrv; + int max_buffer_size = HID_MAX_BUFFER_SIZE; unsigned int a; u32 rsize, csize = size; u8 *cdata = data; @@ -1522,10 +1530,13 @@ int hid_report_raw_event(struct hid_device *hid, int type, u8 *data, u32 size, rsize = hid_compute_report_size(report); - if (report_enum->numbered && rsize >= HID_MAX_BUFFER_SIZE) - rsize = HID_MAX_BUFFER_SIZE - 1; - else if (rsize > HID_MAX_BUFFER_SIZE) - rsize = HID_MAX_BUFFER_SIZE; + if (hid->ll_driver->max_buffer_size) + max_buffer_size = hid->ll_driver->max_buffer_size; + + if (report_enum->numbered && rsize >= max_buffer_size) + rsize = max_buffer_size - 1; + else if (rsize > max_buffer_size) + rsize = max_buffer_size; if (csize < rsize) { dbg_hid("report %d is too short, (%d < %d)\n", report->id, diff --git a/drivers/iommu/io-pgtable-arm.c b/drivers/iommu/io-pgtable-arm.c index 3f1617ca2fc0..137062b22ca9 100644 --- a/drivers/iommu/io-pgtable-arm.c +++ b/drivers/iommu/io-pgtable-arm.c @@ -642,9 +642,11 @@ static int arm_lpae_map_sg(struct io_pgtable_ops *ops, unsigned long iova, arm_lpae_iopte *ptep = ms.pgtable + ARM_LPAE_LVL_IDX(iova, MAP_STATE_LVL, data); - arm_lpae_init_pte( + ret = arm_lpae_init_pte( data, iova, phys, prot, MAP_STATE_LVL, ptep, ms.prev_pgtable, false); + if (ret) + goto out_err; ms.num_pte++; } else { ret = __arm_lpae_map(data, iova, phys, pgsize, diff --git a/drivers/leds/leds-qpnp.c b/drivers/leds/leds-qpnp.c index deec2c4e246a..0eec6d0f52d4 100644 --- a/drivers/leds/leds-qpnp.c +++ b/drivers/leds/leds-qpnp.c @@ -2819,7 +2819,7 @@ static ssize_t rgb_blink_store(struct device *dev, const char *buf, size_t count) { struct rgb_sync *rgb_sync; - struct qpnp_led_data *led; + struct qpnp_led_data *led = NULL; unsigned long blinking; struct led_classdev *led_cdev = dev_get_drvdata(dev); ssize_t rc = -EINVAL, i; diff --git a/drivers/net/wireguard/compat/Makefile.include b/drivers/net/wireguard/compat/Makefile.include index 513dba444a37..df7670ae8d6c 100644 --- a/drivers/net/wireguard/compat/Makefile.include +++ b/drivers/net/wireguard/compat/Makefile.include @@ -6,11 +6,16 @@ kbuild-dir := $(if $(filter /%,$(src)),$(src),$(srctree)/$(src)) ccflags-y += -include $(kbuild-dir)/compat/compat.h asflags-y += -include $(kbuild-dir)/compat/compat-asm.h +LINUXINCLUDE := -DCOMPAT_VERSION=$(VERSION) -DCOMPAT_PATCHLEVEL=$(PATCHLEVEL) -DCOMPAT_SUBLEVEL=$(SUBLEVEL) -I$(kbuild-dir)/compat/version $(LINUXINCLUDE) ifeq ($(wildcard $(srctree)/include/linux/ptr_ring.h),) ccflags-y += -I$(kbuild-dir)/compat/ptr_ring/include endif +ifeq ($(wildcard $(srctree)/include/linux/skb_array.h),) +ccflags-y += -I$(kbuild-dir)/compat/skb_array/include +endif + ifeq ($(wildcard $(srctree)/include/linux/siphash.h),) ccflags-y += -I$(kbuild-dir)/compat/siphash/include wireguard-y += compat/siphash/siphash.o @@ -64,6 +69,10 @@ ifeq ($(wildcard $(srctree)/arch/arm64/include/asm/neon.h)$(CONFIG_ARM64),y) ccflags-y += -I$(kbuild-dir)/compat/neon-arm/include endif +ifeq ($(wildcard $(srctree)/include/net/dst_metadata.h),) +ccflags-y += -I$(kbuild-dir)/compat/dstmetadata/include +endif + ifeq ($(CONFIG_X86_64),y) ifeq ($(ssse3_instr),) ssse3_instr := $(call as-instr,pshufb %xmm0$(comma)%xmm0,-DCONFIG_AS_SSSE3=1) diff --git a/drivers/net/wireguard/compat/compat-asm.h b/drivers/net/wireguard/compat/compat-asm.h index fde21dabba4f..951fc1094470 100644 --- a/drivers/net/wireguard/compat/compat-asm.h +++ b/drivers/net/wireguard/compat/compat-asm.h @@ -15,14 +15,14 @@ #define ISRHEL7 #elif RHEL_MAJOR == 8 #define ISRHEL8 -#if RHEL_MINOR >= 4 +#if RHEL_MINOR >= 6 #define ISCENTOS8S #endif #endif #endif /* PaX compatibility */ -#if defined(RAP_PLUGIN) +#if defined(RAP_PLUGIN) && defined(RAP_ENTRY) #undef ENTRY #define ENTRY RAP_ENTRY #endif @@ -51,7 +51,7 @@ #undef pull #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 76) && !defined(ISCENTOS8S) +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 4, 76) && !defined(ISRHEL8) && !defined(SYM_FUNC_START) #define SYM_FUNC_START ENTRY #define SYM_FUNC_END ENDPROC #endif diff --git a/drivers/net/wireguard/compat/compat.h b/drivers/net/wireguard/compat/compat.h index 91d4388824ea..d166ac235a99 100644 --- a/drivers/net/wireguard/compat/compat.h +++ b/drivers/net/wireguard/compat/compat.h @@ -16,15 +16,13 @@ #define ISRHEL7 #elif RHEL_MAJOR == 8 #define ISRHEL8 -#if RHEL_MINOR >= 4 +#if RHEL_MINOR >= 6 #define ISCENTOS8S #endif #endif #endif #ifdef UTS_UBUNTU_RELEASE_ABI -#if LINUX_VERSION_CODE == KERNEL_VERSION(3, 13, 11) -#define ISUBUNTU1404 -#elif LINUX_VERSION_CODE < KERNEL_VERSION(4, 5, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 4, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 5, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 4, 0) #define ISUBUNTU1604 #elif LINUX_VERSION_CODE < KERNEL_VERSION(4, 16, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0) #define ISUBUNTU1804 @@ -219,7 +217,7 @@ static inline void skb_scrub_packet(struct sk_buff *skb, bool xnet) #define skb_scrub_packet(a, b) skb_scrub_packet(a) #endif -#if ((LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 63) || defined(ISUBUNTU1404)) && !defined(ISRHEL7) +#if ((LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 63)) && !defined(ISRHEL7) #include <linux/random.h> static inline u32 __compat_prandom_u32_max(u32 ep_ro) { @@ -268,7 +266,7 @@ static inline u32 __compat_prandom_u32_max(u32 ep_ro) #endif #endif -#if (LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 3) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 16, 35) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 15, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 24) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0) && !defined(ISUBUNTU1404)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 33) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 11, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 10, 60) && !defined(ISRHEL7)) +#if (LINUX_VERSION_CODE < KERNEL_VERSION(3, 17, 3) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 16, 35) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 15, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 14, 24) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 33) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 11, 0)) || (LINUX_VERSION_CODE < KERNEL_VERSION(3, 10, 60) && !defined(ISRHEL7)) static inline void memzero_explicit(void *s, size_t count) { memset(s, 0, count); @@ -281,7 +279,7 @@ static const struct in6_addr __compat_in6addr_any = IN6ADDR_ANY_INIT; #define in6addr_any __compat_in6addr_any #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 13, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 13, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 320)) #include <linux/completion.h> #include <linux/random.h> #include <linux/errno.h> @@ -325,7 +323,7 @@ static inline int wait_for_random_bytes(void) } #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 19, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0) && !defined(ISRHEL8) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 19, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 14, 285)) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 320)) && !defined(ISRHEL8) #include <linux/random.h> #include <linux/slab.h> struct rng_is_initialized_callback { @@ -377,7 +375,7 @@ static inline bool rng_is_initialized(void) } #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 13, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 13, 0) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 320)) static inline int get_random_bytes_wait(void *buf, int nbytes) { int ret = wait_for_random_bytes(); @@ -502,7 +500,7 @@ static inline void *__compat_kvzalloc(size_t size, gfp_t flags) #define kvzalloc __compat_kvzalloc #endif -#if ((LINUX_VERSION_CODE < KERNEL_VERSION(3, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 41)) && !defined(ISUBUNTU1404) +#if ((LINUX_VERSION_CODE < KERNEL_VERSION(3, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 13, 0)) || LINUX_VERSION_CODE < KERNEL_VERSION(3, 12, 41)) #include <linux/vmalloc.h> #include <linux/mm.h> static inline void __compat_kvfree(const void *addr) @@ -515,6 +513,28 @@ static inline void __compat_kvfree(const void *addr) #define kvfree __compat_kvfree #endif +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 12, 0) +#include <linux/vmalloc.h> +#include <linux/mm.h> +static inline void *__compat_kvmalloc_array(size_t n, size_t size, gfp_t flags) +{ + if (n != 0 && SIZE_MAX / n < size) + return NULL; + return kvmalloc(n * size, flags); +} +#define kvmalloc_array __compat_kvmalloc_array +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 18, 0) +#include <linux/vmalloc.h> +#include <linux/mm.h> +static inline void *__compat_kvcalloc(size_t n, size_t size, gfp_t flags) +{ + return kvmalloc_array(n, size, flags | __GFP_ZERO); +} +#define kvcalloc __compat_kvcalloc +#endif + #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 11, 9) #include <linux/netdevice.h> #define priv_destructor destructor @@ -704,7 +724,7 @@ static inline void *skb_put_data(struct sk_buff *skb, const void *data, unsigned #endif #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 17, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(4, 17, 0) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 14, 285)) && (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 10, 0) || LINUX_VERSION_CODE < KERNEL_VERSION(4, 9, 320)) static inline void le32_to_cpu_array(u32 *buf, unsigned int words) { while (words--) { @@ -757,7 +777,7 @@ static inline void crypto_xor_cpy(u8 *dst, const u8 *src1, const u8 *src2, #define hlist_add_behind(a, b) hlist_add_after(b, a) #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 0, 0) && !defined(ISCENTOS8S) +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 0, 0) && !defined(ISRHEL8) #define totalram_pages() totalram_pages #endif @@ -831,10 +851,16 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb) #endif #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0) && !defined(ISRHEL8) +#include <net/netlink.h> +#ifndef NLA_POLICY_EXACT_LEN #define NLA_POLICY_EXACT_LEN(_len) { .type = NLA_UNSPEC, .len = _len } #endif +#endif #if LINUX_VERSION_CODE < KERNEL_VERSION(5, 2, 0) && !defined(ISRHEL8) +#include <net/netlink.h> +#ifndef NLA_POLICY_MIN_LEN #define NLA_POLICY_MIN_LEN(_len) { .type = NLA_UNSPEC, .len = _len } +#endif #define COMPAT_CANNOT_INDIVIDUAL_NETLINK_OPS_POLICY #endif @@ -849,7 +875,7 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb) #endif #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 0) && !defined(ISRHEL8) #define genl_dumpit_info(cb) ({ \ struct { struct nlattr **attrs; } *a = (void *)((u8 *)cb->args + offsetofend(struct dump_ctx, next_allowedip)); \ BUILD_BUG_ON(sizeof(cb->args) < offsetofend(struct dump_ctx, next_allowedip) + sizeof(*a)); \ @@ -869,11 +895,13 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb) #endif #endif -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0) +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 200) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 249)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 14, 285)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 10, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 9, 320)) #define blake2s_init zinc_blake2s_init #define blake2s_init_key zinc_blake2s_init_key #define blake2s_update zinc_blake2s_update #define blake2s_final zinc_blake2s_final +#endif +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0) #define blake2s_hmac zinc_blake2s_hmac #define chacha20 zinc_chacha20 #define hchacha20 zinc_hchacha20 @@ -1096,6 +1124,37 @@ static const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tun #endif #endif +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 16, 0) +#include <net/dst_cache.h> +struct dst_cache_pcpu { + unsigned long refresh_ts; + struct dst_entry *dst; + u32 cookie; + union { + struct in_addr in_saddr; + struct in6_addr in6_saddr; + }; +}; +#define COMPAT_HAS_DEFINED_DST_CACHE_PCPU +static inline void dst_cache_reset_now(struct dst_cache *dst_cache) +{ + int i; + + if (!dst_cache->cache) + return; + + dst_cache->reset_ts = jiffies; + for_each_possible_cpu(i) { + struct dst_cache_pcpu *idst = per_cpu_ptr(dst_cache->cache, i); + struct dst_entry *dst = idst->dst; + + idst->cookie = 0; + idst->dst = NULL; + dst_release(dst); + } +} +#endif + #if defined(ISUBUNTU1604) || defined(ISRHEL7) #include <linux/siphash.h> #ifndef _WG_LINUX_SIPHASH_H @@ -1127,7 +1186,7 @@ static const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tun #undef __read_mostly #define __read_mostly #endif -#if (defined(RAP_PLUGIN) || defined(CONFIG_CFI_CLANG)) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) +#if (defined(CONFIG_PAX) || defined(CONFIG_CFI_CLANG)) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) #include <linux/timer.h> #define wg_expired_retransmit_handshake(a) wg_expired_retransmit_handshake(unsigned long timer) #define wg_expired_send_keepalive(a) wg_expired_send_keepalive(unsigned long timer) diff --git a/drivers/net/wireguard/compat/dst_cache/dst_cache.c b/drivers/net/wireguard/compat/dst_cache/dst_cache.c index 7ec22f768a8f..f74c43c550eb 100644 --- a/drivers/net/wireguard/compat/dst_cache/dst_cache.c +++ b/drivers/net/wireguard/compat/dst_cache/dst_cache.c @@ -27,6 +27,7 @@ static inline u32 rt6_get_cookie(const struct rt6_info *rt) #endif #include <uapi/linux/in.h> +#ifndef COMPAT_HAS_DEFINED_DST_CACHE_PCPU struct dst_cache_pcpu { unsigned long refresh_ts; struct dst_entry *dst; @@ -36,6 +37,7 @@ struct dst_cache_pcpu { struct in6_addr in6_saddr; }; }; +#endif static void dst_cache_per_cpu_dst_set(struct dst_cache_pcpu *dst_cache, struct dst_entry *dst, u32 cookie) diff --git a/drivers/net/wireguard/compat/dstmetadata/include/net/dst_metadata.h b/drivers/net/wireguard/compat/dstmetadata/include/net/dst_metadata.h new file mode 100644 index 000000000000..995094d4f099 --- /dev/null +++ b/drivers/net/wireguard/compat/dstmetadata/include/net/dst_metadata.h @@ -0,0 +1,3 @@ +#ifndef skb_valid_dst +#define skb_valid_dst(skb) (!!skb_dst(skb)) +#endif diff --git a/drivers/net/wireguard/compat/siphash/include/linux/siphash.h b/drivers/net/wireguard/compat/siphash/include/linux/siphash.h index 1e5e337d15bf..3b30b3c47778 100644 --- a/drivers/net/wireguard/compat/siphash/include/linux/siphash.h +++ b/drivers/net/wireguard/compat/siphash/include/linux/siphash.h @@ -22,9 +22,7 @@ typedef struct { } siphash_key_t; u64 __siphash_aligned(const void *data, size_t len, const siphash_key_t *key); -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS u64 __siphash_unaligned(const void *data, size_t len, const siphash_key_t *key); -#endif u64 siphash_1u64(const u64 a, const siphash_key_t *key); u64 siphash_2u64(const u64 a, const u64 b, const siphash_key_t *key); @@ -77,10 +75,9 @@ static inline u64 ___siphash_aligned(const __le64 *data, size_t len, static inline u64 siphash(const void *data, size_t len, const siphash_key_t *key) { -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS - if (!IS_ALIGNED((unsigned long)data, SIPHASH_ALIGNMENT)) + if (IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) || + !IS_ALIGNED((unsigned long)data, SIPHASH_ALIGNMENT)) return __siphash_unaligned(data, len, key); -#endif return ___siphash_aligned(data, len, key); } @@ -91,10 +88,8 @@ typedef struct { u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key); -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS u32 __hsiphash_unaligned(const void *data, size_t len, const hsiphash_key_t *key); -#endif u32 hsiphash_1u32(const u32 a, const hsiphash_key_t *key); u32 hsiphash_2u32(const u32 a, const u32 b, const hsiphash_key_t *key); @@ -130,10 +125,9 @@ static inline u32 ___hsiphash_aligned(const __le32 *data, size_t len, static inline u32 hsiphash(const void *data, size_t len, const hsiphash_key_t *key) { -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS - if (!IS_ALIGNED((unsigned long)data, HSIPHASH_ALIGNMENT)) + if (IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) || + !IS_ALIGNED((unsigned long)data, HSIPHASH_ALIGNMENT)) return __hsiphash_unaligned(data, len, key); -#endif return ___hsiphash_aligned(data, len, key); } diff --git a/drivers/net/wireguard/compat/siphash/siphash.c b/drivers/net/wireguard/compat/siphash/siphash.c index 58855328e6e0..7dc72cb4a710 100644 --- a/drivers/net/wireguard/compat/siphash/siphash.c +++ b/drivers/net/wireguard/compat/siphash/siphash.c @@ -57,6 +57,7 @@ SIPROUND; \ return (v0 ^ v1) ^ (v2 ^ v3); +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS u64 __siphash_aligned(const void *data, size_t len, const siphash_key_t *key) { const u8 *end = data + len - (len % sizeof(u64)); @@ -76,19 +77,19 @@ u64 __siphash_aligned(const void *data, size_t len, const siphash_key_t *key) bytemask_from_count(left))); #else switch (left) { - case 7: b |= ((u64)end[6]) << 48; - case 6: b |= ((u64)end[5]) << 40; - case 5: b |= ((u64)end[4]) << 32; + case 7: b |= ((u64)end[6]) << 48; fallthrough; + case 6: b |= ((u64)end[5]) << 40; fallthrough; + case 5: b |= ((u64)end[4]) << 32; fallthrough; case 4: b |= le32_to_cpup(data); break; - case 3: b |= ((u64)end[2]) << 16; + case 3: b |= ((u64)end[2]) << 16; fallthrough; case 2: b |= le16_to_cpup(data); break; case 1: b |= end[0]; } #endif POSTAMBLE } +#endif -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS u64 __siphash_unaligned(const void *data, size_t len, const siphash_key_t *key) { const u8 *end = data + len - (len % sizeof(u64)); @@ -108,18 +109,17 @@ u64 __siphash_unaligned(const void *data, size_t len, const siphash_key_t *key) bytemask_from_count(left))); #else switch (left) { - case 7: b |= ((u64)end[6]) << 48; - case 6: b |= ((u64)end[5]) << 40; - case 5: b |= ((u64)end[4]) << 32; + case 7: b |= ((u64)end[6]) << 48; fallthrough; + case 6: b |= ((u64)end[5]) << 40; fallthrough; + case 5: b |= ((u64)end[4]) << 32; fallthrough; case 4: b |= get_unaligned_le32(end); break; - case 3: b |= ((u64)end[2]) << 16; + case 3: b |= ((u64)end[2]) << 16; fallthrough; case 2: b |= get_unaligned_le16(end); break; case 1: b |= end[0]; } #endif POSTAMBLE } -#endif /** * siphash_1u64 - compute 64-bit siphash PRF value of a u64 @@ -250,6 +250,7 @@ u64 siphash_3u32(const u32 first, const u32 second, const u32 third, HSIPROUND; \ return (v0 ^ v1) ^ (v2 ^ v3); +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key) { const u8 *end = data + len - (len % sizeof(u64)); @@ -268,19 +269,19 @@ u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key) bytemask_from_count(left))); #else switch (left) { - case 7: b |= ((u64)end[6]) << 48; - case 6: b |= ((u64)end[5]) << 40; - case 5: b |= ((u64)end[4]) << 32; + case 7: b |= ((u64)end[6]) << 48; fallthrough; + case 6: b |= ((u64)end[5]) << 40; fallthrough; + case 5: b |= ((u64)end[4]) << 32; fallthrough; case 4: b |= le32_to_cpup(data); break; - case 3: b |= ((u64)end[2]) << 16; + case 3: b |= ((u64)end[2]) << 16; fallthrough; case 2: b |= le16_to_cpup(data); break; case 1: b |= end[0]; } #endif HPOSTAMBLE } +#endif -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS u32 __hsiphash_unaligned(const void *data, size_t len, const hsiphash_key_t *key) { @@ -300,18 +301,17 @@ u32 __hsiphash_unaligned(const void *data, size_t len, bytemask_from_count(left))); #else switch (left) { - case 7: b |= ((u64)end[6]) << 48; - case 6: b |= ((u64)end[5]) << 40; - case 5: b |= ((u64)end[4]) << 32; + case 7: b |= ((u64)end[6]) << 48; fallthrough; + case 6: b |= ((u64)end[5]) << 40; fallthrough; + case 5: b |= ((u64)end[4]) << 32; fallthrough; case 4: b |= get_unaligned_le32(end); break; - case 3: b |= ((u64)end[2]) << 16; + case 3: b |= ((u64)end[2]) << 16; fallthrough; case 2: b |= get_unaligned_le16(end); break; case 1: b |= end[0]; } #endif HPOSTAMBLE } -#endif /** * hsiphash_1u32 - compute 64-bit hsiphash PRF value of a u32 @@ -412,6 +412,7 @@ u32 hsiphash_4u32(const u32 first, const u32 second, const u32 third, HSIPROUND; \ return v1 ^ v3; +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key) { const u8 *end = data + len - (len % sizeof(u32)); @@ -425,14 +426,14 @@ u32 __hsiphash_aligned(const void *data, size_t len, const hsiphash_key_t *key) v0 ^= m; } switch (left) { - case 3: b |= ((u32)end[2]) << 16; + case 3: b |= ((u32)end[2]) << 16; fallthrough; case 2: b |= le16_to_cpup(data); break; case 1: b |= end[0]; } HPOSTAMBLE } +#endif -#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS u32 __hsiphash_unaligned(const void *data, size_t len, const hsiphash_key_t *key) { @@ -447,13 +448,12 @@ u32 __hsiphash_unaligned(const void *data, size_t len, v0 ^= m; } switch (left) { - case 3: b |= ((u32)end[2]) << 16; + case 3: b |= ((u32)end[2]) << 16; fallthrough; case 2: b |= get_unaligned_le16(end); break; case 1: b |= end[0]; } HPOSTAMBLE } -#endif /** * hsiphash_1u32 - compute 32-bit hsiphash PRF value of a u32 diff --git a/drivers/net/wireguard/compat/skb_array/include/linux/skb_array.h b/drivers/net/wireguard/compat/skb_array/include/linux/skb_array.h new file mode 100644 index 000000000000..c91fedcdbfc6 --- /dev/null +++ b/drivers/net/wireguard/compat/skb_array/include/linux/skb_array.h @@ -0,0 +1,11 @@ +#ifndef _WG_SKB_ARRAY_H +#define _WG_SKB_ARRAY_H + +#include <linux/skbuff.h> + +static void __skb_array_destroy_skb(void *ptr) +{ + kfree_skb(ptr); +} + +#endif diff --git a/drivers/net/wireguard/compat/udp_tunnel/udp_tunnel.c b/drivers/net/wireguard/compat/udp_tunnel/udp_tunnel.c index 9b8770ae7b3f..d287b917be84 100644 --- a/drivers/net/wireguard/compat/udp_tunnel/udp_tunnel.c +++ b/drivers/net/wireguard/compat/udp_tunnel/udp_tunnel.c @@ -38,9 +38,10 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, struct socket *sock = NULL; struct sockaddr_in udp_addr; - err = __sock_create(net, AF_INET, SOCK_DGRAM, 0, &sock, 1); + err = sock_create_kern(AF_INET, SOCK_DGRAM, 0, &sock); if (err < 0) goto error; + sk_change_net(sock->sk, net); udp_addr.sin_family = AF_INET; udp_addr.sin_addr = cfg->local_ip; @@ -72,7 +73,7 @@ int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, error: if (sock) { kernel_sock_shutdown(sock, SHUT_RDWR); - sock_release(sock); + sk_release_kernel(sock->sk); } *sockp = NULL; return err; @@ -229,7 +230,7 @@ void udp_tunnel_sock_release(struct socket *sock) { rcu_assign_sk_user_data(sock->sk, NULL); kernel_sock_shutdown(sock, SHUT_RDWR); - sock_release(sock); + sk_release_kernel(sock->sk); } #if IS_ENABLED(CONFIG_IPV6) @@ -254,9 +255,10 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg, int err; struct socket *sock = NULL; - err = __sock_create(net, AF_INET6, SOCK_DGRAM, 0, &sock, 1); + err = sock_create_kern(AF_INET6, SOCK_DGRAM, 0, &sock); if (err < 0) goto error; + sk_change_net(sock->sk, net); if (cfg->ipv6_v6only) { int val = 1; @@ -301,7 +303,7 @@ int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg, error: if (sock) { kernel_sock_shutdown(sock, SHUT_RDWR); - sock_release(sock); + sk_release_kernel(sock->sk); } *sockp = NULL; return err; diff --git a/drivers/net/wireguard/crypto/zinc/curve25519/curve25519-x86_64.c b/drivers/net/wireguard/crypto/zinc/curve25519/curve25519-x86_64.c index 79716c425b0c..8b6872a2f0d0 100644 --- a/drivers/net/wireguard/crypto/zinc/curve25519/curve25519-x86_64.c +++ b/drivers/net/wireguard/crypto/zinc/curve25519/curve25519-x86_64.c @@ -34,11 +34,11 @@ static inline u64 add_scalar(u64 *out, const u64 *f1, u64 f2) asm volatile( /* Clear registers to propagate the carry bit */ - " xor %%r8, %%r8;" - " xor %%r9, %%r9;" - " xor %%r10, %%r10;" - " xor %%r11, %%r11;" - " xor %1, %1;" + " xor %%r8d, %%r8d;" + " xor %%r9d, %%r9d;" + " xor %%r10d, %%r10d;" + " xor %%r11d, %%r11d;" + " xor %k1, %k1;" /* Begin addition chain */ " addq 0(%3), %0;" @@ -52,10 +52,9 @@ static inline u64 add_scalar(u64 *out, const u64 *f1, u64 f2) /* Return the carry bit in a register */ " adcx %%r11, %1;" - : "+&r" (f2), "=&r" (carry_r) - : "r" (out), "r" (f1) - : "%r8", "%r9", "%r10", "%r11", "memory", "cc" - ); + : "+&r"(f2), "=&r"(carry_r) + : "r"(out), "r"(f1) + : "%r8", "%r9", "%r10", "%r11", "memory", "cc"); return carry_r; } @@ -82,7 +81,7 @@ static inline void fadd(u64 *out, const u64 *f1, const u64 *f2) " cmovc %0, %%rax;" /* Step 2: Add carry*38 to the original sum */ - " xor %%rcx, %%rcx;" + " xor %%ecx, %%ecx;" " add %%rax, %%r8;" " adcx %%rcx, %%r9;" " movq %%r9, 8(%1);" @@ -96,17 +95,16 @@ static inline void fadd(u64 *out, const u64 *f1, const u64 *f2) " cmovc %0, %%rax;" " add %%rax, %%r8;" " movq %%r8, 0(%1);" - : "+&r" (f2) - : "r" (out), "r" (f1) - : "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "memory", "cc" - ); + : "+&r"(f2) + : "r"(out), "r"(f1) + : "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "memory", "cc"); } -/* Computes the field substraction of two field elements */ +/* Computes the field subtraction of two field elements */ static inline void fsub(u64 *out, const u64 *f1, const u64 *f2) { asm volatile( - /* Compute the raw substraction of f1-f2 */ + /* Compute the raw subtraction of f1-f2 */ " movq 0(%1), %%r8;" " subq 0(%2), %%r8;" " movq 8(%1), %%r9;" @@ -123,7 +121,7 @@ static inline void fsub(u64 *out, const u64 *f1, const u64 *f2) " mov $38, %%rcx;" " cmovc %%rcx, %%rax;" - /* Step 2: Substract carry*38 from the original difference */ + /* Step 2: Subtract carry*38 from the original difference */ " sub %%rax, %%r8;" " sbb $0, %%r9;" " sbb $0, %%r10;" @@ -139,10 +137,9 @@ static inline void fsub(u64 *out, const u64 *f1, const u64 *f2) " movq %%r9, 8(%0);" " movq %%r10, 16(%0);" " movq %%r11, 24(%0);" - : - : "r" (out), "r" (f1), "r" (f2) - : "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "memory", "cc" - ); + : + : "r"(out), "r"(f1), "r"(f2) + : "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "memory", "cc"); } /* Computes a field multiplication: out <- f1 * f2 @@ -150,239 +147,400 @@ static inline void fsub(u64 *out, const u64 *f1, const u64 *f2) static inline void fmul(u64 *out, const u64 *f1, const u64 *f2, u64 *tmp) { asm volatile( + /* Compute the raw multiplication: tmp <- src1 * src2 */ /* Compute src1[0] * src2 */ - " movq 0(%1), %%rdx;" - " mulxq 0(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " movq %%r8, 0(%0);" - " mulxq 8(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " movq %%r10, 8(%0);" - " mulxq 16(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" - " mulxq 24(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" + " movq 0(%0), %%rdx;" + " mulxq 0(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " movq %%r8, 0(%2);" + " mulxq 8(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " movq %%r10, 8(%2);" + " mulxq 16(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " mulxq 24(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + /* Compute src1[1] * src2 */ - " movq 8(%1), %%rdx;" - " mulxq 0(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 8(%0), %%r8;" " movq %%r8, 8(%0);" - " mulxq 8(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 16(%0);" - " mulxq 16(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " mov $0, %%r8;" - " mulxq 24(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" + " movq 8(%0), %%rdx;" + " mulxq 0(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 8(%2), %%r8;" + " movq %%r8, 8(%2);" + " mulxq 8(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 16(%2);" + " mulxq 16(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " mov $0, %%r8;" + " mulxq 24(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + /* Compute src1[2] * src2 */ - " movq 16(%1), %%rdx;" - " mulxq 0(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 16(%0), %%r8;" " movq %%r8, 16(%0);" - " mulxq 8(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 24(%0);" - " mulxq 16(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " mov $0, %%r8;" - " mulxq 24(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" + " movq 16(%0), %%rdx;" + " mulxq 0(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 16(%2), %%r8;" + " movq %%r8, 16(%2);" + " mulxq 8(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 24(%2);" + " mulxq 16(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " mov $0, %%r8;" + " mulxq 24(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + /* Compute src1[3] * src2 */ - " movq 24(%1), %%rdx;" - " mulxq 0(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 24(%0), %%r8;" " movq %%r8, 24(%0);" - " mulxq 8(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 32(%0);" - " mulxq 16(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " movq %%rbx, 40(%0);" " mov $0, %%r8;" - " mulxq 24(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " movq %%r14, 48(%0);" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" " movq %%rax, 56(%0);" + " movq 24(%0), %%rdx;" + " mulxq 0(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 24(%2), %%r8;" + " movq %%r8, 24(%2);" + " mulxq 8(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 32(%2);" + " mulxq 16(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " movq %%rbx, 40(%2);" + " mov $0, %%r8;" + " mulxq 24(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " movq %%r14, 48(%2);" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + " movq %%rax, 56(%2);" + /* Line up pointers */ - " mov %0, %1;" " mov %2, %0;" + " mov %3, %2;" /* Wrap the result back into the field */ /* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */ " mov $38, %%rdx;" - " mulxq 32(%1), %%r8, %%r13;" - " xor %3, %3;" - " adoxq 0(%1), %%r8;" - " mulxq 40(%1), %%r9, %%rbx;" + " mulxq 32(%0), %%r8, %%r13;" + " xor %k1, %k1;" + " adoxq 0(%0), %%r8;" + " mulxq 40(%0), %%r9, %%rbx;" " adcx %%r13, %%r9;" - " adoxq 8(%1), %%r9;" - " mulxq 48(%1), %%r10, %%r13;" + " adoxq 8(%0), %%r9;" + " mulxq 48(%0), %%r10, %%r13;" " adcx %%rbx, %%r10;" - " adoxq 16(%1), %%r10;" - " mulxq 56(%1), %%r11, %%rax;" + " adoxq 16(%0), %%r10;" + " mulxq 56(%0), %%r11, %%rax;" " adcx %%r13, %%r11;" - " adoxq 24(%1), %%r11;" - " adcx %3, %%rax;" - " adox %3, %%rax;" + " adoxq 24(%0), %%r11;" + " adcx %1, %%rax;" + " adox %1, %%rax;" " imul %%rdx, %%rax;" /* Step 2: Fold the carry back into dst */ " add %%rax, %%r8;" - " adcx %3, %%r9;" - " movq %%r9, 8(%0);" - " adcx %3, %%r10;" - " movq %%r10, 16(%0);" - " adcx %3, %%r11;" - " movq %%r11, 24(%0);" + " adcx %1, %%r9;" + " movq %%r9, 8(%2);" + " adcx %1, %%r10;" + " movq %%r10, 16(%2);" + " adcx %1, %%r11;" + " movq %%r11, 24(%2);" /* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */ " mov $0, %%rax;" " cmovc %%rdx, %%rax;" " add %%rax, %%r8;" - " movq %%r8, 0(%0);" - : "+&r" (tmp), "+&r" (f1), "+&r" (out), "+&r" (f2) - : - : "%rax", "%rdx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "%r14", "memory", "cc" - ); + " movq %%r8, 0(%2);" + : "+&r"(f1), "+&r"(f2), "+&r"(tmp) + : "r"(out) + : "%rax", "%rbx", "%rdx", "%r8", "%r9", "%r10", "%r11", "%r13", + "%r14", "memory", "cc"); } /* Computes two field multiplications: - * out[0] <- f1[0] * f2[0] - * out[1] <- f1[1] * f2[1] - * Uses the 16-element buffer tmp for intermediate results. */ + * out[0] <- f1[0] * f2[0] + * out[1] <- f1[1] * f2[1] + * Uses the 16-element buffer tmp for intermediate results: */ static inline void fmul2(u64 *out, const u64 *f1, const u64 *f2, u64 *tmp) { asm volatile( + /* Compute the raw multiplication tmp[0] <- f1[0] * f2[0] */ /* Compute src1[0] * src2 */ - " movq 0(%1), %%rdx;" - " mulxq 0(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " movq %%r8, 0(%0);" - " mulxq 8(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " movq %%r10, 8(%0);" - " mulxq 16(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" - " mulxq 24(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" + " movq 0(%0), %%rdx;" + " mulxq 0(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " movq %%r8, 0(%2);" + " mulxq 8(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " movq %%r10, 8(%2);" + " mulxq 16(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " mulxq 24(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + /* Compute src1[1] * src2 */ - " movq 8(%1), %%rdx;" - " mulxq 0(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 8(%0), %%r8;" " movq %%r8, 8(%0);" - " mulxq 8(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 16(%0);" - " mulxq 16(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " mov $0, %%r8;" - " mulxq 24(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" + " movq 8(%0), %%rdx;" + " mulxq 0(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 8(%2), %%r8;" + " movq %%r8, 8(%2);" + " mulxq 8(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 16(%2);" + " mulxq 16(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " mov $0, %%r8;" + " mulxq 24(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + /* Compute src1[2] * src2 */ - " movq 16(%1), %%rdx;" - " mulxq 0(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 16(%0), %%r8;" " movq %%r8, 16(%0);" - " mulxq 8(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 24(%0);" - " mulxq 16(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " mov $0, %%r8;" - " mulxq 24(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" + " movq 16(%0), %%rdx;" + " mulxq 0(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 16(%2), %%r8;" + " movq %%r8, 16(%2);" + " mulxq 8(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 24(%2);" + " mulxq 16(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " mov $0, %%r8;" + " mulxq 24(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + /* Compute src1[3] * src2 */ - " movq 24(%1), %%rdx;" - " mulxq 0(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 24(%0), %%r8;" " movq %%r8, 24(%0);" - " mulxq 8(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 32(%0);" - " mulxq 16(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " movq %%rbx, 40(%0);" " mov $0, %%r8;" - " mulxq 24(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " movq %%r14, 48(%0);" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" " movq %%rax, 56(%0);" + " movq 24(%0), %%rdx;" + " mulxq 0(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 24(%2), %%r8;" + " movq %%r8, 24(%2);" + " mulxq 8(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 32(%2);" + " mulxq 16(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " movq %%rbx, 40(%2);" + " mov $0, %%r8;" + " mulxq 24(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " movq %%r14, 48(%2);" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + " movq %%rax, 56(%2);" /* Compute the raw multiplication tmp[1] <- f1[1] * f2[1] */ /* Compute src1[0] * src2 */ - " movq 32(%1), %%rdx;" - " mulxq 32(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " movq %%r8, 64(%0);" - " mulxq 40(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " movq %%r10, 72(%0);" - " mulxq 48(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" - " mulxq 56(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" + " movq 32(%0), %%rdx;" + " mulxq 32(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " movq %%r8, 64(%2);" + " mulxq 40(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " movq %%r10, 72(%2);" + " mulxq 48(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " mulxq 56(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + /* Compute src1[1] * src2 */ - " movq 40(%1), %%rdx;" - " mulxq 32(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 72(%0), %%r8;" " movq %%r8, 72(%0);" - " mulxq 40(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 80(%0);" - " mulxq 48(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " mov $0, %%r8;" - " mulxq 56(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" + " movq 40(%0), %%rdx;" + " mulxq 32(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 72(%2), %%r8;" + " movq %%r8, 72(%2);" + " mulxq 40(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 80(%2);" + " mulxq 48(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " mov $0, %%r8;" + " mulxq 56(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + /* Compute src1[2] * src2 */ - " movq 48(%1), %%rdx;" - " mulxq 32(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 80(%0), %%r8;" " movq %%r8, 80(%0);" - " mulxq 40(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 88(%0);" - " mulxq 48(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " mov $0, %%r8;" - " mulxq 56(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" + " movq 48(%0), %%rdx;" + " mulxq 32(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 80(%2), %%r8;" + " movq %%r8, 80(%2);" + " mulxq 40(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 88(%2);" + " mulxq 48(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " mov $0, %%r8;" + " mulxq 56(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + /* Compute src1[3] * src2 */ - " movq 56(%1), %%rdx;" - " mulxq 32(%3), %%r8, %%r9;" " xor %%r10, %%r10;" " adcxq 88(%0), %%r8;" " movq %%r8, 88(%0);" - " mulxq 40(%3), %%r10, %%r11;" " adox %%r9, %%r10;" " adcx %%rbx, %%r10;" " movq %%r10, 96(%0);" - " mulxq 48(%3), %%rbx, %%r13;" " adox %%r11, %%rbx;" " adcx %%r14, %%rbx;" " movq %%rbx, 104(%0);" " mov $0, %%r8;" - " mulxq 56(%3), %%r14, %%rdx;" " adox %%r13, %%r14;" " adcx %%rax, %%r14;" " movq %%r14, 112(%0);" " mov $0, %%rax;" - " adox %%rdx, %%rax;" " adcx %%r8, %%rax;" " movq %%rax, 120(%0);" + " movq 56(%0), %%rdx;" + " mulxq 32(%1), %%r8, %%r9;" + " xor %%r10d, %%r10d;" + " adcxq 88(%2), %%r8;" + " movq %%r8, 88(%2);" + " mulxq 40(%1), %%r10, %%r11;" + " adox %%r9, %%r10;" + " adcx %%rbx, %%r10;" + " movq %%r10, 96(%2);" + " mulxq 48(%1), %%rbx, %%r13;" + " adox %%r11, %%rbx;" + " adcx %%r14, %%rbx;" + " movq %%rbx, 104(%2);" + " mov $0, %%r8;" + " mulxq 56(%1), %%r14, %%rdx;" + " adox %%r13, %%r14;" + " adcx %%rax, %%r14;" + " movq %%r14, 112(%2);" + " mov $0, %%rax;" + " adox %%rdx, %%rax;" + " adcx %%r8, %%rax;" + " movq %%rax, 120(%2);" + /* Line up pointers */ - " mov %0, %1;" " mov %2, %0;" + " mov %3, %2;" /* Wrap the results back into the field */ /* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */ " mov $38, %%rdx;" - " mulxq 32(%1), %%r8, %%r13;" - " xor %3, %3;" - " adoxq 0(%1), %%r8;" - " mulxq 40(%1), %%r9, %%rbx;" + " mulxq 32(%0), %%r8, %%r13;" + " xor %k1, %k1;" + " adoxq 0(%0), %%r8;" + " mulxq 40(%0), %%r9, %%rbx;" " adcx %%r13, %%r9;" - " adoxq 8(%1), %%r9;" - " mulxq 48(%1), %%r10, %%r13;" + " adoxq 8(%0), %%r9;" + " mulxq 48(%0), %%r10, %%r13;" " adcx %%rbx, %%r10;" - " adoxq 16(%1), %%r10;" - " mulxq 56(%1), %%r11, %%rax;" + " adoxq 16(%0), %%r10;" + " mulxq 56(%0), %%r11, %%rax;" " adcx %%r13, %%r11;" - " adoxq 24(%1), %%r11;" - " adcx %3, %%rax;" - " adox %3, %%rax;" + " adoxq 24(%0), %%r11;" + " adcx %1, %%rax;" + " adox %1, %%rax;" " imul %%rdx, %%rax;" /* Step 2: Fold the carry back into dst */ " add %%rax, %%r8;" - " adcx %3, %%r9;" - " movq %%r9, 8(%0);" - " adcx %3, %%r10;" - " movq %%r10, 16(%0);" - " adcx %3, %%r11;" - " movq %%r11, 24(%0);" + " adcx %1, %%r9;" + " movq %%r9, 8(%2);" + " adcx %1, %%r10;" + " movq %%r10, 16(%2);" + " adcx %1, %%r11;" + " movq %%r11, 24(%2);" /* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */ " mov $0, %%rax;" " cmovc %%rdx, %%rax;" " add %%rax, %%r8;" - " movq %%r8, 0(%0);" + " movq %%r8, 0(%2);" /* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */ " mov $38, %%rdx;" - " mulxq 96(%1), %%r8, %%r13;" - " xor %3, %3;" - " adoxq 64(%1), %%r8;" - " mulxq 104(%1), %%r9, %%rbx;" + " mulxq 96(%0), %%r8, %%r13;" + " xor %k1, %k1;" + " adoxq 64(%0), %%r8;" + " mulxq 104(%0), %%r9, %%rbx;" " adcx %%r13, %%r9;" - " adoxq 72(%1), %%r9;" - " mulxq 112(%1), %%r10, %%r13;" + " adoxq 72(%0), %%r9;" + " mulxq 112(%0), %%r10, %%r13;" " adcx %%rbx, %%r10;" - " adoxq 80(%1), %%r10;" - " mulxq 120(%1), %%r11, %%rax;" + " adoxq 80(%0), %%r10;" + " mulxq 120(%0), %%r11, %%rax;" " adcx %%r13, %%r11;" - " adoxq 88(%1), %%r11;" - " adcx %3, %%rax;" - " adox %3, %%rax;" + " adoxq 88(%0), %%r11;" + " adcx %1, %%rax;" + " adox %1, %%rax;" " imul %%rdx, %%rax;" /* Step 2: Fold the carry back into dst */ " add %%rax, %%r8;" - " adcx %3, %%r9;" - " movq %%r9, 40(%0);" - " adcx %3, %%r10;" - " movq %%r10, 48(%0);" - " adcx %3, %%r11;" - " movq %%r11, 56(%0);" + " adcx %1, %%r9;" + " movq %%r9, 40(%2);" + " adcx %1, %%r10;" + " movq %%r10, 48(%2);" + " adcx %1, %%r11;" + " movq %%r11, 56(%2);" /* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */ " mov $0, %%rax;" " cmovc %%rdx, %%rax;" " add %%rax, %%r8;" - " movq %%r8, 32(%0);" - : "+&r" (tmp), "+&r" (f1), "+&r" (out), "+&r" (f2) - : - : "%rax", "%rdx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "%r14", "memory", "cc" - ); + " movq %%r8, 32(%2);" + : "+&r"(f1), "+&r"(f2), "+&r"(tmp) + : "r"(out) + : "%rax", "%rbx", "%rdx", "%r8", "%r9", "%r10", "%r11", "%r13", + "%r14", "memory", "cc"); } -/* Computes the field multiplication of four-element f1 with value in f2 */ +/* Computes the field multiplication of four-element f1 with value in f2 + * Requires f2 to be smaller than 2^17 */ static inline void fmul_scalar(u64 *out, const u64 *f1, u64 f2) { register u64 f2_r asm("rdx") = f2; asm volatile( /* Compute the raw multiplication of f1*f2 */ - " mulxq 0(%2), %%r8, %%rcx;" /* f1[0]*f2 */ - " mulxq 8(%2), %%r9, %%rbx;" /* f1[1]*f2 */ + " mulxq 0(%2), %%r8, %%rcx;" /* f1[0]*f2 */ + " mulxq 8(%2), %%r9, %%rbx;" /* f1[1]*f2 */ " add %%rcx, %%r9;" " mov $0, %%rcx;" - " mulxq 16(%2), %%r10, %%r13;" /* f1[2]*f2 */ + " mulxq 16(%2), %%r10, %%r13;" /* f1[2]*f2 */ " adcx %%rbx, %%r10;" - " mulxq 24(%2), %%r11, %%rax;" /* f1[3]*f2 */ + " mulxq 24(%2), %%r11, %%rax;" /* f1[3]*f2 */ " adcx %%r13, %%r11;" " adcx %%rcx, %%rax;" @@ -406,17 +564,17 @@ static inline void fmul_scalar(u64 *out, const u64 *f1, u64 f2) " cmovc %%rdx, %%rax;" " add %%rax, %%r8;" " movq %%r8, 0(%1);" - : "+&r" (f2_r) - : "r" (out), "r" (f1) - : "%rax", "%rcx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "memory", "cc" - ); + : "+&r"(f2_r) + : "r"(out), "r"(f1) + : "%rax", "%rbx", "%rcx", "%r8", "%r9", "%r10", "%r11", "%r13", + "memory", "cc"); } /* Computes p1 <- bit ? p2 : p1 in constant time */ static inline void cswap2(u64 bit, const u64 *p1, const u64 *p2) { asm volatile( - /* Invert the polarity of bit to match cmov expectations */ + /* Transfer bit into CF flag */ " add $18446744073709551615, %0;" /* cswap p1[0], p2[0] */ @@ -490,10 +648,9 @@ static inline void cswap2(u64 bit, const u64 *p1, const u64 *p2) " cmovc %%r10, %%r9;" " movq %%r8, 56(%1);" " movq %%r9, 56(%2);" - : "+&r" (bit) - : "r" (p1), "r" (p2) - : "%r8", "%r9", "%r10", "memory", "cc" - ); + : "+&r"(bit) + : "r"(p1), "r"(p2) + : "%r8", "%r9", "%r10", "memory", "cc"); } /* Computes the square of a field element: out <- f * f @@ -504,18 +661,25 @@ static inline void fsqr(u64 *out, const u64 *f, u64 *tmp) /* Compute the raw multiplication: tmp <- f * f */ /* Step 1: Compute all partial products */ - " movq 0(%1), %%rdx;" /* f[0] */ - " mulxq 8(%1), %%r8, %%r14;" " xor %%r15, %%r15;" /* f[1]*f[0] */ - " mulxq 16(%1), %%r9, %%r10;" " adcx %%r14, %%r9;" /* f[2]*f[0] */ - " mulxq 24(%1), %%rax, %%rcx;" " adcx %%rax, %%r10;" /* f[3]*f[0] */ - " movq 24(%1), %%rdx;" /* f[3] */ - " mulxq 8(%1), %%r11, %%rbx;" " adcx %%rcx, %%r11;" /* f[1]*f[3] */ - " mulxq 16(%1), %%rax, %%r13;" " adcx %%rax, %%rbx;" /* f[2]*f[3] */ - " movq 8(%1), %%rdx;" " adcx %%r15, %%r13;" /* f1 */ - " mulxq 16(%1), %%rax, %%rcx;" " mov $0, %%r14;" /* f[2]*f[1] */ + " movq 0(%0), %%rdx;" /* f[0] */ + " mulxq 8(%0), %%r8, %%r14;" + " xor %%r15d, %%r15d;" /* f[1]*f[0] */ + " mulxq 16(%0), %%r9, %%r10;" + " adcx %%r14, %%r9;" /* f[2]*f[0] */ + " mulxq 24(%0), %%rax, %%rcx;" + " adcx %%rax, %%r10;" /* f[3]*f[0] */ + " movq 24(%0), %%rdx;" /* f[3] */ + " mulxq 8(%0), %%r11, %%rbx;" + " adcx %%rcx, %%r11;" /* f[1]*f[3] */ + " mulxq 16(%0), %%rax, %%r13;" + " adcx %%rax, %%rbx;" /* f[2]*f[3] */ + " movq 8(%0), %%rdx;" + " adcx %%r15, %%r13;" /* f1 */ + " mulxq 16(%0), %%rax, %%rcx;" + " mov $0, %%r14;" /* f[2]*f[1] */ /* Step 2: Compute two parallel carry chains */ - " xor %%r15, %%r15;" + " xor %%r15d, %%r15d;" " adox %%rax, %%r10;" " adcx %%r8, %%r8;" " adox %%rcx, %%r11;" @@ -530,39 +694,50 @@ static inline void fsqr(u64 *out, const u64 *f, u64 *tmp) " adcx %%r14, %%r14;" /* Step 3: Compute intermediate squares */ - " movq 0(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ - " movq %%rax, 0(%0);" - " add %%rcx, %%r8;" " movq %%r8, 8(%0);" - " movq 8(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ - " adcx %%rax, %%r9;" " movq %%r9, 16(%0);" - " adcx %%rcx, %%r10;" " movq %%r10, 24(%0);" - " movq 16(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ - " adcx %%rax, %%r11;" " movq %%r11, 32(%0);" - " adcx %%rcx, %%rbx;" " movq %%rbx, 40(%0);" - " movq 24(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ - " adcx %%rax, %%r13;" " movq %%r13, 48(%0);" - " adcx %%rcx, %%r14;" " movq %%r14, 56(%0);" + " movq 0(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ + " movq %%rax, 0(%1);" + " add %%rcx, %%r8;" + " movq %%r8, 8(%1);" + " movq 8(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ + " adcx %%rax, %%r9;" + " movq %%r9, 16(%1);" + " adcx %%rcx, %%r10;" + " movq %%r10, 24(%1);" + " movq 16(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ + " adcx %%rax, %%r11;" + " movq %%r11, 32(%1);" + " adcx %%rcx, %%rbx;" + " movq %%rbx, 40(%1);" + " movq 24(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ + " adcx %%rax, %%r13;" + " movq %%r13, 48(%1);" + " adcx %%rcx, %%r14;" + " movq %%r14, 56(%1);" /* Line up pointers */ - " mov %0, %1;" - " mov %2, %0;" + " mov %1, %0;" + " mov %2, %1;" /* Wrap the result back into the field */ /* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */ " mov $38, %%rdx;" - " mulxq 32(%1), %%r8, %%r13;" - " xor %%rcx, %%rcx;" - " adoxq 0(%1), %%r8;" - " mulxq 40(%1), %%r9, %%rbx;" + " mulxq 32(%0), %%r8, %%r13;" + " xor %%ecx, %%ecx;" + " adoxq 0(%0), %%r8;" + " mulxq 40(%0), %%r9, %%rbx;" " adcx %%r13, %%r9;" - " adoxq 8(%1), %%r9;" - " mulxq 48(%1), %%r10, %%r13;" + " adoxq 8(%0), %%r9;" + " mulxq 48(%0), %%r10, %%r13;" " adcx %%rbx, %%r10;" - " adoxq 16(%1), %%r10;" - " mulxq 56(%1), %%r11, %%rax;" + " adoxq 16(%0), %%r10;" + " mulxq 56(%0), %%r11, %%rax;" " adcx %%r13, %%r11;" - " adoxq 24(%1), %%r11;" + " adoxq 24(%0), %%r11;" " adcx %%rcx, %%rax;" " adox %%rcx, %%rax;" " imul %%rdx, %%rax;" @@ -570,43 +745,50 @@ static inline void fsqr(u64 *out, const u64 *f, u64 *tmp) /* Step 2: Fold the carry back into dst */ " add %%rax, %%r8;" " adcx %%rcx, %%r9;" - " movq %%r9, 8(%0);" + " movq %%r9, 8(%1);" " adcx %%rcx, %%r10;" - " movq %%r10, 16(%0);" + " movq %%r10, 16(%1);" " adcx %%rcx, %%r11;" - " movq %%r11, 24(%0);" + " movq %%r11, 24(%1);" /* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */ " mov $0, %%rax;" " cmovc %%rdx, %%rax;" " add %%rax, %%r8;" - " movq %%r8, 0(%0);" - : "+&r" (tmp), "+&r" (f), "+&r" (out) - : - : "%rax", "%rcx", "%rdx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "%r14", "%r15", "memory", "cc" - ); + " movq %%r8, 0(%1);" + : "+&r,&r"(f), "+&r,&r"(tmp) + : "r,m"(out) + : "%rax", "%rbx", "%rcx", "%rdx", "%r8", "%r9", "%r10", "%r11", + "%r13", "%r14", "%r15", "memory", "cc"); } /* Computes two field squarings: - * out[0] <- f[0] * f[0] - * out[1] <- f[1] * f[1] + * out[0] <- f[0] * f[0] + * out[1] <- f[1] * f[1] * Uses the 16-element buffer tmp for intermediate results */ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp) { asm volatile( /* Step 1: Compute all partial products */ - " movq 0(%1), %%rdx;" /* f[0] */ - " mulxq 8(%1), %%r8, %%r14;" " xor %%r15, %%r15;" /* f[1]*f[0] */ - " mulxq 16(%1), %%r9, %%r10;" " adcx %%r14, %%r9;" /* f[2]*f[0] */ - " mulxq 24(%1), %%rax, %%rcx;" " adcx %%rax, %%r10;" /* f[3]*f[0] */ - " movq 24(%1), %%rdx;" /* f[3] */ - " mulxq 8(%1), %%r11, %%rbx;" " adcx %%rcx, %%r11;" /* f[1]*f[3] */ - " mulxq 16(%1), %%rax, %%r13;" " adcx %%rax, %%rbx;" /* f[2]*f[3] */ - " movq 8(%1), %%rdx;" " adcx %%r15, %%r13;" /* f1 */ - " mulxq 16(%1), %%rax, %%rcx;" " mov $0, %%r14;" /* f[2]*f[1] */ + " movq 0(%0), %%rdx;" /* f[0] */ + " mulxq 8(%0), %%r8, %%r14;" + " xor %%r15d, %%r15d;" /* f[1]*f[0] */ + " mulxq 16(%0), %%r9, %%r10;" + " adcx %%r14, %%r9;" /* f[2]*f[0] */ + " mulxq 24(%0), %%rax, %%rcx;" + " adcx %%rax, %%r10;" /* f[3]*f[0] */ + " movq 24(%0), %%rdx;" /* f[3] */ + " mulxq 8(%0), %%r11, %%rbx;" + " adcx %%rcx, %%r11;" /* f[1]*f[3] */ + " mulxq 16(%0), %%rax, %%r13;" + " adcx %%rax, %%rbx;" /* f[2]*f[3] */ + " movq 8(%0), %%rdx;" + " adcx %%r15, %%r13;" /* f1 */ + " mulxq 16(%0), %%rax, %%rcx;" + " mov $0, %%r14;" /* f[2]*f[1] */ /* Step 2: Compute two parallel carry chains */ - " xor %%r15, %%r15;" + " xor %%r15d, %%r15d;" " adox %%rax, %%r10;" " adcx %%r8, %%r8;" " adox %%rcx, %%r11;" @@ -621,32 +803,50 @@ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp) " adcx %%r14, %%r14;" /* Step 3: Compute intermediate squares */ - " movq 0(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ - " movq %%rax, 0(%0);" - " add %%rcx, %%r8;" " movq %%r8, 8(%0);" - " movq 8(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ - " adcx %%rax, %%r9;" " movq %%r9, 16(%0);" - " adcx %%rcx, %%r10;" " movq %%r10, 24(%0);" - " movq 16(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ - " adcx %%rax, %%r11;" " movq %%r11, 32(%0);" - " adcx %%rcx, %%rbx;" " movq %%rbx, 40(%0);" - " movq 24(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ - " adcx %%rax, %%r13;" " movq %%r13, 48(%0);" - " adcx %%rcx, %%r14;" " movq %%r14, 56(%0);" + " movq 0(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ + " movq %%rax, 0(%1);" + " add %%rcx, %%r8;" + " movq %%r8, 8(%1);" + " movq 8(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ + " adcx %%rax, %%r9;" + " movq %%r9, 16(%1);" + " adcx %%rcx, %%r10;" + " movq %%r10, 24(%1);" + " movq 16(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ + " adcx %%rax, %%r11;" + " movq %%r11, 32(%1);" + " adcx %%rcx, %%rbx;" + " movq %%rbx, 40(%1);" + " movq 24(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ + " adcx %%rax, %%r13;" + " movq %%r13, 48(%1);" + " adcx %%rcx, %%r14;" + " movq %%r14, 56(%1);" /* Step 1: Compute all partial products */ - " movq 32(%1), %%rdx;" /* f[0] */ - " mulxq 40(%1), %%r8, %%r14;" " xor %%r15, %%r15;" /* f[1]*f[0] */ - " mulxq 48(%1), %%r9, %%r10;" " adcx %%r14, %%r9;" /* f[2]*f[0] */ - " mulxq 56(%1), %%rax, %%rcx;" " adcx %%rax, %%r10;" /* f[3]*f[0] */ - " movq 56(%1), %%rdx;" /* f[3] */ - " mulxq 40(%1), %%r11, %%rbx;" " adcx %%rcx, %%r11;" /* f[1]*f[3] */ - " mulxq 48(%1), %%rax, %%r13;" " adcx %%rax, %%rbx;" /* f[2]*f[3] */ - " movq 40(%1), %%rdx;" " adcx %%r15, %%r13;" /* f1 */ - " mulxq 48(%1), %%rax, %%rcx;" " mov $0, %%r14;" /* f[2]*f[1] */ + " movq 32(%0), %%rdx;" /* f[0] */ + " mulxq 40(%0), %%r8, %%r14;" + " xor %%r15d, %%r15d;" /* f[1]*f[0] */ + " mulxq 48(%0), %%r9, %%r10;" + " adcx %%r14, %%r9;" /* f[2]*f[0] */ + " mulxq 56(%0), %%rax, %%rcx;" + " adcx %%rax, %%r10;" /* f[3]*f[0] */ + " movq 56(%0), %%rdx;" /* f[3] */ + " mulxq 40(%0), %%r11, %%rbx;" + " adcx %%rcx, %%r11;" /* f[1]*f[3] */ + " mulxq 48(%0), %%rax, %%r13;" + " adcx %%rax, %%rbx;" /* f[2]*f[3] */ + " movq 40(%0), %%rdx;" + " adcx %%r15, %%r13;" /* f1 */ + " mulxq 48(%0), %%rax, %%rcx;" + " mov $0, %%r14;" /* f[2]*f[1] */ /* Step 2: Compute two parallel carry chains */ - " xor %%r15, %%r15;" + " xor %%r15d, %%r15d;" " adox %%rax, %%r10;" " adcx %%r8, %%r8;" " adox %%rcx, %%r11;" @@ -661,37 +861,48 @@ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp) " adcx %%r14, %%r14;" /* Step 3: Compute intermediate squares */ - " movq 32(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ - " movq %%rax, 64(%0);" - " add %%rcx, %%r8;" " movq %%r8, 72(%0);" - " movq 40(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ - " adcx %%rax, %%r9;" " movq %%r9, 80(%0);" - " adcx %%rcx, %%r10;" " movq %%r10, 88(%0);" - " movq 48(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ - " adcx %%rax, %%r11;" " movq %%r11, 96(%0);" - " adcx %%rcx, %%rbx;" " movq %%rbx, 104(%0);" - " movq 56(%1), %%rdx;" " mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ - " adcx %%rax, %%r13;" " movq %%r13, 112(%0);" - " adcx %%rcx, %%r14;" " movq %%r14, 120(%0);" + " movq 32(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[0]^2 */ + " movq %%rax, 64(%1);" + " add %%rcx, %%r8;" + " movq %%r8, 72(%1);" + " movq 40(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[1]^2 */ + " adcx %%rax, %%r9;" + " movq %%r9, 80(%1);" + " adcx %%rcx, %%r10;" + " movq %%r10, 88(%1);" + " movq 48(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[2]^2 */ + " adcx %%rax, %%r11;" + " movq %%r11, 96(%1);" + " adcx %%rcx, %%rbx;" + " movq %%rbx, 104(%1);" + " movq 56(%0), %%rdx;" + " mulx %%rdx, %%rax, %%rcx;" /* f[3]^2 */ + " adcx %%rax, %%r13;" + " movq %%r13, 112(%1);" + " adcx %%rcx, %%r14;" + " movq %%r14, 120(%1);" /* Line up pointers */ - " mov %0, %1;" - " mov %2, %0;" + " mov %1, %0;" + " mov %2, %1;" /* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */ " mov $38, %%rdx;" - " mulxq 32(%1), %%r8, %%r13;" - " xor %%rcx, %%rcx;" - " adoxq 0(%1), %%r8;" - " mulxq 40(%1), %%r9, %%rbx;" + " mulxq 32(%0), %%r8, %%r13;" + " xor %%ecx, %%ecx;" + " adoxq 0(%0), %%r8;" + " mulxq 40(%0), %%r9, %%rbx;" " adcx %%r13, %%r9;" - " adoxq 8(%1), %%r9;" - " mulxq 48(%1), %%r10, %%r13;" + " adoxq 8(%0), %%r9;" + " mulxq 48(%0), %%r10, %%r13;" " adcx %%rbx, %%r10;" - " adoxq 16(%1), %%r10;" - " mulxq 56(%1), %%r11, %%rax;" + " adoxq 16(%0), %%r10;" + " mulxq 56(%0), %%r11, %%rax;" " adcx %%r13, %%r11;" - " adoxq 24(%1), %%r11;" + " adoxq 24(%0), %%r11;" " adcx %%rcx, %%rax;" " adox %%rcx, %%rax;" " imul %%rdx, %%rax;" @@ -699,32 +910,32 @@ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp) /* Step 2: Fold the carry back into dst */ " add %%rax, %%r8;" " adcx %%rcx, %%r9;" - " movq %%r9, 8(%0);" + " movq %%r9, 8(%1);" " adcx %%rcx, %%r10;" - " movq %%r10, 16(%0);" + " movq %%r10, 16(%1);" " adcx %%rcx, %%r11;" - " movq %%r11, 24(%0);" + " movq %%r11, 24(%1);" /* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */ " mov $0, %%rax;" " cmovc %%rdx, %%rax;" " add %%rax, %%r8;" - " movq %%r8, 0(%0);" + " movq %%r8, 0(%1);" /* Step 1: Compute dst + carry == tmp_hi * 38 + tmp_lo */ " mov $38, %%rdx;" - " mulxq 96(%1), %%r8, %%r13;" - " xor %%rcx, %%rcx;" - " adoxq 64(%1), %%r8;" - " mulxq 104(%1), %%r9, %%rbx;" + " mulxq 96(%0), %%r8, %%r13;" + " xor %%ecx, %%ecx;" + " adoxq 64(%0), %%r8;" + " mulxq 104(%0), %%r9, %%rbx;" " adcx %%r13, %%r9;" - " adoxq 72(%1), %%r9;" - " mulxq 112(%1), %%r10, %%r13;" + " adoxq 72(%0), %%r9;" + " mulxq 112(%0), %%r10, %%r13;" " adcx %%rbx, %%r10;" - " adoxq 80(%1), %%r10;" - " mulxq 120(%1), %%r11, %%rax;" + " adoxq 80(%0), %%r10;" + " mulxq 120(%0), %%r11, %%rax;" " adcx %%r13, %%r11;" - " adoxq 88(%1), %%r11;" + " adoxq 88(%0), %%r11;" " adcx %%rcx, %%rax;" " adox %%rcx, %%rax;" " imul %%rdx, %%rax;" @@ -732,21 +943,21 @@ static inline void fsqr2(u64 *out, const u64 *f, u64 *tmp) /* Step 2: Fold the carry back into dst */ " add %%rax, %%r8;" " adcx %%rcx, %%r9;" - " movq %%r9, 40(%0);" + " movq %%r9, 40(%1);" " adcx %%rcx, %%r10;" - " movq %%r10, 48(%0);" + " movq %%r10, 48(%1);" " adcx %%rcx, %%r11;" - " movq %%r11, 56(%0);" + " movq %%r11, 56(%1);" /* Step 3: Fold the carry bit back in; guaranteed not to carry at this point */ " mov $0, %%rax;" " cmovc %%rdx, %%rax;" " add %%rax, %%r8;" - " movq %%r8, 32(%0);" - : "+&r" (tmp), "+&r" (f), "+&r" (out) - : - : "%rax", "%rcx", "%rdx", "%r8", "%r9", "%r10", "%r11", "%rbx", "%r13", "%r14", "%r15", "memory", "cc" - ); + " movq %%r8, 32(%1);" + : "+&r,&r"(f), "+&r,&r"(tmp) + : "r,m"(out) + : "%rax", "%rbx", "%rcx", "%rdx", "%r8", "%r9", "%r10", "%r11", + "%r13", "%r14", "%r15", "memory", "cc"); } static void point_add_and_double(u64 *q, u64 *p01_tmp1, u64 *tmp2) diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c index b8c2390b0a35..062490f1b8a7 100644 --- a/drivers/net/wireguard/device.c +++ b/drivers/net/wireguard/device.c @@ -19,6 +19,7 @@ #include <linux/if_arp.h> #include <linux/icmp.h> #include <linux/suspend.h> +#include <net/dst_metadata.h> #include <net/icmp.h> #include <net/rtnetlink.h> #include <net/ip_tunnels.h> @@ -106,6 +107,7 @@ static int wg_stop(struct net_device *dev) { struct wg_device *wg = netdev_priv(dev); struct wg_peer *peer; + struct sk_buff *skb; mutex_lock(&wg->device_update_lock); list_for_each_entry(peer, &wg->peer_list, peer_list) { @@ -116,7 +118,9 @@ static int wg_stop(struct net_device *dev) wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake); } mutex_unlock(&wg->device_update_lock); - skb_queue_purge(&wg->incoming_handshakes); + while ((skb = ptr_ring_consume(&wg->handshake_queue.ring)) != NULL) + kfree_skb(skb); + atomic_set(&wg->handshake_queue_len, 0); wg_socket_reinit(wg, NULL, NULL); return 0; } @@ -157,7 +161,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev) goto err_peer; } - mtu = skb_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu; + mtu = skb_valid_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu; __skb_queue_head_init(&packets); if (!skb_is_gso(skb)) { @@ -243,14 +247,13 @@ static void wg_destruct(struct net_device *dev) destroy_workqueue(wg->handshake_receive_wq); destroy_workqueue(wg->handshake_send_wq); destroy_workqueue(wg->packet_crypt_wq); - wg_packet_queue_free(&wg->decrypt_queue); - wg_packet_queue_free(&wg->encrypt_queue); + wg_packet_queue_free(&wg->handshake_queue, true); + wg_packet_queue_free(&wg->decrypt_queue, false); + wg_packet_queue_free(&wg->encrypt_queue, false); rcu_barrier(); /* Wait for all the peers to be actually freed. */ wg_ratelimiter_uninit(); memzero_explicit(&wg->static_identity, sizeof(wg->static_identity)); - skb_queue_purge(&wg->incoming_handshakes); free_percpu(dev->tstats); - free_percpu(wg->incoming_handshakes_worker); kvfree(wg->index_hashtable); kvfree(wg->peer_hashtable); mutex_unlock(&wg->device_update_lock); @@ -312,7 +315,6 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, init_rwsem(&wg->static_identity.lock); mutex_init(&wg->socket_update_lock); mutex_init(&wg->device_update_lock); - skb_queue_head_init(&wg->incoming_handshakes); wg_allowedips_init(&wg->peer_allowedips); wg_cookie_checker_init(&wg->cookie_checker, wg); INIT_LIST_HEAD(&wg->peer_list); @@ -330,16 +332,10 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, if (!dev->tstats) goto err_free_index_hashtable; - wg->incoming_handshakes_worker = - wg_packet_percpu_multicore_worker_alloc( - wg_packet_handshake_receive_worker, wg); - if (!wg->incoming_handshakes_worker) - goto err_free_tstats; - wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name); if (!wg->handshake_receive_wq) - goto err_free_incoming_handshakes; + goto err_free_tstats; wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); @@ -361,10 +357,15 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, if (ret < 0) goto err_free_encrypt_queue; - ret = wg_ratelimiter_init(); + ret = wg_packet_queue_init(&wg->handshake_queue, wg_packet_handshake_receive_worker, + MAX_QUEUED_INCOMING_HANDSHAKES); if (ret < 0) goto err_free_decrypt_queue; + ret = wg_ratelimiter_init(); + if (ret < 0) + goto err_free_handshake_queue; + ret = register_netdevice(dev); if (ret < 0) goto err_uninit_ratelimiter; @@ -381,18 +382,18 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, err_uninit_ratelimiter: wg_ratelimiter_uninit(); +err_free_handshake_queue: + wg_packet_queue_free(&wg->handshake_queue, false); err_free_decrypt_queue: - wg_packet_queue_free(&wg->decrypt_queue); + wg_packet_queue_free(&wg->decrypt_queue, false); err_free_encrypt_queue: - wg_packet_queue_free(&wg->encrypt_queue); + wg_packet_queue_free(&wg->encrypt_queue, false); err_destroy_packet_crypt: destroy_workqueue(wg->packet_crypt_wq); err_destroy_handshake_send: destroy_workqueue(wg->handshake_send_wq); err_destroy_handshake_receive: destroy_workqueue(wg->handshake_receive_wq); -err_free_incoming_handshakes: - free_percpu(wg->incoming_handshakes_worker); err_free_tstats: free_percpu(dev->tstats); err_free_index_hashtable: @@ -412,6 +413,7 @@ static struct rtnl_link_ops link_ops __read_mostly = { static void wg_netns_pre_exit(struct net *net) { struct wg_device *wg; + struct wg_peer *peer; rtnl_lock(); list_for_each_entry(wg, &device_list, device_list) { @@ -421,6 +423,8 @@ static void wg_netns_pre_exit(struct net *net) mutex_lock(&wg->device_update_lock); rcu_assign_pointer(wg->creating_net, NULL); wg_socket_reinit(wg, NULL, NULL); + list_for_each_entry(peer, &wg->peer_list, peer_list) + wg_socket_clear_peer_endpoint_src(peer); mutex_unlock(&wg->device_update_lock); } } diff --git a/drivers/net/wireguard/device.h b/drivers/net/wireguard/device.h index 854bc3d97150..43c7cebbf50b 100644 --- a/drivers/net/wireguard/device.h +++ b/drivers/net/wireguard/device.h @@ -39,21 +39,18 @@ struct prev_queue { struct wg_device { struct net_device *dev; - struct crypt_queue encrypt_queue, decrypt_queue; + struct crypt_queue encrypt_queue, decrypt_queue, handshake_queue; struct sock __rcu *sock4, *sock6; struct net __rcu *creating_net; struct noise_static_identity static_identity; - struct workqueue_struct *handshake_receive_wq, *handshake_send_wq; - struct workqueue_struct *packet_crypt_wq; - struct sk_buff_head incoming_handshakes; - int incoming_handshake_cpu; - struct multicore_worker __percpu *incoming_handshakes_worker; + struct workqueue_struct *packet_crypt_wq,*handshake_receive_wq, *handshake_send_wq; struct cookie_checker cookie_checker; struct pubkey_hashtable *peer_hashtable; struct index_hashtable *index_hashtable; struct allowedips peer_allowedips; struct mutex device_update_lock, socket_update_lock; struct list_head device_list, peer_list; + atomic_t handshake_queue_len; unsigned int num_peers, device_update_gen; u32 fwmark; u16 incoming_port; diff --git a/drivers/net/wireguard/main.c b/drivers/net/wireguard/main.c index 9b8bbe27999e..a6714ce13a65 100644 --- a/drivers/net/wireguard/main.c +++ b/drivers/net/wireguard/main.c @@ -18,7 +18,7 @@ #include <linux/genetlink.h> #include <net/rtnetlink.h> -static int __init mod_init(void) +static int __init wg_mod_init(void) { int ret; @@ -66,7 +66,7 @@ err_allowedips: return ret; } -static void __exit mod_exit(void) +static void __exit wg_mod_exit(void) { wg_genetlink_uninit(); wg_device_uninit(); @@ -74,8 +74,8 @@ static void __exit mod_exit(void) wg_allowedips_slab_uninit(); } -module_init(mod_init); -module_exit(mod_exit); +module_init(wg_mod_init); +module_exit(wg_mod_exit); MODULE_LICENSE("GPL v2"); MODULE_DESCRIPTION("WireGuard secure network tunnel"); MODULE_AUTHOR("Jason A. Donenfeld <Jason@zx2c4.com>"); diff --git a/drivers/net/wireguard/queueing.c b/drivers/net/wireguard/queueing.c index 48e7b982a307..8084e7408c0a 100644 --- a/drivers/net/wireguard/queueing.c +++ b/drivers/net/wireguard/queueing.c @@ -4,6 +4,7 @@ */ #include "queueing.h" +#include <linux/skb_array.h> struct multicore_worker __percpu * wg_packet_percpu_multicore_worker_alloc(work_func_t function, void *ptr) @@ -38,11 +39,11 @@ int wg_packet_queue_init(struct crypt_queue *queue, work_func_t function, return 0; } -void wg_packet_queue_free(struct crypt_queue *queue) +void wg_packet_queue_free(struct crypt_queue *queue, bool purge) { free_percpu(queue->worker); - WARN_ON(!__ptr_ring_empty(&queue->ring)); - ptr_ring_cleanup(&queue->ring, NULL); + WARN_ON(!purge && !__ptr_ring_empty(&queue->ring)); + ptr_ring_cleanup(&queue->ring, purge ? __skb_array_destroy_skb : NULL); } #define NEXT(skb) ((skb)->prev) diff --git a/drivers/net/wireguard/queueing.h b/drivers/net/wireguard/queueing.h index b6ccf650c738..03850c43ebaf 100644 --- a/drivers/net/wireguard/queueing.h +++ b/drivers/net/wireguard/queueing.h @@ -23,7 +23,7 @@ struct sk_buff; /* queueing.c APIs: */ int wg_packet_queue_init(struct crypt_queue *queue, work_func_t function, unsigned int len); -void wg_packet_queue_free(struct crypt_queue *queue); +void wg_packet_queue_free(struct crypt_queue *queue, bool purge); struct multicore_worker __percpu * wg_packet_percpu_multicore_worker_alloc(work_func_t function, void *ptr); diff --git a/drivers/net/wireguard/ratelimiter.c b/drivers/net/wireguard/ratelimiter.c index e33ec72a9642..ecee41f528a5 100644 --- a/drivers/net/wireguard/ratelimiter.c +++ b/drivers/net/wireguard/ratelimiter.c @@ -188,12 +188,12 @@ int wg_ratelimiter_init(void) (1U << 14) / sizeof(struct hlist_head))); max_entries = table_size * 8; - table_v4 = kvzalloc(table_size * sizeof(*table_v4), GFP_KERNEL); + table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL); if (unlikely(!table_v4)) goto err_kmemcache; #if IS_ENABLED(CONFIG_IPV6) - table_v6 = kvzalloc(table_size * sizeof(*table_v6), GFP_KERNEL); + table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL); if (unlikely(!table_v6)) { kvfree(table_v4); goto err_kmemcache; diff --git a/drivers/net/wireguard/receive.c b/drivers/net/wireguard/receive.c index 07147ff0522d..214889edb48e 100644 --- a/drivers/net/wireguard/receive.c +++ b/drivers/net/wireguard/receive.c @@ -117,8 +117,8 @@ static void wg_receive_handshake_packet(struct wg_device *wg, return; } - under_load = skb_queue_len(&wg->incoming_handshakes) >= - MAX_QUEUED_INCOMING_HANDSHAKES / 8; + under_load = atomic_read(&wg->handshake_queue_len) >= + MAX_QUEUED_INCOMING_HANDSHAKES / 8; if (under_load) { last_under_load = ktime_get_coarse_boottime_ns(); } else if (last_under_load) { @@ -213,13 +213,14 @@ static void wg_receive_handshake_packet(struct wg_device *wg, void wg_packet_handshake_receive_worker(struct work_struct *work) { - struct wg_device *wg = container_of(work, struct multicore_worker, - work)->ptr; + struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; + struct wg_device *wg = container_of(queue, struct wg_device, handshake_queue); struct sk_buff *skb; - while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) { + while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) { wg_receive_handshake_packet(wg, skb); dev_kfree_skb(skb); + atomic_dec(&wg->handshake_queue_len); cond_resched(); } } @@ -562,22 +563,28 @@ void wg_packet_receive(struct wg_device *wg, struct sk_buff *skb) case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): { - int cpu; - - if (skb_queue_len(&wg->incoming_handshakes) > - MAX_QUEUED_INCOMING_HANDSHAKES || - unlikely(!rng_is_initialized())) { + int cpu, ret = -EBUSY; + + if (unlikely(!rng_is_initialized())) + goto drop; + if (atomic_read(&wg->handshake_queue_len) > MAX_QUEUED_INCOMING_HANDSHAKES / 2) { + if (spin_trylock_bh(&wg->handshake_queue.ring.producer_lock)) { + ret = __ptr_ring_produce(&wg->handshake_queue.ring, skb); + spin_unlock_bh(&wg->handshake_queue.ring.producer_lock); + } + } else + ret = ptr_ring_produce_bh(&wg->handshake_queue.ring, skb); + if (ret) { + drop: net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n", wg->dev->name, skb); goto err; } - skb_queue_tail(&wg->incoming_handshakes, skb); - /* Queues up a call to packet_process_queued_handshake_ - * packets(skb): - */ - cpu = wg_cpumask_next_online(&wg->incoming_handshake_cpu); + atomic_inc(&wg->handshake_queue_len); + cpu = wg_cpumask_next_online(&wg->handshake_queue.last_cpu); + /* Queues up a call to packet_process_queued_handshake_packets(skb): */ queue_work_on(cpu, wg->handshake_receive_wq, - &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work); + &per_cpu_ptr(wg->handshake_queue.worker, cpu)->work); break; } case cpu_to_le32(MESSAGE_DATA): diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c index 04739763e303..0414d7a6ce74 100644 --- a/drivers/net/wireguard/socket.c +++ b/drivers/net/wireguard/socket.c @@ -49,7 +49,7 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, rt = dst_cache_get_ip4(cache, &fl.saddr); if (!rt) { - security_sk_classify_flow(sock, flowi4_to_flowi(&fl)); + security_sk_classify_flow(sock, flowi4_to_flowi_common(&fl)); if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, fl.saddr, RT_SCOPE_HOST))) { endpoint->src4.s_addr = 0; @@ -129,7 +129,7 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, dst = dst_cache_get_ip6(cache, &fl.saddr); if (!dst) { - security_sk_classify_flow(sock, flowi6_to_flowi(&fl)); + security_sk_classify_flow(sock, flowi6_to_flowi_common(&fl)); if (unlikely(!ipv6_addr_any(&fl.saddr) && !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) { endpoint->src6 = fl.saddr = in6addr_any; @@ -160,6 +160,7 @@ out: rcu_read_unlock_bh(); return ret; #else + kfree_skb(skb); return -EAFNOSUPPORT; #endif } @@ -241,7 +242,7 @@ int wg_socket_endpoint_from_skb(struct endpoint *endpoint, endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr; endpoint->src4.s_addr = ip_hdr(skb)->daddr; endpoint->src_if4 = skb->skb_iif; - } else if (skb->protocol == htons(ETH_P_IPV6)) { + } else if (IS_ENABLED(CONFIG_IPV6) && skb->protocol == htons(ETH_P_IPV6)) { endpoint->addr6.sin6_family = AF_INET6; endpoint->addr6.sin6_port = udp_hdr(skb)->source; endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr; @@ -284,7 +285,7 @@ void wg_socket_set_peer_endpoint(struct wg_peer *peer, peer->endpoint.addr4 = endpoint->addr4; peer->endpoint.src4 = endpoint->src4; peer->endpoint.src_if4 = endpoint->src_if4; - } else if (endpoint->addr.sa_family == AF_INET6) { + } else if (IS_ENABLED(CONFIG_IPV6) && endpoint->addr.sa_family == AF_INET6) { peer->endpoint.addr6 = endpoint->addr6; peer->endpoint.src6 = endpoint->src6; } else { @@ -308,7 +309,7 @@ void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer) { write_lock_bh(&peer->endpoint_lock); memset(&peer->endpoint.src6, 0, sizeof(peer->endpoint.src6)); - dst_cache_reset(&peer->endpoint_cache); + dst_cache_reset_now(&peer->endpoint_cache); write_unlock_bh(&peer->endpoint_lock); } diff --git a/drivers/net/wireguard/version.h b/drivers/net/wireguard/version.h index 35ef576765c9..c7f9028f0177 100644 --- a/drivers/net/wireguard/version.h +++ b/drivers/net/wireguard/version.h @@ -1 +1,3 @@ -#define WIREGUARD_VERSION "1.0.20210606" +#ifndef WIREGUARD_VERSION +#define WIREGUARD_VERSION "1.0.20220627" +#endif diff --git a/drivers/power/reset/msm-poweroff.c b/drivers/power/reset/msm-poweroff.c index f267ec9d80c0..50b5d76d7df0 100644 --- a/drivers/power/reset/msm-poweroff.c +++ b/drivers/power/reset/msm-poweroff.c @@ -84,7 +84,7 @@ static struct notifier_block panic_blk = { #endif static int dload_type = SCM_DLOAD_FULLDUMP; -static int download_mode; +static int download_mode = 1; static struct kobject dload_kobj; static void *dload_mode_addr, *dload_type_addr; static bool dload_mode_enabled; diff --git a/drivers/tty/pty.c b/drivers/tty/pty.c index 8ee146b14aae..6a024b5cfb0f 100644 --- a/drivers/tty/pty.c +++ b/drivers/tty/pty.c @@ -106,21 +106,11 @@ static void pty_unthrottle(struct tty_struct *tty) static int pty_write(struct tty_struct *tty, const unsigned char *buf, int c) { struct tty_struct *to = tty->link; - unsigned long flags; - if (tty->stopped) + if (tty->stopped || !c) return 0; - if (c > 0) { - spin_lock_irqsave(&to->port->lock, flags); - /* Stuff the data into the input queue of the other end */ - c = tty_insert_flip_string(to->port, buf, c); - spin_unlock_irqrestore(&to->port->lock, flags); - /* And shovel */ - if (c) - tty_flip_buffer_push(to->port); - } - return c; + return tty_insert_flip_string_and_push_buffer(to->port, buf, c); } /** diff --git a/drivers/tty/tty_buffer.c b/drivers/tty/tty_buffer.c index 8f1f668c4532..49c6cbab984f 100644 --- a/drivers/tty/tty_buffer.c +++ b/drivers/tty/tty_buffer.c @@ -168,7 +168,8 @@ static struct tty_buffer *tty_buffer_alloc(struct tty_port *port, size_t size) have queued and recycle that ? */ if (atomic_read(&port->buf.mem_used) > port->buf.mem_limit) return NULL; - p = kmalloc(sizeof(struct tty_buffer) + 2 * size, GFP_ATOMIC); + p = kmalloc(sizeof(struct tty_buffer) + 2 * size, + GFP_ATOMIC | __GFP_NOWARN); if (p == NULL) return NULL; @@ -389,6 +390,15 @@ int __tty_insert_flip_char(struct tty_port *port, unsigned char ch, char flag) } EXPORT_SYMBOL(__tty_insert_flip_char); +static inline void tty_flip_buffer_commit(struct tty_buffer *tail) +{ + /* + * Paired w/ acquire in flush_to_ldisc(); ensures flush_to_ldisc() sees + * buffer data. + */ + smp_store_release(&tail->commit, tail->used); +} + /** * tty_schedule_flip - push characters to ldisc * @port: tty port to push from @@ -402,10 +412,7 @@ void tty_schedule_flip(struct tty_port *port) { struct tty_bufhead *buf = &port->buf; - /* paired w/ acquire in flush_to_ldisc(); ensures - * flush_to_ldisc() sees buffer data. - */ - smp_store_release(&buf->tail->commit, buf->tail->used); + tty_flip_buffer_commit(buf->tail); queue_kthread_work(&port->worker, &buf->work); } EXPORT_SYMBOL(tty_schedule_flip); @@ -549,6 +556,37 @@ void tty_flip_buffer_push(struct tty_port *port) EXPORT_SYMBOL(tty_flip_buffer_push); /** + * tty_insert_flip_string_and_push_buffer - add characters to the tty buffer and + * push + * @port: tty port + * @chars: characters + * @size: size + * + * The function combines tty_insert_flip_string() and tty_flip_buffer_push() + * with the exception of properly holding the @port->lock. + * + * To be used only internally (by pty currently). + * + * Returns: the number added. + */ +int tty_insert_flip_string_and_push_buffer(struct tty_port *port, + const unsigned char *chars, size_t size) +{ + struct tty_bufhead *buf = &port->buf; + unsigned long flags; + + spin_lock_irqsave(&port->lock, flags); + size = tty_insert_flip_string(port, chars, size); + if (size) + tty_flip_buffer_commit(buf->tail); + spin_unlock_irqrestore(&port->lock, flags); + + queue_kthread_work(&port->worker, &buf->work); + + return size; +} + +/** * tty_buffer_init - prepare a tty buffer structure * @tty: tty to initialise * diff --git a/drivers/usb/core/hub.c b/drivers/usb/core/hub.c index afd509564e2e..143c33cbc3d8 100644 --- a/drivers/usb/core/hub.c +++ b/drivers/usb/core/hub.c @@ -2257,9 +2257,8 @@ static int usb_enumerate_device_otg(struct usb_device *udev) * usb_enumerate_device - Read device configs/intfs/otg (usbcore-internal) * @udev: newly addressed device (in ADDRESS state) * - * This is only called by usb_new_device() and usb_authorize_device() - * and FIXME -- all comments that apply to them apply here wrt to - * environment. + * This is only called by usb_new_device() -- all comments that apply there + * apply here wrt to environment. * * If the device is WUSB and not authorized, we don't attempt to read * the string descriptors, as they will be errored out by the device @@ -5621,6 +5620,11 @@ re_enumerate_no_bos: * the reset is over (using their post_reset method). * * Return: The same as for usb_reset_and_verify_device(). + * However, if a reset is already in progress (for instance, if a + * driver doesn't have pre_ or post_reset() callbacks, and while + * being unbound or re-bound during the ongoing reset its disconnect() + * or probe() routine tries to perform a second, nested reset), the + * routine returns -EINPROGRESS. * * Note: * The caller must own the device lock. For example, it's safe to use @@ -5654,6 +5658,10 @@ int usb_reset_device(struct usb_device *udev) return -EISDIR; } + if (udev->reset_in_progress) + return -EINPROGRESS; + udev->reset_in_progress = 1; + port_dev = hub->ports[udev->portnum - 1]; /* @@ -5718,6 +5726,7 @@ int usb_reset_device(struct usb_device *udev) usb_autosuspend_device(udev); memalloc_noio_restore(noio_flag); + udev->reset_in_progress = 0; return ret; } EXPORT_SYMBOL_GPL(usb_reset_device); diff --git a/drivers/usb/core/sysfs.c b/drivers/usb/core/sysfs.c index 6dc0f4e25cf3..c7c32326d50c 100644 --- a/drivers/usb/core/sysfs.c +++ b/drivers/usb/core/sysfs.c @@ -825,7 +825,6 @@ read_descriptors(struct file *filp, struct kobject *kobj, * Following that are the raw descriptor entries for all the * configurations (config plus subsidiary descriptors). */ - usb_lock_device(udev); for (cfgno = -1; cfgno < udev->descriptor.bNumConfigurations && nleft > 0; ++cfgno) { if (cfgno < 0) { @@ -846,7 +845,6 @@ read_descriptors(struct file *filp, struct kobject *kobj, off -= srclen; } } - usb_unlock_device(udev); return count - nleft; } |
