Compare commits

...

12 Commits

Author SHA1 Message Date
Andrew Lytvynov
776ab357b1 cmd/tailscale: add --json-docs flag
This prints all command and flag docs as JSON. To be used for generating
the contents of https://tailscale.com/kb/1080/cli.

Updates https://github.com/tailscale/tailscale-www/issues/4722
2024-08-08 08:06:18 -07:00
Jordan Whited
a93dc6cdb1 wgengine/magicsock: refactor batchingUDPConn to batchingConn interface (#13042)
This commit adds a batchingConn interface, and renames batchingUDPConn
to linuxBatchingConn. tryUpgradeToBatchingConn() may return a platform-
specific implementation of batchingConn. So far only a Linux
implementation of this interface exists, but this refactor is being
done in anticipation of a Windows implementation.

Updates tailscale/corp#21874

Signed-off-by: Jordan Whited <jordan@tailscale.com>
2024-08-06 09:00:28 -07:00
Anton Tolchanov
7bac5dffcb control/controlhttp: extract the last network connection
The same context we use for the HTTP request here might be re-used by
the dialer, which could result in `GotConn` being called multiple times.
We only care about the last one.

Fixes #13009

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
2024-08-06 11:42:06 +01:00
Anton Tolchanov
b3fc345aba cmd/derpprobe: use a status page from the prober library
Updates tailscale/corp#20583

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
2024-08-06 11:27:59 +01:00
Anton Tolchanov
9106187a95 prober: support JSON response in RunHandler
Updates tailscale/corp#20583

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
2024-08-06 11:27:59 +01:00
Anton Tolchanov
9b08399d9e prober: add a status page handler
This change adds an HTTP handler with a table showing a list of all
probes, their status, and a button that allows triggering a specific
probe.

Updates tailscale/corp#20583

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
2024-08-06 11:27:59 +01:00
Anton Tolchanov
153a476957 prober: add an HTTP endpoint for triggering a probe
- Keep track of the last 10 probe results and successful probe
  latencies;
- Add an HTTP handler that triggers a given probe by name and returns it
  result as a plaintext HTML page, showing recent probe results as a
  baseline

Updates tailscale/corp#20583

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
2024-08-06 11:27:59 +01:00
Anton Tolchanov
227509547f {control,net}: close idle connections of custom transports
I noticed a few places with custom http.Transport where we are not
closing idle connections when transport is no longer used.

Updates tailscale/corp#21609

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
2024-08-05 17:28:15 +01:00
VimT
e3f047618b net/socks5: support UDP
Updates #7581

Signed-off-by: VimT <me@vimt.me>
2024-08-05 09:25:24 -07:00
Kot C
91d2e1772d words: raccoon dog, dog with the raccoon in 'im
Signed-off-by: Kot C <kot@yukata.dev>
2024-08-05 09:24:33 -07:00
License Updater
3b6849e362 licenses: update license notices
Signed-off-by: License Updater <noreply+license-updater@tailscale.com>
2024-08-05 08:45:07 -07:00
Anton Tolchanov
0fd73746dd cmd/tailscale/cli: fix revoke-keys command name in CLI output
During review of #8644 the `recover-compromised-key` command was renamed
to `revoke-key`, but the old name remained in some messages printed by
the command.

Fixes tailscale/corp#19446

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
2024-08-05 14:49:48 +01:00
27 changed files with 2011 additions and 834 deletions

View File

@@ -7,8 +7,6 @@ package main
import (
"flag"
"fmt"
"html"
"io"
"log"
"net/http"
"sort"
@@ -70,8 +68,13 @@ func main() {
}
mux := http.NewServeMux()
tsweb.Debugger(mux)
mux.HandleFunc("/", http.HandlerFunc(serveFunc(p)))
d := tsweb.Debugger(mux)
d.Handle("probe-run", "Run a probe", tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{Logf: log.Printf}))
mux.Handle("/", tsweb.StdHandler(p.StatusHandler(
prober.WithTitle("DERP Prober"),
prober.WithPageLink("Prober metrics", "/debug/varz"),
prober.WithProbeLink("Run Probe", "/debug/probe-run?name={{.Name}}"),
), tsweb.HandlerOptions{Logf: log.Printf}))
log.Printf("Listening on %s", *listen)
log.Fatal(http.ListenAndServe(*listen, mux))
}
@@ -105,26 +108,3 @@ func getOverallStatus(p *prober.Prober) (o overallStatus) {
sort.Strings(o.good)
return
}
func serveFunc(p *prober.Prober) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
st := getOverallStatus(p)
summary := "All good"
if (float64(len(st.bad)) / float64(len(st.bad)+len(st.good))) > 0.25 {
// Returning a 500 allows monitoring this server externally and configuring
// an alert on HTTP response code.
w.WriteHeader(500)
summary = fmt.Sprintf("%d problems", len(st.bad))
}
io.WriteString(w, "<html><head><style>.bad { font-weight: bold; color: #700; }</style></head>\n")
fmt.Fprintf(w, "<body><h1>derp probe</h1>\n%s:<ul>", summary)
for _, s := range st.bad {
fmt.Fprintf(w, "<li class=bad>%s</li>\n", html.EscapeString(s))
}
for _, s := range st.good {
fmt.Fprintf(w, "<li>%s</li>\n", html.EscapeString(s))
}
io.WriteString(w, "</ul></body></html>\n")
}
}

View File

@@ -7,6 +7,7 @@ package cli
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
@@ -159,8 +160,10 @@ func newRootCmd() *ffcli.Command {
return nil
})
rootfs.Lookup("socket").DefValue = localClient.Socket
jsonDocs := rootfs.Bool("json-docs", false, hidden+"print JSON-encoded docs for all subcommands and flags")
rootCmd := &ffcli.Command{
var rootCmd *ffcli.Command
rootCmd = &ffcli.Command{
Name: "tailscale",
ShortUsage: "tailscale [flags] <subcommand> [command flags]",
ShortHelp: "The easiest, most secure way to use WireGuard.",
@@ -202,6 +205,9 @@ change in the future.
},
FlagSet: rootfs,
Exec: func(ctx context.Context, args []string) error {
if *jsonDocs {
return printJSONDocs(rootCmd)
}
if len(args) > 0 {
return fmt.Errorf("tailscale: unknown subcommand: %s", args[0])
}
@@ -401,3 +407,54 @@ func colorableOutput() (w io.Writer, ok bool) {
}
return colorable.NewColorableStdout(), true
}
type commandDoc struct {
Name string
Desc string
Subcommands []commandDoc `json:",omitempty"`
Flags []flagDoc `json:",omitempty"`
}
type flagDoc struct {
Name string
Desc string
}
func printJSONDocs(root *ffcli.Command) error {
docs := jsonDocsWalk(root)
return json.NewEncoder(os.Stdout).Encode(docs)
}
func jsonDocsWalk(cmd *ffcli.Command) *commandDoc {
res := &commandDoc{
Name: cmd.Name,
}
if cmd.LongHelp != "" {
res.Desc = cmd.LongHelp
} else if cmd.ShortHelp != "" {
res.Desc = cmd.ShortHelp
} else {
res.Desc = cmd.ShortUsage
}
if strings.HasPrefix(res.Desc, hidden) {
return nil
}
if cmd.FlagSet != nil {
cmd.FlagSet.VisitAll(func(f *flag.Flag) {
if strings.HasPrefix(f.Usage, hidden) {
return
}
res.Flags = append(res.Flags, flagDoc{
Name: f.Name,
Desc: f.Usage,
})
})
}
for _, sub := range cmd.Subcommands {
subj := jsonDocsWalk(sub)
if subj != nil {
res.Subcommands = append(res.Subcommands, *subj)
}
}
return res
}

View File

@@ -789,7 +789,7 @@ func runNetworkLockRevokeKeys(ctx context.Context, args []string) error {
}
fmt.Printf(`Run the following command on another machine with a trusted tailnet lock key:
%s lock recover-compromised-key --cosign %X
%s lock revoke-keys --cosign %X
`, os.Args[0], aumBytes)
return nil
}
@@ -813,10 +813,10 @@ func runNetworkLockRevokeKeys(ctx context.Context, args []string) error {
fmt.Printf(`Co-signing completed successfully.
To accumulate an additional signature, run the following command on another machine with a trusted tailnet lock key:
%s lock recover-compromised-key --cosign %X
%s lock revoke-keys --cosign %X
Alternatively if you are done with co-signing, complete recovery by running the following command:
%s lock recover-compromised-key --finish %X
%s lock revoke-keys --finish %X
`, os.Args[0], aumBytes, os.Args[0], aumBytes)
}

View File

@@ -333,6 +333,9 @@ func (c *Direct) Close() error {
}
}
c.noiseClient = nil
if tr, ok := c.httpc.Transport.(*http.Transport); ok {
tr.CloseIdleConnections()
}
return nil
}

View File

