Compare commits
43 Commits
will/statu
...
jknodt/io-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f9021e287 | ||
|
|
9fd01334cf | ||
|
|
c5cb642376 | ||
|
|
87fc7aa6b0 | ||
|
|
34c5677308 | ||
|
|
6d10acc6dd | ||
|
|
61e3d919ef | ||
|
|
cc5c696834 | ||
|
|
33c0997447 | ||
|
|
08af39ae24 | ||
|
|
43ccdc8879 | ||
|
|
e0abf1b3dd | ||
|
|
7642d9fafd | ||
|
|
bf20f000fd | ||
|
|
4679379ebe | ||
|
|
f6b49d3e0e | ||
|
|
ae2f24ec4e | ||
|
|
ba49da429a | ||
|
|
83742afabf | ||
|
|
b668e5d185 | ||
|
|
5410042dcd | ||
|
|
75efd794a3 | ||
|
|
36a4741bc5 | ||
|
|
208e6eb0db | ||
|
|
fec66b4100 | ||
|
|
eff65381f2 | ||
|
|
a8df9fa7cc | ||
|
|
78fd2b7880 | ||
|
|
001dec84de | ||
|
|
a8a7208dbd | ||
|
|
f254f779b5 | ||
|
|
3d91c5b369 | ||
|
|
2a2ed7cd17 | ||
|
|
55c1ce00be | ||
|
|
4013c0edbb | ||
|
|
783d2d4327 | ||
|
|
71f35bda1a | ||
|
|
b83ac004f1 | ||
|
|
59512181b5 | ||
|
|
1ee40d1670 | ||
|
|
bbccf68a76 | ||
|
|
d7a7e2d17d | ||
|
|
f26c0fcbd5 |
@@ -18,6 +18,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/uring"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/filter"
|
||||
@@ -160,6 +161,17 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper {
|
||||
filterFlags: filter.LogAccepts | filter.LogDrops,
|
||||
}
|
||||
|
||||
if uring.Available() {
|
||||
uringTun, err := uring.NewTUN(tdev)
|
||||
name, _ := tdev.Name()
|
||||
if err != nil {
|
||||
logf("not using io_uring for TUN %v: %v", name, err)
|
||||
} else {
|
||||
logf("using uring for TUN %v", name)
|
||||
tdev = uringTun
|
||||
}
|
||||
}
|
||||
|
||||
go tun.poll()
|
||||
go tun.pumpEvents()
|
||||
// The buffer starts out consumed.
|
||||
@@ -519,7 +531,54 @@ func (t *Wrapper) Write(buf []byte, offset int) (int, error) {
|
||||
}
|
||||
|
||||
t.noteActivity()
|
||||
return t.tdev.Write(buf, offset)
|
||||
return t.write(buf, offset)
|
||||
}
|
||||
|
||||
func (t *Wrapper) write(buf []byte, offset int) (int, error) {
|
||||
if t.ring == nil {
|
||||
return t.tdev.Write(buf, offset)
|
||||
}
|
||||
|
||||
// below copied from wireguard-go NativeTUN.Write
|
||||
|
||||
// reserve space for header
|
||||
buf = buf[offset-4:]
|
||||
|
||||
// add packet information header
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
if buf[4]>>4 == ipv6.Version {
|
||||
buf[2] = 0x86
|
||||
buf[3] = 0xdd
|
||||
} else {
|
||||
buf[2] = 0x08
|
||||
buf[3] = 0x00
|
||||
}
|
||||
|
||||
n, err := t.ring.Write(buf)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (t *Wrapper) read(buf []byte, offset int) (n int, err error) {
|
||||
// TODO: upstream has graceful shutdown error handling here.
|
||||
buff := buf[offset-4:]
|
||||
if uring.URingAvailable() {
|
||||
n, err = t.ring.Read(buff[:])
|
||||
} else {
|
||||
n, err = t.tdev.(*wgtun.NativeTun).File().Read(buff[:])
|
||||
}
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
if n < 4 {
|
||||
n = 0
|
||||
} else {
|
||||
n -= 4
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Wrapper) GetFilter() *filter.Filter {
|
||||
|
||||
2
net/uring/.gitignore
vendored
Normal file
2
net/uring/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
liburing/
|
||||
*.so
|
||||
62
net/uring/all.go
Normal file
62
net/uring/all.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package uring
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// This file contains code shared across all platforms.
|
||||
|
||||
// Available reports whether io_uring is available on this machine.
|
||||
// If Available reports false, no other package uring APIs should be called.
|
||||
func Available() bool {
|
||||
return runtime.GOOS == "linux" && *useIOURing
|
||||
}
|
||||
|
||||
var useIOURing = flag.Bool("use-io-uring", true, "attempt to use io_uring if available")
|
||||
|
||||
// NotSupportedError indicates an operation was attempted when io_uring is not supported.
|
||||
var NotSupportedError = errors.New("io_uring not supported")
|
||||
|
||||
// DisabledError indicates that io_uring was explicitly disabled.
|
||||
var DisabledError = errors.New("io_uring disabled")
|
||||
|
||||
type IORingOp = int
|
||||
|
||||
//https://unixism.net/loti/tutorial/probe_liburing.html
|
||||
const (
|
||||
IORING_OP_NOP IORingOp = iota
|
||||
IORING_OP_READV
|
||||
IORING_OP_WRITEV
|
||||
IORING_OP_FSYNC
|
||||
IORING_OP_READ_FIXED
|
||||
IORING_OP_WRITE_FIXED
|
||||
IORING_OP_POLL_ADD
|
||||
IORING_OP_POLL_REMOVE
|
||||
IORING_OP_SYNC_FILE_RANGE
|
||||
IORING_OP_SENDMSG
|
||||
IORING_OP_RECVMSG
|
||||
IORING_OP_TIMEOUT
|
||||
IORING_OP_TIMEOUT_REMOVE
|
||||
IORING_OP_ACCEPT
|
||||
IORING_OP_ASYNC_CANCEL
|
||||
IORING_OP_LINK_TIMEOUT
|
||||
IORING_OP_CONNECT
|
||||
IORING_OP_FALLOCATE
|
||||
IORING_OP_OPENAT
|
||||
IORING_OP_CLOSE
|
||||
IORING_OP_FILES_UPDATE
|
||||
IORING_OP_STATX
|
||||
IORING_OP_READ
|
||||
IORING_OP_WRITE
|
||||
IORING_OP_FADVISE
|
||||
IORING_OP_MADVISE
|
||||
IORING_OP_SEND
|
||||
IORING_OP_RECV
|
||||
IORING_OP_OPENAT2
|
||||
IORING_OP_EPOLL_CTL
|
||||
IORING_OP_SPLICE
|
||||
IORING_OP_PROVIDE_BUFFERS
|
||||
IORING_OP_REMOVE_BUFFERS
|
||||
)
|
||||
55
net/uring/file_test.go
Normal file
55
net/uring/file_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package uring
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
qt "github.com/frankban/quicktest"
|
||||
)
|
||||
|
||||
func TestFileRead(t *testing.T) {
|
||||
if !Available() {
|
||||
t.Skip("io_uring not available")
|
||||
}
|
||||
c := qt.New(t)
|
||||
|
||||
const path = "testdata/voltaire.txt"
|
||||
want, err := os.ReadFile(path)
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
f, err := os.Open(path)
|
||||
c.Assert(err, qt.IsNil)
|
||||
t.Cleanup(func() { f.Close() })
|
||||
|
||||
uf, err := newFile(f)
|
||||
if err != nil {
|
||||
t.Skipf("io_uring not available: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { uf.Close() })
|
||||
buf := make([]byte, len(want)+128)
|
||||
n, err := uf.Read(buf)
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Assert(buf[:n], qt.DeepEquals, want)
|
||||
}
|
||||
|
||||
func TestFileWrite(t *testing.T) {
|
||||
if !Available() {
|
||||
t.Skip("io_uring not available")
|
||||
}
|
||||
c := qt.New(t)
|
||||
tmpFile, err := ioutil.TempFile(".", "uring-test")
|
||||
c.Assert(err, qt.IsNil)
|
||||
t.Cleanup(func() {
|
||||
os.Remove(tmpFile.Name())
|
||||
})
|
||||
f, err := newFile(tmpFile)
|
||||
c.Assert(err, qt.IsNil)
|
||||
content := []byte("a test string to check writing works 😀 with non-unicode input")
|
||||
n, err := f.Write(content)
|
||||
if n != len(content) {
|
||||
t.Errorf("mismatch between written len and content len: want %d, got %d", len(content), n)
|
||||
}
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Assert(f.Close(), qt.IsNil)
|
||||
}
|
||||
202
net/uring/io_uring.c
Normal file
202
net/uring/io_uring.c
Normal file
@@ -0,0 +1,202 @@
|
||||
// +build linux
|
||||
|
||||
#if __has_include(<liburing.h>)
|
||||
|
||||
#include <arpa/inet.h> // debugging
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <errno.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <liburing.h>
|
||||
#include <linux/io_uring.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/udp.h>
|
||||
|
||||
// TODO: use fixed buffers? https://unixism.net/loti/tutorial/fixed_buffers.html
|
||||
|
||||
typedef struct io_uring go_uring;
|
||||
typedef struct msghdr go_msghdr;
|
||||
typedef struct iovec go_iovec;
|
||||
typedef struct sockaddr_in go_sockaddr_in;
|
||||
typedef struct io_uring_params go_io_uring_params;
|
||||
|
||||
static int initialize(struct io_uring *ring, int fd) {
|
||||
int ret = io_uring_queue_init(16, ring, 0); // 16: size of ring
|
||||
if (ret < 0) {
|
||||
return ret;
|
||||
}
|
||||
ret = io_uring_register_files(ring, &fd, 1);
|
||||
// TODO: Do we need to unregister files on close, or is Closing the uring enough?
|
||||
if (ret < 0) {
|
||||
perror("io_uring_queue_init");
|
||||
return ret;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct req {
|
||||
struct msghdr hdr;
|
||||
struct iovec iov;
|
||||
struct sockaddr_in sa;
|
||||
struct sockaddr_in6 sa6;
|
||||
// in_kernel indicates (by being non-zero) whether this request is sitting in the kernel
|
||||
// It is accessed atomically.
|
||||
int32_t in_kernel;
|
||||
char *buf;
|
||||
};
|
||||
|
||||
typedef struct req goreq;
|
||||
|
||||
static struct req *initializeReq(size_t sz, int ipVersion) {
|
||||
struct req *r = malloc(sizeof(struct req));
|
||||
memset(r, 0, sizeof(*r));
|
||||
r->buf = malloc(sz);
|
||||
memset(r->buf, 0, sz);
|
||||
r->iov.iov_base = r->buf;
|
||||
r->iov.iov_len = sz;
|
||||
r->hdr.msg_iov = &r->iov;
|
||||
r->hdr.msg_iovlen = 1;
|
||||
switch(ipVersion) {
|
||||
case 4:
|
||||
r->hdr.msg_name = &r->sa;
|
||||
r->hdr.msg_namelen = sizeof(r->sa);
|
||||
break;
|
||||
case 6:
|
||||
r->hdr.msg_name = &r->sa6;
|
||||
r->hdr.msg_namelen = sizeof(r->sa6);
|
||||
break;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
static void freeReq(struct req *r) {
|
||||
free(r->buf);
|
||||
free(r);
|
||||
}
|
||||
|
||||
// submit a recvmsg request via liburing
|
||||
// TODO: What recvfrom support arrives, maybe use that instead?
|
||||
static int submit_recvmsg_request(struct io_uring *ring, struct req *r, size_t idx) {
|
||||
struct io_uring_sqe *sqe = io_uring_get_sqe(ring);
|
||||
if (!sqe) {
|
||||
return -1;
|
||||
}
|
||||
io_uring_prep_recvmsg(sqe, 0, &r->hdr, 0); // use the 0th file in the list of registered fds
|
||||
io_uring_sqe_set_flags(sqe, IOSQE_FIXED_FILE);
|
||||
io_uring_sqe_set_data(sqe, (void *)(idx));
|
||||
io_uring_submit(ring);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// submit a recvmsg request via liburing
|
||||
// TODO: What recvfrom support arrives, maybe use that instead?
|
||||
static int submit_sendmsg_request(struct io_uring *ring, struct req *r, int buflen, size_t idx) {
|
||||
r->iov.iov_len = buflen;
|
||||
struct io_uring_sqe *sqe = io_uring_get_sqe(ring);
|
||||
io_uring_prep_sendmsg(sqe, 0, &r->hdr, 0); // use the 0th file in the list of registered fds
|
||||
io_uring_sqe_set_flags(sqe, IOSQE_FIXED_FILE);
|
||||
io_uring_sqe_set_data(sqe, (void *)(idx));
|
||||
io_uring_submit(ring);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void submit_nop_request(struct io_uring *ring) {
|
||||
struct io_uring_sqe *sqe = io_uring_get_sqe(ring);
|
||||
io_uring_prep_nop(sqe);
|
||||
io_uring_sqe_set_data(sqe, (void *)(-1));
|
||||
io_uring_submit(ring);
|
||||
}
|
||||
|
||||
static void submit_cancel_request(struct io_uring *ring, size_t idx) {
|
||||
struct io_uring_sqe *sqe = io_uring_get_sqe(ring);
|
||||
io_uring_prep_cancel(sqe, (void *)(idx), 0);
|
||||
io_uring_submit(ring);
|
||||
}
|
||||
|
||||
// submit a writev request via liburing
|
||||
static int submit_writev_request(struct io_uring *ring, struct req *r, int buflen, size_t idx) {
|
||||
r->iov.iov_len = buflen;
|
||||
struct io_uring_sqe *sqe = io_uring_get_sqe(ring);
|
||||
io_uring_prep_writev(sqe, 0, &r->iov, 1, 0); // use the 0th file in the list of registered fds
|
||||
io_uring_sqe_set_flags(sqe, IOSQE_FIXED_FILE);
|
||||
io_uring_sqe_set_data(sqe, (void *)(idx));
|
||||
int submitted = io_uring_submit(ring);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// submit a readv request via liburing
|
||||
static int submit_readv_request(struct io_uring *ring, struct req *r, size_t idx) {
|
||||
struct io_uring_sqe *sqe = io_uring_get_sqe(ring);
|
||||
io_uring_prep_readv(sqe, 0, &r->iov, 1, 0); // use the 0th file in the list of registered fds
|
||||
io_uring_sqe_set_flags(sqe, IOSQE_FIXED_FILE);
|
||||
io_uring_sqe_set_data(sqe, (void *)(idx));
|
||||
int submitted = io_uring_submit(ring);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
struct completion_result {
|
||||
int err;
|
||||
int n;
|
||||
size_t idx;
|
||||
};
|
||||
|
||||
typedef struct completion_result go_completion_result;
|
||||
|
||||
static go_completion_result completion(struct io_uring *ring, int block) {
|
||||
struct io_uring_cqe *cqe;
|
||||
struct completion_result res;
|
||||
res.err = 0;
|
||||
res.n = 0;
|
||||
res.idx = 0;
|
||||
if (block) {
|
||||
res.err = io_uring_wait_cqe(ring, &cqe);
|
||||
} else {
|
||||
res.err = io_uring_peek_cqe(ring, &cqe);
|
||||
}
|
||||
if (res.err < 0) {
|
||||
return res;
|
||||
}
|
||||
res.idx = (size_t)io_uring_cqe_get_data(cqe);
|
||||
res.n = cqe->res;
|
||||
io_uring_cqe_seen(ring, cqe);
|
||||
return res;
|
||||
}
|
||||
|
||||
static int set_deadline(struct io_uring *ring, int64_t sec, long long ns) {
|
||||
// TODO where to put this timespec so that it lives beyond the scope of this call?
|
||||
struct __kernel_timespec ts = { sec, ns };
|
||||
struct io_uring_sqe *sqe = io_uring_get_sqe(ring);
|
||||
// TODO should these be through function calls?
|
||||
sqe->opcode = IORING_OP_TIMEOUT;
|
||||
sqe->addr = (__u64)&ts;
|
||||
sqe->len = 1;
|
||||
sqe->timeout_flags = 0;
|
||||
int submitted = io_uring_submit(ring);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// index of io uring capability
|
||||
static int has_capability(int i) {
|
||||
int supported;
|
||||
struct io_uring_probe *probe = io_uring_get_probe();
|
||||
supported = io_uring_opcode_supported(probe, i);
|
||||
free(probe);
|
||||
return supported;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static int has_io_uring(void) {
|
||||
#if __has_include(<liburing.h>)
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
639
net/uring/io_uring_linux.go
Normal file
639
net/uring/io_uring_linux.go
Normal file
@@ -0,0 +1,639 @@
|
||||
package uring
|
||||
|
||||
// #cgo CFLAGS: -I${SRCDIR}/liburing/src/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/liburing/src/ -luring
|
||||
// #include "io_uring.c"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/util/endian"
|
||||
)
|
||||
|
||||
const bufferSize = device.MaxSegmentSize
|
||||
|
||||
func URingAvailable() bool { return *useIOURing && C.has_io_uring() > 0 }
|
||||
|
||||
// A UDPConn is a recv-only UDP fd manager.
|
||||
// We'd like to enqueue a bunch of recv calls and deqeueue them later,
|
||||
// but we have a problem with buffer management: We get our buffers just-in-time
|
||||
// from wireguard-go, which means we have to make copies.
|
||||
// That's OK for now, but later it could be a performance issue.
|
||||
// For now, keep it simple and enqueue/dequeue in a single step.
|
||||
type UDPConn struct {
|
||||
// We have two urings so that we don't have to demux completion events.
|
||||
|
||||
// recvRing is the uring for recvmsg calls.
|
||||
recvRing *C.go_uring
|
||||
// sendRing is the uring for sendmsg calls.
|
||||
sendRing *C.go_uring
|
||||
|
||||
// close ensures that connection closes occur exactly once.
|
||||
close sync.Once
|
||||
// closed is an atomic variable that indicates whether the connection has been closed.
|
||||
// TODO: Make an atomic bool type that we can use here.
|
||||
closed uint32
|
||||
|
||||
// local is the local address of this UDPConn.
|
||||
local net.Addr
|
||||
|
||||
// recvReqs is an array of re-usable UDP recvmsg requests.
|
||||
// We attempt to keep them all queued up for the kernel to fulfill.
|
||||
// The array length is tied to the size of the uring.
|
||||
recvReqs [8]*C.goreq
|
||||
// sendReqs is an array of re-usable UDP sendmsg requests.
|
||||
// We dispatch them to the kernel as writes are requested.
|
||||
// The array length is tied to the size of the uring.
|
||||
sendReqs [8]*C.goreq
|
||||
|
||||
// sendReqC is a channel containing indices into sendReqs
|
||||
// that are free to use (that is, not in the kernel).
|
||||
sendReqC chan int
|
||||
is4 bool
|
||||
// reads counts the number of outstanding read requests.
|
||||
// It is accessed atomically.
|
||||
reads int32
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
// checks capabilities available on this system
|
||||
capabilities map[IORingOp]bool
|
||||
)
|
||||
|
||||
func checkCapability(op IORingOp) bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if v, ok := capabilities[op]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
has_op := C.has_capability(C.int(op)) == 1
|
||||
capabilities[op] = has_op
|
||||
return has_op
|
||||
}
|
||||
|
||||
func NewUDPConn(pconn net.PacketConn) (*UDPConn, error) {
|
||||
if !*useIOURing {
|
||||
return nil, DisabledError
|
||||
}
|
||||
conn, ok := pconn.(*net.UDPConn)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot use io_uring with conn of type %T", pconn)
|
||||
}
|
||||
// this is dumb
|
||||
local := conn.LocalAddr()
|
||||
var ipp netaddr.IPPort
|
||||
switch l := local.(type) {
|
||||
case *net.UDPAddr:
|
||||
ip, ok := netaddr.FromStdIP(l.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to parse IP: %v", ip)
|
||||
}
|
||||
ipp = netaddr.IPPortFrom(ip, uint16(l.Port))
|
||||
default:
|
||||
var err error
|
||||
if ipp, err = netaddr.ParseIPPort(l.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse UDPConn local addr %s as IP: %w", local, err)
|
||||
}
|
||||
}
|
||||
ipVersion := 6
|
||||
if ipp.IP().Is4() {
|
||||
ipVersion = 4
|
||||
}
|
||||
// TODO: probe for system capabilities: https://unixism.net/loti/tutorial/probe_liburing.html
|
||||
file, err := conn.File()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// conn.File dup'd the conn's fd. We no longer need the original conn.
|
||||
conn.Close()
|
||||
recvRing := new(C.go_uring)
|
||||
sendRing := new(C.go_uring)
|
||||
|
||||
fd := file.Fd()
|
||||
for _, r := range []*C.go_uring{recvRing, sendRing} {
|
||||
ret := C.initialize(r, C.int(fd))
|
||||
if ret < 0 {
|
||||
// TODO: free recvRing if sendRing initialize failed
|
||||
return nil, fmt.Errorf("uring initialization failed: %d", ret)
|
||||
}
|
||||
}
|
||||
u := &UDPConn{
|
||||
recvRing: recvRing,
|
||||
sendRing: sendRing,
|
||||
local: conn.LocalAddr(),
|
||||
is4: ipVersion == 4,
|
||||
}
|
||||
|
||||
// Initialize buffers
|
||||
for _, reqs := range []*[8]*C.goreq{&u.recvReqs, &u.sendReqs} {
|
||||
for i := range reqs {
|
||||
reqs[i] = C.initializeReq(bufferSize, C.int(ipVersion))
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize recv half.
|
||||
for i := range u.recvReqs {
|
||||
if err := u.submitRecvRequest(i); err != nil {
|
||||
u.Close() // TODO: will this crash?
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Initialize send half.
|
||||
u.sendReqC = make(chan int, len(u.sendReqs))
|
||||
for i := range u.sendReqs {
|
||||
u.sendReqC <- i
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (u *UDPConn) submitRecvRequest(idx int) error {
|
||||
// TODO: make a C struct instead of a Go struct, and pass that in, to simplify call sites.
|
||||
errno := C.submit_recvmsg_request(u.recvRing, u.recvReqs[idx], C.size_t(idx))
|
||||
if errno < 0 {
|
||||
return fmt.Errorf("uring.submitRecvRequest failed: %w", syscall.Errno(-errno)) // TODO: Improve
|
||||
}
|
||||
atomic.AddInt32(u.recvReqInKernel(idx), 1) // TODO: CAS?
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UDPConn) recvReqInKernel(idx int) *int32 {
|
||||
return (*int32)(unsafe.Pointer(&u.recvReqs[idx].in_kernel))
|
||||
}
|
||||
|
||||
// TODO: replace with unsafe.Slice once we are using Go 1.17.
|
||||
|
||||
func sliceOf(ptr *C.char, n int) []byte {
|
||||
var b []byte
|
||||
h := (*reflect.SliceHeader)(unsafe.Pointer(&b))
|
||||
h.Data = uintptr(unsafe.Pointer(ptr))
|
||||
h.Len = n
|
||||
h.Cap = n
|
||||
return b
|
||||
}
|
||||
|
||||
func (u *UDPConn) ReadFromNetaddr(buf []byte) (int, netaddr.IPPort, error) {
|
||||
// Important: register that there is a read before checking whether the conn is closed.
|
||||
// Close assumes that once it has set u.closed to non-zero there are no "hidden" reads outstanding,
|
||||
// as their could be if we did this in the other order.
|
||||
atomic.AddInt32(&u.reads, 1)
|
||||
defer atomic.AddInt32(&u.reads, -1)
|
||||
if atomic.LoadUint32(&u.closed) != 0 {
|
||||
return 0, netaddr.IPPort{}, net.ErrClosed
|
||||
}
|
||||
n, idx, err := waitCompletion(u.recvRing)
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.ECANCELED) {
|
||||
atomic.AddInt32(u.recvReqInKernel(idx), -1)
|
||||
}
|
||||
// io_uring failed to run our syscall.
|
||||
return 0, netaddr.IPPort{}, fmt.Errorf("ReadFromNetaddr io_uring could not run syscall: %w", err)
|
||||
}
|
||||
atomic.AddInt32(u.recvReqInKernel(idx), -1)
|
||||
if n < 0 {
|
||||
// io_uring ran our syscall, which failed.
|
||||
// Best effort attempt not to leak idx.
|
||||
u.submitRecvRequest(int(idx))
|
||||
return 0, netaddr.IPPort{}, fmt.Errorf("ReadFromNetaddr syscall failed: %w", syscall.Errno(-n))
|
||||
}
|
||||
r := u.recvReqs[idx]
|
||||
var ip netaddr.IP
|
||||
var port uint16
|
||||
// TODO: native go endianness conversion routines so we don't have to call ntohl, etc.
|
||||
if u.is4 {
|
||||
ip = netaddr.IPFrom4(*(*[4]byte)((unsafe.Pointer)((&r.sa.sin_addr.s_addr))))
|
||||
port = endian.Ntoh16(uint16(r.sa.sin_port))
|
||||
} else {
|
||||
ip = netaddr.IPFrom16(*(*[16]byte)((unsafe.Pointer)((&r.sa6.sin6_addr))))
|
||||
port = endian.Ntoh16(uint16(r.sa6.sin6_port))
|
||||
}
|
||||
ipp := netaddr.IPPortFrom(ip, port)
|
||||
rbuf := sliceOf(r.buf, n)
|
||||
copy(buf, rbuf)
|
||||
// Queue up a new request.
|
||||
if err := u.submitRecvRequest(int(idx)); err != nil {
|
||||
// Aggressively return this error.
|
||||
// The error will bubble up and cause the entire conn to be closed down,
|
||||
// so it doesn't matter that we lost a packet here.
|
||||
return 0, netaddr.IPPort{}, err
|
||||
}
|
||||
return n, ipp, nil
|
||||
}
|
||||
|
||||
func (u *UDPConn) Close() error {
|
||||
u.close.Do(func() {
|
||||
// Announce to readers and writers that we are closing down.
|
||||
atomic.StoreUint32(&u.closed, 1)
|
||||
// It is now not possible for u.reads to reach zero without
|
||||
// all reads being unblocked.
|
||||
|
||||
// Busy loop until all reads are unblocked.
|
||||
// This is unpleasant, but I don't know of another way that
|
||||
// doesn't introduce significant synchronization overhead.
|
||||
// (The obvious alternative is to use a sync.RWMutex,
|
||||
// but that has a chicken-and-egg problem: Reads must take an rlock,
|
||||
// but we cannot take a wlock under all the rlocks are released,
|
||||
// but we cannot issue cancellations to release the rlocks without
|
||||
// first taking a wlock.)
|
||||
BusyLoop:
|
||||
for {
|
||||
for idx := range u.recvReqs {
|
||||
if atomic.LoadInt32(u.recvReqInKernel(idx)) != 0 {
|
||||
C.submit_cancel_request(u.recvRing, C.size_t(idx))
|
||||
}
|
||||
}
|
||||
reads := atomic.LoadInt32(&u.reads)
|
||||
if reads > 0 {
|
||||
time.Sleep(time.Millisecond)
|
||||
} else {
|
||||
break BusyLoop
|
||||
}
|
||||
}
|
||||
// TODO: block until no one else uses our rings.
|
||||
// (Or is that unnecessary now?)
|
||||
C.io_uring_queue_exit(u.recvRing)
|
||||
C.io_uring_queue_exit(u.sendRing)
|
||||
|
||||
// Free buffers
|
||||
for _, r := range u.recvReqs {
|
||||
C.freeReq(r)
|
||||
}
|
||||
for _, r := range u.sendReqs {
|
||||
C.freeReq(r)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Implement net.PacketConn, for convenience integrating with magicsock.
|
||||
|
||||
var _ net.PacketConn = (*UDPConn)(nil)
|
||||
|
||||
type udpAddr struct {
|
||||
ipp netaddr.IPPort
|
||||
}
|
||||
|
||||
func (u udpAddr) Network() string { return "udp4" } // TODO: ipv6
|
||||
func (u udpAddr) String() string { return u.ipp.String() }
|
||||
|
||||
func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, ipp, err := c.ReadFromNetaddr(p)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
return n, udpAddr{ipp: ipp}, err
|
||||
}
|
||||
|
||||
func (u *UDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if atomic.LoadUint32(&u.closed) != 0 {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("cannot WriteTo net.Addr of type %T", addr)
|
||||
}
|
||||
// If we need a buffer, get a buffer, potentially blocking.
|
||||
var idx int
|
||||
select {
|
||||
case idx = <-u.sendReqC:
|
||||
default:
|
||||
// No request available. Get one from the kernel.
|
||||
n, idx, err = waitCompletion(u.sendRing)
|
||||
if err != nil {
|
||||
// io_uring failed to issue the syscall.
|
||||
return 0, fmt.Errorf("WriteTo io_uring call failed: %w", err)
|
||||
}
|
||||
if n < 0 {
|
||||
// Past syscall failed.
|
||||
u.sendReqC <- idx // don't leak idx
|
||||
return 0, fmt.Errorf("previous WriteTo failed: %w", syscall.Errno(-n))
|
||||
}
|
||||
}
|
||||
r := u.sendReqs[idx]
|
||||
// Do the write.
|
||||
rbuf := sliceOf(r.buf, len(p))
|
||||
copy(rbuf, p)
|
||||
|
||||
if u.is4 {
|
||||
ipu32 := binary.BigEndian.Uint32(udpAddr.IP)
|
||||
r.sa.sin_addr.s_addr = C.uint32_t(endian.Hton32(ipu32))
|
||||
r.sa.sin_port = C.uint16_t(endian.Hton16(uint16(udpAddr.Port)))
|
||||
r.sa.sin_family = C.AF_INET
|
||||
} else {
|
||||
dst := (*[16]byte)((unsafe.Pointer)(&r.sa6.sin6_addr))
|
||||
src := (*[16]byte)((unsafe.Pointer)(&udpAddr.IP[0]))
|
||||
*dst = *src
|
||||
r.sa6.sin6_port = C.uint16_t(endian.Hton16(uint16(udpAddr.Port)))
|
||||
r.sa6.sin6_family = C.AF_INET6
|
||||
}
|
||||
C.submit_sendmsg_request(
|
||||
u.sendRing, // ring
|
||||
r,
|
||||
C.int(len(p)), // buffer len, ditto
|
||||
C.size_t(idx), // user data
|
||||
)
|
||||
// Get an extra buffer, if available.
|
||||
if idx, ok := peekCompletion(u.sendRing); ok {
|
||||
// Put the request buffer back in the usable queue.
|
||||
// Should never block, by construction.
|
||||
u.sendReqC <- idx
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (c *UDPConn) LocalAddr() net.Addr { return c.local }
|
||||
|
||||
func (c *UDPConn) SetDeadline(t time.Time) error { panic("not implemented") }
|
||||
func (c *UDPConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
|
||||
func (c *UDPConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }
|
||||
|
||||
// Files!
|
||||
|
||||
// A File is a write-only file fd manager.
|
||||
// TODO: Support reads
|
||||
// TODO: all the todos from UDPConn
|
||||
type file struct {
|
||||
writeRing *C.go_uring
|
||||
readRing *C.go_uring
|
||||
close sync.Once
|
||||
file *os.File // must keep file from being GC'd
|
||||
fd uintptr
|
||||
readReqs [1]*C.goreq // Whoops! The kernel apparently cannot handle more than 1 concurrent preadv calls on a tun device!
|
||||
writeReqs [8]*C.goreq
|
||||
writeReqC chan int // indices into reqs
|
||||
}
|
||||
|
||||
func newFile(f *os.File) (*file, error) {
|
||||
fd := f.Fd()
|
||||
u := &file{
|
||||
file: f,
|
||||
fd: fd,
|
||||
}
|
||||
for _, ringPtr := range []**C.go_uring{&u.writeRing, &u.readRing} {
|
||||
r := new(C.go_uring)
|
||||
ret := C.initialize(r, C.int(fd))
|
||||
if ret < 0 {
|
||||
// TODO: handle unwinding partial initialization
|
||||
return nil, fmt.Errorf("uring initialization failed: %d", ret)
|
||||
}
|
||||
*ringPtr = r
|
||||
}
|
||||
|
||||
// Initialize buffers
|
||||
for i := range &u.readReqs {
|
||||
u.readReqs[i] = C.initializeReq(bufferSize, 0)
|
||||
}
|
||||
for i := range &u.writeReqs {
|
||||
u.writeReqs[i] = C.initializeReq(bufferSize, 0)
|
||||
}
|
||||
|
||||
// Initialize read half.
|
||||
for i := range u.readReqs {
|
||||
if err := u.submitReadvRequest(i); err != nil {
|
||||
u.Close() // TODO: will this crash?
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
u.writeReqC = make(chan int, len(u.writeReqs))
|
||||
for i := range u.writeReqs {
|
||||
u.writeReqC <- i
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (u *file) submitReadvRequest(idx int) error {
|
||||
// TODO: make a C struct instead of a Go struct, and pass that in, to simplify call sites.
|
||||
errno := C.submit_readv_request(u.readRing, u.readReqs[idx], C.size_t(idx))
|
||||
if errno < 0 {
|
||||
return fmt.Errorf("uring.submitReadvRequest failed: %v", errno) // TODO: Improve
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
noBlockForCompletion = 0
|
||||
blockForCompletion = 1
|
||||
)
|
||||
|
||||
// waitCompletion blocks until a completion on ring succeeds, or until *fd == 0.
|
||||
// If *fd == 0, that indicates that the ring is no loner valid, in which case waitCompletion returns net.ErrClosed.
|
||||
// Reads of *fd are atomic.
|
||||
func waitCompletion(ring *C.go_uring) (n, idx int, err error) {
|
||||
for {
|
||||
r := C.completion(ring, blockForCompletion)
|
||||
if syscall.Errno(-r.err) == syscall.EAGAIN {
|
||||
continue
|
||||
}
|
||||
var err error
|
||||
if r.err < 0 {
|
||||
err = syscall.Errno(-r.err)
|
||||
}
|
||||
return int(r.n), int(r.idx), err
|
||||
}
|
||||
}
|
||||
|
||||
func peekCompletion(ring *C.go_uring) (idx int, ok bool) {
|
||||
r := C.completion(ring, noBlockForCompletion)
|
||||
if r.err < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return int(r.idx), true
|
||||
}
|
||||
|
||||
type fileReq struct {
|
||||
iov C.go_iovec
|
||||
buf [device.MaxSegmentSize]byte
|
||||
}
|
||||
|
||||
// Read data into buf[offset:].
|
||||
// We are allowed to write junk into buf[offset-4:offset].
|
||||
func (u *file) Read(buf []byte) (n int, err error) { // read a packet from the device (without any additional headers)
|
||||
if u.fd == 0 { // TODO: review all uses of u.fd for atomic read/write
|
||||
return 0, errors.New("invalid uring.File")
|
||||
}
|
||||
n, idx, err := waitCompletion(u.readRing)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Read: io_uring failed to issue syscall: %w", err)
|
||||
}
|
||||
if n < 0 {
|
||||
// Syscall failed.
|
||||
u.submitReadvRequest(int(idx)) // best effort attempt not to leak idx
|
||||
return 0, fmt.Errorf("Read: syscall failed: %w", syscall.Errno(-n))
|
||||
}
|
||||
// Success.
|
||||
r := u.readReqs[idx]
|
||||
rbuf := sliceOf(r.buf, n)
|
||||
copy(buf, rbuf)
|
||||
// Queue up a new request.
|
||||
if err := u.submitReadvRequest(int(idx)); err != nil {
|
||||
// Aggressively return this error.
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (u *file) Write(buf []byte) (int, error) {
|
||||
if u.fd == 0 {
|
||||
return 0, errors.New("invalid uring.FileConn")
|
||||
}
|
||||
// If we need a buffer, get a buffer, potentially blocking.
|
||||
var idx int
|
||||
select {
|
||||
case idx = <-u.writeReqC:
|
||||
default:
|
||||
// No request available. Get one from the kernel.
|
||||
n, idx, err := waitCompletion(u.writeRing)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Write io_uring call failed: %w", err)
|
||||
}
|
||||
if n < 0 {
|
||||
// Past syscall failed.
|
||||
u.writeReqC <- idx // don't leak idx
|
||||
return 0, fmt.Errorf("previous Write failed: %w", syscall.Errno(-n))
|
||||
}
|
||||
}
|
||||
r := u.writeReqs[idx]
|
||||
// Do the write.
|
||||
rbuf := sliceOf(r.buf, len(buf))
|
||||
copy(rbuf, buf)
|
||||
C.submit_writev_request(u.writeRing, r, C.int(len(buf)), C.size_t(idx))
|
||||
// Get an extra buffer, if available.
|
||||
idx, ok := peekCompletion(u.writeRing)
|
||||
if ok {
|
||||
// Put the request buffer back in the usable queue.
|
||||
// Should never block, by construction.
|
||||
u.writeReqC <- idx
|
||||
}
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
func (u *file) Close() error {
|
||||
u.close.Do(func() {
|
||||
atomic.StoreUintptr(&u.fd, 0)
|
||||
u.file.Close()
|
||||
u.file = nil
|
||||
// TODO: bring the shutdown logic from UDPConn.Close here?
|
||||
// Or is closing the file above enough, unlike for UDP?
|
||||
C.io_uring_queue_exit(u.readRing)
|
||||
C.io_uring_queue_exit(u.writeRing)
|
||||
|
||||
// Free buffers
|
||||
for _, r := range u.readReqs {
|
||||
C.freeReq(r)
|
||||
}
|
||||
for _, r := range u.writeReqs {
|
||||
C.freeReq(r)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wrap files into TUN devices.
|
||||
|
||||
func NewTUN(d tun.Device) (tun.Device, error) {
|
||||
nt, ok := d.(*tun.NativeTun)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("NewTUN only wraps *tun.NativeTun, got %T", d)
|
||||
}
|
||||
f, err := newFile(nt.File())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v := reflect.ValueOf(nt)
|
||||
field, ok := v.Elem().Type().FieldByName("errors")
|
||||
if !ok {
|
||||
return nil, errors.New("could not find internal tun.NativeTun errors field")
|
||||
}
|
||||
ptr := unsafe.Pointer(nt)
|
||||
ptr = unsafe.Pointer(uintptr(ptr) + field.Offset) // TODO: switch to unsafe.Add with Go 1.17...as if that's the worst thing in this line
|
||||
c := *(*chan error)(ptr)
|
||||
return &TUN{d: nt, f: f, errors: c}, nil
|
||||
}
|
||||
|
||||
// No nopi
|
||||
type TUN struct {
|
||||
d *tun.NativeTun
|
||||
f *file
|
||||
errors chan error
|
||||
}
|
||||
|
||||
func (t *TUN) File() *os.File {
|
||||
return t.f.file
|
||||
}
|
||||
|
||||
func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
||||
select {
|
||||
case err := <-t.errors:
|
||||
return 0, err
|
||||
default:
|
||||
}
|
||||
// TODO: upstream has graceful shutdown error handling here.
|
||||
buff := buf[offset-4:]
|
||||
n, err := t.f.Read(buff[:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
if n < 4 {
|
||||
n = 0
|
||||
} else {
|
||||
n -= 4
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (t *TUN) Write(buf []byte, offset int) (int, error) {
|
||||
// below copied from wireguard-go NativeTun.Write
|
||||
|
||||
// reserve space for header
|
||||
buf = buf[offset-4:]
|
||||
|
||||
// add packet information header
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
if buf[4]>>4 == ipv6.Version {
|
||||
buf[2] = 0x86
|
||||
buf[3] = 0xdd
|
||||
} else {
|
||||
buf[2] = 0x08
|
||||
buf[3] = 0x00
|
||||
}
|
||||
|
||||
n, err := t.f.Write(buf)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (t *TUN) Flush() error { return t.d.Flush() }
|
||||
func (t *TUN) MTU() (int, error) { return t.d.MTU() }
|
||||
func (t *TUN) Name() (string, error) { return t.d.Name() }
|
||||
func (t *TUN) Events() chan tun.Event { return t.d.Events() }
|
||||
|
||||
func (t *TUN) Close() error {
|
||||
err1 := t.f.Close()
|
||||
err2 := t.d.Close()
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
3
net/uring/makefile
Normal file
3
net/uring/makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
get_liburing:
|
||||
git clone git@github.com:axboe/liburing.git
|
||||
cd liburing && make
|
||||
32
net/uring/stubs.go
Normal file
32
net/uring/stubs.go
Normal file
@@ -0,0 +1,32 @@
|
||||
// +build !linux
|
||||
|
||||
package uring
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// This file contains stubs for platforms that are known at compile time not to support io_uring.
|
||||
|
||||
type UDPConn struct{}
|
||||
|
||||
func NewUDPConn(*net.UDPConn) (*UDPConn, error) { panic("io_uring unavailable") }
|
||||
func (u *UDPConn) ReadFromNetaddr([]byte) (int, netaddr.IPPort, error) { panic("io_uring unavailable") }
|
||||
func (u *UDPConn) Close() error { panic("io_uring unavailable") }
|
||||
func (c *UDPConn) ReadFrom([]byte) (int, net.Addr, error) { panic("io_uring unavailable") }
|
||||
func (u *UDPConn) WriteTo([]byte, net.Addr) (int, error) { panic("io_uring unavailable") }
|
||||
func (c *UDPConn) LocalAddr() net.Addr { panic("io_uring unavailable") }
|
||||
func (c *UDPConn) SetDeadline(time.Time) error { panic("io_uring unavailable") }
|
||||
func (c *UDPConn) SetReadDeadline(time.Time) error { panic("io_uring unavailable") }
|
||||
func (c *UDPConn) SetWriteDeadline(time.Time) error { panic("io_uring unavailable") }
|
||||
|
||||
type File struct{}
|
||||
|
||||
func NewFile(file *os.File) (*File, error) { panic("io_uring unavailable") }
|
||||
func (u *File) Read([]byte) (int, error) { panic("io_uring unavailable") }
|
||||
func (u *File) Write([]byte) (int, error) { panic("io_uring unavailable") }
|
||||
func (u *File) Close() error { panic("io_uring unavailable") }
|
||||
1
net/uring/testdata/voltaire.txt
vendored
Normal file
1
net/uring/testdata/voltaire.txt
vendored
Normal file
@@ -0,0 +1 @@
|
||||
If io_uring did not exist, it would be necessary to invent it.
|
||||
105
net/uring/udp_test.go
Normal file
105
net/uring/udp_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package uring
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
qt "github.com/frankban/quicktest"
|
||||
)
|
||||
|
||||
func TestUDPSendRecv(t *testing.T) {
|
||||
if !Available() {
|
||||
t.Skip("io_uring not available")
|
||||
}
|
||||
c := qt.New(t)
|
||||
|
||||
listen, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 9999})
|
||||
t.Cleanup(func() { listen.Close() })
|
||||
c.Assert(err, qt.IsNil)
|
||||
|
||||
conn, err := NewUDPConn(listen)
|
||||
t.Cleanup(func() { conn.Close() })
|
||||
if err != nil {
|
||||
t.Skipf("io_uring not available: %v", err)
|
||||
}
|
||||
addr := listen.LocalAddr()
|
||||
sendBuf := make([]byte, 200)
|
||||
for i := range sendBuf {
|
||||
sendBuf[i] = byte(i)
|
||||
}
|
||||
recvBuf := make([]byte, 200)
|
||||
|
||||
// Write one direction.
|
||||
_, err = conn.WriteTo(sendBuf, addr)
|
||||
c.Assert(err, qt.IsNil)
|
||||
n, ipp, err := conn.ReadFromNetaddr(recvBuf)
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Assert(recvBuf[:n], qt.DeepEquals, sendBuf)
|
||||
|
||||
// Write the other direction, to check that ipp is correct.
|
||||
_, err = conn.WriteTo(sendBuf, ipp.UDPAddr())
|
||||
c.Assert(err, qt.IsNil)
|
||||
n, _, err = conn.ReadFromNetaddr(recvBuf)
|
||||
c.Assert(err, qt.IsNil)
|
||||
c.Assert(recvBuf[:n], qt.DeepEquals, sendBuf)
|
||||
}
|
||||
|
||||
// TODO(jknodt): maybe delete the test below because it's redundant
|
||||
|
||||
const TestPort = 3636
|
||||
|
||||
var serverAddr = &net.UDPAddr{
|
||||
Port: TestPort,
|
||||
}
|
||||
|
||||
func NewUDPTestServer(t *testing.T) (closer func() error, err error) {
|
||||
conn, err := net.ListenUDP("udp", serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, 512)
|
||||
_, _, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read on server: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
return conn.Close, nil
|
||||
}
|
||||
|
||||
func TestUDPConn(t *testing.T) {
|
||||
if !Available() {
|
||||
t.Skip("io_uring not available")
|
||||
}
|
||||
c := qt.New(t)
|
||||
// TODO add a closer here
|
||||
closer, err := NewUDPTestServer(t)
|
||||
c.Assert(err, qt.IsNil)
|
||||
t.Cleanup(func() { closer() })
|
||||
udpConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||
c.Assert(err, qt.IsNil)
|
||||
defer udpConn.Close()
|
||||
|
||||
conn, err := NewUDPConn(udpConn)
|
||||
c.Assert(err, qt.IsNil)
|
||||
defer conn.Close()
|
||||
|
||||
content := []byte("a test string to check udpconn works 😀 with non-unicode input")
|
||||
n, err := conn.WriteTo(content, serverAddr)
|
||||
c.Assert(err, qt.IsNil)
|
||||
if n != len(content) {
|
||||
t.Errorf("written len mismatch: want %v, got %v", len(content), n)
|
||||
}
|
||||
|
||||
// Test many writes at once
|
||||
for i := 0; i < 256; i++ {
|
||||
n, err := conn.WriteTo(content, serverAddr)
|
||||
c.Assert(err, qt.IsNil)
|
||||
if n != len(content) {
|
||||
t.Errorf("written len mismatch: want %v, got %v", len(content), n)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,3 +13,12 @@ const Big = true
|
||||
|
||||
// Native is the platform's native byte order.
|
||||
var Native = binary.BigEndian
|
||||
|
||||
// Ntoh16 converts network order into native/host.
|
||||
func Ntoh16(v uint16) uint16 { return v }
|
||||
|
||||
// Hton32 converts native/host uint32 order into network order.
|
||||
func Hton32(v uint32) uint32 { return v }
|
||||
|
||||
// Hton16 converts native/host uint16 order into network order.
|
||||
func Hton16(v uint16) uint16 { return v }
|
||||
|
||||
48
util/endian/encoding_test.go
Normal file
48
util/endian/encoding_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package endian
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func TestNtoh16(t *testing.T) {
|
||||
raw := uint16(0xABCD)
|
||||
rawBytes := toNativeBytes16(raw)
|
||||
big := binary.BigEndian.Uint16(rawBytes[:])
|
||||
if raw != Ntoh16(big) {
|
||||
t.Errorf("ntohs failed, want %v, got %v", raw, Ntoh16(big))
|
||||
}
|
||||
}
|
||||
|
||||
func toNativeBytes32(v uint32) [4]byte {
|
||||
return *(*[4]byte)(unsafe.Pointer(&v))
|
||||
}
|
||||
|
||||
func TestHton32(t *testing.T) {
|
||||
raw := uint32(0xDEADBEEF)
|
||||
|
||||
networkOrder := Hton32(raw)
|
||||
bytes := toNativeBytes32(networkOrder)
|
||||
fromBig := binary.BigEndian.Uint32(bytes[:])
|
||||
|
||||
if fromBig != raw {
|
||||
t.Errorf("htonl failed, want %v, got %v", raw, fromBig)
|
||||
}
|
||||
}
|
||||
|
||||
func toNativeBytes16(v uint16) [2]byte {
|
||||
return *(*[2]byte)(unsafe.Pointer(&v))
|
||||
}
|
||||
|
||||
func TestHton16(t *testing.T) {
|
||||
raw := uint16(0xBEEF)
|
||||
|
||||
networkOrder := Hton16(raw)
|
||||
bytes := toNativeBytes16(networkOrder)
|
||||
fromBig := binary.BigEndian.Uint16(bytes[:])
|
||||
|
||||
if fromBig != raw {
|
||||
t.Errorf("htonl failed, want %v, got %v", raw, fromBig)
|
||||
}
|
||||
}
|
||||
@@ -6,10 +6,22 @@
|
||||
|
||||
package endian
|
||||
|
||||
import "encoding/binary"
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// Big is whether the current platform is big endian.
|
||||
const Big = false
|
||||
|
||||
// Native is the platform's native byte order.
|
||||
var Native = binary.LittleEndian
|
||||
|
||||
// Ntoh16 converts network into native/host order.
|
||||
func Ntoh16(v uint16) uint16 { return bits.ReverseBytes16(v) }
|
||||
|
||||
// Hton32 converts native/host uint32 order into network order.
|
||||
func Hton32(v uint32) uint32 { return bits.ReverseBytes32(v) }
|
||||
|
||||
// Hton16 converts native/host uint16 order into network order.
|
||||
func Hton16(v uint16) uint16 { return bits.ReverseBytes16(v) }
|
||||
|
||||
@@ -44,6 +44,7 @@ import (
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/net/portmapper"
|
||||
"tailscale.com/net/stun"
|
||||
"tailscale.com/net/uring"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstime"
|
||||
@@ -2690,6 +2691,15 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
|
||||
}
|
||||
// Success.
|
||||
ruc.pconn = pconn
|
||||
if uring.Available() {
|
||||
uringConn, err := uring.NewUDPConn(pconn)
|
||||
if err != nil {
|
||||
c.logf("not using io_uring for UDP %v: %v", pconn.LocalAddr(), err)
|
||||
} else {
|
||||
c.logf("using uring for UDP %v", pconn.LocalAddr())
|
||||
ruc.pconn = uringConn
|
||||
}
|
||||
}
|
||||
if network == "udp4" {
|
||||
health.SetUDP4Unbound(false)
|
||||
}
|
||||
@@ -2845,17 +2855,22 @@ func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netaddr.IPPort,
|
||||
for {
|
||||
pconn := c.currentConn()
|
||||
|
||||
// Optimization: Treat *net.UDPConn specially.
|
||||
// ReadFromUDP gets partially inlined, avoiding allocating a *net.UDPAddr,
|
||||
// Optimization: Treat a few pconn types specially.
|
||||
// For *net.UDPConn, ReadFromUDP gets partially inlined, avoiding allocating a *net.UDPAddr,
|
||||
// as long as pAddr itself doesn't escape.
|
||||
// The non-*net.UDPConn case works, but it allocates.
|
||||
// *uring.UDPConn can return netaddr.IPPorts directly.
|
||||
// The default case works, but it allocates.
|
||||
var pAddr *net.UDPAddr
|
||||
if udpConn, ok := pconn.(*net.UDPConn); ok {
|
||||
n, pAddr, err = udpConn.ReadFromUDP(b)
|
||||
} else {
|
||||
switch pconn := pconn.(type) {
|
||||
case *net.UDPConn:
|
||||
n, pAddr, err = pconn.ReadFromUDP(b)
|
||||
case *uring.UDPConn:
|
||||
n, ipp, err = pconn.ReadFromNetaddr(b)
|
||||
default:
|
||||
var addr net.Addr
|
||||
n, addr, err = pconn.ReadFrom(b)
|
||||
if addr != nil {
|
||||
var ok bool
|
||||
pAddr, ok = addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, netaddr.IPPort{}, fmt.Errorf("RebindingUDPConn.ReadFromNetaddr: underlying connection returned address of type %T, want *netaddr.UDPAddr", addr)
|
||||
@@ -2867,7 +2882,7 @@ func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netaddr.IPPort,
|
||||
if pconn != c.currentConn() {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
} else if pAddr != nil {
|
||||
// Convert pAddr to a netaddr.IPPort.
|
||||
// This prevents pAddr from escaping.
|
||||
var ok bool
|
||||
|
||||
Reference in New Issue
Block a user