@@ -46,6 +46,7 @@ import (
"tailscale.com/net/sockstats"
"tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tstime"
"tailscale.com/util/multierr"
@@ -497,11 +498,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
tr.DisableCompression = true
// (mis)use httptrace to extract the underlying net.Conn from the
// transport. We make exactly 1 request using this transport, so
// there will be exactly 1 GotConn call. Additionally, the
// transport handles 101 Switching Protocols correctly, such that
// the Conn will not be reused or kept alive by the transport once
// the response has been handed back from RoundTrip.
// transport. The transport handles 101 Switching Protocols correctly,
// such that the Conn will not be reused or kept alive by the transport
// once the response has been handed back from RoundTrip.
//
// In theory, the machinery of net/http should make it such that
// the trace callback happens-before we get the response, but
@@ -517,10 +516,16 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
// unexpected EOFs...), and we're bound to forget someday and
// introduce a protocol optimization at a higher level that starts
// eagerly transmitting from the server.
connCh := make(chan net.Conn, 1)
var lastConn syncs.AtomicValue[net.Conn]
trace := httptrace.ClientTrace{
// Even though we only make a single HTTP request which should
// require a single connection, the context (with the attached
// trace configuration) might be used by our custom dialer to
// make other HTTP requests (e.g. BootstrapDNS). We only care
// about the last connection made, which should be the one to
// the control server.
GotConn: func(info httptrace.GotConnInfo) {
connCh <- info.Conn
lastConn.Store(info.Conn)
},
}
ctx = httptrace.WithClientTrace(ctx, &trace)
@@ -548,11 +553,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
// is still a read buffer attached to it within resp.Body. So, we
// must direct I/O through resp.Body, but we can still use the
// underlying net.Conn for stuff like deadlines.
var switchedConn net.Conn
select {
case switchedConn = <-connCh:
default:
}
switchedConn := lastConn.Load()
if switchedConn == nil {
resp.Body.Close()
return nil, fmt.Errorf("httptrace didn't provide a connection")

View File

@@ -11,10 +11,12 @@ import (
"log"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"runtime"
"slices"
"strconv"
"sync"
"testing"
@@ -41,6 +43,8 @@ type httpTestParam struct {
makeHTTPHangAfterUpgrade bool
doEarlyWrite bool
httpInDial bool
}
func TestControlHTTP(t *testing.T) {
@@ -120,6 +124,12 @@ func TestControlHTTP(t *testing.T) {
name: "early_write",
doEarlyWrite: true,
},
// Dialer needed to make another HTTP request along the way (e.g. to
// resolve the hostname via BootstrapDNS).
{
name: "http_request_in_dial",
httpInDial: true,
},
}
for _, test := range tests {
@@ -217,6 +227,29 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
Clock: clock,
}
if param.httpInDial {
// Spin up a separate server to get a different port on localhost.
secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return }))
defer secondServer.Close()
prev := a.Dialer
a.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", secondServer.URL, nil)
if err != nil {
t.Errorf("http.NewRequest: %v", err)
}
r, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("http.Get: %v", err)
}
r.Body.Close()
return prev(ctx, network, addr)
}
}
if proxy != nil {
proxyEnv := proxy.Start(t)
defer proxy.Close()
@@ -238,6 +271,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
t.Fatalf("dialing controlhttp: %v", err)
}
defer conn.Close()
si := <-sch
if si.conn != nil {
defer si.conn.Close()
@@ -266,6 +300,19 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
}
}
// When no proxy is used, the RemoteAddr of the returned connection should match
// one of the listeners of the test server.
if proxy == nil {
var expectedAddrs []string
for _, ln := range []net.Listener{httpLn, httpsLn} {
expectedAddrs = append(expectedAddrs, fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port))
expectedAddrs = append(expectedAddrs, fmt.Sprintf("[::1]:%d", ln.Addr().(*net.TCPAddr).Port))
}
if !slices.Contains(expectedAddrs, conn.RemoteAddr().String()) {
t.Errorf("unexpected remote addr: %s, want %s", conn.RemoteAddr(), expectedAddrs)
}
}
}
type serverResult struct {

View File

@@ -65,8 +65,8 @@ See also the dependencies in the [Tailscale CLI][].
- [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE))
- [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/cabfb018fe85/LICENSE))
- [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE))
- [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/2f5d148bcfe1/LICENSE))
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/62b9a7c569f9/LICENSE))
- [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/71393c576b98/LICENSE))
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE))
- [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE))
- [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/a3c409a6018e/LICENSE))
- [github.com/vishvananda/netlink/nl](https://pkg.go.dev/github.com/vishvananda/netlink/nl) ([Apache-2.0](https://github.com/vishvananda/netlink/blob/v1.2.1-beta.2/LICENSE))
@@ -82,7 +82,7 @@ See also the dependencies in the [Tailscale CLI][].
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE))
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE))
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE))
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/ee1e1f6070e3/LICENSE))
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE))
- [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) ([ISC](https://github.com/nhooyr/websocket/blob/v1.8.10/LICENSE.txt))
- [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE))

View File

@@ -84,8 +84,8 @@ Some packages may only be included on certain architectures or operating systems
- [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE))
- [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/5db17b287bf1/LICENSE))
- [github.com/tailscale/wf](https://pkg.go.dev/github.com/tailscale/wf) ([BSD-3-Clause](https://github.com/tailscale/wf/blob/6fbb0a674ee6/LICENSE))
- [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/2f5d148bcfe1/LICENSE))
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/62b9a7c569f9/LICENSE))
- [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/71393c576b98/LICENSE))
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE))
- [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE))
- [github.com/toqueteos/webbrowser](https://pkg.go.dev/github.com/toqueteos/webbrowser) ([MIT](https://github.com/toqueteos/webbrowser/blob/v1.2.0/LICENSE.md))
- [github.com/u-root/u-root/pkg/termios](https://pkg.go.dev/github.com/u-root/u-root/pkg/termios) ([BSD-3-Clause](https://github.com/u-root/u-root/blob/v0.12.0/LICENSE))
@@ -95,19 +95,19 @@ Some packages may only be included on certain architectures or operating systems
- [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE))
- [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/4f986261bf13/LICENSE))
- [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE))
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.24.0:LICENSE))
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE))
- [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/1b970713:LICENSE))
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.26.0:LICENSE))
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE))
- [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.16.0:LICENSE))
- [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE))
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.21.0:LICENSE))
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.21.0:LICENSE))
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE))
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE))
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE))
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE))
- [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2))
- [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3))
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/ee1e1f6070e3/LICENSE))
- [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.30.1/LICENSE))
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE))
- [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.30.3/LICENSE))
- [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) ([ISC](https://github.com/nhooyr/websocket/blob/v1.8.10/LICENSE.txt))
- [sigs.k8s.io/yaml](https://pkg.go.dev/sigs.k8s.io/yaml) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/LICENSE))
- [sigs.k8s.io/yaml/goyaml.v2](https://pkg.go.dev/sigs.k8s.io/yaml/goyaml.v2) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/goyaml.v2/LICENSE))

View File

@@ -57,9 +57,9 @@ Windows][]. See also the dependencies in the [Tailscale CLI][].
- [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE))
- [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE))
- [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/cabfb018fe85/LICENSE))
- [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/7601212d8e23/LICENSE))
- [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/4327221bd339/LICENSE))
- [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/6580b55d49ca/LICENSE))
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/62b9a7c569f9/LICENSE))
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE))
- [github.com/tc-hib/winres](https://pkg.go.dev/github.com/tc-hib/winres) ([0BSD](https://github.com/tc-hib/winres/blob/v0.2.1/LICENSE))
- [github.com/vishvananda/netlink/nl](https://pkg.go.dev/github.com/vishvananda/netlink/nl) ([Apache-2.0](https://github.com/vishvananda/netlink/blob/v1.2.1-beta.2/LICENSE))
- [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE))
@@ -69,7 +69,7 @@ Windows][]. See also the dependencies in the [Tailscale CLI][].
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE))
- [golang.org/x/exp/constraints](https://pkg.go.dev/golang.org/x/exp/constraints) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE))
- [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE))
- [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.18.0:LICENSE))
- [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE))
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE))
- [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE))
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE))

View File

@@ -281,6 +281,7 @@ func lookup(ctx context.Context, host string, logf logger.Logf, ht *health.Track
func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr, queryName string, logf logger.Logf, ht *health.Tracker, netMon *netmon.Monitor) (dnsMap, error) {
dialer := netns.NewDialer(logf, netMon)
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true // This transport is meant to be used once.
tr.Proxy = tshttpproxy.ProxyFromEnvironment
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443"))

View File

@@ -13,8 +13,10 @@
package socks5
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
@@ -121,7 +123,7 @@ func (s *Server) Serve(l net.Listener) error {
}
go func() {
defer c.Close()
conn := &Conn{clientConn: c, srv: s}
conn := &Conn{logf: s.Logf, clientConn: c, srv: s}
err := conn.Run()
if err != nil {
s.logf("client connection failed: %v", err)
@@ -136,9 +138,12 @@ type Conn struct {
// The struct is filled by each of the internal
// methods in turn as the transaction progresses.
logf logger.Logf
srv *Server
clientConn net.Conn
request *request
udpClientAddr net.Addr
}
// Run starts the new connection.
@@ -172,58 +177,59 @@ func (c *Conn) Run() error {
func (c *Conn) handleRequest() error {
req, err := parseClientRequest(c.clientConn)
if err != nil {
res := &response{reply: generalFailure}
res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
if req.command != connect {
res := &response{reply: commandNotSupported}
c.request = req
switch req.command {
case connect:
return c.handleTCP()
case udpAssociate:
return c.handleUDP()
default:
res := errorResponse(commandNotSupported)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return fmt.Errorf("unsupported command %v", req.command)
}
c.request = req
}
func (c *Conn) handleTCP() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
srv, err := c.srv.dial(
ctx,
"tcp",
net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))),
c.request.destination.hostPort(),
)
if err != nil {
res := &response{reply: generalFailure}
res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
defer srv.Close()
serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
localAddr := srv.LocalAddr().String()
serverAddr, serverPort, err := splitHostPort(localAddr)
if err != nil {
return err
}
serverPort, _ := strconv.Atoi(serverPortStr)
var bindAddrType addrType
if ip := net.ParseIP(serverAddr); ip != nil {
if ip.To4() != nil {
bindAddrType = ipv4
} else {
bindAddrType = ipv6
}
} else {
bindAddrType = domainName
}
res := &response{
reply: success,
bindAddrType: bindAddrType,
bindAddr: serverAddr,
bindPort: uint16(serverPort),
reply: success,
bindAddr: socksAddr{
addrType: getAddrType(serverAddr),
addr: serverAddr,
port: serverPort,
},
}
buf, err := res.marshal()
if err != nil {
res = &response{reply: generalFailure}
res = errorResponse(generalFailure)
buf, _ = res.marshal()
}
c.clientConn.Write(buf)
@@ -246,6 +252,208 @@ func (c *Conn) handleRequest() error {
return <-errc
}
func (c *Conn) handleUDP() error {
// The DST.ADDR and DST.PORT fields contain the address and port that
// the client expects to use to send UDP datagrams on for the
// association. The server MAY use this information to limit access
// to the association.
// @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928.
//
// We do NOT limit the access from the client currently in this implementation.
_ = c.request.destination
addr := c.clientConn.LocalAddr()
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
return err
}
clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0"))
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
defer clientUDPConn.Close()
serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
defer serverUDPConn.Close()
bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
if err != nil {
return err
}
res := &response{
reply: success,
bindAddr: socksAddr{
addrType: getAddrType(bindAddr),
addr: bindAddr,
port: bindPort,
},
}
buf, err := res.marshal()
if err != nil {
res = errorResponse(generalFailure)
buf, _ = res.marshal()
}
c.clientConn.Write(buf)
return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
}
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
const bufferSize = 8 * 1024
const readTimeout = 5 * time.Second
// client -> target
go func() {
defer cancel()
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) {
return
}
c.logf("udp transfer: handle udp request fail: %v", err)
}
}
}
}()
// target -> client
go func() {
defer cancel()
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) {
return
}
c.logf("udp transfer: handle udp response fail: %v", err)
}
}
}
}()
// A UDP association terminates when the TCP connection that the UDP
// ASSOCIATE request arrived on terminates. RFC1928
_, err := io.Copy(io.Discard, associatedTCP)
if err != nil {
err = fmt.Errorf("udp associated tcp conn: %w", err)
}
return err
}
func (c *Conn) handleUDPRequest(
clientConn net.PacketConn,
targetConn net.PacketConn,
buf []byte,
readTimeout time.Duration,
) error {
// add a deadline for the read to avoid blocking forever
_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
n, addr, err := clientConn.ReadFrom(buf)
if err != nil {
return fmt.Errorf("read from client: %w", err)
}
c.udpClientAddr = addr
req, data, err := parseUDPRequest(buf[:n])
if err != nil {
return fmt.Errorf("parse udp request: %w", err)
}
targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
if err != nil {
c.logf("resolve target addr fail: %v", err)
}
nn, err := targetConn.WriteTo(data, targetAddr)
if err != nil {
return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
}
if nn != len(data) {
return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
}
return nil
}
func (c *Conn) handleUDPResponse(
targetConn net.PacketConn,
clientConn net.PacketConn,
buf []byte,
readTimeout time.Duration,
) error {
// add a deadline for the read to avoid blocking forever
_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
n, addr, err := targetConn.ReadFrom(buf)
if err != nil {
return fmt.Errorf("read from target: %w", err)
}
host, port, err := splitHostPort(addr.String())
if err != nil {
return fmt.Errorf("split host port: %w", err)
}
hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
pkt, err := hdr.marshal()
if err != nil {
return fmt.Errorf("marshal udp request: %w", err)
}
data := append(pkt, buf[:n]...)
// use addr from client to send back
nn, err := clientConn.WriteTo(data, c.udpClientAddr)
if err != nil {
return fmt.Errorf("write to client: %w", err)
}
if nn != len(data) {
return fmt.Errorf("write to client: %w", io.ErrShortWrite)
}
return nil
}
func isTimeout(err error) bool {
terr, ok := errors.Unwrap(err).(interface{ Timeout() bool })
return ok && terr.Timeout()
}
func splitHostPort(hostport string) (host string, port uint16, err error) {
host, portStr, err := net.SplitHostPort(hostport)
if err != nil {
return "", 0, err
}
portInt, err := strconv.Atoi(portStr)
if err != nil {
return "", 0, err
}
if portInt < 0 || portInt > 65535 {
return "", 0, fmt.Errorf("invalid port number %d", portInt)
}
return host, uint16(portInt), nil
}
// parseClientGreeting parses a request initiation packet.
func parseClientGreeting(r io.Reader, authMethod byte) error {
var hdr [2]byte
@@ -295,114 +503,118 @@ func parseClientAuth(r io.Reader) (usr, pwd string, err error) {
return string(usrBytes), string(pwdBytes), nil
}
func getAddrType(addr string) addrType {
if ip := net.ParseIP(addr); ip != nil {
if ip.To4() != nil {
return ipv4
}
return ipv6
}
return domainName
}
// request represents data contained within a SOCKS5
// connection request packet.
type request struct {
command commandType
destination string
port uint16
destAddrType addrType
command commandType
destination socksAddr
}
// parseClientRequest converts raw packet bytes into a
// SOCKS5Request struct.
func parseClientRequest(r io.Reader) (*request, error) {
var hdr [4]byte
var hdr [3]byte
_, err := io.ReadFull(r, hdr[:])
if err != nil {
return nil, fmt.Errorf("could not read packet header")
}
cmd := hdr[1]
destAddrType := addrType(hdr[3])
destination, err := parseSocksAddr(r)
return &request{
command: commandType(cmd),
destination: destination,
}, err
}
type socksAddr struct {
addrType addrType
addr string
port uint16
}
var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
func parseSocksAddr(r io.Reader) (addr socksAddr, err error) {
var addrTypeData [1]byte
_, err = io.ReadFull(r, addrTypeData[:])
if err != nil {
return socksAddr{}, fmt.Errorf("could not read address type")
}
dstAddrType := addrType(addrTypeData[0])
var destination string
var port uint16
if destAddrType == ipv4 {
switch dstAddrType {
case ipv4:
var ip [4]byte
_, err = io.ReadFull(r, ip[:])
if err != nil {
return nil, fmt.Errorf("could not read IPv4 address")
return socksAddr{}, fmt.Errorf("could not read IPv4 address")
}
destination = net.IP(ip[:]).String()
} else if destAddrType == domainName {
case domainName:
var dstSizeByte [1]byte
_, err = io.ReadFull(r, dstSizeByte[:])
if err != nil {
return nil, fmt.Errorf("could not read domain name size")
return socksAddr{}, fmt.Errorf("could not read domain name size")
}
dstSize := int(dstSizeByte[0])
domainName := make([]byte, dstSize)
_, err = io.ReadFull(r, domainName)
if err != nil {
return nil, fmt.Errorf("could not read domain name")
return socksAddr{}, fmt.Errorf("could not read domain name")
}
destination = string(domainName)
} else if destAddrType == ipv6 {
case ipv6:
var ip [16]byte
_, err = io.ReadFull(r, ip[:])
if err != nil {
return nil, fmt.Errorf("could not read IPv6 address")
return socksAddr{}, fmt.Errorf("could not read IPv6 address")
}
destination = net.IP(ip[:]).String()
} else {
return nil, fmt.Errorf("unsupported address type")
default:
return socksAddr{}, fmt.Errorf("unsupported address type")
}
var portBytes [2]byte
_, err = io.ReadFull(r, portBytes[:])
if err != nil {
return nil, fmt.Errorf("could not read port")
return socksAddr{}, fmt.Errorf("could not read port")
}
port = binary.BigEndian.Uint16(portBytes[:])
return &request{
command: commandType(cmd),
destination: destination,
port: port,
destAddrType: destAddrType,
port := binary.BigEndian.Uint16(portBytes[:])
return socksAddr{
addrType: dstAddrType,
addr: destination,
port: port,
}, nil
}
// response contains the contents of
// a response packet sent from the proxy
// to the client.
type response struct {
reply replyCode
bindAddrType addrType
bindAddr string
bindPort uint16
}
// marshal converts a SOCKS5Response struct into
// a packet. If res.reply == Success, it may throw an error on
// receiving an invalid bind address. Otherwise, it will not throw.
func (res *response) marshal() ([]byte, error) {
pkt := make([]byte, 4)
pkt[0] = socks5Version
pkt[1] = byte(res.reply)
pkt[2] = 0 // null reserved byte
pkt[3] = byte(res.bindAddrType)
if res.reply != success {
return pkt, nil
}
func (s socksAddr) marshal() ([]byte, error) {
var addr []byte
switch res.bindAddrType {
switch s.addrType {
case ipv4:
addr = net.ParseIP(res.bindAddr).To4()
addr = net.ParseIP(s.addr).To4()
if addr == nil {
return nil, fmt.Errorf("invalid IPv4 address for binding")
}
case domainName:
if len(res.bindAddr) > 255 {
if len(s.addr) > 255 {
return nil, fmt.Errorf("invalid domain name for binding")
}
addr = make([]byte, 0, len(res.bindAddr)+1)
addr = append(addr, byte(len(res.bindAddr)))
addr = append(addr, []byte(res.bindAddr)...)
addr = make([]byte, 0, len(s.addr)+1)
addr = append(addr, byte(len(s.addr)))
addr = append(addr, []byte(s.addr)...)
case ipv6:
addr = net.ParseIP(res.bindAddr).To16()
addr = net.ParseIP(s.addr).To16()
if addr == nil {
return nil, fmt.Errorf("invalid IPv6 address for binding")
}
@@ -410,8 +622,86 @@ func (res *response) marshal() ([]byte, error) {
return nil, fmt.Errorf("unsupported address type")
}
pkt := []byte{byte(s.addrType)}
pkt = append(pkt, addr...)
pkt = binary.BigEndian.AppendUint16(pkt, uint16(res.bindPort))
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
return pkt, nil
}
func (s socksAddr) hostPort() string {
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
}
// response contains the contents of
// a response packet sent from the proxy
// to the client.
type response struct {
reply replyCode
bindAddr socksAddr
}
func errorResponse(code replyCode) *response {
return &response{reply: code, bindAddr: zeroSocksAddr}
}
// marshal converts a SOCKS5Response struct into
// a packet. If res.reply == Success, it may throw an error on
// receiving an invalid bind address. Otherwise, it will not throw.
func (res *response) marshal() ([]byte, error) {
pkt := make([]byte, 3)
pkt[0] = socks5Version
pkt[1] = byte(res.reply)
pkt[2] = 0 // null reserved byte
addrPkt, err := res.bindAddr.marshal()
if err != nil {
return nil, err
}
return append(pkt, addrPkt...), nil
}
type udpRequest struct {
frag byte
addr socksAddr
}
// +----+------+------+----------+----------+----------+
// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
// +----+------+------+----------+----------+----------+
// | 2 | 1 | 1 | Variable | 2 | Variable |
// +----+------+------+----------+----------+----------+
func parseUDPRequest(data []byte) (*udpRequest, []byte, error) {
if len(data) < 4 {
return nil, nil, fmt.Errorf("invalid packet length")
}
// reserved bytes
if !(data[0] == 0 && data[1] == 0) {
return nil, nil, fmt.Errorf("invalid udp request header")
}
frag := data[2]
reader := bytes.NewReader(data[3:])
addr, err := parseSocksAddr(reader)
bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length
body := data[len(data)-bodyLen:]
return &udpRequest{
frag: frag,
addr: addr,
}, body, err
}
func (u *udpRequest) marshal() ([]byte, error) {
pkt := make([]byte, 3)
pkt[0] = 0
pkt[1] = 0
pkt[2] = u.frag
addrPkt, err := u.addr.marshal()
if err != nil {
return nil, err
}
return append(pkt, addrPkt...), nil
}

View File

@@ -4,6 +4,7 @@
package socks5
import (
"bytes"
"errors"
"fmt"
"io"
@@ -32,6 +33,19 @@ func backendServer(listener net.Listener) {
listener.Close()
}
func udpEchoServer(conn net.PacketConn) {
var buf [1024]byte
n, addr, err := conn.ReadFrom(buf[:])
if err != nil {
panic(err)
}
_, err = conn.WriteTo(buf[:n], addr)
if err != nil {
panic(err)
}
conn.Close()
}
func TestRead(t *testing.T) {
// backend server which we'll use SOCKS5 to connect to
listener, err := net.Listen("tcp", ":0")
@@ -152,3 +166,102 @@ func TestReadPassword(t *testing.T) {
t.Fatal(err)
}
}
func TestUDP(t *testing.T) {
// backend UDP server which we'll use SOCKS5 to connect to
listener, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatal(err)
}
backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
go udpEchoServer(listener)
// SOCKS5 server
socks5, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
socks5Port := socks5.Addr().(*net.TCPAddr).Port
go socks5Server(socks5)
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
if err != nil {
t.Fatal(err)
}
_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := conn.Read(buf) // server hello
if err != nil {
t.Fatal(err)
}
if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
}
targetAddr := socksAddr{
addrType: domainName,
addr: "localhost",
port: uint16(backendServerPort),
}
targetAddrPkt, err := targetAddr.marshal()
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
if err != nil {
t.Fatal(err)
}
n, err = conn.Read(buf) // server response
if err != nil {
t.Fatal(err)
}
if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
}
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
if err != nil {
t.Fatal(err)
}
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
if err != nil {
t.Fatal(err)
}
udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
if err != nil {
t.Fatal(err)
}
udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
if err != nil {
t.Fatal(err)
}
udpPayload = append(udpPayload, []byte("Test")...)
_, err = udpConn.Write(udpPayload) // send udp package
if err != nil {
t.Fatal(err)
}
n, _, err = udpConn.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
if err != nil {
t.Fatal(err)
}
if string(responseBody) != "Test" {
t.Fatalf("got: %q want: Test", responseBody)
}
err = udpConn.Close()
if err != nil {
t.Fatal(err)
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -166,6 +166,7 @@ func (d *Dialer) Close() error {
c.Close()
}
d.activeSysConns = nil
d.PeerAPITransport().CloseIdleConnections()
return nil
}

View File

@@ -7,19 +7,26 @@
package prober
import (
"container/ring"
"context"
"errors"
"encoding/json"
"fmt"
"hash/fnv"
"log"
"maps"
"math/rand"
"net/http"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"tailscale.com/tsweb"
)
// recentHistSize is the number of recent probe results and latencies to keep
// in memory.
const recentHistSize = 10
// ProbeClass defines a probe of a specific type: a probing function that will
// be regularly ran, and metric labels that will be added automatically to all
// probes using this class.
@@ -106,6 +113,14 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob
l[k] = v
}
probe := newProbe(p, name, interval, l, pc)
p.probes[name] = probe
go probe.loop()
return probe
}
// newProbe creates a new Probe with the given parameters, but does not start it.
func newProbe(p *Prober, name string, interval time.Duration, l prometheus.Labels, pc ProbeClass) *Probe {
ctx, cancel := context.WithCancel(context.Background())
probe := &Probe{
prober: p,
@@ -117,6 +132,9 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob
probeClass: pc,
interval: interval,
initialDelay: initialDelay(name, interval),
successHist: ring.New(recentHistSize),
latencyHist: ring.New(recentHistSize),
metrics: prometheus.NewRegistry(),
metricLabels: l,
mInterval: prometheus.NewDesc("interval_secs", "Probe interval in seconds", nil, l),
@@ -131,15 +149,14 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob
Name: "seconds_total", Help: "Total amount of time spent executing the probe", ConstLabels: l,
}, []string{"status"}),
}
prometheus.WrapRegistererWithPrefix(p.namespace+"_", p.metrics).MustRegister(probe.metrics)
if p.metrics != nil {
prometheus.WrapRegistererWithPrefix(p.namespace+"_", p.metrics).MustRegister(probe.metrics)
}
probe.metrics.MustRegister(probe)
p.probes[name] = probe
go probe.loop()
return probe
}
// unregister removes a probe from the prober's internal state.
func (p *Prober) unregister(probe *Probe) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -206,6 +223,7 @@ type Probe struct {
ctx context.Context
cancel context.CancelFunc // run to initiate shutdown
stopped chan struct{} // closed when shutdown is complete
runMu sync.Mutex // ensures only one probe runs at a time
name string
probeClass ProbeClass
@@ -232,6 +250,10 @@ type Probe struct {
latency time.Duration // last successful probe latency
succeeded bool // whether the last doProbe call succeeded
lastErr error
// History of recent probe results and latencies.
successHist *ring.Ring
latencyHist *ring.Ring
}
// Close shuts down the Probe and unregisters it from its Prober.
@@ -278,13 +300,17 @@ func (p *Probe) loop() {
}
}
// run invokes fun and records the results.
// run invokes the probe function and records the result. It returns the probe
// result and an error if the probe failed.
//
// fun is invoked with a timeout slightly less than interval, so that
// the probe either succeeds or fails before the next cycle is
// scheduled to start.
func (p *Probe) run() {
start := p.recordStart()
// The probe function is invoked with a timeout slightly less than interval, so
// that the probe either succeeds or fails before the next cycle is scheduled to
// start.
func (p *Probe) run() (pi ProbeInfo, err error) {
p.runMu.Lock()
defer p.runMu.Unlock()
p.recordStart()
defer func() {
// Prevent a panic within one probe function from killing the
// entire prober, so that a single buggy probe doesn't destroy
@@ -293,29 +319,30 @@ func (p *Probe) run() {
// alert for debugging.
if r := recover(); r != nil {
log.Printf("probe %s panicked: %v", p.name, r)
p.recordEnd(start, errors.New("panic"))
err = fmt.Errorf("panic: %v", r)
p.recordEnd(err)
}
}()
timeout := time.Duration(float64(p.interval) * 0.8)
ctx, cancel := context.WithTimeout(p.ctx, timeout)
defer cancel()
err := p.probeClass.Probe(ctx)
p.recordEnd(start, err)
err = p.probeClass.Probe(ctx)
p.recordEnd(err)
if err != nil {
log.Printf("probe %s: %v", p.name, err)
}
pi = p.probeInfoLocked()
return
}
func (p *Probe) recordStart() time.Time {
st := p.prober.now()
func (p *Probe) recordStart() {
p.mu.Lock()
defer p.mu.Unlock()
p.start = st
return st
p.start = p.prober.now()
p.mu.Unlock()
}
func (p *Probe) recordEnd(start time.Time, err error) {
func (p *Probe) recordEnd(err error) {
end := p.prober.now()
p.mu.Lock()
defer p.mu.Unlock()
@@ -327,22 +354,55 @@ func (p *Probe) recordEnd(start time.Time, err error) {
p.latency = latency
p.mAttempts.WithLabelValues("ok").Inc()
p.mSeconds.WithLabelValues("ok").Add(latency.Seconds())
p.latencyHist.Value = latency
p.latencyHist = p.latencyHist.Next()
} else {
p.latency = 0
p.mAttempts.WithLabelValues("fail").Inc()
p.mSeconds.WithLabelValues("fail").Add(latency.Seconds())
}
p.successHist.Value = p.succeeded
p.successHist = p.successHist.Next()
}
// ProbeInfo is the state of a Probe.
// ProbeInfo is a snapshot of the configuration and state of a Probe.
type ProbeInfo struct {
Start time.Time
End time.Time
Latency string
Result bool
Error string
Name string
Class string
Interval time.Duration
Labels map[string]string
Start time.Time
End time.Time
Latency time.Duration
Result bool
Error string
RecentResults []bool
RecentLatencies []time.Duration
}
// RecentSuccessRatio returns the success ratio of the probe in the recent history.
func (pb ProbeInfo) RecentSuccessRatio() float64 {
if len(pb.RecentResults) == 0 {
return 0
}
var sum int
for _, r := range pb.RecentResults {
if r {
sum++
}
}
return float64(sum) / float64(len(pb.RecentResults))
}
// RecentMedianLatency returns the median latency of the probe in the recent history.
func (pb ProbeInfo) RecentMedianLatency() time.Duration {
if len(pb.RecentLatencies) == 0 {
return 0
}
return pb.RecentLatencies[len(pb.RecentLatencies)/2]
}
// ProbeInfo returns the state of all probes.
func (p *Prober) ProbeInfo() map[string]ProbeInfo {
out := map[string]ProbeInfo{}
@@ -352,26 +412,100 @@ func (p *Prober) ProbeInfo() map[string]ProbeInfo {
probes = append(probes, probe)
}
p.mu.Unlock()
for _, probe := range probes {
probe.mu.Lock()
inf := ProbeInfo{
Start: probe.start,
End: probe.end,
Result: probe.succeeded,
}
if probe.lastErr != nil {
inf.Error = probe.lastErr.Error()
}
if probe.latency > 0 {
inf.Latency = probe.latency.String()
}
out[probe.name] = inf
out[probe.name] = probe.probeInfoLocked()
probe.mu.Unlock()
}
return out
}
// probeInfoLocked returns the state of the probe.
func (probe *Probe) probeInfoLocked() ProbeInfo {
inf := ProbeInfo{
Name: probe.name,
Class: probe.probeClass.Class,
Interval: probe.interval,
Labels: probe.metricLabels,
Start: probe.start,
End: probe.end,
Result: probe.succeeded,
}
if probe.lastErr != nil {
inf.Error = probe.lastErr.Error()
}
if probe.latency > 0 {
inf.Latency = probe.latency
}
probe.latencyHist.Do(func(v any) {
if l, ok := v.(time.Duration); ok {
inf.RecentLatencies = append(inf.RecentLatencies, l)
}
})
probe.successHist.Do(func(v any) {
if r, ok := v.(bool); ok {
inf.RecentResults = append(inf.RecentResults, r)
}
})
return inf
}
// RunHandlerResponse is the JSON response format for the RunHandler.
type RunHandlerResponse struct {
ProbeInfo ProbeInfo
PreviousSuccessRatio float64
PreviousMedianLatency time.Duration
}
// RunHandler runs a probe by name and returns the result as an HTTP response.
func (p *Prober) RunHandler(w http.ResponseWriter, r *http.Request) error {
// Look up prober by name.
name := r.FormValue("name")
if name == "" {
return tsweb.Error(http.StatusBadRequest, "missing name parameter", nil)
}
p.mu.Lock()
probe, ok := p.probes[name]
p.mu.Unlock()
if !ok {
return tsweb.Error(http.StatusNotFound, fmt.Sprintf("unknown probe %q", name), nil)
}
probe.mu.Lock()
prevInfo := probe.probeInfoLocked()
probe.mu.Unlock()
info, err := probe.run()
respStatus := http.StatusOK
if err != nil {
respStatus = http.StatusFailedDependency
}
// Return serialized JSON response if the client requested JSON
if r.Header.Get("Accept") == "application/json" {
resp := &RunHandlerResponse{
ProbeInfo: info,
PreviousSuccessRatio: prevInfo.RecentSuccessRatio(),
PreviousMedianLatency: prevInfo.RecentMedianLatency(),
}
w.WriteHeader(respStatus)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
return tsweb.Error(http.StatusInternalServerError, "error encoding JSON response", err)
}
return nil
}
stats := fmt.Sprintf("Previous runs: success rate %d%%, median latency %v",
int(prevInfo.RecentSuccessRatio()*100), prevInfo.RecentMedianLatency())
if err != nil {
return tsweb.Error(respStatus, fmt.Sprintf("Probe failed: %s\n%s", err.Error(), stats), err)
}
w.WriteHeader(respStatus)
w.Write([]byte(fmt.Sprintf("Probe succeeded in %v\n%s", info.Latency, stats)))
return nil
}
// Describe implements prometheus.Collector.
func (p *Probe) Describe(ch chan<- *prometheus.Desc) {
ch <- p.mInterval

View File

@@ -5,16 +5,22 @@ package prober
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/prometheus/client_golang/prometheus/testutil"
"tailscale.com/tstest"
"tailscale.com/tsweb"
)
const (
@@ -292,6 +298,254 @@ func TestOnceMode(t *testing.T) {
}
}
func TestProberProbeInfo(t *testing.T) {
clk := newFakeTime()
p := newForTest(clk.Now, clk.NewTicker).WithOnce(true)
p.Run("probe1", probeInterval, nil, FuncProbe(func(context.Context) error {
clk.Advance(500 * time.Millisecond)
return nil
}))
p.Run("probe2", probeInterval, nil, FuncProbe(func(context.Context) error { return fmt.Errorf("error2") }))
p.Wait()
info := p.ProbeInfo()
wantInfo := map[string]ProbeInfo{
"probe1": {
Name: "probe1",
Interval: probeInterval,
Labels: map[string]string{"class": "", "name": "probe1"},
Latency: 500 * time.Millisecond,
Result: true,
RecentResults: []bool{true},
RecentLatencies: []time.Duration{500 * time.Millisecond},
},
"probe2": {
Name: "probe2",
Interval: probeInterval,
Labels: map[string]string{"class": "", "name": "probe2"},
Error: "error2",
RecentResults: []bool{false},
RecentLatencies: nil, // no latency for failed probes
},
}
if diff := cmp.Diff(wantInfo, info, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End")); diff != "" {
t.Fatalf("unexpected ProbeInfo (-want +got):\n%s", diff)
}
}
func TestProbeInfoRecent(t *testing.T) {
type probeResult struct {
latency time.Duration
err error
}
tests := []struct {
name string
results []probeResult
wantProbeInfo ProbeInfo
wantRecentSuccessRatio float64
wantRecentMedianLatency time.Duration
}{
{
name: "no_runs",
wantProbeInfo: ProbeInfo{},
wantRecentSuccessRatio: 0,
wantRecentMedianLatency: 0,
},
{
name: "single_success",
results: []probeResult{{latency: 100 * time.Millisecond, err: nil}},
wantProbeInfo: ProbeInfo{
Latency: 100 * time.Millisecond,
Result: true,
RecentResults: []bool{true},
RecentLatencies: []time.Duration{100 * time.Millisecond},
},
wantRecentSuccessRatio: 1,
wantRecentMedianLatency: 100 * time.Millisecond,
},
{
name: "single_failure",
results: []probeResult{{latency: 100 * time.Millisecond, err: errors.New("error123")}},
wantProbeInfo: ProbeInfo{
Result: false,
RecentResults: []bool{false},
RecentLatencies: nil,
Error: "error123",
},
wantRecentSuccessRatio: 0,
wantRecentMedianLatency: 0,
},
{
name: "recent_mix",
results: []probeResult{
{latency: 10 * time.Millisecond, err: errors.New("error1")},
{latency: 20 * time.Millisecond, err: nil},
{latency: 30 * time.Millisecond, err: nil},
{latency: 40 * time.Millisecond, err: errors.New("error4")},
{latency: 50 * time.Millisecond, err: nil},
{latency: 60 * time.Millisecond, err: nil},
{latency: 70 * time.Millisecond, err: errors.New("error7")},
{latency: 80 * time.Millisecond, err: nil},
},
wantProbeInfo: ProbeInfo{
Result: true,
Latency: 80 * time.Millisecond,
RecentResults: []bool{false, true, true, false, true, true, false, true},
RecentLatencies: []time.Duration{
20 * time.Millisecond,
30 * time.Millisecond,
50 * time.Millisecond,
60 * time.Millisecond,
80 * time.Millisecond,
},
},
wantRecentSuccessRatio: 0.625,
wantRecentMedianLatency: 50 * time.Millisecond,
},
{
name: "only_last_10",
results: []probeResult{
{latency: 10 * time.Millisecond, err: errors.New("old_error")},
{latency: 20 * time.Millisecond, err: nil},
{latency: 30 * time.Millisecond, err: nil},
{latency: 40 * time.Millisecond, err: nil},
{latency: 50 * time.Millisecond, err: nil},
{latency: 60 * time.Millisecond, err: nil},
{latency: 70 * time.Millisecond, err: nil},
{latency: 80 * time.Millisecond, err: nil},
{latency: 90 * time.Millisecond, err: nil},
{latency: 100 * time.Millisecond, err: nil},
{latency: 110 * time.Millisecond, err: nil},
},
wantProbeInfo: ProbeInfo{
Result: true,
Latency: 110 * time.Millisecond,
RecentResults: []bool{true, true, true, true, true, true, true, true, true, true},
RecentLatencies: []time.Duration{
20 * time.Millisecond,
30 * time.Millisecond,
40 * time.Millisecond,
50 * time.Millisecond,
60 * time.Millisecond,
70 * time.Millisecond,
80 * time.Millisecond,
90 * time.Millisecond,
100 * time.Millisecond,
110 * time.Millisecond,
},
},
wantRecentSuccessRatio: 1,
wantRecentMedianLatency: 70 * time.Millisecond,
},
}
clk := newFakeTime()
p := newForTest(clk.Now, clk.NewTicker).WithOnce(true)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
probe := newProbe(p, "", probeInterval, nil, FuncProbe(func(context.Context) error { return nil }))
for _, r := range tt.results {
probe.recordStart()
clk.Advance(r.latency)
probe.recordEnd(r.err)
}
info := probe.probeInfoLocked()
if diff := cmp.Diff(tt.wantProbeInfo, info, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Interval")); diff != "" {
t.Fatalf("unexpected ProbeInfo (-want +got):\n%s", diff)
}
if got := info.RecentSuccessRatio(); got != tt.wantRecentSuccessRatio {
t.Errorf("recentSuccessRatio() = %v, want %v", got, tt.wantRecentSuccessRatio)
}
if got := info.RecentMedianLatency(); got != tt.wantRecentMedianLatency {
t.Errorf("recentMedianLatency() = %v, want %v", got, tt.wantRecentMedianLatency)
}
})
}
}
func TestProberRunHandler(t *testing.T) {
clk := newFakeTime()
tests := []struct {
name string
probeFunc func(context.Context) error
wantResponseCode int
wantJSONResponse RunHandlerResponse
wantPlaintextResponse string
}{
{
name: "success",
probeFunc: func(context.Context) error { return nil },
wantResponseCode: 200,
wantJSONResponse: RunHandlerResponse{
ProbeInfo: ProbeInfo{
Name: "success",
Interval: probeInterval,
Result: true,
RecentResults: []bool{true, true},
},
PreviousSuccessRatio: 1,
},
wantPlaintextResponse: "Probe succeeded",
},
{
name: "failure",
probeFunc: func(context.Context) error { return fmt.Errorf("error123") },
wantResponseCode: 424,
wantJSONResponse: RunHandlerResponse{
ProbeInfo: ProbeInfo{
Name: "failure",
Interval: probeInterval,
Result: false,
Error: "error123",
RecentResults: []bool{false, false},
},
},
wantPlaintextResponse: "Probe failed",
},
}
for _, tt := range tests {
for _, reqJSON := range []bool{true, false} {
t.Run(fmt.Sprintf("%s_json-%v", tt.name, reqJSON), func(t *testing.T) {
p := newForTest(clk.Now, clk.NewTicker).WithOnce(true)
probe := p.Run(tt.name, probeInterval, nil, FuncProbe(tt.probeFunc))
defer probe.Close()
<-probe.stopped // wait for the first run.
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/prober/run/?name="+tt.name, nil)
if reqJSON {
req.Header.Set("Accept", "application/json")
}
tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{}).ServeHTTP(w, req)
if w.Result().StatusCode != tt.wantResponseCode {
t.Errorf("unexpected response code: got %d, want %d", w.Code, tt.wantResponseCode)
}
if reqJSON {
var gotJSON RunHandlerResponse
if err := json.Unmarshal(w.Body.Bytes(), &gotJSON); err != nil {
t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, w.Body.String())
}
if diff := cmp.Diff(tt.wantJSONResponse, gotJSON, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Labels", "RecentLatencies")); diff != "" {
t.Errorf("unexpected JSON response (-want +got):\n%s", diff)
}
} else {
body, _ := io.ReadAll(w.Result().Body)
if !strings.Contains(string(body), tt.wantPlaintextResponse) {
t.Errorf("unexpected response body: got %q, want to contain %q", body, tt.wantPlaintextResponse)
}
}
})
}
}
}
type fakeTicker struct {
ch chan time.Time
interval time.Duration

124
prober/status.go Normal file
View File

@@ -0,0 +1,124 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package prober
import (
"embed"
"fmt"
"html/template"
"net/http"
"strings"
"time"
"tailscale.com/tsweb"
"tailscale.com/util/mak"
)
//go:embed status.html
var statusFiles embed.FS
var statusTpl = template.Must(template.ParseFS(statusFiles, "status.html"))
type statusHandlerOpt func(*statusHandlerParams)
type statusHandlerParams struct {
title string
pageLinks map[string]string
probeLinks map[string]string
}
// WithTitle sets the title of the status page.
func WithTitle(title string) statusHandlerOpt {
return func(opts *statusHandlerParams) {
opts.title = title
}
}
// WithPageLink adds a top-level link to the status page.
func WithPageLink(text, url string) statusHandlerOpt {
return func(opts *statusHandlerParams) {
mak.Set(&opts.pageLinks, text, url)
}
}
// WithProbeLink adds a link to each probe on the status page.
// The textTpl and urlTpl are Go templates that will be rendered
// with the respective ProbeInfo struct as the data.
func WithProbeLink(textTpl, urlTpl string) statusHandlerOpt {
return func(opts *statusHandlerParams) {
mak.Set(&opts.probeLinks, textTpl, urlTpl)
}
}
// StatusHandler is a handler for the probe overview HTTP endpoint.
// It shows a list of probes and their current status.
func (p *Prober) StatusHandler(opts ...statusHandlerOpt) tsweb.ReturnHandlerFunc {
params := &statusHandlerParams{
title: "Prober Status",
}
for _, opt := range opts {
opt(params)
}
return func(w http.ResponseWriter, r *http.Request) error {
type probeStatus struct {
ProbeInfo
TimeSinceLast time.Duration
Links map[string]template.URL
}
vars := struct {
Title string
Links map[string]template.URL
TotalProbes int64
UnhealthyProbes int64
Probes map[string]probeStatus
}{
Title: params.title,
}
for text, url := range params.pageLinks {
mak.Set(&vars.Links, text, template.URL(url))
}
for name, info := range p.ProbeInfo() {
vars.TotalProbes++
if !info.Result {
vars.UnhealthyProbes++
}
s := probeStatus{ProbeInfo: info}
if !info.End.IsZero() {
s.TimeSinceLast = time.Since(info.End)
}
for textTpl, urlTpl := range params.probeLinks {
text, err := renderTemplate(textTpl, info)
if err != nil {
return tsweb.Error(500, err.Error(), err)
}
url, err := renderTemplate(urlTpl, info)
if err != nil {
return tsweb.Error(500, err.Error(), err)
}
mak.Set(&s.Links, text, template.URL(url))
}
mak.Set(&vars.Probes, name, s)
}
if err := statusTpl.ExecuteTemplate(w, "status", vars); err != nil {
return tsweb.HTTPError{Code: 500, Err: err, Msg: "error rendering status page"}
}
return nil
}
}
// renderTemplate renders the given Go template with the provided data
// and returns the result as a string.
func renderTemplate(tpl string, data any) (string, error) {
t, err := template.New("").Parse(tpl)
if err != nil {
return "", fmt.Errorf("error parsing template %q: %w", tpl, err)
}
var buf strings.Builder
if err := t.ExecuteTemplate(&buf, "", data); err != nil {
return "", fmt.Errorf("error rendering template %q with data %v: %w", tpl, data, err)
}
return buf.String(), nil
}

132
prober/status.html Normal file
View File

@@ -0,0 +1,132 @@
{{define "status"}}
<html>
<head><title>{{.Title}}</title></head>
<style>
body {
/* max-width: 60rem; */
margin-left: auto;
margin-right: auto;
padding: 3rem 1rem 8rem;
line-height: 1.4;
font-size: 1rem;
font-weight: 400;
font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, Arial, Noto Sans, sans-serif, Apple Color Emoji, Segoe UI Emoji, Segoe UI Symbol, Noto Color Emoji;
text-rendering: optimizeLegibility;
}
.small {
font-size: 0.7rem;
}
h1 {
font-weight: 500;
letter-spacing: -.025em;
}
a { color: rgb(74 125 221); }
a:hover { color: rgb(73 100 149); }
ul {
list-style: none;
margin: 0;
padding: 0;
}
ul>li::before {
position: absolute;
top: .625rem;
left: .125rem;
height: .375rem;
width: .375rem;
border-radius: 9999px;
background-color: currentColor;
opacity: .4;
content: "";
}
ul>li {
position: relative;
padding-left: 1.25rem;
}
th, td {
padding: 5px;
text-align: left;
background: #eeeeee;
}
.error {
color: red;
}
</style>
<body>
<h1>{{.Title}}</h1>
<ul>
<li>Prober Status:
{{if .UnhealthyProbes }}
<span class="error">{{.UnhealthyProbes}}</span>
out of {{.TotalProbes}} probes failed or never ran.
{{else}}
All {{.TotalProbes}} probes are healthy
{{end}}
</li>
{{ range $text, $url := .Links }}
<li><a href="{{$url}}">{{$text}}</a></li>
{{end}}
</ul>
<h1>Probes:</h1>
<table class="sortable">
<thead><tr>
<th>Name</th>
<th>Class & Labels</th>
<th>Interval</th>
<th>Result</th>
<th>Success</th>
<th>Latency</th>
<th>Error</th>
</tr></thead>
<tbody>
{{range $name, $probeInfo := .Probes}}
<tr>
<td>
{{$name}}
{{range $text, $url := $probeInfo.Links}}
<br/>
<button onclick="location.href='{{$url}}';" type="button">
{{$text}}
</button>
{{end}}
</td>
<td>{{$probeInfo.Class}}<br/>
<div class="small">
{{range $label, $value := $probeInfo.Labels}}
{{$label}}={{$value}}<br/>
{{end}}
</div>
</td>
<td>{{$probeInfo.Interval}}</td>
<td data-sort="{{$probeInfo.TimeSinceLast.Milliseconds}}">
{{if $probeInfo.TimeSinceLast}}
{{$probeInfo.TimeSinceLast.String}}<br/>
<span class="small">{{$probeInfo.End}}</span>
{{else}}
Never
{{end}}
</td>
<td>
{{if $probeInfo.Result}}
{{$probeInfo.Result}}
{{else}}
<span class="error">{{$probeInfo.Result}}</span>
{{end}}<br/>
<div class="small">Recent: {{$probeInfo.RecentResults}}</div>
<div class="small">Mean: {{$probeInfo.RecentSuccessRatio}}</div>
</td>
<td data-sort="{{$probeInfo.Latency.Milliseconds}}">
{{$probeInfo.Latency.String}}
<div class="small">Recent: {{$probeInfo.RecentLatencies}}</div>
<div class="small">Median: {{$probeInfo.RecentMedianLatency}}</div>
</td>
<td class="small">{{$probeInfo.Error}}</td>
</tr>
{{end}}
</tbody>
</table>
<link href="https://cdn.jsdelivr.net/gh/tofsjonas/sortable@latest/sortable-base.min.css" rel="stylesheet" />
<script src="https://cdn.jsdelivr.net/gh/tofsjonas/sortable@latest/sortable.min.js"></script>
</body>
</html>
{{end}}

View File

@@ -4,200 +4,22 @@
package magicsock
import (
"errors"
"net"
"net/netip"
"sync"
"sync/atomic"
"syscall"
"time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"tailscale.com/net/neterror"
"tailscale.com/types/nettype"
)
// xnetBatchReaderWriter defines the batching i/o methods of
// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn).
// TODO(jwhited): This should eventually be replaced with the standard library
// implementation of https://github.com/golang/go/issues/45886
type xnetBatchReaderWriter interface {
xnetBatchReader
xnetBatchWriter
}
type xnetBatchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type xnetBatchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
// batchingUDPConn is a UDP socket that provides batched i/o.
type batchingUDPConn struct {
pc nettype.PacketConn
xpc xnetBatchReaderWriter
rxOffload bool // supports UDP GRO or similar
txOffload atomic.Bool // supports UDP GSO or similar
setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing
getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing
sendBatchPool sync.Pool
}
func (c *batchingUDPConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
if c.rxOffload {
// UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may
// receive a "monster datagram" from any read call. The ReadFrom() API
// does not support passing the GSO size and is unsafe to use in such a
// case. Other platforms may vary in behavior, but we go with the most
// conservative approach to prevent this from becoming a footgun in the
// future.
return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
}
return c.pc.ReadFromUDPAddrPort(p)
}
func (c *batchingUDPConn) SetDeadline(t time.Time) error {
return c.pc.SetDeadline(t)
}
func (c *batchingUDPConn) SetReadDeadline(t time.Time) error {
return c.pc.SetReadDeadline(t)
}
func (c *batchingUDPConn) SetWriteDeadline(t time.Time) error {
return c.pc.SetWriteDeadline(t)
}
const (
// This was initially established for Linux, but may split out to
// GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the
// kernel's TX path, and UDP_GRO_CNT_MAX for RX.
udpSegmentMaxDatagrams = 64
var (
// This acts as a compile-time check for our usage of ipv6.Message in
// batchingConn for both IPv6 and IPv4 operations.
_ ipv6.Message = ipv4.Message{}
)
const (
// Exceeding these values results in EMSGSIZE.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
)
// coalesceMessages iterates msgs, coalescing them where possible while
// maintaining datagram order. All msgs have their Addr field set to addr.
func (c *batchingUDPConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of buffs
)
maxPayloadLen := maxIPv4PayloadLen
if addr.IP.To4() == nil {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buff := range buffs {
if i > 0 {
msgLen := len(buff)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...)
copy(msgs[base].Buffers[0][baseLenBefore:], buff)
if i == len(buffs)-1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buff)
msgs[base].OOB = msgs[base].OOB[:0]
msgs[base].Buffers[0] = buff
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type sendBatch struct {
msgs []ipv6.Message
ua *net.UDPAddr
}
func (c *batchingUDPConn) getSendBatch() *sendBatch {
batch := c.sendBatchPool.Get().(*sendBatch)
return batch
}
func (c *batchingUDPConn) putSendBatch(batch *sendBatch) {
for i := range batch.msgs {
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB}
}
c.sendBatchPool.Put(batch)
}
func (c *batchingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
batch := c.getSendBatch()
defer c.putSendBatch(batch)
if addr.Addr().Is6() {
as16 := addr.Addr().As16()
copy(batch.ua.IP, as16[:])
batch.ua.IP = batch.ua.IP[:16]
} else {
as4 := addr.Addr().As4()
copy(batch.ua.IP, as4[:])
batch.ua.IP = batch.ua.IP[:4]
}
batch.ua.Port = int(addr.Port())
var (
n int
retried bool
)
retry:
if c.txOffload.Load() {
n = c.coalesceMessages(batch.ua, buffs, batch.msgs)
} else {
for i := range buffs {
batch.msgs[i].Buffers[0] = buffs[i]
batch.msgs[i].Addr = batch.ua
batch.msgs[i].OOB = batch.msgs[i].OOB[:0]
}
n = len(buffs)
}
err := c.writeBatch(batch.msgs[:n])
if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) {
c.txOffload.Store(false)
retried = true
goto retry
}
if retried {
return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err}
}
return err
}
func (c *batchingUDPConn) SyscallConn() (syscall.RawConn, error) {
sc, ok := c.pc.(syscall.Conn)
if !ok {
return nil, errUnsupportedConnType
}
return sc.SyscallConn()
// batchingConn is a nettype.PacketConn that provides batched i/o.
type batchingConn interface {
nettype.PacketConn
ReadBatch(msgs []ipv6.Message, flags int) (n int, err error)
WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error
}

View File

@@ -0,0 +1,14 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux
package magicsock
import (
"tailscale.com/types/nettype"
)
func tryUpgradeToBatchingConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn {
return pconn
}

View File

@@ -0,0 +1,419 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"tailscale.com/hostinfo"
"tailscale.com/net/neterror"
"tailscale.com/types/nettype"
)
// xnetBatchReaderWriter defines the batching i/o methods of
// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn).
// TODO(jwhited): This should eventually be replaced with the standard library
// implementation of https://github.com/golang/go/issues/45886
type xnetBatchReaderWriter interface {
xnetBatchReader
xnetBatchWriter
}
type xnetBatchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type xnetBatchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
// linuxBatchingConn is a UDP socket that provides batched i/o. It implements
// batchingConn.
type linuxBatchingConn struct {
pc nettype.PacketConn
xpc xnetBatchReaderWriter
rxOffload bool // supports UDP GRO or similar
txOffload atomic.Bool // supports UDP GSO or similar
setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing
getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing
sendBatchPool sync.Pool
}
func (c *linuxBatchingConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
if c.rxOffload {
// UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may
// receive a "monster datagram" from any read call. The ReadFrom() API
// does not support passing the GSO size and is unsafe to use in such a
// case. Other platforms may vary in behavior, but we go with the most
// conservative approach to prevent this from becoming a footgun in the
// future.
return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
}
return c.pc.ReadFromUDPAddrPort(p)
}
func (c *linuxBatchingConn) SetDeadline(t time.Time) error {
return c.pc.SetDeadline(t)
}
func (c *linuxBatchingConn) SetReadDeadline(t time.Time) error {
return c.pc.SetReadDeadline(t)
}
func (c *linuxBatchingConn) SetWriteDeadline(t time.Time) error {
return c.pc.SetWriteDeadline(t)
}
const (
// This was initially established for Linux, but may split out to
// GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the
// kernel's TX path, and UDP_GRO_CNT_MAX for RX.
udpSegmentMaxDatagrams = 64
)
const (
// Exceeding these values results in EMSGSIZE.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
)
// coalesceMessages iterates msgs, coalescing them where possible while
// maintaining datagram order. All msgs have their Addr field set to addr.
func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of buffs
)
maxPayloadLen := maxIPv4PayloadLen
if addr.IP.To4() == nil {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buff := range buffs {
if i > 0 {
msgLen := len(buff)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...)
copy(msgs[base].Buffers[0][baseLenBefore:], buff)
if i == len(buffs)-1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buff)
msgs[base].OOB = msgs[base].OOB[:0]
msgs[base].Buffers[0] = buff
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type sendBatch struct {
msgs []ipv6.Message
ua *net.UDPAddr
}
func (c *linuxBatchingConn) getSendBatch() *sendBatch {
batch := c.sendBatchPool.Get().(*sendBatch)
return batch
}
func (c *linuxBatchingConn) putSendBatch(batch *sendBatch) {
for i := range batch.msgs {
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB}
}
c.sendBatchPool.Put(batch)
}
func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
batch := c.getSendBatch()
defer c.putSendBatch(batch)
if addr.Addr().Is6() {
as16 := addr.Addr().As16()
copy(batch.ua.IP, as16[:])
batch.ua.IP = batch.ua.IP[:16]
} else {
as4 := addr.Addr().As4()
copy(batch.ua.IP, as4[:])
batch.ua.IP = batch.ua.IP[:4]
}
batch.ua.Port = int(addr.Port())
var (
n int
retried bool
)
retry:
if c.txOffload.Load() {
n = c.coalesceMessages(batch.ua, buffs, batch.msgs)
} else {
for i := range buffs {
batch.msgs[i].Buffers[0] = buffs[i]
batch.msgs[i].Addr = batch.ua
batch.msgs[i].OOB = batch.msgs[i].OOB[:0]
}
n = len(buffs)
}
err := c.writeBatch(batch.msgs[:n])
if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) {
c.txOffload.Store(false)
retried = true
goto retry
}
if retried {
return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err}
}
return err
}
func (c *linuxBatchingConn) SyscallConn() (syscall.RawConn, error) {
sc, ok := c.pc.(syscall.Conn)
if !ok {
return nil, errUnsupportedConnType
}
return sc.SyscallConn()
}
func (c *linuxBatchingConn) writeBatch(msgs []ipv6.Message) error {
var head int
for {
n, err := c.xpc.WriteBatch(msgs[head:], 0)
if err != nil || n == len(msgs[head:]) {
// Returning the number of packets written would require
// unraveling individual msg len and gso size during a coalesced
// write. The top of the call stack disregards partial success,
// so keep this simple for now.
return err
}
head += n
}
}
// splitCoalescedMessages splits coalesced messages from the tail of dst
// beginning at index 'firstMsgAt' into the head of the same slice. It reports
// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An
// error is returned if a socket control message cannot be parsed or a split
// operation would overflow msgs.
func (c *linuxBatchingConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}
func (c *linuxBatchingConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
if !c.rxOffload || len(msgs) < 2 {
return c.xpc.ReadBatch(msgs, flags)
}
// Read into the tail of msgs, split into the head.
readAt := len(msgs) - 2
numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0)
if err != nil || numRead == 0 {
return 0, err
}
return c.splitCoalescedMessages(msgs, readAt)
}
func (c *linuxBatchingConn) LocalAddr() net.Addr {
return c.pc.LocalAddr().(*net.UDPAddr)
}
func (c *linuxBatchingConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
return c.pc.WriteToUDPAddrPort(b, addr)
}
func (c *linuxBatchingConn) Close() error {
return c.pc.Close()
}
// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn,
// and returns two booleans indicating TX and RX UDP offload support.
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
if c, ok := pconn.(*net.UDPConn); ok {
rc, err := c.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
hasTX = errSyscall == nil
errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
hasRX = errSyscall == nil
})
if err != nil {
return false, false
}
}
return hasTX, hasRX
}
// getGSOSizeFromControl returns the GSO size found in control. If no GSO size
// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0.
// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but
// its contents cannot be parsed as a socket control message.
func getGSOSizeFromControl(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 {
return int(binary.NativeEndian.Uint16(data[:2])), nil
}
}
return 0, nil
}
// setGSOSizeInControl sets a socket control message in control containing
// gsoSize. If len(control) < controlMessageSize control's len will be set to 0.
func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
*control = (*control)[:0]
if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
return
}
if cap(*control) < controlMessageSize {
return
}
*control = (*control)[:cap(*control)]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(2))
binary.NativeEndian.PutUint16((*control)[unix.SizeofCmsghdr:], gsoSize)
*control = (*control)[:unix.CmsgSpace(2)]
}
// tryUpgradeToBatchingConn probes the capabilities of the OS and pconn, and
// upgrades pconn to a *linuxBatchingConn if appropriate.
func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
if network != "udp4" && network != "udp6" {
return pconn
}
if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") {
// recvmmsg/sendmmsg were added in 2.6.33, but we support down to
// 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807.
// As a cheap heuristic: if the Linux kernel starts with "2", just
// consider it too old for mmsg. Nobody who cares about performance runs
// such ancient kernels. UDP offload was added much later, so no
// upgrades are available.
return pconn
}
uc, ok := pconn.(*net.UDPConn)
if !ok {
return pconn
}
b := &linuxBatchingConn{
pc: pconn,
getGSOSizeFromControl: getGSOSizeFromControl,
setGSOSizeInControl: setGSOSizeInControl,
sendBatchPool: sync.Pool{
New: func() any {
ua := &net.UDPAddr{
IP: make([]byte, 16),
}
msgs := make([]ipv6.Message, batchSize)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
msgs[i].OOB = make([]byte, controlMessageSize)
}
return &sendBatch{
ua: ua,
msgs: msgs,
}
},
},
}
switch network {
case "udp4":
b.xpc = ipv4.NewPacketConn(uc)
case "udp6":
b.xpc = ipv6.NewPacketConn(uc)
default:
panic("bogus network")
}
var txOffload bool
txOffload, b.rxOffload = tryEnableUDPOffload(uc)
b.txOffload.Store(txOffload)
return b
}

View File

@@ -0,0 +1,244 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"encoding/binary"
"net"
"testing"
"golang.org/x/net/ipv6"
)
func setGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func getGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) {
c := &linuxBatchingConn{
setGSOSizeInControl: setGSOSize,
getGSOSizeFromControl: getGSOSize,
}
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1024)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := c.splitCoalescedMessages(tt.msgs, 2)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}
func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
c := &linuxBatchingConn{
setGSOSizeInControl: setGSOSize,
getGSOSizeFromControl: getGSOSize,
}
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := c.coalesceMessages(addr, tt.buffs, msgs)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := range got {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := getGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}

View File

@@ -25,7 +25,6 @@ import (
"github.com/tailscale/wireguard-go/conn"
"go4.org/mem"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"tailscale.com/control/controlknobs"
@@ -1101,12 +1100,6 @@ var errNoUDP = errors.New("no UDP available on platform")
var errUnsupportedConnType = errors.New("unsupported connection type")
var (
// This acts as a compile-time check for our usage of ipv6.Message in
// batchingUDPConn for both IPv6 and IPv4 operations.
_ ipv6.Message = ipv4.Message{}
)
func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) {
isIPv6 := false
switch {
@@ -2656,153 +2649,6 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) {
return ep, nil
}
func (c *batchingUDPConn) writeBatch(msgs []ipv6.Message) error {
var head int
for {
n, err := c.xpc.WriteBatch(msgs[head:], 0)
if err != nil || n == len(msgs[head:]) {
// Returning the number of packets written would require
// unraveling individual msg len and gso size during a coalesced
// write. The top of the call stack disregards partial success,
// so keep this simple for now.
return err
}
head += n
}
}
// splitCoalescedMessages splits coalesced messages from the tail of dst
// beginning at index 'firstMsgAt' into the head of the same slice. It reports
// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An
// error is returned if a socket control message cannot be parsed or a split
// operation would overflow msgs.
func (c *batchingUDPConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}
func (c *batchingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
if !c.rxOffload || len(msgs) < 2 {
return c.xpc.ReadBatch(msgs, flags)
}
// Read into the tail of msgs, split into the head.
readAt := len(msgs) - 2
numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0)
if err != nil || numRead == 0 {
return 0, err
}
return c.splitCoalescedMessages(msgs, readAt)
}
func (c *batchingUDPConn) LocalAddr() net.Addr {
return c.pc.LocalAddr().(*net.UDPAddr)
}
func (c *batchingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
return c.pc.WriteToUDPAddrPort(b, addr)
}
func (c *batchingUDPConn) Close() error {
return c.pc.Close()
}
// tryUpgradeToBatchingUDPConn probes the capabilities of the OS and pconn, and
// upgrades pconn to a *batchingUDPConn if appropriate.
func tryUpgradeToBatchingUDPConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
if network != "udp4" && network != "udp6" {
return pconn
}
if runtime.GOOS != "linux" {
return pconn
}
if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") {
// recvmmsg/sendmmsg were added in 2.6.33, but we support down to
// 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807.
// As a cheap heuristic: if the Linux kernel starts with "2", just
// consider it too old for mmsg. Nobody who cares about performance runs
// such ancient kernels. UDP offload was added much later, so no
// upgrades are available.
return pconn
}
uc, ok := pconn.(*net.UDPConn)
if !ok {
return pconn
}
b := &batchingUDPConn{
pc: pconn,
getGSOSizeFromControl: getGSOSizeFromControl,
setGSOSizeInControl: setGSOSizeInControl,
sendBatchPool: sync.Pool{
New: func() any {
ua := &net.UDPAddr{
IP: make([]byte, 16),
}
msgs := make([]ipv6.Message, batchSize)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
msgs[i].OOB = make([]byte, controlMessageSize)
}
return &sendBatch{
ua: ua,
msgs: msgs,
}
},
},
}
switch network {
case "udp4":
b.xpc = ipv4.NewPacketConn(uc)
case "udp6":
b.xpc = ipv6.NewPacketConn(uc)
default:
panic("bogus network")
}
var txOffload bool
txOffload, b.rxOffload = tryEnableUDPOffload(uc)
b.txOffload.Store(txOffload)
return b
}
func newBlockForeverConn() *blockForeverConn {
c := new(blockForeverConn)
c.cond = sync.NewCond(&c.mu)

View File

@@ -21,16 +21,6 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) {
portableTrySetSocketBuffer(pconn, logf)
}
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
return false, false
}
func getGSOSizeFromControl(control []byte) (int, error) {
return 0, nil
}
func setGSOSizeInControl(control *[]byte, gso uint16) {}
const (
controlMessageSize = 0
)

View File

@@ -318,70 +318,6 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) {
}
}
// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn,
// and returns two booleans indicating TX and RX UDP offload support.
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
if c, ok := pconn.(*net.UDPConn); ok {
rc, err := c.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
hasTX = errSyscall == nil
errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
hasRX = errSyscall == nil
})
if err != nil {
return false, false
}
}
return hasTX, hasRX
}
// getGSOSizeFromControl returns the GSO size found in control. If no GSO size
// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0.
// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but
// its contents cannot be parsed as a socket control message.
func getGSOSizeFromControl(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 {
return int(binary.NativeEndian.Uint16(data[:2])), nil
}
}
return 0, nil
}
// setGSOSizeInControl sets a socket control message in control containing
// gsoSize. If len(control) < controlMessageSize control's len will be set to 0.
func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
*control = (*control)[:0]
if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
return
}
if cap(*control) < controlMessageSize {
return
}
*control = (*control)[:cap(*control)]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(2))
binary.NativeEndian.PutUint16((*control)[unix.SizeofCmsghdr:], gsoSize)
*control = (*control)[:unix.CmsgSpace(2)]
}
var controlMessageSize = -1 // bomb if used for allocation before init
func init() {

View File

@@ -35,7 +35,6 @@ import (
xmaps "golang.org/x/exp/maps"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/control/controlknobs"
"tailscale.com/derp"
@@ -2038,238 +2037,6 @@ func TestBufferedDerpWritesBeforeDrop(t *testing.T) {
t.Logf("bufferedDerpWritesBeforeDrop = %d", vv)
}
func setGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func getGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_batchingUDPConn_splitCoalescedMessages(t *testing.T) {
c := &batchingUDPConn{
setGSOSizeInControl: setGSOSize,
getGSOSizeFromControl: getGSOSize,
}
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1024)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := c.splitCoalescedMessages(tt.msgs, 2)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}
func Test_batchingUDPConn_coalesceMessages(t *testing.T) {
c := &batchingUDPConn{
setGSOSizeInControl: setGSOSize,
getGSOSizeFromControl: getGSOSize,
}
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := c.coalesceMessages(addr, tt.buffs, msgs)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := range got {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := getGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}
// newWireguard starts up a new wireguard-go device attached to a test tun, and
// returns the device, tun and endpoint port. To add peers call device.IpcSet with UAPI instructions.
func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, uint16) {

View File

@@ -35,12 +35,12 @@ type RebindingUDPConn struct {
// setConnLocked sets the provided nettype.PacketConn. It should be called only
// after acquiring RebindingUDPConn.mu. It upgrades the provided
// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade
// is intentionally pushed closest to where read/write ops occur in order to
// avoid disrupting surrounding code that assumes nettype.PacketConn is a
// nettype.PacketConn to a batchingConn when appropriate. This upgrade is
// intentionally pushed closest to where read/write ops occur in order to avoid
// disrupting surrounding code that assumes nettype.PacketConn is a
// *net.UDPConn.
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) {
upc := tryUpgradeToBatchingUDPConn(p, network, batchSize)
upc := tryUpgradeToBatchingConn(p, network, batchSize)
c.pconn = upc
c.pconnAtomic.Store(&upc)
c.port = uint16(c.localAddrLocked().Port)
@@ -74,7 +74,7 @@ func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, e
func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
for {
pconn := *c.pconnAtomic.Load()
b, ok := pconn.(*batchingUDPConn)
b, ok := pconn.(batchingConn)
if !ok {
for _, buf := range buffs {
_, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr)
@@ -101,7 +101,7 @@ func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) err
func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) {
for {
pconn := *c.pconnAtomic.Load()
b, ok := pconn.(*batchingUDPConn)
b, ok := pconn.(batchingConn)
if !ok {
n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0])
if err == nil {

View File

@@ -692,3 +692,5 @@ azules
tabby
ussuri
kitty
tanuki
neko