Compare commits
26 Commits
nickkhyl/s
...
awly/cli-j
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
776ab357b1 | ||
|
|
a93dc6cdb1 | ||
|
|
7bac5dffcb | ||
|
|
b3fc345aba | ||
|
|
9106187a95 | ||
|
|
9b08399d9e | ||
|
|
153a476957 | ||
|
|
227509547f | ||
|
|
e3f047618b | ||
|
|
91d2e1772d | ||
|
|
3b6849e362 | ||
|
|
0fd73746dd | ||
|
|
17c88a19be | ||
|
|
25f0a3fc8f | ||
|
|
a7a394e7d9 | ||
|
|
07e2487c1d | ||
|
|
1dd9c44d51 | ||
|
|
0a6eb12f05 | ||
|
|
f205efcf18 | ||
|
|
a917718353 | ||
|
|
4099a36468 | ||
|
|
d9d9d525d9 | ||
|
|
9939374c48 | ||
|
|
4055b63b9b | ||
|
|
f0230ce0b5 | ||
|
|
cc370314e7 |
@@ -7,8 +7,6 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sort"
|
||||
@@ -70,8 +68,13 @@ func main() {
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
tsweb.Debugger(mux)
|
||||
mux.HandleFunc("/", http.HandlerFunc(serveFunc(p)))
|
||||
d := tsweb.Debugger(mux)
|
||||
d.Handle("probe-run", "Run a probe", tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{Logf: log.Printf}))
|
||||
mux.Handle("/", tsweb.StdHandler(p.StatusHandler(
|
||||
prober.WithTitle("DERP Prober"),
|
||||
prober.WithPageLink("Prober metrics", "/debug/varz"),
|
||||
prober.WithProbeLink("Run Probe", "/debug/probe-run?name={{.Name}}"),
|
||||
), tsweb.HandlerOptions{Logf: log.Printf}))
|
||||
log.Printf("Listening on %s", *listen)
|
||||
log.Fatal(http.ListenAndServe(*listen, mux))
|
||||
}
|
||||
@@ -105,26 +108,3 @@ func getOverallStatus(p *prober.Prober) (o overallStatus) {
|
||||
sort.Strings(o.good)
|
||||
return
|
||||
}
|
||||
|
||||
func serveFunc(p *prober.Prober) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
st := getOverallStatus(p)
|
||||
summary := "All good"
|
||||
if (float64(len(st.bad)) / float64(len(st.bad)+len(st.good))) > 0.25 {
|
||||
// Returning a 500 allows monitoring this server externally and configuring
|
||||
// an alert on HTTP response code.
|
||||
w.WriteHeader(500)
|
||||
summary = fmt.Sprintf("%d problems", len(st.bad))
|
||||
}
|
||||
|
||||
io.WriteString(w, "<html><head><style>.bad { font-weight: bold; color: #700; }</style></head>\n")
|
||||
fmt.Fprintf(w, "<body><h1>derp probe</h1>\n%s:<ul>", summary)
|
||||
for _, s := range st.bad {
|
||||
fmt.Fprintf(w, "<li class=bad>%s</li>\n", html.EscapeString(s))
|
||||
}
|
||||
for _, s := range st.good {
|
||||
fmt.Fprintf(w, "<li>%s</li>\n", html.EscapeString(s))
|
||||
}
|
||||
io.WriteString(w, "</ul></body></html>\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,6 +310,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+
|
||||
gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+
|
||||
💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+
|
||||
gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack
|
||||
gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+
|
||||
gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack
|
||||
gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+
|
||||
|
||||
@@ -7,6 +7,7 @@ package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
@@ -159,8 +160,10 @@ func newRootCmd() *ffcli.Command {
|
||||
return nil
|
||||
})
|
||||
rootfs.Lookup("socket").DefValue = localClient.Socket
|
||||
jsonDocs := rootfs.Bool("json-docs", false, hidden+"print JSON-encoded docs for all subcommands and flags")
|
||||
|
||||
rootCmd := &ffcli.Command{
|
||||
var rootCmd *ffcli.Command
|
||||
rootCmd = &ffcli.Command{
|
||||
Name: "tailscale",
|
||||
ShortUsage: "tailscale [flags] <subcommand> [command flags]",
|
||||
ShortHelp: "The easiest, most secure way to use WireGuard.",
|
||||
@@ -202,6 +205,9 @@ change in the future.
|
||||
},
|
||||
FlagSet: rootfs,
|
||||
Exec: func(ctx context.Context, args []string) error {
|
||||
if *jsonDocs {
|
||||
return printJSONDocs(rootCmd)
|
||||
}
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("tailscale: unknown subcommand: %s", args[0])
|
||||
}
|
||||
@@ -401,3 +407,54 @@ func colorableOutput() (w io.Writer, ok bool) {
|
||||
}
|
||||
return colorable.NewColorableStdout(), true
|
||||
}
|
||||
|
||||
type commandDoc struct {
|
||||
Name string
|
||||
Desc string
|
||||
Subcommands []commandDoc `json:",omitempty"`
|
||||
Flags []flagDoc `json:",omitempty"`
|
||||
}
|
||||
|
||||
type flagDoc struct {
|
||||
Name string
|
||||
Desc string
|
||||
}
|
||||
|
||||
func printJSONDocs(root *ffcli.Command) error {
|
||||
docs := jsonDocsWalk(root)
|
||||
return json.NewEncoder(os.Stdout).Encode(docs)
|
||||
}
|
||||
|
||||
func jsonDocsWalk(cmd *ffcli.Command) *commandDoc {
|
||||
res := &commandDoc{
|
||||
Name: cmd.Name,
|
||||
}
|
||||
if cmd.LongHelp != "" {
|
||||
res.Desc = cmd.LongHelp
|
||||
} else if cmd.ShortHelp != "" {
|
||||
res.Desc = cmd.ShortHelp
|
||||
} else {
|
||||
res.Desc = cmd.ShortUsage
|
||||
}
|
||||
if strings.HasPrefix(res.Desc, hidden) {
|
||||
return nil
|
||||
}
|
||||
if cmd.FlagSet != nil {
|
||||
cmd.FlagSet.VisitAll(func(f *flag.Flag) {
|
||||
if strings.HasPrefix(f.Usage, hidden) {
|
||||
return
|
||||
}
|
||||
res.Flags = append(res.Flags, flagDoc{
|
||||
Name: f.Name,
|
||||
Desc: f.Usage,
|
||||
})
|
||||
})
|
||||
}
|
||||
for _, sub := range cmd.Subcommands {
|
||||
subj := jsonDocsWalk(sub)
|
||||
if subj != nil {
|
||||
res.Subcommands = append(res.Subcommands, *subj)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
@@ -789,7 +789,7 @@ func runNetworkLockRevokeKeys(ctx context.Context, args []string) error {
|
||||
}
|
||||
|
||||
fmt.Printf(`Run the following command on another machine with a trusted tailnet lock key:
|
||||
%s lock recover-compromised-key --cosign %X
|
||||
%s lock revoke-keys --cosign %X
|
||||
`, os.Args[0], aumBytes)
|
||||
return nil
|
||||
}
|
||||
@@ -813,10 +813,10 @@ func runNetworkLockRevokeKeys(ctx context.Context, args []string) error {
|
||||
fmt.Printf(`Co-signing completed successfully.
|
||||
|
||||
To accumulate an additional signature, run the following command on another machine with a trusted tailnet lock key:
|
||||
%s lock recover-compromised-key --cosign %X
|
||||
%s lock revoke-keys --cosign %X
|
||||
|
||||
Alternatively if you are done with co-signing, complete recovery by running the following command:
|
||||
%s lock recover-compromised-key --finish %X
|
||||
%s lock revoke-keys --finish %X
|
||||
`, os.Args[0], aumBytes, os.Args[0], aumBytes)
|
||||
}
|
||||
|
||||
|
||||
@@ -221,6 +221,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+
|
||||
gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+
|
||||
💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+
|
||||
gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack
|
||||
gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+
|
||||
gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack
|
||||
gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+
|
||||
|
||||
@@ -333,6 +333,9 @@ func (c *Direct) Close() error {
|
||||
}
|
||||
}
|
||||
c.noiseClient = nil
|
||||
if tr, ok := c.httpc.Transport.(*http.Transport); ok {
|
||||
tr.CloseIdleConnections()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ import (
|
||||
"tailscale.com/net/sockstats"
|
||||
"tailscale.com/net/tlsdial"
|
||||
"tailscale.com/net/tshttpproxy"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstime"
|
||||
"tailscale.com/util/multierr"
|
||||
@@ -497,11 +498,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
|
||||
tr.DisableCompression = true
|
||||
|
||||
// (mis)use httptrace to extract the underlying net.Conn from the
|
||||
// transport. We make exactly 1 request using this transport, so
|
||||
// there will be exactly 1 GotConn call. Additionally, the
|
||||
// transport handles 101 Switching Protocols correctly, such that
|
||||
// the Conn will not be reused or kept alive by the transport once
|
||||
// the response has been handed back from RoundTrip.
|
||||
// transport. The transport handles 101 Switching Protocols correctly,
|
||||
// such that the Conn will not be reused or kept alive by the transport
|
||||
// once the response has been handed back from RoundTrip.
|
||||
//
|
||||
// In theory, the machinery of net/http should make it such that
|
||||
// the trace callback happens-before we get the response, but
|
||||
@@ -517,10 +516,16 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
|
||||
// unexpected EOFs...), and we're bound to forget someday and
|
||||
// introduce a protocol optimization at a higher level that starts
|
||||
// eagerly transmitting from the server.
|
||||
connCh := make(chan net.Conn, 1)
|
||||
var lastConn syncs.AtomicValue[net.Conn]
|
||||
trace := httptrace.ClientTrace{
|
||||
// Even though we only make a single HTTP request which should
|
||||
// require a single connection, the context (with the attached
|
||||
// trace configuration) might be used by our custom dialer to
|
||||
// make other HTTP requests (e.g. BootstrapDNS). We only care
|
||||
// about the last connection made, which should be the one to
|
||||
// the control server.
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
connCh <- info.Conn
|
||||
lastConn.Store(info.Conn)
|
||||
},
|
||||
}
|
||||
ctx = httptrace.WithClientTrace(ctx, &trace)
|
||||
@@ -548,11 +553,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
|
||||
// is still a read buffer attached to it within resp.Body. So, we
|
||||
// must direct I/O through resp.Body, but we can still use the
|
||||
// underlying net.Conn for stuff like deadlines.
|
||||
var switchedConn net.Conn
|
||||
select {
|
||||
case switchedConn = <-connCh:
|
||||
default:
|
||||
}
|
||||
switchedConn := lastConn.Load()
|
||||
if switchedConn == nil {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("httptrace didn't provide a connection")
|
||||
|
||||
@@ -11,10 +11,12 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -41,6 +43,8 @@ type httpTestParam struct {
|
||||
makeHTTPHangAfterUpgrade bool
|
||||
|
||||
doEarlyWrite bool
|
||||
|
||||
httpInDial bool
|
||||
}
|
||||
|
||||
func TestControlHTTP(t *testing.T) {
|
||||
@@ -120,6 +124,12 @@ func TestControlHTTP(t *testing.T) {
|
||||
name: "early_write",
|
||||
doEarlyWrite: true,
|
||||
},
|
||||
// Dialer needed to make another HTTP request along the way (e.g. to
|
||||
// resolve the hostname via BootstrapDNS).
|
||||
{
|
||||
name: "http_request_in_dial",
|
||||
httpInDial: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -217,6 +227,29 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
||||
Clock: clock,
|
||||
}
|
||||
|
||||
if param.httpInDial {
|
||||
// Spin up a separate server to get a different port on localhost.
|
||||
secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return }))
|
||||
defer secondServer.Close()
|
||||
|
||||
prev := a.Dialer
|
||||
a.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", secondServer.URL, nil)
|
||||
if err != nil {
|
||||
t.Errorf("http.NewRequest: %v", err)
|
||||
}
|
||||
r, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Errorf("http.Get: %v", err)
|
||||
}
|
||||
r.Body.Close()
|
||||
|
||||
return prev(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
if proxy != nil {
|
||||
proxyEnv := proxy.Start(t)
|
||||
defer proxy.Close()
|
||||
@@ -238,6 +271,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
||||
t.Fatalf("dialing controlhttp: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
si := <-sch
|
||||
if si.conn != nil {
|
||||
defer si.conn.Close()
|
||||
@@ -266,6 +300,19 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
||||
t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// When no proxy is used, the RemoteAddr of the returned connection should match
|
||||
// one of the listeners of the test server.
|
||||
if proxy == nil {
|
||||
var expectedAddrs []string
|
||||
for _, ln := range []net.Listener{httpLn, httpsLn} {
|
||||
expectedAddrs = append(expectedAddrs, fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port))
|
||||
expectedAddrs = append(expectedAddrs, fmt.Sprintf("[::1]:%d", ln.Addr().(*net.TCPAddr).Port))
|
||||
}
|
||||
if !slices.Contains(expectedAddrs, conn.RemoteAddr().String()) {
|
||||
t.Errorf("unexpected remote addr: %s, want %s", conn.RemoteAddr(), expectedAddrs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type serverResult struct {
|
||||
|
||||
@@ -120,4 +120,4 @@
|
||||
in
|
||||
flake-utils.lib.eachDefaultSystem (system: flakeForSystem nixpkgs system);
|
||||
}
|
||||
# nix-direnv cache busting line: sha256-N0TZ1JuDqh6bZjOHcMfoEDOsiUlrC/tR72fBns1GwrM=
|
||||
# nix-direnv cache busting line: sha256-1hekcJr1jEJFu4ZnapNkbAAv+8phTQuMloULIZ0f018=
|
||||
|
||||
4
go.mod
4
go.mod
@@ -80,7 +80,7 @@ require (
|
||||
github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4
|
||||
github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1
|
||||
github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6
|
||||
github.com/tailscale/wireguard-go v0.0.0-20240724015428-60eeedfd624b
|
||||
github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98
|
||||
github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e
|
||||
github.com/tc-hib/winres v0.2.1
|
||||
github.com/tcnksm/go-httpstat v0.2.0
|
||||
@@ -104,7 +104,7 @@ require (
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
gopkg.in/square/go-jose.v2 v2.6.0
|
||||
gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3
|
||||
gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987
|
||||
honnef.co/go/tools v0.4.6
|
||||
k8s.io/api v0.30.3
|
||||
k8s.io/apimachinery v0.30.3
|
||||
|
||||
@@ -1 +1 @@
|
||||
sha256-N0TZ1JuDqh6bZjOHcMfoEDOsiUlrC/tR72fBns1GwrM=
|
||||
sha256-1hekcJr1jEJFu4ZnapNkbAAv+8phTQuMloULIZ0f018=
|
||||
|
||||
8
go.sum
8
go.sum
@@ -934,8 +934,8 @@ github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:t
|
||||
github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ=
|
||||
github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M=
|
||||
github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y=
|
||||
github.com/tailscale/wireguard-go v0.0.0-20240724015428-60eeedfd624b h1:8U9NaPB32iFoNjJ+H/yPkAVqXw/dudtj+fLTE4edF+Q=
|
||||
github.com/tailscale/wireguard-go v0.0.0-20240724015428-60eeedfd624b/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4=
|
||||
github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso=
|
||||
github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4=
|
||||
github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek=
|
||||
github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg=
|
||||
github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA=
|
||||
@@ -1491,8 +1491,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o=
|
||||
gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g=
|
||||
gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM=
|
||||
gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0=
|
||||
gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8=
|
||||
gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
@@ -987,8 +988,12 @@ func (t *Tracker) updateBuiltinWarnablesLocked() {
|
||||
}
|
||||
|
||||
if t.lastLoginErr != nil {
|
||||
var errMsg string
|
||||
if !errors.Is(t.lastLoginErr, context.Canceled) {
|
||||
errMsg = t.lastLoginErr.Error()
|
||||
}
|
||||
t.setUnhealthyLocked(LoginStateWarnable, Args{
|
||||
ArgError: t.lastLoginErr.Error(),
|
||||
ArgError: errMsg,
|
||||
})
|
||||
return
|
||||
} else {
|
||||
|
||||
@@ -65,8 +65,8 @@ See also the dependencies in the [Tailscale CLI][].
|
||||
- [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE))
|
||||
- [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/cabfb018fe85/LICENSE))
|
||||
- [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE))
|
||||
- [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/2f5d148bcfe1/LICENSE))
|
||||
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/62b9a7c569f9/LICENSE))
|
||||
- [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/71393c576b98/LICENSE))
|
||||
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE))
|
||||
- [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE))
|
||||
- [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/a3c409a6018e/LICENSE))
|
||||
- [github.com/vishvananda/netlink/nl](https://pkg.go.dev/github.com/vishvananda/netlink/nl) ([Apache-2.0](https://github.com/vishvananda/netlink/blob/v1.2.1-beta.2/LICENSE))
|
||||
@@ -82,7 +82,7 @@ See also the dependencies in the [Tailscale CLI][].
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE))
|
||||
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE))
|
||||
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/ee1e1f6070e3/LICENSE))
|
||||
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE))
|
||||
- [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) ([ISC](https://github.com/nhooyr/websocket/blob/v1.8.10/LICENSE.txt))
|
||||
- [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE))
|
||||
|
||||
|
||||
@@ -84,8 +84,8 @@ Some packages may only be included on certain architectures or operating systems
|
||||
- [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE))
|
||||
- [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/5db17b287bf1/LICENSE))
|
||||
- [github.com/tailscale/wf](https://pkg.go.dev/github.com/tailscale/wf) ([BSD-3-Clause](https://github.com/tailscale/wf/blob/6fbb0a674ee6/LICENSE))
|
||||
- [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/2f5d148bcfe1/LICENSE))
|
||||
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/62b9a7c569f9/LICENSE))
|
||||
- [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/71393c576b98/LICENSE))
|
||||
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE))
|
||||
- [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE))
|
||||
- [github.com/toqueteos/webbrowser](https://pkg.go.dev/github.com/toqueteos/webbrowser) ([MIT](https://github.com/toqueteos/webbrowser/blob/v1.2.0/LICENSE.md))
|
||||
- [github.com/u-root/u-root/pkg/termios](https://pkg.go.dev/github.com/u-root/u-root/pkg/termios) ([BSD-3-Clause](https://github.com/u-root/u-root/blob/v0.12.0/LICENSE))
|
||||
@@ -95,19 +95,19 @@ Some packages may only be included on certain architectures or operating systems
|
||||
- [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE))
|
||||
- [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/4f986261bf13/LICENSE))
|
||||
- [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE))
|
||||
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.24.0:LICENSE))
|
||||
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE))
|
||||
- [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/1b970713:LICENSE))
|
||||
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.26.0:LICENSE))
|
||||
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE))
|
||||
- [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.16.0:LICENSE))
|
||||
- [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.21.0:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.21.0:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE))
|
||||
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE))
|
||||
- [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2))
|
||||
- [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3))
|
||||
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/ee1e1f6070e3/LICENSE))
|
||||
- [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.30.1/LICENSE))
|
||||
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE))
|
||||
- [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.30.3/LICENSE))
|
||||
- [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) ([ISC](https://github.com/nhooyr/websocket/blob/v1.8.10/LICENSE.txt))
|
||||
- [sigs.k8s.io/yaml](https://pkg.go.dev/sigs.k8s.io/yaml) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/LICENSE))
|
||||
- [sigs.k8s.io/yaml/goyaml.v2](https://pkg.go.dev/sigs.k8s.io/yaml/goyaml.v2) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/goyaml.v2/LICENSE))
|
||||
|
||||
@@ -57,9 +57,9 @@ Windows][]. See also the dependencies in the [Tailscale CLI][].
|
||||
- [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE))
|
||||
- [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE))
|
||||
- [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/cabfb018fe85/LICENSE))
|
||||
- [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/7601212d8e23/LICENSE))
|
||||
- [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/4327221bd339/LICENSE))
|
||||
- [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/6580b55d49ca/LICENSE))
|
||||
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/62b9a7c569f9/LICENSE))
|
||||
- [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE))
|
||||
- [github.com/tc-hib/winres](https://pkg.go.dev/github.com/tc-hib/winres) ([0BSD](https://github.com/tc-hib/winres/blob/v0.2.1/LICENSE))
|
||||
- [github.com/vishvananda/netlink/nl](https://pkg.go.dev/github.com/vishvananda/netlink/nl) ([Apache-2.0](https://github.com/vishvananda/netlink/blob/v1.2.1-beta.2/LICENSE))
|
||||
- [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE))
|
||||
@@ -69,7 +69,7 @@ Windows][]. See also the dependencies in the [Tailscale CLI][].
|
||||
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE))
|
||||
- [golang.org/x/exp/constraints](https://pkg.go.dev/golang.org/x/exp/constraints) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE))
|
||||
- [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE))
|
||||
- [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.18.0:LICENSE))
|
||||
- [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE))
|
||||
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE))
|
||||
- [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE))
|
||||
|
||||
@@ -106,13 +106,19 @@ func (d *Detector) detectCaptivePortalWithGOOS(ctx context.Context, netMon *netm
|
||||
return false
|
||||
}
|
||||
|
||||
// interfaceNameDoesNotNeedCaptiveDetection returns true if an interface does not require captive portal detection
|
||||
// based on its name. This is useful to avoid making unnecessary HTTP requests on interfaces that are known to not
|
||||
// require it. We also avoid making requests on the interface prefixes "pdp" and "rmnet", which are cellular data
|
||||
// interfaces on iOS and Android, respectively, and would be needlessly battery-draining.
|
||||
func interfaceNameDoesNotNeedCaptiveDetection(ifName string, goos string) bool {
|
||||
ifName = strings.ToLower(ifName)
|
||||
excludedPrefixes := []string{"tailscale", "tun", "tap", "docker", "kube", "wg"}
|
||||
if goos == "windows" {
|
||||
excludedPrefixes = append(excludedPrefixes, "loopback", "tunnel", "ppp", "isatap", "teredo", "6to4")
|
||||
} else if goos == "darwin" || goos == "ios" {
|
||||
excludedPrefixes = append(excludedPrefixes, "awdl", "bridge", "ap", "utun", "tap", "llw", "anpi", "lo", "stf", "gif", "xhc")
|
||||
excludedPrefixes = append(excludedPrefixes, "pdp", "awdl", "bridge", "ap", "utun", "tap", "llw", "anpi", "lo", "stf", "gif", "xhc", "pktap")
|
||||
} else if goos == "android" {
|
||||
excludedPrefixes = append(excludedPrefixes, "rmnet", "p2p", "dummy", "sit")
|
||||
}
|
||||
for _, prefix := range excludedPrefixes {
|
||||
if strings.HasPrefix(ifName, prefix) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/cmd/testwrapper/flakytest"
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
@@ -36,6 +37,7 @@ func TestDetectCaptivePortalReturnsFalse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAllEndpointsAreUpAndReturnExpectedResponse(t *testing.T) {
|
||||
flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13019")
|
||||
d := NewDetector(t.Logf)
|
||||
endpoints := availableEndpoints(nil, 0, t.Logf, runtime.GOOS)
|
||||
|
||||
|
||||
@@ -281,6 +281,7 @@ func lookup(ctx context.Context, host string, logf logger.Logf, ht *health.Track
|
||||
func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr, queryName string, logf logger.Logf, ht *health.Tracker, netMon *netmon.Monitor) (dnsMap, error) {
|
||||
dialer := netns.NewDialer(logf, netMon)
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DisableKeepAlives = true // This transport is meant to be used once.
|
||||
tr.Proxy = tshttpproxy.ProxyFromEnvironment
|
||||
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443"))
|
||||
|
||||
@@ -61,7 +61,7 @@ func UpdateDstAddr(q *packet.Parsed, dst netip.Addr) {
|
||||
b := q.Buffer()
|
||||
if dst.Is6() {
|
||||
v6 := dst.As16()
|
||||
copy(b[24:36], v6[:])
|
||||
copy(b[24:40], v6[:])
|
||||
updateV6PacketChecksums(q, old, dst)
|
||||
} else {
|
||||
v4 := dst.As4()
|
||||
|
||||
@@ -5,6 +5,7 @@ package checksum
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/rand/v2"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -94,7 +95,7 @@ func TestHeaderChecksumsV4(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNatChecksumsV6UDP(t *testing.T) {
|
||||
a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1")
|
||||
a1, a2 := randV6Addr(), randV6Addr()
|
||||
|
||||
// Make a fake UDP packet with 32 bytes of zeros as the datagram payload.
|
||||
b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.UDPMinimumSize+32))
|
||||
@@ -124,25 +125,43 @@ func TestNatChecksumsV6UDP(t *testing.T) {
|
||||
}
|
||||
|
||||
// Parse the packet.
|
||||
var p packet.Parsed
|
||||
var p, p2 packet.Parsed
|
||||
p.Decode(b)
|
||||
t.Log(p.String())
|
||||
|
||||
// Update the source address of the packet to be the same as the dest.
|
||||
UpdateSrcAddr(&p, a2)
|
||||
p2.Decode(p.Buffer())
|
||||
if p2.Src.Addr() != a2 {
|
||||
t.Fatalf("got %v, want %v", p2.Src, a2)
|
||||
}
|
||||
if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) {
|
||||
t.Fatal("incorrect checksum after updating source address")
|
||||
}
|
||||
|
||||
// Update the dest address of the packet to be the original source address.
|
||||
UpdateDstAddr(&p, a1)
|
||||
p2.Decode(p.Buffer())
|
||||
if p2.Dst.Addr() != a1 {
|
||||
t.Fatalf("got %v, want %v", p2.Dst, a1)
|
||||
}
|
||||
if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) {
|
||||
t.Fatal("incorrect checksum after updating destination address")
|
||||
}
|
||||
}
|
||||
|
||||
func randV6Addr() netip.Addr {
|
||||
a1, a2 := rand.Int64(), rand.Int64()
|
||||
return netip.AddrFrom16([16]byte{
|
||||
byte(a1 >> 56), byte(a1 >> 48), byte(a1 >> 40), byte(a1 >> 32),
|
||||
byte(a1 >> 24), byte(a1 >> 16), byte(a1 >> 8), byte(a1),
|
||||
byte(a2 >> 56), byte(a2 >> 48), byte(a2 >> 40), byte(a2 >> 32),
|
||||
byte(a2 >> 24), byte(a2 >> 16), byte(a2 >> 8), byte(a2),
|
||||
})
|
||||
}
|
||||
|
||||
func TestNatChecksumsV6TCP(t *testing.T) {
|
||||
a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1")
|
||||
a1, a2 := randV6Addr(), randV6Addr()
|
||||
|
||||
// Make a fake TCP packet with no payload.
|
||||
b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.TCPMinimumSize))
|
||||
@@ -178,18 +197,26 @@ func TestNatChecksumsV6TCP(t *testing.T) {
|
||||
}
|
||||
|
||||
// Parse the packet.
|
||||
var p packet.Parsed
|
||||
var p, p2 packet.Parsed
|
||||
p.Decode(b)
|
||||
t.Log(p.String())
|
||||
|
||||
// Update the source address of the packet to be the same as the dest.
|
||||
UpdateSrcAddr(&p, a2)
|
||||
p2.Decode(p.Buffer())
|
||||
if p2.Src.Addr() != a2 {
|
||||
t.Fatalf("got %v, want %v", p2.Src, a2)
|
||||
}
|
||||
if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) {
|
||||
t.Fatal("incorrect checksum after updating source address")
|
||||
}
|
||||
|
||||
// Update the dest address of the packet to be the original source address.
|
||||
UpdateDstAddr(&p, a1)
|
||||
p2.Decode(p.Buffer())
|
||||
if p2.Dst.Addr() != a1 {
|
||||
t.Fatalf("got %v, want %v", p2.Dst, a1)
|
||||
}
|
||||
if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), 0, 0) {
|
||||
t.Fatal("incorrect checksum after updating destination address")
|
||||
}
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -121,7 +123,7 @@ func (s *Server) Serve(l net.Listener) error {
|
||||
}
|
||||
go func() {
|
||||
defer c.Close()
|
||||
conn := &Conn{clientConn: c, srv: s}
|
||||
conn := &Conn{logf: s.Logf, clientConn: c, srv: s}
|
||||
err := conn.Run()
|
||||
if err != nil {
|
||||
s.logf("client connection failed: %v", err)
|
||||
@@ -136,9 +138,12 @@ type Conn struct {
|
||||
// The struct is filled by each of the internal
|
||||
// methods in turn as the transaction progresses.
|
||||
|
||||
logf logger.Logf
|
||||
srv *Server
|
||||
clientConn net.Conn
|
||||
request *request
|
||||
|
||||
udpClientAddr net.Addr
|
||||
}
|
||||
|
||||
// Run starts the new connection.
|
||||
@@ -172,58 +177,59 @@ func (c *Conn) Run() error {
|
||||
func (c *Conn) handleRequest() error {
|
||||
req, err := parseClientRequest(c.clientConn)
|
||||
if err != nil {
|
||||
res := &response{reply: generalFailure}
|
||||
res := errorResponse(generalFailure)
|
||||
buf, _ := res.marshal()
|
||||
c.clientConn.Write(buf)
|
||||
return err
|
||||
}
|
||||
if req.command != connect {
|
||||
res := &response{reply: commandNotSupported}
|
||||
|
||||
c.request = req
|
||||
switch req.command {
|
||||
case connect:
|
||||
return c.handleTCP()
|
||||
case udpAssociate:
|
||||
return c.handleUDP()
|
||||
default:
|
||||
res := errorResponse(commandNotSupported)
|
||||
buf, _ := res.marshal()
|
||||
c.clientConn.Write(buf)
|
||||
return fmt.Errorf("unsupported command %v", req.command)
|
||||
}
|
||||
c.request = req
|
||||
}
|
||||
|
||||
func (c *Conn) handleTCP() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
srv, err := c.srv.dial(
|
||||
ctx,
|
||||
"tcp",
|
||||
net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))),
|
||||
c.request.destination.hostPort(),
|
||||
)
|
||||
if err != nil {
|
||||
res := &response{reply: generalFailure}
|
||||
res := errorResponse(generalFailure)
|
||||
buf, _ := res.marshal()
|
||||
c.clientConn.Write(buf)
|
||||
return err
|
||||
}
|
||||
defer srv.Close()
|
||||
serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
|
||||
|
||||
localAddr := srv.LocalAddr().String()
|
||||
serverAddr, serverPort, err := splitHostPort(localAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverPort, _ := strconv.Atoi(serverPortStr)
|
||||
|
||||
var bindAddrType addrType
|
||||
if ip := net.ParseIP(serverAddr); ip != nil {
|
||||
if ip.To4() != nil {
|
||||
bindAddrType = ipv4
|
||||
} else {
|
||||
bindAddrType = ipv6
|
||||
}
|
||||
} else {
|
||||
bindAddrType = domainName
|
||||
}
|
||||
res := &response{
|
||||
reply: success,
|
||||
bindAddrType: bindAddrType,
|
||||
bindAddr: serverAddr,
|
||||
bindPort: uint16(serverPort),
|
||||
reply: success,
|
||||
bindAddr: socksAddr{
|
||||
addrType: getAddrType(serverAddr),
|
||||
addr: serverAddr,
|
||||
port: serverPort,
|
||||
},
|
||||
}
|
||||
buf, err := res.marshal()
|
||||
if err != nil {
|
||||
res = &response{reply: generalFailure}
|
||||
res = errorResponse(generalFailure)
|
||||
buf, _ = res.marshal()
|
||||
}
|
||||
c.clientConn.Write(buf)
|
||||
@@ -246,6 +252,208 @@ func (c *Conn) handleRequest() error {
|
||||
return <-errc
|
||||
}
|
||||
|
||||
func (c *Conn) handleUDP() error {
|
||||
// The DST.ADDR and DST.PORT fields contain the address and port that
|
||||
// the client expects to use to send UDP datagrams on for the
|
||||
// association. The server MAY use this information to limit access
|
||||
// to the association.
|
||||
// @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928.
|
||||
//
|
||||
// We do NOT limit the access from the client currently in this implementation.
|
||||
_ = c.request.destination
|
||||
|
||||
addr := c.clientConn.LocalAddr()
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0"))
|
||||
if err != nil {
|
||||
res := errorResponse(generalFailure)
|
||||
buf, _ := res.marshal()
|
||||
c.clientConn.Write(buf)
|
||||
return err
|
||||
}
|
||||
defer clientUDPConn.Close()
|
||||
|
||||
serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
|
||||
if err != nil {
|
||||
res := errorResponse(generalFailure)
|
||||
buf, _ := res.marshal()
|
||||
c.clientConn.Write(buf)
|
||||
return err
|
||||
}
|
||||
defer serverUDPConn.Close()
|
||||
|
||||
bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res := &response{
|
||||
reply: success,
|
||||
bindAddr: socksAddr{
|
||||
addrType: getAddrType(bindAddr),
|
||||
addr: bindAddr,
|
||||
port: bindPort,
|
||||
},
|
||||
}
|
||||
buf, err := res.marshal()
|
||||
if err != nil {
|
||||
res = errorResponse(generalFailure)
|
||||
buf, _ = res.marshal()
|
||||
}
|
||||
c.clientConn.Write(buf)
|
||||
|
||||
return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
|
||||
}
|
||||
|
||||
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
const bufferSize = 8 * 1024
|
||||
const readTimeout = 5 * time.Second
|
||||
|
||||
// client -> target
|
||||
go func() {
|
||||
defer cancel()
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
c.logf("udp transfer: handle udp request fail: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// target -> client
|
||||
go func() {
|
||||
defer cancel()
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
c.logf("udp transfer: handle udp response fail: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// A UDP association terminates when the TCP connection that the UDP
|
||||
// ASSOCIATE request arrived on terminates. RFC1928
|
||||
_, err := io.Copy(io.Discard, associatedTCP)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("udp associated tcp conn: %w", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) handleUDPRequest(
|
||||
clientConn net.PacketConn,
|
||||
targetConn net.PacketConn,
|
||||
buf []byte,
|
||||
readTimeout time.Duration,
|
||||
) error {
|
||||
// add a deadline for the read to avoid blocking forever
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
n, addr, err := clientConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read from client: %w", err)
|
||||
}
|
||||
c.udpClientAddr = addr
|
||||
req, data, err := parseUDPRequest(buf[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse udp request: %w", err)
|
||||
}
|
||||
targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
|
||||
if err != nil {
|
||||
c.logf("resolve target addr fail: %v", err)
|
||||
}
|
||||
|
||||
nn, err := targetConn.WriteTo(data, targetAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
|
||||
}
|
||||
if nn != len(data) {
|
||||
return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) handleUDPResponse(
|
||||
targetConn net.PacketConn,
|
||||
clientConn net.PacketConn,
|
||||
buf []byte,
|
||||
readTimeout time.Duration,
|
||||
) error {
|
||||
// add a deadline for the read to avoid blocking forever
|
||||
_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
n, addr, err := targetConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read from target: %w", err)
|
||||
}
|
||||
host, port, err := splitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
|
||||
pkt, err := hdr.marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal udp request: %w", err)
|
||||
}
|
||||
data := append(pkt, buf[:n]...)
|
||||
// use addr from client to send back
|
||||
nn, err := clientConn.WriteTo(data, c.udpClientAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to client: %w", err)
|
||||
}
|
||||
if nn != len(data) {
|
||||
return fmt.Errorf("write to client: %w", io.ErrShortWrite)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
terr, ok := errors.Unwrap(err).(interface{ Timeout() bool })
|
||||
return ok && terr.Timeout()
|
||||
}
|
||||
|
||||
func splitHostPort(hostport string) (host string, port uint16, err error) {
|
||||
host, portStr, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
portInt, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if portInt < 0 || portInt > 65535 {
|
||||
return "", 0, fmt.Errorf("invalid port number %d", portInt)
|
||||
}
|
||||
return host, uint16(portInt), nil
|
||||
}
|
||||
|
||||
// parseClientGreeting parses a request initiation packet.
|
||||
func parseClientGreeting(r io.Reader, authMethod byte) error {
|
||||
var hdr [2]byte
|
||||
@@ -295,114 +503,118 @@ func parseClientAuth(r io.Reader) (usr, pwd string, err error) {
|
||||
return string(usrBytes), string(pwdBytes), nil
|
||||
}
|
||||
|
||||
func getAddrType(addr string) addrType {
|
||||
if ip := net.ParseIP(addr); ip != nil {
|
||||
if ip.To4() != nil {
|
||||
return ipv4
|
||||
}
|
||||
return ipv6
|
||||
}
|
||||
return domainName
|
||||
}
|
||||
|
||||
// request represents data contained within a SOCKS5
|
||||
// connection request packet.
|
||||
type request struct {
|
||||
command commandType
|
||||
destination string
|
||||
port uint16
|
||||
destAddrType addrType
|
||||
command commandType
|
||||
destination socksAddr
|
||||
}
|
||||
|
||||
// parseClientRequest converts raw packet bytes into a
|
||||
// SOCKS5Request struct.
|
||||
func parseClientRequest(r io.Reader) (*request, error) {
|
||||
var hdr [4]byte
|
||||
var hdr [3]byte
|
||||
_, err := io.ReadFull(r, hdr[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read packet header")
|
||||
}
|
||||
cmd := hdr[1]
|
||||
destAddrType := addrType(hdr[3])
|
||||
|
||||
destination, err := parseSocksAddr(r)
|
||||
return &request{
|
||||
command: commandType(cmd),
|
||||
destination: destination,
|
||||
}, err
|
||||
}
|
||||
|
||||
type socksAddr struct {
|
||||
addrType addrType
|
||||
addr string
|
||||
port uint16
|
||||
}
|
||||
|
||||
var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
|
||||
|
||||
func parseSocksAddr(r io.Reader) (addr socksAddr, err error) {
|
||||
var addrTypeData [1]byte
|
||||
_, err = io.ReadFull(r, addrTypeData[:])
|
||||
if err != nil {
|
||||
return socksAddr{}, fmt.Errorf("could not read address type")
|
||||
}
|
||||
|
||||
dstAddrType := addrType(addrTypeData[0])
|
||||
var destination string
|
||||
var port uint16
|
||||
|
||||
if destAddrType == ipv4 {
|
||||
switch dstAddrType {
|
||||
case ipv4:
|
||||
var ip [4]byte
|
||||
_, err = io.ReadFull(r, ip[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read IPv4 address")
|
||||
return socksAddr{}, fmt.Errorf("could not read IPv4 address")
|
||||
}
|
||||
destination = net.IP(ip[:]).String()
|
||||
} else if destAddrType == domainName {
|
||||
case domainName:
|
||||
var dstSizeByte [1]byte
|
||||
_, err = io.ReadFull(r, dstSizeByte[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read domain name size")
|
||||
return socksAddr{}, fmt.Errorf("could not read domain name size")
|
||||
}
|
||||
dstSize := int(dstSizeByte[0])
|
||||
domainName := make([]byte, dstSize)
|
||||
_, err = io.ReadFull(r, domainName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read domain name")
|
||||
return socksAddr{}, fmt.Errorf("could not read domain name")
|
||||
}
|
||||
destination = string(domainName)
|
||||
} else if destAddrType == ipv6 {
|
||||
case ipv6:
|
||||
var ip [16]byte
|
||||
_, err = io.ReadFull(r, ip[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read IPv6 address")
|
||||
return socksAddr{}, fmt.Errorf("could not read IPv6 address")
|
||||
}
|
||||
destination = net.IP(ip[:]).String()
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported address type")
|
||||
default:
|
||||
return socksAddr{}, fmt.Errorf("unsupported address type")
|
||||
}
|
||||
var portBytes [2]byte
|
||||
_, err = io.ReadFull(r, portBytes[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read port")
|
||||
return socksAddr{}, fmt.Errorf("could not read port")
|
||||
}
|
||||
port = binary.BigEndian.Uint16(portBytes[:])
|
||||
|
||||
return &request{
|
||||
command: commandType(cmd),
|
||||
destination: destination,
|
||||
port: port,
|
||||
destAddrType: destAddrType,
|
||||
port := binary.BigEndian.Uint16(portBytes[:])
|
||||
return socksAddr{
|
||||
addrType: dstAddrType,
|
||||
addr: destination,
|
||||
port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// response contains the contents of
|
||||
// a response packet sent from the proxy
|
||||
// to the client.
|
||||
type response struct {
|
||||
reply replyCode
|
||||
bindAddrType addrType
|
||||
bindAddr string
|
||||
bindPort uint16
|
||||
}
|
||||
|
||||
// marshal converts a SOCKS5Response struct into
|
||||
// a packet. If res.reply == Success, it may throw an error on
|
||||
// receiving an invalid bind address. Otherwise, it will not throw.
|
||||
func (res *response) marshal() ([]byte, error) {
|
||||
pkt := make([]byte, 4)
|
||||
pkt[0] = socks5Version
|
||||
pkt[1] = byte(res.reply)
|
||||
pkt[2] = 0 // null reserved byte
|
||||
pkt[3] = byte(res.bindAddrType)
|
||||
|
||||
if res.reply != success {
|
||||
return pkt, nil
|
||||
}
|
||||
|
||||
func (s socksAddr) marshal() ([]byte, error) {
|
||||
var addr []byte
|
||||
switch res.bindAddrType {
|
||||
switch s.addrType {
|
||||
case ipv4:
|
||||
addr = net.ParseIP(res.bindAddr).To4()
|
||||
addr = net.ParseIP(s.addr).To4()
|
||||
if addr == nil {
|
||||
return nil, fmt.Errorf("invalid IPv4 address for binding")
|
||||
}
|
||||
case domainName:
|
||||
if len(res.bindAddr) > 255 {
|
||||
if len(s.addr) > 255 {
|
||||
return nil, fmt.Errorf("invalid domain name for binding")
|
||||
}
|
||||
addr = make([]byte, 0, len(res.bindAddr)+1)
|
||||
addr = append(addr, byte(len(res.bindAddr)))
|
||||
addr = append(addr, []byte(res.bindAddr)...)
|
||||
addr = make([]byte, 0, len(s.addr)+1)
|
||||
addr = append(addr, byte(len(s.addr)))
|
||||
addr = append(addr, []byte(s.addr)...)
|
||||
case ipv6:
|
||||
addr = net.ParseIP(res.bindAddr).To16()
|
||||
addr = net.ParseIP(s.addr).To16()
|
||||
if addr == nil {
|
||||
return nil, fmt.Errorf("invalid IPv6 address for binding")
|
||||
}
|
||||
@@ -410,8 +622,86 @@ func (res *response) marshal() ([]byte, error) {
|
||||
return nil, fmt.Errorf("unsupported address type")
|
||||
}
|
||||
|
||||
pkt := []byte{byte(s.addrType)}
|
||||
pkt = append(pkt, addr...)
|
||||
pkt = binary.BigEndian.AppendUint16(pkt, uint16(res.bindPort))
|
||||
|
||||
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
|
||||
return pkt, nil
|
||||
}
|
||||
func (s socksAddr) hostPort() string {
|
||||
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
|
||||
}
|
||||
|
||||
// response contains the contents of
|
||||
// a response packet sent from the proxy
|
||||
// to the client.
|
||||
type response struct {
|
||||
reply replyCode
|
||||
bindAddr socksAddr
|
||||
}
|
||||
|
||||
func errorResponse(code replyCode) *response {
|
||||
return &response{reply: code, bindAddr: zeroSocksAddr}
|
||||
}
|
||||
|
||||
// marshal converts a SOCKS5Response struct into
|
||||
// a packet. If res.reply == Success, it may throw an error on
|
||||
// receiving an invalid bind address. Otherwise, it will not throw.
|
||||
func (res *response) marshal() ([]byte, error) {
|
||||
pkt := make([]byte, 3)
|
||||
pkt[0] = socks5Version
|
||||
pkt[1] = byte(res.reply)
|
||||
pkt[2] = 0 // null reserved byte
|
||||
|
||||
addrPkt, err := res.bindAddr.marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(pkt, addrPkt...), nil
|
||||
}
|
||||
|
||||
type udpRequest struct {
|
||||
frag byte
|
||||
addr socksAddr
|
||||
}
|
||||
|
||||
// +----+------+------+----------+----------+----------+
|
||||
// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
|
||||
// +----+------+------+----------+----------+----------+
|
||||
// | 2 | 1 | 1 | Variable | 2 | Variable |
|
||||
// +----+------+------+----------+----------+----------+
|
||||
func parseUDPRequest(data []byte) (*udpRequest, []byte, error) {
|
||||
if len(data) < 4 {
|
||||
return nil, nil, fmt.Errorf("invalid packet length")
|
||||
}
|
||||
|
||||
// reserved bytes
|
||||
if !(data[0] == 0 && data[1] == 0) {
|
||||
return nil, nil, fmt.Errorf("invalid udp request header")
|
||||
}
|
||||
|
||||
frag := data[2]
|
||||
|
||||
reader := bytes.NewReader(data[3:])
|
||||
addr, err := parseSocksAddr(reader)
|
||||
bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length
|
||||
body := data[len(data)-bodyLen:]
|
||||
return &udpRequest{
|
||||
frag: frag,
|
||||
addr: addr,
|
||||
}, body, err
|
||||
}
|
||||
|
||||
func (u *udpRequest) marshal() ([]byte, error) {
|
||||
pkt := make([]byte, 3)
|
||||
pkt[0] = 0
|
||||
pkt[1] = 0
|
||||
pkt[2] = u.frag
|
||||
|
||||
addrPkt, err := u.addr.marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(pkt, addrPkt...), nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -32,6 +33,19 @@ func backendServer(listener net.Listener) {
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
func udpEchoServer(conn net.PacketConn) {
|
||||
var buf [1024]byte
|
||||
n, addr, err := conn.ReadFrom(buf[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = conn.WriteTo(buf[:n], addr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
// backend server which we'll use SOCKS5 to connect to
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
@@ -152,3 +166,102 @@ func TestReadPassword(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP(t *testing.T) {
|
||||
// backend UDP server which we'll use SOCKS5 to connect to
|
||||
listener, err := net.ListenPacket("udp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
|
||||
go udpEchoServer(listener)
|
||||
|
||||
// SOCKS5 server
|
||||
socks5, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
socks5Port := socks5.Addr().(*net.TCPAddr).Port
|
||||
go socks5Server(socks5)
|
||||
|
||||
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf) // server hello
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
|
||||
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
|
||||
}
|
||||
|
||||
targetAddr := socksAddr{
|
||||
addrType: domainName,
|
||||
addr: "localhost",
|
||||
port: uint16(backendServerPort),
|
||||
}
|
||||
targetAddrPkt, err := targetAddr.marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n, err = conn.Read(buf) // server response
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
|
||||
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
|
||||
}
|
||||
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
udpPayload = append(udpPayload, []byte("Test")...)
|
||||
_, err = udpConn.Write(udpPayload) // send udp package
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
n, _, err = udpConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(responseBody) != "Test" {
|
||||
t.Fatalf("got: %q want: Test", responseBody)
|
||||
}
|
||||
err = udpConn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ func (d *Dialer) Close() error {
|
||||
c.Close()
|
||||
}
|
||||
d.activeSysConns = nil
|
||||
d.PeerAPITransport().CloseIdleConnections()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -162,6 +162,10 @@ type Wrapper struct {
|
||||
PreFilterPacketInboundFromWireGuard FilterFunc
|
||||
// PostFilterPacketInboundFromWireGuard is the inbound filter function that runs after the main filter.
|
||||
PostFilterPacketInboundFromWireGuard FilterFunc
|
||||
// EndPacketVectorInboundFromWireGuardFlush is a function that runs after all packets in a given vector
|
||||
// have been handled by all filters. Filters may queue packets for the purposes of GRO, requiring an
|
||||
// explicit flush.
|
||||
EndPacketVectorInboundFromWireGuardFlush func()
|
||||
// PreFilterPacketOutboundToWireGuardNetstackIntercept is a filter function that runs before the main filter
|
||||
// for packets from the local system. This filter is populated by netstack to hook
|
||||
// packets that should be handled by netstack. If set, this filter runs before
|
||||
@@ -1179,6 +1183,9 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.EndPacketVectorInboundFromWireGuardFlush != nil {
|
||||
t.EndPacketVectorInboundFromWireGuardFlush()
|
||||
}
|
||||
if t.disableFilter {
|
||||
i = len(buffs)
|
||||
}
|
||||
|
||||
214
prober/prober.go
214
prober/prober.go
@@ -7,19 +7,26 @@
|
||||
package prober
|
||||
|
||||
import (
|
||||
"container/ring"
|
||||
"context"
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"maps"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"tailscale.com/tsweb"
|
||||
)
|
||||
|
||||
// recentHistSize is the number of recent probe results and latencies to keep
|
||||
// in memory.
|
||||
const recentHistSize = 10
|
||||
|
||||
// ProbeClass defines a probe of a specific type: a probing function that will
|
||||
// be regularly ran, and metric labels that will be added automatically to all
|
||||
// probes using this class.
|
||||
@@ -106,6 +113,14 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob
|
||||
l[k] = v
|
||||
}
|
||||
|
||||
probe := newProbe(p, name, interval, l, pc)
|
||||
p.probes[name] = probe
|
||||
go probe.loop()
|
||||
return probe
|
||||
}
|
||||
|
||||
// newProbe creates a new Probe with the given parameters, but does not start it.
|
||||
func newProbe(p *Prober, name string, interval time.Duration, l prometheus.Labels, pc ProbeClass) *Probe {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
probe := &Probe{
|
||||
prober: p,
|
||||
@@ -117,6 +132,9 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob
|
||||
probeClass: pc,
|
||||
interval: interval,
|
||||
initialDelay: initialDelay(name, interval),
|
||||
successHist: ring.New(recentHistSize),
|
||||
latencyHist: ring.New(recentHistSize),
|
||||
|
||||
metrics: prometheus.NewRegistry(),
|
||||
metricLabels: l,
|
||||
mInterval: prometheus.NewDesc("interval_secs", "Probe interval in seconds", nil, l),
|
||||
@@ -131,15 +149,14 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob
|
||||
Name: "seconds_total", Help: "Total amount of time spent executing the probe", ConstLabels: l,
|
||||
}, []string{"status"}),
|
||||
}
|
||||
|
||||
prometheus.WrapRegistererWithPrefix(p.namespace+"_", p.metrics).MustRegister(probe.metrics)
|
||||
if p.metrics != nil {
|
||||
prometheus.WrapRegistererWithPrefix(p.namespace+"_", p.metrics).MustRegister(probe.metrics)
|
||||
}
|
||||
probe.metrics.MustRegister(probe)
|
||||
|
||||
p.probes[name] = probe
|
||||
go probe.loop()
|
||||
return probe
|
||||
}
|
||||
|
||||
// unregister removes a probe from the prober's internal state.
|
||||
func (p *Prober) unregister(probe *Probe) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -206,6 +223,7 @@ type Probe struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc // run to initiate shutdown
|
||||
stopped chan struct{} // closed when shutdown is complete
|
||||
runMu sync.Mutex // ensures only one probe runs at a time
|
||||
|
||||
name string
|
||||
probeClass ProbeClass
|
||||
@@ -232,6 +250,10 @@ type Probe struct {
|
||||
latency time.Duration // last successful probe latency
|
||||
succeeded bool // whether the last doProbe call succeeded
|
||||
lastErr error
|
||||
|
||||
// History of recent probe results and latencies.
|
||||
successHist *ring.Ring
|
||||
latencyHist *ring.Ring
|
||||
}
|
||||
|
||||
// Close shuts down the Probe and unregisters it from its Prober.
|
||||
@@ -278,13 +300,17 @@ func (p *Probe) loop() {
|
||||
}
|
||||
}
|
||||
|
||||
// run invokes fun and records the results.
|
||||
// run invokes the probe function and records the result. It returns the probe
|
||||
// result and an error if the probe failed.
|
||||
//
|
||||
// fun is invoked with a timeout slightly less than interval, so that
|
||||
// the probe either succeeds or fails before the next cycle is
|
||||
// scheduled to start.
|
||||
func (p *Probe) run() {
|
||||
start := p.recordStart()
|
||||
// The probe function is invoked with a timeout slightly less than interval, so
|
||||
// that the probe either succeeds or fails before the next cycle is scheduled to
|
||||
// start.
|
||||
func (p *Probe) run() (pi ProbeInfo, err error) {
|
||||
p.runMu.Lock()
|
||||
defer p.runMu.Unlock()
|
||||
|
||||
p.recordStart()
|
||||
defer func() {
|
||||
// Prevent a panic within one probe function from killing the
|
||||
// entire prober, so that a single buggy probe doesn't destroy
|
||||
@@ -293,29 +319,30 @@ func (p *Probe) run() {
|
||||
// alert for debugging.
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("probe %s panicked: %v", p.name, r)
|
||||
p.recordEnd(start, errors.New("panic"))
|
||||
err = fmt.Errorf("panic: %v", r)
|
||||
p.recordEnd(err)
|
||||
}
|
||||
}()
|
||||
timeout := time.Duration(float64(p.interval) * 0.8)
|
||||
ctx, cancel := context.WithTimeout(p.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
err := p.probeClass.Probe(ctx)
|
||||
p.recordEnd(start, err)
|
||||
err = p.probeClass.Probe(ctx)
|
||||
p.recordEnd(err)
|
||||
if err != nil {
|
||||
log.Printf("probe %s: %v", p.name, err)
|
||||
}
|
||||
pi = p.probeInfoLocked()
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Probe) recordStart() time.Time {
|
||||
st := p.prober.now()
|
||||
func (p *Probe) recordStart() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.start = st
|
||||
return st
|
||||
p.start = p.prober.now()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *Probe) recordEnd(start time.Time, err error) {
|
||||
func (p *Probe) recordEnd(err error) {
|
||||
end := p.prober.now()
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -327,22 +354,55 @@ func (p *Probe) recordEnd(start time.Time, err error) {
|
||||
p.latency = latency
|
||||
p.mAttempts.WithLabelValues("ok").Inc()
|
||||
p.mSeconds.WithLabelValues("ok").Add(latency.Seconds())
|
||||
p.latencyHist.Value = latency
|
||||
p.latencyHist = p.latencyHist.Next()
|
||||
} else {
|
||||
p.latency = 0
|
||||
p.mAttempts.WithLabelValues("fail").Inc()
|
||||
p.mSeconds.WithLabelValues("fail").Add(latency.Seconds())
|
||||
}
|
||||
p.successHist.Value = p.succeeded
|
||||
p.successHist = p.successHist.Next()
|
||||
}
|
||||
|
||||
// ProbeInfo is the state of a Probe.
|
||||
// ProbeInfo is a snapshot of the configuration and state of a Probe.
|
||||
type ProbeInfo struct {
|
||||
Start time.Time
|
||||
End time.Time
|
||||
Latency string
|
||||
Result bool
|
||||
Error string
|
||||
Name string
|
||||
Class string
|
||||
Interval time.Duration
|
||||
Labels map[string]string
|
||||
Start time.Time
|
||||
End time.Time
|
||||
Latency time.Duration
|
||||
Result bool
|
||||
Error string
|
||||
RecentResults []bool
|
||||
RecentLatencies []time.Duration
|
||||
}
|
||||
|
||||
// RecentSuccessRatio returns the success ratio of the probe in the recent history.
|
||||
func (pb ProbeInfo) RecentSuccessRatio() float64 {
|
||||
if len(pb.RecentResults) == 0 {
|
||||
return 0
|
||||
}
|
||||
var sum int
|
||||
for _, r := range pb.RecentResults {
|
||||
if r {
|
||||
sum++
|
||||
}
|
||||
}
|
||||
return float64(sum) / float64(len(pb.RecentResults))
|
||||
}
|
||||
|
||||
// RecentMedianLatency returns the median latency of the probe in the recent history.
|
||||
func (pb ProbeInfo) RecentMedianLatency() time.Duration {
|
||||
if len(pb.RecentLatencies) == 0 {
|
||||
return 0
|
||||
}
|
||||
return pb.RecentLatencies[len(pb.RecentLatencies)/2]
|
||||
}
|
||||
|
||||
// ProbeInfo returns the state of all probes.
|
||||
func (p *Prober) ProbeInfo() map[string]ProbeInfo {
|
||||
out := map[string]ProbeInfo{}
|
||||
|
||||
@@ -352,26 +412,100 @@ func (p *Prober) ProbeInfo() map[string]ProbeInfo {
|
||||
probes = append(probes, probe)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
for _, probe := range probes {
|
||||
probe.mu.Lock()
|
||||
inf := ProbeInfo{
|
||||
Start: probe.start,
|
||||
End: probe.end,
|
||||
Result: probe.succeeded,
|
||||
}
|
||||
if probe.lastErr != nil {
|
||||
inf.Error = probe.lastErr.Error()
|
||||
}
|
||||
if probe.latency > 0 {
|
||||
inf.Latency = probe.latency.String()
|
||||
}
|
||||
out[probe.name] = inf
|
||||
out[probe.name] = probe.probeInfoLocked()
|
||||
probe.mu.Unlock()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// probeInfoLocked returns the state of the probe.
|
||||
func (probe *Probe) probeInfoLocked() ProbeInfo {
|
||||
inf := ProbeInfo{
|
||||
Name: probe.name,
|
||||
Class: probe.probeClass.Class,
|
||||
Interval: probe.interval,
|
||||
Labels: probe.metricLabels,
|
||||
Start: probe.start,
|
||||
End: probe.end,
|
||||
Result: probe.succeeded,
|
||||
}
|
||||
if probe.lastErr != nil {
|
||||
inf.Error = probe.lastErr.Error()
|
||||
}
|
||||
if probe.latency > 0 {
|
||||
inf.Latency = probe.latency
|
||||
}
|
||||
probe.latencyHist.Do(func(v any) {
|
||||
if l, ok := v.(time.Duration); ok {
|
||||
inf.RecentLatencies = append(inf.RecentLatencies, l)
|
||||
}
|
||||
})
|
||||
probe.successHist.Do(func(v any) {
|
||||
if r, ok := v.(bool); ok {
|
||||
inf.RecentResults = append(inf.RecentResults, r)
|
||||
}
|
||||
})
|
||||
return inf
|
||||
}
|
||||
|
||||
// RunHandlerResponse is the JSON response format for the RunHandler.
|
||||
type RunHandlerResponse struct {
|
||||
ProbeInfo ProbeInfo
|
||||
PreviousSuccessRatio float64
|
||||
PreviousMedianLatency time.Duration
|
||||
}
|
||||
|
||||
// RunHandler runs a probe by name and returns the result as an HTTP response.
|
||||
func (p *Prober) RunHandler(w http.ResponseWriter, r *http.Request) error {
|
||||
// Look up prober by name.
|
||||
name := r.FormValue("name")
|
||||
if name == "" {
|
||||
return tsweb.Error(http.StatusBadRequest, "missing name parameter", nil)
|
||||
}
|
||||
p.mu.Lock()
|
||||
probe, ok := p.probes[name]
|
||||
p.mu.Unlock()
|
||||
if !ok {
|
||||
return tsweb.Error(http.StatusNotFound, fmt.Sprintf("unknown probe %q", name), nil)
|
||||
}
|
||||
|
||||
probe.mu.Lock()
|
||||
prevInfo := probe.probeInfoLocked()
|
||||
probe.mu.Unlock()
|
||||
|
||||
info, err := probe.run()
|
||||
respStatus := http.StatusOK
|
||||
if err != nil {
|
||||
respStatus = http.StatusFailedDependency
|
||||
}
|
||||
|
||||
// Return serialized JSON response if the client requested JSON
|
||||
if r.Header.Get("Accept") == "application/json" {
|
||||
resp := &RunHandlerResponse{
|
||||
ProbeInfo: info,
|
||||
PreviousSuccessRatio: prevInfo.RecentSuccessRatio(),
|
||||
PreviousMedianLatency: prevInfo.RecentMedianLatency(),
|
||||
}
|
||||
w.WriteHeader(respStatus)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
return tsweb.Error(http.StatusInternalServerError, "error encoding JSON response", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stats := fmt.Sprintf("Previous runs: success rate %d%%, median latency %v",
|
||||
int(prevInfo.RecentSuccessRatio()*100), prevInfo.RecentMedianLatency())
|
||||
if err != nil {
|
||||
return tsweb.Error(respStatus, fmt.Sprintf("Probe failed: %s\n%s", err.Error(), stats), err)
|
||||
}
|
||||
w.WriteHeader(respStatus)
|
||||
w.Write([]byte(fmt.Sprintf("Probe succeeded in %v\n%s", info.Latency, stats)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Describe implements prometheus.Collector.
|
||||
func (p *Probe) Describe(ch chan<- *prometheus.Desc) {
|
||||
ch <- p.mInterval
|
||||
|
||||
@@ -5,16 +5,22 @@ package prober
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/tsweb"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -292,6 +298,254 @@ func TestOnceMode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProberProbeInfo(t *testing.T) {
|
||||
clk := newFakeTime()
|
||||
p := newForTest(clk.Now, clk.NewTicker).WithOnce(true)
|
||||
|
||||
p.Run("probe1", probeInterval, nil, FuncProbe(func(context.Context) error {
|
||||
clk.Advance(500 * time.Millisecond)
|
||||
return nil
|
||||
}))
|
||||
p.Run("probe2", probeInterval, nil, FuncProbe(func(context.Context) error { return fmt.Errorf("error2") }))
|
||||
p.Wait()
|
||||
|
||||
info := p.ProbeInfo()
|
||||
wantInfo := map[string]ProbeInfo{
|
||||
"probe1": {
|
||||
Name: "probe1",
|
||||
Interval: probeInterval,
|
||||
Labels: map[string]string{"class": "", "name": "probe1"},
|
||||
Latency: 500 * time.Millisecond,
|
||||
Result: true,
|
||||
RecentResults: []bool{true},
|
||||
RecentLatencies: []time.Duration{500 * time.Millisecond},
|
||||
},
|
||||
"probe2": {
|
||||
Name: "probe2",
|
||||
Interval: probeInterval,
|
||||
Labels: map[string]string{"class": "", "name": "probe2"},
|
||||
Error: "error2",
|
||||
RecentResults: []bool{false},
|
||||
RecentLatencies: nil, // no latency for failed probes
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(wantInfo, info, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End")); diff != "" {
|
||||
t.Fatalf("unexpected ProbeInfo (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeInfoRecent(t *testing.T) {
|
||||
type probeResult struct {
|
||||
latency time.Duration
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
results []probeResult
|
||||
wantProbeInfo ProbeInfo
|
||||
wantRecentSuccessRatio float64
|
||||
wantRecentMedianLatency time.Duration
|
||||
}{
|
||||
{
|
||||
name: "no_runs",
|
||||
wantProbeInfo: ProbeInfo{},
|
||||
wantRecentSuccessRatio: 0,
|
||||
wantRecentMedianLatency: 0,
|
||||
},
|
||||
{
|
||||
name: "single_success",
|
||||
results: []probeResult{{latency: 100 * time.Millisecond, err: nil}},
|
||||
wantProbeInfo: ProbeInfo{
|
||||
Latency: 100 * time.Millisecond,
|
||||
Result: true,
|
||||
RecentResults: []bool{true},
|
||||
RecentLatencies: []time.Duration{100 * time.Millisecond},
|
||||
},
|
||||
wantRecentSuccessRatio: 1,
|
||||
wantRecentMedianLatency: 100 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "single_failure",
|
||||
results: []probeResult{{latency: 100 * time.Millisecond, err: errors.New("error123")}},
|
||||
wantProbeInfo: ProbeInfo{
|
||||
Result: false,
|
||||
RecentResults: []bool{false},
|
||||
RecentLatencies: nil,
|
||||
Error: "error123",
|
||||
},
|
||||
wantRecentSuccessRatio: 0,
|
||||
wantRecentMedianLatency: 0,
|
||||
},
|
||||
{
|
||||
name: "recent_mix",
|
||||
results: []probeResult{
|
||||
{latency: 10 * time.Millisecond, err: errors.New("error1")},
|
||||
{latency: 20 * time.Millisecond, err: nil},
|
||||
{latency: 30 * time.Millisecond, err: nil},
|
||||
{latency: 40 * time.Millisecond, err: errors.New("error4")},
|
||||
{latency: 50 * time.Millisecond, err: nil},
|
||||
{latency: 60 * time.Millisecond, err: nil},
|
||||
{latency: 70 * time.Millisecond, err: errors.New("error7")},
|
||||
{latency: 80 * time.Millisecond, err: nil},
|
||||
},
|
||||
wantProbeInfo: ProbeInfo{
|
||||
Result: true,
|
||||
Latency: 80 * time.Millisecond,
|
||||
RecentResults: []bool{false, true, true, false, true, true, false, true},
|
||||
RecentLatencies: []time.Duration{
|
||||
20 * time.Millisecond,
|
||||
30 * time.Millisecond,
|
||||
50 * time.Millisecond,
|
||||
60 * time.Millisecond,
|
||||
80 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
wantRecentSuccessRatio: 0.625,
|
||||
wantRecentMedianLatency: 50 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "only_last_10",
|
||||
results: []probeResult{
|
||||
{latency: 10 * time.Millisecond, err: errors.New("old_error")},
|
||||
{latency: 20 * time.Millisecond, err: nil},
|
||||
{latency: 30 * time.Millisecond, err: nil},
|
||||
{latency: 40 * time.Millisecond, err: nil},
|
||||
{latency: 50 * time.Millisecond, err: nil},
|
||||
{latency: 60 * time.Millisecond, err: nil},
|
||||
{latency: 70 * time.Millisecond, err: nil},
|
||||
{latency: 80 * time.Millisecond, err: nil},
|
||||
{latency: 90 * time.Millisecond, err: nil},
|
||||
{latency: 100 * time.Millisecond, err: nil},
|
||||
{latency: 110 * time.Millisecond, err: nil},
|
||||
},
|
||||
wantProbeInfo: ProbeInfo{
|
||||
Result: true,
|
||||
Latency: 110 * time.Millisecond,
|
||||
RecentResults: []bool{true, true, true, true, true, true, true, true, true, true},
|
||||
RecentLatencies: []time.Duration{
|
||||
20 * time.Millisecond,
|
||||
30 * time.Millisecond,
|
||||
40 * time.Millisecond,
|
||||
50 * time.Millisecond,
|
||||
60 * time.Millisecond,
|
||||
70 * time.Millisecond,
|
||||
80 * time.Millisecond,
|
||||
90 * time.Millisecond,
|
||||
100 * time.Millisecond,
|
||||
110 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
wantRecentSuccessRatio: 1,
|
||||
wantRecentMedianLatency: 70 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
|
||||
clk := newFakeTime()
|
||||
p := newForTest(clk.Now, clk.NewTicker).WithOnce(true)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
probe := newProbe(p, "", probeInterval, nil, FuncProbe(func(context.Context) error { return nil }))
|
||||
for _, r := range tt.results {
|
||||
probe.recordStart()
|
||||
clk.Advance(r.latency)
|
||||
probe.recordEnd(r.err)
|
||||
}
|
||||
info := probe.probeInfoLocked()
|
||||
if diff := cmp.Diff(tt.wantProbeInfo, info, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Interval")); diff != "" {
|
||||
t.Fatalf("unexpected ProbeInfo (-want +got):\n%s", diff)
|
||||
}
|
||||
if got := info.RecentSuccessRatio(); got != tt.wantRecentSuccessRatio {
|
||||
t.Errorf("recentSuccessRatio() = %v, want %v", got, tt.wantRecentSuccessRatio)
|
||||
}
|
||||
if got := info.RecentMedianLatency(); got != tt.wantRecentMedianLatency {
|
||||
t.Errorf("recentMedianLatency() = %v, want %v", got, tt.wantRecentMedianLatency)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProberRunHandler(t *testing.T) {
|
||||
clk := newFakeTime()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
probeFunc func(context.Context) error
|
||||
wantResponseCode int
|
||||
wantJSONResponse RunHandlerResponse
|
||||
wantPlaintextResponse string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
probeFunc: func(context.Context) error { return nil },
|
||||
wantResponseCode: 200,
|
||||
wantJSONResponse: RunHandlerResponse{
|
||||
ProbeInfo: ProbeInfo{
|
||||
Name: "success",
|
||||
Interval: probeInterval,
|
||||
Result: true,
|
||||
RecentResults: []bool{true, true},
|
||||
},
|
||||
PreviousSuccessRatio: 1,
|
||||
},
|
||||
wantPlaintextResponse: "Probe succeeded",
|
||||
},
|
||||
{
|
||||
name: "failure",
|
||||
probeFunc: func(context.Context) error { return fmt.Errorf("error123") },
|
||||
wantResponseCode: 424,
|
||||
wantJSONResponse: RunHandlerResponse{
|
||||
ProbeInfo: ProbeInfo{
|
||||
Name: "failure",
|
||||
Interval: probeInterval,
|
||||
Result: false,
|
||||
Error: "error123",
|
||||
RecentResults: []bool{false, false},
|
||||
},
|
||||
},
|
||||
wantPlaintextResponse: "Probe failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for _, reqJSON := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("%s_json-%v", tt.name, reqJSON), func(t *testing.T) {
|
||||
p := newForTest(clk.Now, clk.NewTicker).WithOnce(true)
|
||||
probe := p.Run(tt.name, probeInterval, nil, FuncProbe(tt.probeFunc))
|
||||
defer probe.Close()
|
||||
<-probe.stopped // wait for the first run.
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest("GET", "/prober/run/?name="+tt.name, nil)
|
||||
if reqJSON {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{}).ServeHTTP(w, req)
|
||||
if w.Result().StatusCode != tt.wantResponseCode {
|
||||
t.Errorf("unexpected response code: got %d, want %d", w.Code, tt.wantResponseCode)
|
||||
}
|
||||
|
||||
if reqJSON {
|
||||
var gotJSON RunHandlerResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &gotJSON); err != nil {
|
||||
t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, w.Body.String())
|
||||
}
|
||||
if diff := cmp.Diff(tt.wantJSONResponse, gotJSON, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Labels", "RecentLatencies")); diff != "" {
|
||||
t.Errorf("unexpected JSON response (-want +got):\n%s", diff)
|
||||
}
|
||||
} else {
|
||||
body, _ := io.ReadAll(w.Result().Body)
|
||||
if !strings.Contains(string(body), tt.wantPlaintextResponse) {
|
||||
t.Errorf("unexpected response body: got %q, want to contain %q", body, tt.wantPlaintextResponse)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type fakeTicker struct {
|
||||
ch chan time.Time
|
||||
interval time.Duration
|
||||
|
||||
124
prober/status.go
Normal file
124
prober/status.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package prober
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tsweb"
|
||||
"tailscale.com/util/mak"
|
||||
)
|
||||
|
||||
//go:embed status.html
|
||||
var statusFiles embed.FS
|
||||
var statusTpl = template.Must(template.ParseFS(statusFiles, "status.html"))
|
||||
|
||||
type statusHandlerOpt func(*statusHandlerParams)
|
||||
type statusHandlerParams struct {
|
||||
title string
|
||||
|
||||
pageLinks map[string]string
|
||||
probeLinks map[string]string
|
||||
}
|
||||
|
||||
// WithTitle sets the title of the status page.
|
||||
func WithTitle(title string) statusHandlerOpt {
|
||||
return func(opts *statusHandlerParams) {
|
||||
opts.title = title
|
||||
}
|
||||
}
|
||||
|
||||
// WithPageLink adds a top-level link to the status page.
|
||||
func WithPageLink(text, url string) statusHandlerOpt {
|
||||
return func(opts *statusHandlerParams) {
|
||||
mak.Set(&opts.pageLinks, text, url)
|
||||
}
|
||||
}
|
||||
|
||||
// WithProbeLink adds a link to each probe on the status page.
|
||||
// The textTpl and urlTpl are Go templates that will be rendered
|
||||
// with the respective ProbeInfo struct as the data.
|
||||
func WithProbeLink(textTpl, urlTpl string) statusHandlerOpt {
|
||||
return func(opts *statusHandlerParams) {
|
||||
mak.Set(&opts.probeLinks, textTpl, urlTpl)
|
||||
}
|
||||
}
|
||||
|
||||
// StatusHandler is a handler for the probe overview HTTP endpoint.
|
||||
// It shows a list of probes and their current status.
|
||||
func (p *Prober) StatusHandler(opts ...statusHandlerOpt) tsweb.ReturnHandlerFunc {
|
||||
params := &statusHandlerParams{
|
||||
title: "Prober Status",
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(params)
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) error {
|
||||
type probeStatus struct {
|
||||
ProbeInfo
|
||||
TimeSinceLast time.Duration
|
||||
Links map[string]template.URL
|
||||
}
|
||||
vars := struct {
|
||||
Title string
|
||||
Links map[string]template.URL
|
||||
TotalProbes int64
|
||||
UnhealthyProbes int64
|
||||
Probes map[string]probeStatus
|
||||
}{
|
||||
Title: params.title,
|
||||
}
|
||||
|
||||
for text, url := range params.pageLinks {
|
||||
mak.Set(&vars.Links, text, template.URL(url))
|
||||
}
|
||||
|
||||
for name, info := range p.ProbeInfo() {
|
||||
vars.TotalProbes++
|
||||
if !info.Result {
|
||||
vars.UnhealthyProbes++
|
||||
}
|
||||
s := probeStatus{ProbeInfo: info}
|
||||
if !info.End.IsZero() {
|
||||
s.TimeSinceLast = time.Since(info.End)
|
||||
}
|
||||
for textTpl, urlTpl := range params.probeLinks {
|
||||
text, err := renderTemplate(textTpl, info)
|
||||
if err != nil {
|
||||
return tsweb.Error(500, err.Error(), err)
|
||||
}
|
||||
url, err := renderTemplate(urlTpl, info)
|
||||
if err != nil {
|
||||
return tsweb.Error(500, err.Error(), err)
|
||||
}
|
||||
mak.Set(&s.Links, text, template.URL(url))
|
||||
}
|
||||
mak.Set(&vars.Probes, name, s)
|
||||
}
|
||||
|
||||
if err := statusTpl.ExecuteTemplate(w, "status", vars); err != nil {
|
||||
return tsweb.HTTPError{Code: 500, Err: err, Msg: "error rendering status page"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// renderTemplate renders the given Go template with the provided data
|
||||
// and returns the result as a string.
|
||||
func renderTemplate(tpl string, data any) (string, error) {
|
||||
t, err := template.New("").Parse(tpl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing template %q: %w", tpl, err)
|
||||
}
|
||||
var buf strings.Builder
|
||||
if err := t.ExecuteTemplate(&buf, "", data); err != nil {
|
||||
return "", fmt.Errorf("error rendering template %q with data %v: %w", tpl, data, err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
132
prober/status.html
Normal file
132
prober/status.html
Normal file
@@ -0,0 +1,132 @@
|
||||
{{define "status"}}
|
||||
<html>
|
||||
<head><title>{{.Title}}</title></head>
|
||||
<style>
|
||||
body {
|
||||
/* max-width: 60rem; */
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
padding: 3rem 1rem 8rem;
|
||||
line-height: 1.4;
|
||||
font-size: 1rem;
|
||||
font-weight: 400;
|
||||
font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, Arial, Noto Sans, sans-serif, Apple Color Emoji, Segoe UI Emoji, Segoe UI Symbol, Noto Color Emoji;
|
||||
text-rendering: optimizeLegibility;
|
||||
}
|
||||
.small {
|
||||
font-size: 0.7rem;
|
||||
}
|
||||
h1 {
|
||||
font-weight: 500;
|
||||
letter-spacing: -.025em;
|
||||
}
|
||||
a { color: rgb(74 125 221); }
|
||||
a:hover { color: rgb(73 100 149); }
|
||||
ul {
|
||||
list-style: none;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
ul>li::before {
|
||||
position: absolute;
|
||||
top: .625rem;
|
||||
left: .125rem;
|
||||
height: .375rem;
|
||||
width: .375rem;
|
||||
border-radius: 9999px;
|
||||
background-color: currentColor;
|
||||
opacity: .4;
|
||||
content: "";
|
||||
}
|
||||
ul>li {
|
||||
position: relative;
|
||||
padding-left: 1.25rem;
|
||||
}
|
||||
th, td {
|
||||
padding: 5px;
|
||||
text-align: left;
|
||||
background: #eeeeee;
|
||||
}
|
||||
.error {
|
||||
color: red;
|
||||
}
|
||||
</style>
|
||||
<body>
|
||||
<h1>{{.Title}}</h1>
|
||||
<ul>
|
||||
<li>Prober Status:
|
||||
{{if .UnhealthyProbes }}
|
||||
<span class="error">{{.UnhealthyProbes}}</span>
|
||||
out of {{.TotalProbes}} probes failed or never ran.
|
||||
{{else}}
|
||||
All {{.TotalProbes}} probes are healthy
|
||||
{{end}}
|
||||
</li>
|
||||
{{ range $text, $url := .Links }}
|
||||
<li><a href="{{$url}}">{{$text}}</a></li>
|
||||
{{end}}
|
||||
</ul>
|
||||
|
||||
<h1>Probes:</h1>
|
||||
<table class="sortable">
|
||||
<thead><tr>
|
||||
<th>Name</th>
|
||||
<th>Class & Labels</th>
|
||||
<th>Interval</th>
|
||||
<th>Result</th>
|
||||
<th>Success</th>
|
||||
<th>Latency</th>
|
||||
<th>Error</th>
|
||||
</tr></thead>
|
||||
<tbody>
|
||||
{{range $name, $probeInfo := .Probes}}
|
||||
<tr>
|
||||
<td>
|
||||
{{$name}}
|
||||
{{range $text, $url := $probeInfo.Links}}
|
||||
<br/>
|
||||
<button onclick="location.href='{{$url}}';" type="button">
|
||||
{{$text}}
|
||||
</button>
|
||||
{{end}}
|
||||
</td>
|
||||
<td>{{$probeInfo.Class}}<br/>
|
||||
<div class="small">
|
||||
{{range $label, $value := $probeInfo.Labels}}
|
||||
{{$label}}={{$value}}<br/>
|
||||
{{end}}
|
||||
</div>
|
||||
</td>
|
||||
<td>{{$probeInfo.Interval}}</td>
|
||||
<td data-sort="{{$probeInfo.TimeSinceLast.Milliseconds}}">
|
||||
{{if $probeInfo.TimeSinceLast}}
|
||||
{{$probeInfo.TimeSinceLast.String}}<br/>
|
||||
<span class="small">{{$probeInfo.End}}</span>
|
||||
{{else}}
|
||||
Never
|
||||
{{end}}
|
||||
</td>
|
||||
<td>
|
||||
{{if $probeInfo.Result}}
|
||||
{{$probeInfo.Result}}
|
||||
{{else}}
|
||||
<span class="error">{{$probeInfo.Result}}</span>
|
||||
{{end}}<br/>
|
||||
<div class="small">Recent: {{$probeInfo.RecentResults}}</div>
|
||||
<div class="small">Mean: {{$probeInfo.RecentSuccessRatio}}</div>
|
||||
</td>
|
||||
<td data-sort="{{$probeInfo.Latency.Milliseconds}}">
|
||||
{{$probeInfo.Latency.String}}
|
||||
<div class="small">Recent: {{$probeInfo.RecentLatencies}}</div>
|
||||
<div class="small">Median: {{$probeInfo.RecentMedianLatency}}</div>
|
||||
</td>
|
||||
<td class="small">{{$probeInfo.Error}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
<link href="https://cdn.jsdelivr.net/gh/tofsjonas/sortable@latest/sortable-base.min.css" rel="stylesheet" />
|
||||
<script src="https://cdn.jsdelivr.net/gh/tofsjonas/sortable@latest/sortable.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
@@ -16,4 +16,4 @@
|
||||
) {
|
||||
src = ./.;
|
||||
}).shellNix
|
||||
# nix-direnv cache busting line: sha256-N0TZ1JuDqh6bZjOHcMfoEDOsiUlrC/tR72fBns1GwrM=
|
||||
# nix-direnv cache busting line: sha256-1hekcJr1jEJFu4ZnapNkbAAv+8phTQuMloULIZ0f018=
|
||||
|
||||
@@ -146,7 +146,8 @@ type CapabilityVersion int
|
||||
// - 101: 2024-07-01: Client supports SSH agent forwarding when handling connections with /bin/su
|
||||
// - 102: 2024-07-12: NodeAttrDisableMagicSockCryptoRouting support
|
||||
// - 103: 2024-07-24: Client supports NodeAttrDisableCaptivePortalDetection
|
||||
const CurrentCapabilityVersion CapabilityVersion = 103
|
||||
// - 104: 2024-08-03: SelfNodeV6MasqAddrForThisPeer now works
|
||||
const CurrentCapabilityVersion CapabilityVersion = 104
|
||||
|
||||
type StableID string
|
||||
|
||||
|
||||
@@ -842,6 +842,7 @@ func TestClientSideJailing(t *testing.T) {
|
||||
// TestNATPing creates two nodes, n1 and n2, sets up masquerades for both and
|
||||
// tries to do bi-directional pings between them.
|
||||
func TestNATPing(t *testing.T) {
|
||||
flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/12169")
|
||||
tstest.Shard(t)
|
||||
tstest.Parallel(t)
|
||||
for _, v6 := range []bool{false, true} {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"tailscale.com/cmd/testwrapper/flakytest"
|
||||
"tailscale.com/metrics"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/util/httpm"
|
||||
@@ -864,6 +865,7 @@ func TestStdHandler_CanceledAfterHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStdHandler_ConnectionClosedDuringBody(t *testing.T) {
|
||||
flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13017")
|
||||
now := time.Now()
|
||||
|
||||
// Start a HTTP server that returns 1MB of data.
|
||||
|
||||
@@ -592,9 +592,23 @@ func New(logf logger.Logf, prefHint string) (NetfilterRunner, error) {
|
||||
mode := detectFirewallMode(logf, prefHint)
|
||||
switch mode {
|
||||
case FirewallModeIPTables:
|
||||
return newIPTablesRunner(logf)
|
||||
// Note that we don't simply return an newIPTablesRunner here because it
|
||||
// would return a `nil` iptablesRunner which is different from returning
|
||||
// a nil NetfilterRunner.
|
||||
ipr, err := newIPTablesRunner(logf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ipr, nil
|
||||
case FirewallModeNfTables:
|
||||
return newNfTablesRunner(logf)
|
||||
// Note that we don't simply return an newNfTablesRunner here because it
|
||||
// would return a `nil` nftablesRunner which is different from returning
|
||||
// a nil NetfilterRunner.
|
||||
nfr, err := newNfTablesRunner(logf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nfr, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown firewall mode %v", mode)
|
||||
}
|
||||
|
||||
@@ -189,6 +189,7 @@ func (l *PolicyLock) lockSlow() (err error) {
|
||||
select {
|
||||
case resultCh <- policyLockResult{handle, err}:
|
||||
// lockSlow has received the result.
|
||||
break send_result
|
||||
default:
|
||||
select {
|
||||
case <-closing:
|
||||
|
||||
@@ -5,9 +5,9 @@ end
|
||||
tsdebug_ll = Proto("tsdebug", "Tailscale debug")
|
||||
PATH = ProtoField.string("tsdebug.PATH","PATH", base.ASCII)
|
||||
SNAT_IP_4 = ProtoField.ipv4("tsdebug.SNAT_IP_4", "Pre-NAT Source IPv4 address")
|
||||
SNAT_IP_6 = ProtoField.ipv4("tsdebug.SNAT_IP_6", "Pre-NAT Source IPv6 address")
|
||||
SNAT_IP_6 = ProtoField.ipv6("tsdebug.SNAT_IP_6", "Pre-NAT Source IPv6 address")
|
||||
DNAT_IP_4 = ProtoField.ipv4("tsdebug.DNAT_IP_4", "Pre-NAT Dest IPv4 address")
|
||||
DNAT_IP_6 = ProtoField.ipv4("tsdebug.DNAT_IP_6", "Pre-NAT Dest IPv6 address")
|
||||
DNAT_IP_6 = ProtoField.ipv6("tsdebug.DNAT_IP_6", "Pre-NAT Dest IPv6 address")
|
||||
tsdebug_ll.fields = {PATH, SNAT_IP_4, SNAT_IP_6, DNAT_IP_4, DNAT_IP_6}
|
||||
|
||||
function tsdebug_ll.dissector(buffer, pinfo, tree)
|
||||
@@ -63,7 +63,7 @@ local ts_dissectors = DissectorTable.new("ts.proto", "Tailscale-specific dissect
|
||||
tsdisco_meta = Proto("tsdisco", "Tailscale DISCO metadata")
|
||||
DISCO_IS_DERP = ProtoField.bool("tsdisco.IS_DERP","From DERP")
|
||||
DISCO_SRC_IP_4 = ProtoField.ipv4("tsdisco.SRC_IP_4", "Source IPv4 address")
|
||||
DISCO_SRC_IP_6 = ProtoField.ipv4("tsdisco.SRC_IP_6", "Source IPv6 address")
|
||||
DISCO_SRC_IP_6 = ProtoField.ipv6("tsdisco.SRC_IP_6", "Source IPv6 address")
|
||||
DISCO_SRC_PORT = ProtoField.uint16("tsdisco.SRC_PORT","Source port", base.DEC)
|
||||
DISCO_DERP_PUB = ProtoField.bytes("tsdisco.DERP_PUB", "DERP public key", base.SPACE)
|
||||
tsdisco_meta.fields = {DISCO_IS_DERP, DISCO_SRC_PORT, DISCO_DERP_PUB, DISCO_SRC_IP_4, DISCO_SRC_IP_6}
|
||||
|
||||
@@ -4,200 +4,22 @@
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"tailscale.com/net/neterror"
|
||||
"tailscale.com/types/nettype"
|
||||
)
|
||||
|
||||
// xnetBatchReaderWriter defines the batching i/o methods of
|
||||
// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn).
|
||||
// TODO(jwhited): This should eventually be replaced with the standard library
|
||||
// implementation of https://github.com/golang/go/issues/45886
|
||||
type xnetBatchReaderWriter interface {
|
||||
xnetBatchReader
|
||||
xnetBatchWriter
|
||||
}
|
||||
|
||||
type xnetBatchReader interface {
|
||||
ReadBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
type xnetBatchWriter interface {
|
||||
WriteBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
// batchingUDPConn is a UDP socket that provides batched i/o.
|
||||
type batchingUDPConn struct {
|
||||
pc nettype.PacketConn
|
||||
xpc xnetBatchReaderWriter
|
||||
rxOffload bool // supports UDP GRO or similar
|
||||
txOffload atomic.Bool // supports UDP GSO or similar
|
||||
setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing
|
||||
getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing
|
||||
sendBatchPool sync.Pool
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
|
||||
if c.rxOffload {
|
||||
// UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may
|
||||
// receive a "monster datagram" from any read call. The ReadFrom() API
|
||||
// does not support passing the GSO size and is unsafe to use in such a
|
||||
// case. Other platforms may vary in behavior, but we go with the most
|
||||
// conservative approach to prevent this from becoming a footgun in the
|
||||
// future.
|
||||
return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
|
||||
}
|
||||
return c.pc.ReadFromUDPAddrPort(p)
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) SetDeadline(t time.Time) error {
|
||||
return c.pc.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) SetReadDeadline(t time.Time) error {
|
||||
return c.pc.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.pc.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
const (
|
||||
// This was initially established for Linux, but may split out to
|
||||
// GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the
|
||||
// kernel's TX path, and UDP_GRO_CNT_MAX for RX.
|
||||
udpSegmentMaxDatagrams = 64
|
||||
var (
|
||||
// This acts as a compile-time check for our usage of ipv6.Message in
|
||||
// batchingConn for both IPv6 and IPv4 operations.
|
||||
_ ipv6.Message = ipv4.Message{}
|
||||
)
|
||||
|
||||
const (
|
||||
// Exceeding these values results in EMSGSIZE.
|
||||
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||
)
|
||||
|
||||
// coalesceMessages iterates msgs, coalescing them where possible while
|
||||
// maintaining datagram order. All msgs have their Addr field set to addr.
|
||||
func (c *batchingUDPConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int {
|
||||
var (
|
||||
base = -1 // index of msg we are currently coalescing into
|
||||
gsoSize int // segmentation size of msgs[base]
|
||||
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||
endBatch bool // tracking flag to start a new batch on next iteration of buffs
|
||||
)
|
||||
maxPayloadLen := maxIPv4PayloadLen
|
||||
if addr.IP.To4() == nil {
|
||||
maxPayloadLen = maxIPv6PayloadLen
|
||||
}
|
||||
for i, buff := range buffs {
|
||||
if i > 0 {
|
||||
msgLen := len(buff)
|
||||
baseLenBefore := len(msgs[base].Buffers[0])
|
||||
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||
msgLen <= gsoSize &&
|
||||
msgLen <= freeBaseCap &&
|
||||
dgramCnt < udpSegmentMaxDatagrams &&
|
||||
!endBatch {
|
||||
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...)
|
||||
copy(msgs[base].Buffers[0][baseLenBefore:], buff)
|
||||
if i == len(buffs)-1 {
|
||||
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
dgramCnt++
|
||||
if msgLen < gsoSize {
|
||||
// A smaller than gsoSize packet on the tail is legal, but
|
||||
// it must end the batch.
|
||||
endBatch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if dgramCnt > 1 {
|
||||
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
// Reset prior to incrementing base since we are preparing to start a
|
||||
// new potential batch.
|
||||
endBatch = false
|
||||
base++
|
||||
gsoSize = len(buff)
|
||||
msgs[base].OOB = msgs[base].OOB[:0]
|
||||
msgs[base].Buffers[0] = buff
|
||||
msgs[base].Addr = addr
|
||||
dgramCnt = 1
|
||||
}
|
||||
return base + 1
|
||||
}
|
||||
|
||||
type sendBatch struct {
|
||||
msgs []ipv6.Message
|
||||
ua *net.UDPAddr
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) getSendBatch() *sendBatch {
|
||||
batch := c.sendBatchPool.Get().(*sendBatch)
|
||||
return batch
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) putSendBatch(batch *sendBatch) {
|
||||
for i := range batch.msgs {
|
||||
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB}
|
||||
}
|
||||
c.sendBatchPool.Put(batch)
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
|
||||
batch := c.getSendBatch()
|
||||
defer c.putSendBatch(batch)
|
||||
if addr.Addr().Is6() {
|
||||
as16 := addr.Addr().As16()
|
||||
copy(batch.ua.IP, as16[:])
|
||||
batch.ua.IP = batch.ua.IP[:16]
|
||||
} else {
|
||||
as4 := addr.Addr().As4()
|
||||
copy(batch.ua.IP, as4[:])
|
||||
batch.ua.IP = batch.ua.IP[:4]
|
||||
}
|
||||
batch.ua.Port = int(addr.Port())
|
||||
var (
|
||||
n int
|
||||
retried bool
|
||||
)
|
||||
retry:
|
||||
if c.txOffload.Load() {
|
||||
n = c.coalesceMessages(batch.ua, buffs, batch.msgs)
|
||||
} else {
|
||||
for i := range buffs {
|
||||
batch.msgs[i].Buffers[0] = buffs[i]
|
||||
batch.msgs[i].Addr = batch.ua
|
||||
batch.msgs[i].OOB = batch.msgs[i].OOB[:0]
|
||||
}
|
||||
n = len(buffs)
|
||||
}
|
||||
|
||||
err := c.writeBatch(batch.msgs[:n])
|
||||
if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) {
|
||||
c.txOffload.Store(false)
|
||||
retried = true
|
||||
goto retry
|
||||
}
|
||||
if retried {
|
||||
return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) SyscallConn() (syscall.RawConn, error) {
|
||||
sc, ok := c.pc.(syscall.Conn)
|
||||
if !ok {
|
||||
return nil, errUnsupportedConnType
|
||||
}
|
||||
return sc.SyscallConn()
|
||||
// batchingConn is a nettype.PacketConn that provides batched i/o.
|
||||
type batchingConn interface {
|
||||
nettype.PacketConn
|
||||
ReadBatch(msgs []ipv6.Message, flags int) (n int, err error)
|
||||
WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error
|
||||
}
|
||||
|
||||
14
wgengine/magicsock/batching_conn_default.go
Normal file
14
wgengine/magicsock/batching_conn_default.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !linux
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"tailscale.com/types/nettype"
|
||||
)
|
||||
|
||||
func tryUpgradeToBatchingConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn {
|
||||
return pconn
|
||||
}
|
||||
419
wgengine/magicsock/batching_conn_linux.go
Normal file
419
wgengine/magicsock/batching_conn_linux.go
Normal file
@@ -0,0 +1,419 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
"tailscale.com/hostinfo"
|
||||
"tailscale.com/net/neterror"
|
||||
"tailscale.com/types/nettype"
|
||||
)
|
||||
|
||||
// xnetBatchReaderWriter defines the batching i/o methods of
|
||||
// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn).
|
||||
// TODO(jwhited): This should eventually be replaced with the standard library
|
||||
// implementation of https://github.com/golang/go/issues/45886
|
||||
type xnetBatchReaderWriter interface {
|
||||
xnetBatchReader
|
||||
xnetBatchWriter
|
||||
}
|
||||
|
||||
type xnetBatchReader interface {
|
||||
ReadBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
type xnetBatchWriter interface {
|
||||
WriteBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
// linuxBatchingConn is a UDP socket that provides batched i/o. It implements
|
||||
// batchingConn.
|
||||
type linuxBatchingConn struct {
|
||||
pc nettype.PacketConn
|
||||
xpc xnetBatchReaderWriter
|
||||
rxOffload bool // supports UDP GRO or similar
|
||||
txOffload atomic.Bool // supports UDP GSO or similar
|
||||
setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing
|
||||
getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing
|
||||
sendBatchPool sync.Pool
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
|
||||
if c.rxOffload {
|
||||
// UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may
|
||||
// receive a "monster datagram" from any read call. The ReadFrom() API
|
||||
// does not support passing the GSO size and is unsafe to use in such a
|
||||
// case. Other platforms may vary in behavior, but we go with the most
|
||||
// conservative approach to prevent this from becoming a footgun in the
|
||||
// future.
|
||||
return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
|
||||
}
|
||||
return c.pc.ReadFromUDPAddrPort(p)
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) SetDeadline(t time.Time) error {
|
||||
return c.pc.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) SetReadDeadline(t time.Time) error {
|
||||
return c.pc.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) SetWriteDeadline(t time.Time) error {
|
||||
return c.pc.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
const (
|
||||
// This was initially established for Linux, but may split out to
|
||||
// GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the
|
||||
// kernel's TX path, and UDP_GRO_CNT_MAX for RX.
|
||||
udpSegmentMaxDatagrams = 64
|
||||
)
|
||||
|
||||
const (
|
||||
// Exceeding these values results in EMSGSIZE.
|
||||
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||
)
|
||||
|
||||
// coalesceMessages iterates msgs, coalescing them where possible while
|
||||
// maintaining datagram order. All msgs have their Addr field set to addr.
|
||||
func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int {
|
||||
var (
|
||||
base = -1 // index of msg we are currently coalescing into
|
||||
gsoSize int // segmentation size of msgs[base]
|
||||
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||
endBatch bool // tracking flag to start a new batch on next iteration of buffs
|
||||
)
|
||||
maxPayloadLen := maxIPv4PayloadLen
|
||||
if addr.IP.To4() == nil {
|
||||
maxPayloadLen = maxIPv6PayloadLen
|
||||
}
|
||||
for i, buff := range buffs {
|
||||
if i > 0 {
|
||||
msgLen := len(buff)
|
||||
baseLenBefore := len(msgs[base].Buffers[0])
|
||||
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||
msgLen <= gsoSize &&
|
||||
msgLen <= freeBaseCap &&
|
||||
dgramCnt < udpSegmentMaxDatagrams &&
|
||||
!endBatch {
|
||||
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...)
|
||||
copy(msgs[base].Buffers[0][baseLenBefore:], buff)
|
||||
if i == len(buffs)-1 {
|
||||
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
dgramCnt++
|
||||
if msgLen < gsoSize {
|
||||
// A smaller than gsoSize packet on the tail is legal, but
|
||||
// it must end the batch.
|
||||
endBatch = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if dgramCnt > 1 {
|
||||
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
|
||||
}
|
||||
// Reset prior to incrementing base since we are preparing to start a
|
||||
// new potential batch.
|
||||
endBatch = false
|
||||
base++
|
||||
gsoSize = len(buff)
|
||||
msgs[base].OOB = msgs[base].OOB[:0]
|
||||
msgs[base].Buffers[0] = buff
|
||||
msgs[base].Addr = addr
|
||||
dgramCnt = 1
|
||||
}
|
||||
return base + 1
|
||||
}
|
||||
|
||||
type sendBatch struct {
|
||||
msgs []ipv6.Message
|
||||
ua *net.UDPAddr
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) getSendBatch() *sendBatch {
|
||||
batch := c.sendBatchPool.Get().(*sendBatch)
|
||||
return batch
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) putSendBatch(batch *sendBatch) {
|
||||
for i := range batch.msgs {
|
||||
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB}
|
||||
}
|
||||
c.sendBatchPool.Put(batch)
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
|
||||
batch := c.getSendBatch()
|
||||
defer c.putSendBatch(batch)
|
||||
if addr.Addr().Is6() {
|
||||
as16 := addr.Addr().As16()
|
||||
copy(batch.ua.IP, as16[:])
|
||||
batch.ua.IP = batch.ua.IP[:16]
|
||||
} else {
|
||||
as4 := addr.Addr().As4()
|
||||
copy(batch.ua.IP, as4[:])
|
||||
batch.ua.IP = batch.ua.IP[:4]
|
||||
}
|
||||
batch.ua.Port = int(addr.Port())
|
||||
var (
|
||||
n int
|
||||
retried bool
|
||||
)
|
||||
retry:
|
||||
if c.txOffload.Load() {
|
||||
n = c.coalesceMessages(batch.ua, buffs, batch.msgs)
|
||||
} else {
|
||||
for i := range buffs {
|
||||
batch.msgs[i].Buffers[0] = buffs[i]
|
||||
batch.msgs[i].Addr = batch.ua
|
||||
batch.msgs[i].OOB = batch.msgs[i].OOB[:0]
|
||||
}
|
||||
n = len(buffs)
|
||||
}
|
||||
|
||||
err := c.writeBatch(batch.msgs[:n])
|
||||
if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) {
|
||||
c.txOffload.Store(false)
|
||||
retried = true
|
||||
goto retry
|
||||
}
|
||||
if retried {
|
||||
return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) SyscallConn() (syscall.RawConn, error) {
|
||||
sc, ok := c.pc.(syscall.Conn)
|
||||
if !ok {
|
||||
return nil, errUnsupportedConnType
|
||||
}
|
||||
return sc.SyscallConn()
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) writeBatch(msgs []ipv6.Message) error {
|
||||
var head int
|
||||
for {
|
||||
n, err := c.xpc.WriteBatch(msgs[head:], 0)
|
||||
if err != nil || n == len(msgs[head:]) {
|
||||
// Returning the number of packets written would require
|
||||
// unraveling individual msg len and gso size during a coalesced
|
||||
// write. The top of the call stack disregards partial success,
|
||||
// so keep this simple for now.
|
||||
return err
|
||||
}
|
||||
head += n
|
||||
}
|
||||
}
|
||||
|
||||
// splitCoalescedMessages splits coalesced messages from the tail of dst
|
||||
// beginning at index 'firstMsgAt' into the head of the same slice. It reports
|
||||
// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An
|
||||
// error is returned if a socket control message cannot be parsed or a split
|
||||
// operation would overflow msgs.
|
||||
func (c *linuxBatchingConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) {
|
||||
for i := firstMsgAt; i < len(msgs); i++ {
|
||||
msg := &msgs[i]
|
||||
if msg.N == 0 {
|
||||
return n, err
|
||||
}
|
||||
var (
|
||||
gsoSize int
|
||||
start int
|
||||
end = msg.N
|
||||
numToSplit = 1
|
||||
)
|
||||
gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if gsoSize > 0 {
|
||||
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||
end = gsoSize
|
||||
}
|
||||
for j := 0; j < numToSplit; j++ {
|
||||
if n > i {
|
||||
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||
}
|
||||
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||
msgs[n].N = copied
|
||||
msgs[n].Addr = msg.Addr
|
||||
start = end
|
||||
end += gsoSize
|
||||
if end > msg.N {
|
||||
end = msg.N
|
||||
}
|
||||
n++
|
||||
}
|
||||
if i != n-1 {
|
||||
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||
// of splitting, so we only zero the source msg len when it is not
|
||||
// the destination of the last split operation above.
|
||||
msg.N = 0
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
|
||||
if !c.rxOffload || len(msgs) < 2 {
|
||||
return c.xpc.ReadBatch(msgs, flags)
|
||||
}
|
||||
// Read into the tail of msgs, split into the head.
|
||||
readAt := len(msgs) - 2
|
||||
numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0)
|
||||
if err != nil || numRead == 0 {
|
||||
return 0, err
|
||||
}
|
||||
return c.splitCoalescedMessages(msgs, readAt)
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) LocalAddr() net.Addr {
|
||||
return c.pc.LocalAddr().(*net.UDPAddr)
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
|
||||
return c.pc.WriteToUDPAddrPort(b, addr)
|
||||
}
|
||||
|
||||
func (c *linuxBatchingConn) Close() error {
|
||||
return c.pc.Close()
|
||||
}
|
||||
|
||||
// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn,
|
||||
// and returns two booleans indicating TX and RX UDP offload support.
|
||||
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
|
||||
if c, ok := pconn.(*net.UDPConn); ok {
|
||||
rc, err := c.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = rc.Control(func(fd uintptr) {
|
||||
_, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||
hasTX = errSyscall == nil
|
||||
errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||
hasRX = errSyscall == nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
return hasTX, hasRX
|
||||
}
|
||||
|
||||
// getGSOSizeFromControl returns the GSO size found in control. If no GSO size
|
||||
// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0.
|
||||
// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but
|
||||
// its contents cannot be parsed as a socket control message.
|
||||
func getGSOSizeFromControl(control []byte) (int, error) {
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||
}
|
||||
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 {
|
||||
return int(binary.NativeEndian.Uint16(data[:2])), nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSizeInControl sets a socket control message in control containing
|
||||
// gsoSize. If len(control) < controlMessageSize control's len will be set to 0.
|
||||
func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
|
||||
*control = (*control)[:0]
|
||||
if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
|
||||
return
|
||||
}
|
||||
if cap(*control) < controlMessageSize {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:cap(*control)]
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
hdr.SetLen(unix.CmsgLen(2))
|
||||
binary.NativeEndian.PutUint16((*control)[unix.SizeofCmsghdr:], gsoSize)
|
||||
*control = (*control)[:unix.CmsgSpace(2)]
|
||||
}
|
||||
|
||||
// tryUpgradeToBatchingConn probes the capabilities of the OS and pconn, and
|
||||
// upgrades pconn to a *linuxBatchingConn if appropriate.
|
||||
func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
|
||||
if network != "udp4" && network != "udp6" {
|
||||
return pconn
|
||||
}
|
||||
if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") {
|
||||
// recvmmsg/sendmmsg were added in 2.6.33, but we support down to
|
||||
// 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807.
|
||||
// As a cheap heuristic: if the Linux kernel starts with "2", just
|
||||
// consider it too old for mmsg. Nobody who cares about performance runs
|
||||
// such ancient kernels. UDP offload was added much later, so no
|
||||
// upgrades are available.
|
||||
return pconn
|
||||
}
|
||||
uc, ok := pconn.(*net.UDPConn)
|
||||
if !ok {
|
||||
return pconn
|
||||
}
|
||||
b := &linuxBatchingConn{
|
||||
pc: pconn,
|
||||
getGSOSizeFromControl: getGSOSizeFromControl,
|
||||
setGSOSizeInControl: setGSOSizeInControl,
|
||||
sendBatchPool: sync.Pool{
|
||||
New: func() any {
|
||||
ua := &net.UDPAddr{
|
||||
IP: make([]byte, 16),
|
||||
}
|
||||
msgs := make([]ipv6.Message, batchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make([][]byte, 1)
|
||||
msgs[i].Addr = ua
|
||||
msgs[i].OOB = make([]byte, controlMessageSize)
|
||||
}
|
||||
return &sendBatch{
|
||||
ua: ua,
|
||||
msgs: msgs,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
switch network {
|
||||
case "udp4":
|
||||
b.xpc = ipv4.NewPacketConn(uc)
|
||||
case "udp6":
|
||||
b.xpc = ipv6.NewPacketConn(uc)
|
||||
default:
|
||||
panic("bogus network")
|
||||
}
|
||||
var txOffload bool
|
||||
txOffload, b.rxOffload = tryEnableUDPOffload(uc)
|
||||
b.txOffload.Store(txOffload)
|
||||
return b
|
||||
}
|
||||
244
wgengine/magicsock/batching_conn_linux_test.go
Normal file
244
wgengine/magicsock/batching_conn_linux_test.go
Normal file
@@ -0,0 +1,244 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
*control = (*control)[:cap(*control)]
|
||||
binary.LittleEndian.PutUint16(*control, gsoSize)
|
||||
}
|
||||
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
if len(control) < 2 {
|
||||
return 0, nil
|
||||
}
|
||||
return int(binary.LittleEndian.Uint16(control)), nil
|
||||
}
|
||||
|
||||
func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) {
|
||||
c := &linuxBatchingConn{
|
||||
setGSOSizeInControl: setGSOSize,
|
||||
getGSOSizeFromControl: getGSOSize,
|
||||
}
|
||||
|
||||
newMsg := func(n, gso int) ipv6.Message {
|
||||
msg := ipv6.Message{
|
||||
Buffers: [][]byte{make([]byte, 1024)},
|
||||
N: n,
|
||||
OOB: make([]byte, 2),
|
||||
}
|
||||
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
|
||||
if gso > 0 {
|
||||
msg.NN = 2
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
msgs []ipv6.Message
|
||||
firstMsgAt int
|
||||
wantNumEval int
|
||||
wantMsgLens []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "second last split last empty",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(3, 1),
|
||||
newMsg(0, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 3,
|
||||
wantMsgLens: []int{1, 1, 1, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last empty",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(0, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 1,
|
||||
wantMsgLens: []int{1, 0, 0, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last no split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(1, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 2,
|
||||
wantMsgLens: []int{1, 1, 0, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(3, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last split last split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(2, 1),
|
||||
newMsg(2, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last split overflow",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(4, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := c.splitCoalescedMessages(tt.msgs, 2)
|
||||
if err != nil && !tt.wantErr {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if got != tt.wantNumEval {
|
||||
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
|
||||
}
|
||||
for i, msg := range tt.msgs {
|
||||
if msg.N != tt.wantMsgLens[i] {
|
||||
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
|
||||
c := &linuxBatchingConn{
|
||||
setGSOSizeInControl: setGSOSize,
|
||||
getGSOSizeFromControl: getGSOSize,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
buffs [][]byte
|
||||
wantLens []int
|
||||
wantGSO []int
|
||||
}{
|
||||
{
|
||||
name: "one message no coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{1},
|
||||
wantGSO: []int{0},
|
||||
},
|
||||
{
|
||||
name: "two messages equal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 1, 2),
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{2},
|
||||
wantGSO: []int{1},
|
||||
},
|
||||
{
|
||||
name: "two messages unequal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 3),
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{3},
|
||||
wantGSO: []int{2},
|
||||
},
|
||||
{
|
||||
name: "three messages second unequal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 3),
|
||||
make([]byte, 1, 1),
|
||||
make([]byte, 2, 2),
|
||||
},
|
||||
wantLens: []int{3, 2},
|
||||
wantGSO: []int{2, 0},
|
||||
},
|
||||
{
|
||||
name: "three messages limited cap coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 4),
|
||||
make([]byte, 2, 2),
|
||||
make([]byte, 2, 2),
|
||||
},
|
||||
wantLens: []int{4, 2},
|
||||
wantGSO: []int{2, 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 1,
|
||||
}
|
||||
msgs := make([]ipv6.Message, len(tt.buffs))
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make([][]byte, 1)
|
||||
msgs[i].OOB = make([]byte, 0, 2)
|
||||
}
|
||||
got := c.coalesceMessages(addr, tt.buffs, msgs)
|
||||
if got != len(tt.wantLens) {
|
||||
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
|
||||
}
|
||||
for i := range got {
|
||||
if msgs[i].Addr != addr {
|
||||
t.Errorf("msgs[%d].Addr != passed addr", i)
|
||||
}
|
||||
gotLen := len(msgs[i].Buffers[0])
|
||||
if gotLen != tt.wantLens[i] {
|
||||
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
|
||||
}
|
||||
gotGSO, err := getGSOSize(msgs[i].OOB)
|
||||
if err != nil {
|
||||
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
|
||||
}
|
||||
if gotGSO != tt.wantGSO[i] {
|
||||
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
182
wgengine/magicsock/cloudinfo.go
Normal file
182
wgengine/magicsock/cloudinfo.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !(ios || android || js)
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/cloudenv"
|
||||
)
|
||||
|
||||
const maxCloudInfoWait = 2 * time.Second
|
||||
|
||||
type cloudInfo struct {
|
||||
client http.Client
|
||||
logf logger.Logf
|
||||
|
||||
// The following parameters are fixed for the lifetime of the cloudInfo
|
||||
// object, but are used for testing.
|
||||
cloud cloudenv.Cloud
|
||||
endpoint string
|
||||
}
|
||||
|
||||
func newCloudInfo(logf logger.Logf) *cloudInfo {
|
||||
tr := &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: maxCloudInfoWait,
|
||||
}).Dial,
|
||||
}
|
||||
|
||||
return &cloudInfo{
|
||||
client: http.Client{Transport: tr},
|
||||
logf: logf,
|
||||
cloud: cloudenv.Get(),
|
||||
endpoint: "http://" + cloudenv.CommonNonRoutableMetadataIP,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPublicIPs returns any public IPs attached to the current cloud instance,
|
||||
// if the tailscaled process is running in a known cloud and there are any such
|
||||
// IPs present.
|
||||
func (ci *cloudInfo) GetPublicIPs(ctx context.Context) ([]netip.Addr, error) {
|
||||
switch ci.cloud {
|
||||
case cloudenv.AWS:
|
||||
ret, err := ci.getAWS(ctx)
|
||||
ci.logf("[v1] cloudinfo.GetPublicIPs: AWS: %v, %v", ret, err)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getAWSMetadata makes a request to the AWS metadata service at the given
|
||||
// path, authenticating with the provided IMDSv2 token. The returned metadata
|
||||
// is split by newline and returned as a slice.
|
||||
func (ci *cloudInfo) getAWSMetadata(ctx context.Context, token, path string) ([]string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", ci.endpoint+path, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request to %q: %w", path, err)
|
||||
}
|
||||
req.Header.Set("X-aws-ec2-metadata-token", token)
|
||||
|
||||
resp, err := ci.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("making request to metadata service %q: %w", path, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
// Good
|
||||
case http.StatusNotFound:
|
||||
// Nothing found, but this isn't an error; just return
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading response body for %q: %w", path, err)
|
||||
}
|
||||
|
||||
return strings.Split(strings.TrimSpace(string(body)), "\n"), nil
|
||||
}
|
||||
|
||||
// getAWS returns all public IPv4 and IPv6 addresses present in the AWS instance metadata.
|
||||
func (ci *cloudInfo) getAWS(ctx context.Context) ([]netip.Addr, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, maxCloudInfoWait)
|
||||
defer cancel()
|
||||
|
||||
// Get a token so we can query the metadata service.
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", ci.endpoint+"/latest/api/token", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating token request: %w", err)
|
||||
}
|
||||
req.Header.Set("X-Aws-Ec2-Metadata-Token-Ttl-Seconds", "10")
|
||||
|
||||
resp, err := ci.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("making token request to metadata service: %w", err)
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading token response body: %w", err)
|
||||
}
|
||||
token := string(body)
|
||||
|
||||
server := resp.Header.Get("Server")
|
||||
if server != "EC2ws" {
|
||||
return nil, fmt.Errorf("unexpected server header: %q", server)
|
||||
}
|
||||
|
||||
// Iterate over all interfaces and get their public IP addresses, both IPv4 and IPv6.
|
||||
macAddrs, err := ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting interface MAC addresses: %w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
addrs []netip.Addr
|
||||
errs []error
|
||||
)
|
||||
|
||||
addAddr := func(addr string) {
|
||||
ip, err := netip.ParseAddr(addr)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("parsing IP address %q: %w", addr, err))
|
||||
return
|
||||
}
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
for _, mac := range macAddrs {
|
||||
ips, err := ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/"+mac+"/public-ipv4s")
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("getting IPv4 addresses for %q: %w", mac, err))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
addAddr(ip)
|
||||
}
|
||||
|
||||
// Try querying for IPv6 addresses.
|
||||
ips, err = ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/"+mac+"/ipv6s")
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("getting IPv6 addresses for %q: %w", mac, err))
|
||||
continue
|
||||
}
|
||||
for _, ip := range ips {
|
||||
addAddr(ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the returned addresses for determinism.
|
||||
slices.SortFunc(addrs, func(a, b netip.Addr) int {
|
||||
return a.Compare(b)
|
||||
})
|
||||
|
||||
// Preferentially return any addresses we found, even if there were errors.
|
||||
if len(addrs) > 0 {
|
||||
return addrs, nil
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return nil, fmt.Errorf("getting IP addresses: %w", errors.Join(errs...))
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
23
wgengine/magicsock/cloudinfo_nocloud.go
Normal file
23
wgengine/magicsock/cloudinfo_nocloud.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build ios || android || js
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
type cloudInfo struct{}
|
||||
|
||||
func newCloudInfo(_ logger.Logf) *cloudInfo {
|
||||
return &cloudInfo{}
|
||||
}
|
||||
|
||||
func (ci *cloudInfo) GetPublicIPs(_ context.Context) ([]netip.Addr, error) {
|
||||
return nil, nil
|
||||
}
|
||||
123
wgengine/magicsock/cloudinfo_test.go
Normal file
123
wgengine/magicsock/cloudinfo_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/util/cloudenv"
|
||||
)
|
||||
|
||||
func TestCloudInfo_AWS(t *testing.T) {
|
||||
const (
|
||||
mac1 = "06:1d:00:00:00:00"
|
||||
mac2 = "06:1d:00:00:00:01"
|
||||
publicV4 = "1.2.3.4"
|
||||
otherV4_1 = "5.6.7.8"
|
||||
otherV4_2 = "11.12.13.14"
|
||||
v6addr = "2001:db8::1"
|
||||
|
||||
macsPrefix = "/latest/meta-data/network/interfaces/macs/"
|
||||
)
|
||||
// Launch a fake AWS IMDS server
|
||||
fake := &fakeIMDS{
|
||||
tb: t,
|
||||
paths: map[string]string{
|
||||
macsPrefix: mac1 + "\n" + mac2,
|
||||
// This is the "main" public IP address for the instance
|
||||
macsPrefix + mac1 + "/public-ipv4s": publicV4,
|
||||
|
||||
// There's another interface with two public IPs
|
||||
// attached to it and an IPv6 address, all of which we
|
||||
// should discover.
|
||||
macsPrefix + mac2 + "/public-ipv4s": otherV4_1 + "\n" + otherV4_2,
|
||||
macsPrefix + mac2 + "/ipv6s": v6addr,
|
||||
},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(fake)
|
||||
defer srv.Close()
|
||||
|
||||
ci := newCloudInfo(t.Logf)
|
||||
ci.cloud = cloudenv.AWS
|
||||
ci.endpoint = srv.URL
|
||||
|
||||
ips, err := ci.GetPublicIPs(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
wantIPs := []netip.Addr{
|
||||
netip.MustParseAddr(publicV4),
|
||||
netip.MustParseAddr(otherV4_1),
|
||||
netip.MustParseAddr(otherV4_2),
|
||||
netip.MustParseAddr(v6addr),
|
||||
}
|
||||
if !slices.Equal(ips, wantIPs) {
|
||||
t.Fatalf("got %v, want %v", ips, wantIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudInfo_AWSNotPublic(t *testing.T) {
|
||||
returns404 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "PUT" && r.URL.Path == "/latest/api/token" {
|
||||
w.Header().Set("Server", "EC2ws")
|
||||
w.Write([]byte("fake-imds-token"))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
srv := httptest.NewServer(returns404)
|
||||
defer srv.Close()
|
||||
|
||||
ci := newCloudInfo(t.Logf)
|
||||
ci.cloud = cloudenv.AWS
|
||||
ci.endpoint = srv.URL
|
||||
|
||||
// If the IMDS server doesn't return any public IPs, it's not an error
|
||||
// and we should just get an empty list.
|
||||
ips, err := ci.GetPublicIPs(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Fatalf("got %v, want none", ips)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeIMDS struct {
|
||||
tb testing.TB
|
||||
paths map[string]string
|
||||
}
|
||||
|
||||
func (f *fakeIMDS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
f.tb.Logf("%s %s", r.Method, r.URL.Path)
|
||||
path := r.URL.Path
|
||||
|
||||
// Handle the /latest/api/token case
|
||||
const token = "fake-imds-token"
|
||||
if r.Method == "PUT" && path == "/latest/api/token" {
|
||||
w.Header().Set("Server", "EC2ws")
|
||||
w.Write([]byte(token))
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, require the IMDSv2 token to be set
|
||||
if r.Header.Get("X-aws-ec2-metadata-token") != token {
|
||||
f.tb.Errorf("missing or invalid IMDSv2 token")
|
||||
http.Error(w, "missing or invalid IMDSv2 token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := f.paths[path]; ok {
|
||||
w.Write([]byte(v))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"go4.org/mem"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
@@ -133,6 +132,9 @@ type Conn struct {
|
||||
// bind is the wireguard-go conn.Bind for Conn.
|
||||
bind *connBind
|
||||
|
||||
// cloudInfo is used to query cloud metadata services.
|
||||
cloudInfo *cloudInfo
|
||||
|
||||
// ============================================================
|
||||
// Fields that must be accessed via atomic load/stores.
|
||||
|
||||
@@ -425,9 +427,10 @@ func (o *Options) derpActiveFunc() func() {
|
||||
|
||||
// newConn is the error-free, network-listening-side-effect-free based
|
||||
// of NewConn. Mostly for tests.
|
||||
func newConn() *Conn {
|
||||
func newConn(logf logger.Logf) *Conn {
|
||||
discoPrivate := key.NewDisco()
|
||||
c := &Conn{
|
||||
logf: logf,
|
||||
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
|
||||
derpStarted: make(chan struct{}),
|
||||
peerLastDerp: make(map[key.NodePublic]int),
|
||||
@@ -435,6 +438,7 @@ func newConn() *Conn {
|
||||
discoInfo: make(map[key.DiscoPublic]*discoInfo),
|
||||
discoPrivate: discoPrivate,
|
||||
discoPublic: discoPrivate.Public(),
|
||||
cloudInfo: newCloudInfo(logf),
|
||||
}
|
||||
c.discoShort = c.discoPublic.ShortString()
|
||||
c.bind = &connBind{Conn: c, closed: true}
|
||||
@@ -462,10 +466,9 @@ func NewConn(opts Options) (*Conn, error) {
|
||||
return nil, errors.New("magicsock.Options.NetMon must be non-nil")
|
||||
}
|
||||
|
||||
c := newConn()
|
||||
c := newConn(opts.logf())
|
||||
c.port.Store(uint32(opts.Port))
|
||||
c.controlKnobs = opts.ControlKnobs
|
||||
c.logf = opts.logf()
|
||||
c.epFunc = opts.endpointsFunc()
|
||||
c.derpActiveFunc = opts.derpActiveFunc()
|
||||
c.idleFunc = opts.IdleFunc
|
||||
@@ -952,6 +955,27 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
|
||||
addAddr(ap, tailcfg.EndpointExplicitConf)
|
||||
}
|
||||
|
||||
// If we're on a cloud instance, we might have a public IPv4 or IPv6
|
||||
// address that we can be reached at. Find those, if they exist, and
|
||||
// add them.
|
||||
if addrs, err := c.cloudInfo.GetPublicIPs(ctx); err == nil {
|
||||
var port4, port6 uint16
|
||||
if addr := c.pconn4.LocalAddr(); addr != nil {
|
||||
port4 = uint16(addr.Port)
|
||||
}
|
||||
if addr := c.pconn6.LocalAddr(); addr != nil {
|
||||
port6 = uint16(addr.Port)
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
if addr.Is4() && port4 > 0 {
|
||||
addAddr(netip.AddrPortFrom(addr, port4), tailcfg.EndpointLocal)
|
||||
} else if addr.Is6() && port6 > 0 {
|
||||
addAddr(netip.AddrPortFrom(addr, port6), tailcfg.EndpointLocal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update our set of endpoints by adding any endpoints that we
|
||||
// previously found but haven't expired yet. This also updates the
|
||||
// cache with the set of endpoints discovered in this function.
|
||||
@@ -1076,12 +1100,6 @@ var errNoUDP = errors.New("no UDP available on platform")
|
||||
|
||||
var errUnsupportedConnType = errors.New("unsupported connection type")
|
||||
|
||||
var (
|
||||
// This acts as a compile-time check for our usage of ipv6.Message in
|
||||
// batchingUDPConn for both IPv6 and IPv4 operations.
|
||||
_ ipv6.Message = ipv4.Message{}
|
||||
)
|
||||
|
||||
func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) {
|
||||
isIPv6 := false
|
||||
switch {
|
||||
@@ -2631,153 +2649,6 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) {
|
||||
return ep, nil
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) writeBatch(msgs []ipv6.Message) error {
|
||||
var head int
|
||||
for {
|
||||
n, err := c.xpc.WriteBatch(msgs[head:], 0)
|
||||
if err != nil || n == len(msgs[head:]) {
|
||||
// Returning the number of packets written would require
|
||||
// unraveling individual msg len and gso size during a coalesced
|
||||
// write. The top of the call stack disregards partial success,
|
||||
// so keep this simple for now.
|
||||
return err
|
||||
}
|
||||
head += n
|
||||
}
|
||||
}
|
||||
|
||||
// splitCoalescedMessages splits coalesced messages from the tail of dst
|
||||
// beginning at index 'firstMsgAt' into the head of the same slice. It reports
|
||||
// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An
|
||||
// error is returned if a socket control message cannot be parsed or a split
|
||||
// operation would overflow msgs.
|
||||
func (c *batchingUDPConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) {
|
||||
for i := firstMsgAt; i < len(msgs); i++ {
|
||||
msg := &msgs[i]
|
||||
if msg.N == 0 {
|
||||
return n, err
|
||||
}
|
||||
var (
|
||||
gsoSize int
|
||||
start int
|
||||
end = msg.N
|
||||
numToSplit = 1
|
||||
)
|
||||
gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if gsoSize > 0 {
|
||||
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||
end = gsoSize
|
||||
}
|
||||
for j := 0; j < numToSplit; j++ {
|
||||
if n > i {
|
||||
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||
}
|
||||
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||
msgs[n].N = copied
|
||||
msgs[n].Addr = msg.Addr
|
||||
start = end
|
||||
end += gsoSize
|
||||
if end > msg.N {
|
||||
end = msg.N
|
||||
}
|
||||
n++
|
||||
}
|
||||
if i != n-1 {
|
||||
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||
// of splitting, so we only zero the source msg len when it is not
|
||||
// the destination of the last split operation above.
|
||||
msg.N = 0
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
|
||||
if !c.rxOffload || len(msgs) < 2 {
|
||||
return c.xpc.ReadBatch(msgs, flags)
|
||||
}
|
||||
// Read into the tail of msgs, split into the head.
|
||||
readAt := len(msgs) - 2
|
||||
numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0)
|
||||
if err != nil || numRead == 0 {
|
||||
return 0, err
|
||||
}
|
||||
return c.splitCoalescedMessages(msgs, readAt)
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) LocalAddr() net.Addr {
|
||||
return c.pc.LocalAddr().(*net.UDPAddr)
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
|
||||
return c.pc.WriteToUDPAddrPort(b, addr)
|
||||
}
|
||||
|
||||
func (c *batchingUDPConn) Close() error {
|
||||
return c.pc.Close()
|
||||
}
|
||||
|
||||
// tryUpgradeToBatchingUDPConn probes the capabilities of the OS and pconn, and
|
||||
// upgrades pconn to a *batchingUDPConn if appropriate.
|
||||
func tryUpgradeToBatchingUDPConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
|
||||
if network != "udp4" && network != "udp6" {
|
||||
return pconn
|
||||
}
|
||||
if runtime.GOOS != "linux" {
|
||||
return pconn
|
||||
}
|
||||
if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") {
|
||||
// recvmmsg/sendmmsg were added in 2.6.33, but we support down to
|
||||
// 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807.
|
||||
// As a cheap heuristic: if the Linux kernel starts with "2", just
|
||||
// consider it too old for mmsg. Nobody who cares about performance runs
|
||||
// such ancient kernels. UDP offload was added much later, so no
|
||||
// upgrades are available.
|
||||
return pconn
|
||||
}
|
||||
uc, ok := pconn.(*net.UDPConn)
|
||||
if !ok {
|
||||
return pconn
|
||||
}
|
||||
b := &batchingUDPConn{
|
||||
pc: pconn,
|
||||
getGSOSizeFromControl: getGSOSizeFromControl,
|
||||
setGSOSizeInControl: setGSOSizeInControl,
|
||||
sendBatchPool: sync.Pool{
|
||||
New: func() any {
|
||||
ua := &net.UDPAddr{
|
||||
IP: make([]byte, 16),
|
||||
}
|
||||
msgs := make([]ipv6.Message, batchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make([][]byte, 1)
|
||||
msgs[i].Addr = ua
|
||||
msgs[i].OOB = make([]byte, controlMessageSize)
|
||||
}
|
||||
return &sendBatch{
|
||||
ua: ua,
|
||||
msgs: msgs,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
switch network {
|
||||
case "udp4":
|
||||
b.xpc = ipv4.NewPacketConn(uc)
|
||||
case "udp6":
|
||||
b.xpc = ipv6.NewPacketConn(uc)
|
||||
default:
|
||||
panic("bogus network")
|
||||
}
|
||||
var txOffload bool
|
||||
txOffload, b.rxOffload = tryEnableUDPOffload(uc)
|
||||
b.txOffload.Store(txOffload)
|
||||
return b
|
||||
}
|
||||
|
||||
func newBlockForeverConn() *blockForeverConn {
|
||||
c := new(blockForeverConn)
|
||||
c.cond = sync.NewCond(&c.mu)
|
||||
|
||||
@@ -21,16 +21,6 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) {
|
||||
portableTrySetSocketBuffer(pconn, logf)
|
||||
}
|
||||
|
||||
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
func getGSOSizeFromControl(control []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func setGSOSizeInControl(control *[]byte, gso uint16) {}
|
||||
|
||||
const (
|
||||
controlMessageSize = 0
|
||||
)
|
||||
|
||||
@@ -318,70 +318,6 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) {
|
||||
}
|
||||
}
|
||||
|
||||
// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn,
|
||||
// and returns two booleans indicating TX and RX UDP offload support.
|
||||
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
|
||||
if c, ok := pconn.(*net.UDPConn); ok {
|
||||
rc, err := c.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = rc.Control(func(fd uintptr) {
|
||||
_, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||
hasTX = errSyscall == nil
|
||||
errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||
hasRX = errSyscall == nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
return hasTX, hasRX
|
||||
}
|
||||
|
||||
// getGSOSizeFromControl returns the GSO size found in control. If no GSO size
|
||||
// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0.
|
||||
// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but
|
||||
// its contents cannot be parsed as a socket control message.
|
||||
func getGSOSizeFromControl(control []byte) (int, error) {
|
||||
var (
|
||||
hdr unix.Cmsghdr
|
||||
data []byte
|
||||
rem = control
|
||||
err error
|
||||
)
|
||||
|
||||
for len(rem) > unix.SizeofCmsghdr {
|
||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||
}
|
||||
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 {
|
||||
return int(binary.NativeEndian.Uint16(data[:2])), nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// setGSOSizeInControl sets a socket control message in control containing
|
||||
// gsoSize. If len(control) < controlMessageSize control's len will be set to 0.
|
||||
func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
|
||||
*control = (*control)[:0]
|
||||
if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
|
||||
return
|
||||
}
|
||||
if cap(*control) < controlMessageSize {
|
||||
return
|
||||
}
|
||||
*control = (*control)[:cap(*control)]
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
hdr.SetLen(unix.CmsgLen(2))
|
||||
binary.NativeEndian.PutUint16((*control)[unix.SizeofCmsghdr:], gsoSize)
|
||||
*control = (*control)[:unix.CmsgSpace(2)]
|
||||
}
|
||||
|
||||
var controlMessageSize = -1 // bomb if used for allocation before init
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -35,7 +35,6 @@ import (
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"tailscale.com/cmd/testwrapper/flakytest"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/derp"
|
||||
@@ -452,7 +451,7 @@ func TestPickDERPFallback(t *testing.T) {
|
||||
tstest.PanicOnLog()
|
||||
tstest.ResourceCheck(t)
|
||||
|
||||
c := newConn()
|
||||
c := newConn(t.Logf)
|
||||
dm := &tailcfg.DERPMap{
|
||||
Regions: map[int]*tailcfg.DERPRegion{
|
||||
1: {},
|
||||
@@ -483,7 +482,7 @@ func TestPickDERPFallback(t *testing.T) {
|
||||
// distribution over nodes works.
|
||||
got := map[int]int{}
|
||||
for range 50 {
|
||||
c = newConn()
|
||||
c = newConn(t.Logf)
|
||||
c.derpMap = dm
|
||||
got[c.pickDERPFallback()]++
|
||||
}
|
||||
@@ -1185,8 +1184,7 @@ func testTwoDevicePing(t *testing.T, d *devices) {
|
||||
}
|
||||
|
||||
func TestDiscoMessage(t *testing.T) {
|
||||
c := newConn()
|
||||
c.logf = t.Logf
|
||||
c := newConn(t.Logf)
|
||||
c.privateKey = key.NewNode()
|
||||
|
||||
peer1Pub := c.DiscoPublicKey()
|
||||
@@ -2039,238 +2037,6 @@ func TestBufferedDerpWritesBeforeDrop(t *testing.T) {
|
||||
t.Logf("bufferedDerpWritesBeforeDrop = %d", vv)
|
||||
}
|
||||
|
||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||
*control = (*control)[:cap(*control)]
|
||||
binary.LittleEndian.PutUint16(*control, gsoSize)
|
||||
}
|
||||
|
||||
func getGSOSize(control []byte) (int, error) {
|
||||
if len(control) < 2 {
|
||||
return 0, nil
|
||||
}
|
||||
return int(binary.LittleEndian.Uint16(control)), nil
|
||||
}
|
||||
|
||||
func Test_batchingUDPConn_splitCoalescedMessages(t *testing.T) {
|
||||
c := &batchingUDPConn{
|
||||
setGSOSizeInControl: setGSOSize,
|
||||
getGSOSizeFromControl: getGSOSize,
|
||||
}
|
||||
|
||||
newMsg := func(n, gso int) ipv6.Message {
|
||||
msg := ipv6.Message{
|
||||
Buffers: [][]byte{make([]byte, 1024)},
|
||||
N: n,
|
||||
OOB: make([]byte, 2),
|
||||
}
|
||||
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
|
||||
if gso > 0 {
|
||||
msg.NN = 2
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
msgs []ipv6.Message
|
||||
firstMsgAt int
|
||||
wantNumEval int
|
||||
wantMsgLens []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "second last split last empty",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(3, 1),
|
||||
newMsg(0, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 3,
|
||||
wantMsgLens: []int{1, 1, 1, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last empty",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(0, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 1,
|
||||
wantMsgLens: []int{1, 0, 0, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last no split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(1, 0),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 2,
|
||||
wantMsgLens: []int{1, 1, 0, 0},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(3, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last split last split",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(2, 1),
|
||||
newMsg(2, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second last no split last split overflow",
|
||||
msgs: []ipv6.Message{
|
||||
newMsg(0, 0),
|
||||
newMsg(0, 0),
|
||||
newMsg(1, 0),
|
||||
newMsg(4, 1),
|
||||
},
|
||||
firstMsgAt: 2,
|
||||
wantNumEval: 4,
|
||||
wantMsgLens: []int{1, 1, 1, 1},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := c.splitCoalescedMessages(tt.msgs, 2)
|
||||
if err != nil && !tt.wantErr {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if got != tt.wantNumEval {
|
||||
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
|
||||
}
|
||||
for i, msg := range tt.msgs {
|
||||
if msg.N != tt.wantMsgLens[i] {
|
||||
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_batchingUDPConn_coalesceMessages(t *testing.T) {
|
||||
c := &batchingUDPConn{
|
||||
setGSOSizeInControl: setGSOSize,
|
||||
getGSOSizeFromControl: getGSOSize,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
buffs [][]byte
|
||||
wantLens []int
|
||||
wantGSO []int
|
||||
}{
|
||||
{
|
||||
name: "one message no coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{1},
|
||||
wantGSO: []int{0},
|
||||
},
|
||||
{
|
||||
name: "two messages equal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 1, 2),
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{2},
|
||||
wantGSO: []int{1},
|
||||
},
|
||||
{
|
||||
name: "two messages unequal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 3),
|
||||
make([]byte, 1, 1),
|
||||
},
|
||||
wantLens: []int{3},
|
||||
wantGSO: []int{2},
|
||||
},
|
||||
{
|
||||
name: "three messages second unequal len coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 3),
|
||||
make([]byte, 1, 1),
|
||||
make([]byte, 2, 2),
|
||||
},
|
||||
wantLens: []int{3, 2},
|
||||
wantGSO: []int{2, 0},
|
||||
},
|
||||
{
|
||||
name: "three messages limited cap coalesce",
|
||||
buffs: [][]byte{
|
||||
make([]byte, 2, 4),
|
||||
make([]byte, 2, 2),
|
||||
make([]byte, 2, 2),
|
||||
},
|
||||
wantLens: []int{4, 2},
|
||||
wantGSO: []int{2, 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 1,
|
||||
}
|
||||
msgs := make([]ipv6.Message, len(tt.buffs))
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make([][]byte, 1)
|
||||
msgs[i].OOB = make([]byte, 0, 2)
|
||||
}
|
||||
got := c.coalesceMessages(addr, tt.buffs, msgs)
|
||||
if got != len(tt.wantLens) {
|
||||
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
|
||||
}
|
||||
for i := range got {
|
||||
if msgs[i].Addr != addr {
|
||||
t.Errorf("msgs[%d].Addr != passed addr", i)
|
||||
}
|
||||
gotLen := len(msgs[i].Buffers[0])
|
||||
if gotLen != tt.wantLens[i] {
|
||||
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
|
||||
}
|
||||
gotGSO, err := getGSOSize(msgs[i].OOB)
|
||||
if err != nil {
|
||||
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
|
||||
}
|
||||
if gotGSO != tt.wantGSO[i] {
|
||||
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newWireguard starts up a new wireguard-go device attached to a test tun, and
|
||||
// returns the device, tun and endpoint port. To add peers call device.IpcSet with UAPI instructions.
|
||||
func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, uint16) {
|
||||
@@ -3161,8 +2927,7 @@ func TestMaybeSetNearestDERP(t *testing.T) {
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ht := new(health.Tracker)
|
||||
c := newConn()
|
||||
c.logf = t.Logf
|
||||
c := newConn(t.Logf)
|
||||
c.myDerp = tt.old
|
||||
c.derpMap = derpMap
|
||||
c.health = ht
|
||||
|
||||
@@ -35,12 +35,12 @@ type RebindingUDPConn struct {
|
||||
|
||||
// setConnLocked sets the provided nettype.PacketConn. It should be called only
|
||||
// after acquiring RebindingUDPConn.mu. It upgrades the provided
|
||||
// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade
|
||||
// is intentionally pushed closest to where read/write ops occur in order to
|
||||
// avoid disrupting surrounding code that assumes nettype.PacketConn is a
|
||||
// nettype.PacketConn to a batchingConn when appropriate. This upgrade is
|
||||
// intentionally pushed closest to where read/write ops occur in order to avoid
|
||||
// disrupting surrounding code that assumes nettype.PacketConn is a
|
||||
// *net.UDPConn.
|
||||
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) {
|
||||
upc := tryUpgradeToBatchingUDPConn(p, network, batchSize)
|
||||
upc := tryUpgradeToBatchingConn(p, network, batchSize)
|
||||
c.pconn = upc
|
||||
c.pconnAtomic.Store(&upc)
|
||||
c.port = uint16(c.localAddrLocked().Port)
|
||||
@@ -74,7 +74,7 @@ func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, e
|
||||
func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
|
||||
for {
|
||||
pconn := *c.pconnAtomic.Load()
|
||||
b, ok := pconn.(*batchingUDPConn)
|
||||
b, ok := pconn.(batchingConn)
|
||||
if !ok {
|
||||
for _, buf := range buffs {
|
||||
_, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr)
|
||||
@@ -101,7 +101,7 @@ func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) err
|
||||
func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) {
|
||||
for {
|
||||
pconn := *c.pconnAtomic.Load()
|
||||
b, ok := pconn.(*batchingUDPConn)
|
||||
b, ok := pconn.(batchingConn)
|
||||
if !ok {
|
||||
n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0])
|
||||
if err == nil {
|
||||
|
||||
16
wgengine/netstack/gro_default.go
Normal file
16
wgengine/netstack/gro_default.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !ios
|
||||
|
||||
package netstack
|
||||
|
||||
import (
|
||||
nsgro "gvisor.dev/gvisor/pkg/tcpip/stack/gro"
|
||||
)
|
||||
|
||||
// gro wraps a gVisor GRO implementation. It exists solely to prevent iOS from
|
||||
// importing said package (see _ios.go).
|
||||
type gro struct {
|
||||
nsgro.GRO
|
||||
}
|
||||
30
wgengine/netstack/gro_ios.go
Normal file
30
wgengine/netstack/gro_ios.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build ios
|
||||
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
// gro on iOS delivers packets to its Dispatcher, immediately. This type exists
|
||||
// to prevent importation of the gVisor GRO implementation as said package
|
||||
// increases binary size. This is a penalty we do not wish to pay since we
|
||||
// currently do not leverage GRO on iOS.
|
||||
type gro struct {
|
||||
Dispatcher stack.NetworkDispatcher
|
||||
}
|
||||
|
||||
func (g *gro) Init(v bool) {
|
||||
if v {
|
||||
panic("GRO is not supported on this platform")
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gro) Flush() {}
|
||||
|
||||
func (g *gro) Enqueue(pkt *stack.PacketBuffer) {
|
||||
g.Dispatcher.DeliverNetworkPacket(pkt.NetworkProtocolNumber, pkt)
|
||||
}
|
||||
@@ -4,12 +4,18 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/ipproto"
|
||||
)
|
||||
|
||||
type queue struct {
|
||||
@@ -79,34 +85,53 @@ var _ stack.GSOEndpoint = (*linkEndpoint)(nil)
|
||||
|
||||
// linkEndpoint implements stack.LinkEndpoint and stack.GSOEndpoint. Outbound
|
||||
// packets written by gVisor towards Tailscale are stored in a channel.
|
||||
// Inbound is fed to gVisor via InjectInbound. This is loosely modeled after
|
||||
// gvisor.dev/pkg/tcpip/link/channel.Endpoint.
|
||||
// Inbound is fed to gVisor via injectInbound or enqueueGRO. This is loosely
|
||||
// modeled after gvisor.dev/pkg/tcpip/link/channel.Endpoint.
|
||||
type linkEndpoint struct {
|
||||
LinkEPCapabilities stack.LinkEndpointCapabilities
|
||||
SupportedGSOKind stack.SupportedGSO
|
||||
SupportedGSOKind stack.SupportedGSO
|
||||
initGRO initGRO
|
||||
|
||||
mu sync.RWMutex // mu guards the following fields
|
||||
dispatcher stack.NetworkDispatcher
|
||||
linkAddr tcpip.LinkAddress
|
||||
mtu uint32
|
||||
gro gro // mu only guards access to gro.Dispatcher
|
||||
|
||||
q *queue // outbound
|
||||
}
|
||||
|
||||
func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress) *linkEndpoint {
|
||||
return &linkEndpoint{
|
||||
// TODO(jwhited): move to linkEndpointOpts struct or similar.
|
||||
type initGRO bool
|
||||
|
||||
const (
|
||||
disableGRO initGRO = false
|
||||
enableGRO initGRO = true
|
||||
)
|
||||
|
||||
func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, gro initGRO) *linkEndpoint {
|
||||
le := &linkEndpoint{
|
||||
q: &queue{
|
||||
c: make(chan *stack.PacketBuffer, size),
|
||||
},
|
||||
mtu: mtu,
|
||||
linkAddr: linkAddr,
|
||||
}
|
||||
le.initGRO = gro
|
||||
le.gro.Init(bool(gro))
|
||||
return le
|
||||
}
|
||||
|
||||
// Close closes l. Further packet injections will return an error, and all
|
||||
// pending packets are discarded. Close may be called concurrently with
|
||||
// WritePackets.
|
||||
func (l *linkEndpoint) Close() {
|
||||
l.mu.Lock()
|
||||
if l.gro.Dispatcher != nil {
|
||||
l.gro.Flush()
|
||||
}
|
||||
l.dispatcher = nil
|
||||
l.gro.Dispatcher = nil
|
||||
l.mu.Unlock()
|
||||
l.q.Close()
|
||||
l.Drain()
|
||||
}
|
||||
@@ -132,19 +157,149 @@ func (l *linkEndpoint) Drain() int {
|
||||
return c
|
||||
}
|
||||
|
||||
// NumQueued returns the number of packet queued for outbound.
|
||||
// NumQueued returns the number of packets queued for outbound.
|
||||
func (l *linkEndpoint) NumQueued() int {
|
||||
return l.q.Num()
|
||||
}
|
||||
|
||||
// InjectInbound injects an inbound packet. If the endpoint is not attached, the
|
||||
// packet is not delivered.
|
||||
func (l *linkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
|
||||
// rxChecksumOffload validates IPv4, TCP, and UDP header checksums in p,
|
||||
// returning an equivalent *stack.PacketBuffer if they are valid, otherwise nil.
|
||||
// The set of headers validated covers where gVisor would perform validation if
|
||||
// !stack.PacketBuffer.RXChecksumValidated, i.e. it satisfies
|
||||
// stack.CapabilityRXChecksumOffload. Other protocols with checksum fields,
|
||||
// e.g. ICMP{v6}, are still validated by gVisor regardless of rx checksum
|
||||
// offloading capabilities.
|
||||
func rxChecksumOffload(p *packet.Parsed) *stack.PacketBuffer {
|
||||
var (
|
||||
pn tcpip.NetworkProtocolNumber
|
||||
csumStart int
|
||||
)
|
||||
buf := p.Buffer()
|
||||
|
||||
switch p.IPVersion {
|
||||
case 4:
|
||||
if len(buf) < header.IPv4MinimumSize {
|
||||
return nil
|
||||
}
|
||||
csumStart = int((buf[0] & 0x0F) * 4)
|
||||
if csumStart < header.IPv4MinimumSize || csumStart > header.IPv4MaximumHeaderSize || len(buf) < csumStart {
|
||||
return nil
|
||||
}
|
||||
if ^tun.Checksum(buf[:csumStart], 0) != 0 {
|
||||
return nil
|
||||
}
|
||||
pn = header.IPv4ProtocolNumber
|
||||
case 6:
|
||||
if len(buf) < header.IPv6FixedHeaderSize {
|
||||
return nil
|
||||
}
|
||||
csumStart = header.IPv6FixedHeaderSize
|
||||
pn = header.IPv6ProtocolNumber
|
||||
if p.IPProto != ipproto.ICMPv6 && p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP {
|
||||
// buf could have extension headers before a UDP or TCP header, but
|
||||
// packet.Parsed.IPProto will be set to the ext header type, so we
|
||||
// have to look deeper. We are still responsible for validating the
|
||||
// L4 checksum in this case. So, make use of gVisor's existing
|
||||
// extension header parsing via parse.IPv6() in order to unpack the
|
||||
// L4 csumStart index. This is not particularly efficient as we have
|
||||
// to allocate a short-lived stack.PacketBuffer that cannot be
|
||||
// re-used. parse.IPv6() "consumes" the IPv6 headers, so we can't
|
||||
// inject this stack.PacketBuffer into the stack at a later point.
|
||||
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(bytes.Clone(buf)),
|
||||
})
|
||||
defer packetBuf.DecRef()
|
||||
// The rightmost bool returns false only if packetBuf is too short,
|
||||
// which we've already accounted for above.
|
||||
transportProto, _, _, _, _ := parse.IPv6(packetBuf)
|
||||
if transportProto == header.TCPProtocolNumber || transportProto == header.UDPProtocolNumber {
|
||||
csumLen := packetBuf.Data().Size()
|
||||
if len(buf) < csumLen {
|
||||
return nil
|
||||
}
|
||||
csumStart = len(buf) - csumLen
|
||||
p.IPProto = ipproto.Proto(transportProto)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP {
|
||||
lenForPseudo := len(buf) - csumStart
|
||||
csum := tun.PseudoHeaderChecksum(
|
||||
uint8(p.IPProto),
|
||||
p.Src.Addr().AsSlice(),
|
||||
p.Dst.Addr().AsSlice(),
|
||||
uint16(lenForPseudo))
|
||||
csum = tun.Checksum(buf[csumStart:], csum)
|
||||
if ^csum != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(bytes.Clone(buf)),
|
||||
})
|
||||
packetBuf.NetworkProtocolNumber = pn
|
||||
// Setting this is not technically required. gVisor overrides where
|
||||
// stack.CapabilityRXChecksumOffload is advertised from Capabilities().
|
||||
// https://github.com/google/gvisor/blob/64c016c92987cc04dfd4c7b091ddd21bdad875f8/pkg/tcpip/stack/nic.go#L763
|
||||
// This is also why we offload for all packets since we cannot signal this
|
||||
// per-packet.
|
||||
packetBuf.RXChecksumValidated = true
|
||||
return packetBuf
|
||||
}
|
||||
|
||||
func (l *linkEndpoint) injectInbound(p *packet.Parsed) {
|
||||
l.mu.RLock()
|
||||
d := l.dispatcher
|
||||
l.mu.RUnlock()
|
||||
if d != nil {
|
||||
d.DeliverNetworkPacket(protocol, pkt)
|
||||
if d == nil {
|
||||
return
|
||||
}
|
||||
pkt := rxChecksumOffload(p)
|
||||
if pkt == nil {
|
||||
return
|
||||
}
|
||||
d.DeliverNetworkPacket(pkt.NetworkProtocolNumber, pkt)
|
||||
pkt.DecRef()
|
||||
}
|
||||
|
||||
// enqueueGRO enqueues the provided packet for GRO. It may immediately deliver
|
||||
// it to the underlying stack.NetworkDispatcher depending on its contents and if
|
||||
// GRO was initialized via newLinkEndpoint. To explicitly flush previously
|
||||
// enqueued packets see flushGRO. enqueueGRO is not thread-safe and must not
|
||||
// be called concurrently with flushGRO.
|
||||
func (l *linkEndpoint) enqueueGRO(p *packet.Parsed) {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
if l.gro.Dispatcher == nil {
|
||||
return
|
||||
}
|
||||
pkt := rxChecksumOffload(p)
|
||||
if pkt == nil {
|
||||
return
|
||||
}
|
||||
// TODO(jwhited): gro.Enqueue() duplicates a lot of p.Decode().
|
||||
// We may want to push stack.PacketBuffer further up as a
|
||||
// replacement for packet.Parsed, or inversely push packet.Parsed
|
||||
// down into refactored GRO logic.
|
||||
l.gro.Enqueue(pkt)
|
||||
pkt.DecRef()
|
||||
}
|
||||
|
||||
// flushGRO flushes previously enqueueGRO'd packets to the underlying
|
||||
// stack.NetworkDispatcher. flushGRO is not thread-safe, and must not be
|
||||
// called concurrently with enqueueGRO.
|
||||
func (l *linkEndpoint) flushGRO() {
|
||||
if !l.initGRO {
|
||||
// If GRO was not initialized fast path return to avoid scanning GRO
|
||||
// buckets (see l.gro.Flush()) that will always be empty.
|
||||
return
|
||||
}
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
if l.gro.Dispatcher != nil {
|
||||
l.gro.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,6 +309,7 @@ func (l *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.dispatcher = dispatcher
|
||||
l.gro.Dispatcher = dispatcher
|
||||
}
|
||||
|
||||
// IsAttached implements stack.LinkEndpoint.IsAttached.
|
||||
@@ -179,7 +335,9 @@ func (l *linkEndpoint) SetMTU(mtu uint32) {
|
||||
|
||||
// Capabilities implements stack.LinkEndpoint.Capabilities.
|
||||
func (l *linkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return l.LinkEPCapabilities
|
||||
// We are required to offload RX checksum validation for the purposes of
|
||||
// GRO.
|
||||
return stack.CapabilityRXChecksumOffload
|
||||
}
|
||||
|
||||
// GSOMaxSize implements stack.GSOEndpoint.
|
||||
|
||||
112
wgengine/netstack/link_endpoint_test.go
Normal file
112
wgengine/netstack/link_endpoint_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"tailscale.com/net/packet"
|
||||
)
|
||||
|
||||
func Test_rxChecksumOffload(t *testing.T) {
|
||||
payloadLen := 100
|
||||
|
||||
tcpFields := &header.TCPFields{
|
||||
SrcPort: 1,
|
||||
DstPort: 1,
|
||||
SeqNum: 1,
|
||||
AckNum: 1,
|
||||
DataOffset: 20,
|
||||
Flags: header.TCPFlagAck | header.TCPFlagPsh,
|
||||
WindowSize: 3000,
|
||||
}
|
||||
tcp4 := make([]byte, 20+20+payloadLen)
|
||||
ipv4H := header.IPv4(tcp4)
|
||||
ipv4H.Encode(&header.IPv4Fields{
|
||||
SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.1").AsSlice()),
|
||||
DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.2").AsSlice()),
|
||||
Protocol: uint8(header.TCPProtocolNumber),
|
||||
TTL: 64,
|
||||
TotalLength: uint16(len(tcp4)),
|
||||
})
|
||||
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
||||
tcpH := header.TCP(tcp4[20:])
|
||||
tcpH.Encode(tcpFields)
|
||||
pseudoCsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+payloadLen))
|
||||
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||||
|
||||
tcp6ExtHeader := make([]byte, 40+8+20+payloadLen)
|
||||
ipv6H := header.IPv6(tcp6ExtHeader)
|
||||
ipv6H.Encode(&header.IPv6Fields{
|
||||
SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::1").AsSlice()),
|
||||
DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::2").AsSlice()),
|
||||
TransportProtocol: 60, // really next header; destination options ext header
|
||||
HopLimit: 64,
|
||||
PayloadLength: uint16(8 + 20 + payloadLen),
|
||||
})
|
||||
tcp6ExtHeader[40] = uint8(header.TCPProtocolNumber) // next header
|
||||
tcp6ExtHeader[41] = 0 // length of ext header in 8-octet units, exclusive of first 8 octets.
|
||||
// 42-47 options and padding
|
||||
tcpH = header.TCP(tcp6ExtHeader[48:])
|
||||
tcpH.Encode(tcpFields)
|
||||
pseudoCsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+payloadLen))
|
||||
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||||
|
||||
tcp4InvalidCsum := make([]byte, len(tcp4))
|
||||
copy(tcp4InvalidCsum, tcp4)
|
||||
at := 20 + 16
|
||||
tcp4InvalidCsum[at] = ^tcp4InvalidCsum[at]
|
||||
|
||||
tcp6ExtHeaderInvalidCsum := make([]byte, len(tcp6ExtHeader))
|
||||
copy(tcp6ExtHeaderInvalidCsum, tcp6ExtHeader)
|
||||
at = 40 + 8 + 16
|
||||
tcp6ExtHeaderInvalidCsum[at] = ^tcp6ExtHeaderInvalidCsum[at]
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantPB bool
|
||||
}{
|
||||
{
|
||||
"tcp4 packet valid csum",
|
||||
tcp4,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"tcp6 with ext header valid csum",
|
||||
tcp6ExtHeader,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"tcp4 packet invalid csum",
|
||||
tcp4InvalidCsum,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"tcp6 with ext header invalid csum",
|
||||
tcp6ExtHeaderInvalidCsum,
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &packet.Parsed{}
|
||||
p.Decode(tt.input)
|
||||
got := rxChecksumOffload(p)
|
||||
if tt.wantPB != (got != nil) {
|
||||
t.Fatalf("wantPB = %v != (got != nil): %v", tt.wantPB, got != nil)
|
||||
}
|
||||
if tt.wantPB {
|
||||
gotBuf := got.ToBuffer()
|
||||
if !bytes.Equal(tt.input, gotBuf.Flatten()) {
|
||||
t.Fatal("output packet unequal to input")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,6 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"expvar"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/refs"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
@@ -244,6 +242,44 @@ const nicID = 1
|
||||
// have a UDP packet as big as the MTU.
|
||||
const maxUDPPacketSize = tstun.MaxPacketSize
|
||||
|
||||
func setTCPBufSizes(ipstack *stack.Stack) error {
|
||||
// tcpip.TCP{Receive,Send}BufferSizeRangeOption is gVisor's version of
|
||||
// Linux's tcp_{r,w}mem. Application within gVisor differs as some Linux
|
||||
// features are not (yet) implemented, and socket buffer memory is not
|
||||
// controlled within gVisor, e.g. we allocate *stack.PacketBuffer's for the
|
||||
// write path within Tailscale. Therefore, we loosen our understanding of
|
||||
// the relationship between these Linux and gVisor tunables. The chosen
|
||||
// values are biased towards higher throughput on high bandwidth-delay
|
||||
// product paths, except on memory-constrained platforms.
|
||||
tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{
|
||||
// Min is unused by gVisor at the time of writing, but partially plumbed
|
||||
// for application by the TCP_WINDOW_CLAMP socket option.
|
||||
Min: tcpRXBufMinSize,
|
||||
// Default is used by gVisor at socket creation.
|
||||
Default: tcpRXBufDefSize,
|
||||
// Max is used by gVisor to cap the advertised receive window post-read.
|
||||
// (tcp_moderate_rcvbuf=true, the default).
|
||||
Max: tcpRXBufMaxSize,
|
||||
}
|
||||
tcpipErr := ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt)
|
||||
if tcpipErr != nil {
|
||||
return fmt.Errorf("could not set TCP RX buf size: %v", tcpipErr)
|
||||
}
|
||||
tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{
|
||||
// Min in unused by gVisor at the time of writing.
|
||||
Min: tcpTXBufMinSize,
|
||||
// Default is used by gVisor at socket creation.
|
||||
Default: tcpTXBufDefSize,
|
||||
// Max is used by gVisor to cap the send window.
|
||||
Max: tcpTXBufMaxSize,
|
||||
}
|
||||
tcpipErr = ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt)
|
||||
if tcpipErr != nil {
|
||||
return fmt.Errorf("could not set TCP TX buf size: %v", tcpipErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create creates and populates a new Impl.
|
||||
func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper, driveForLocal drive.FileSystemForLocal) (*Impl, error) {
|
||||
if mc == nil {
|
||||
@@ -284,10 +320,17 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
|
||||
return nil, fmt.Errorf("could not disable TCP RACK: %v", tcpipErr)
|
||||
}
|
||||
}
|
||||
linkEP := newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "")
|
||||
err := setTCPBufSizes(ipstack)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var linkEP *linkEndpoint
|
||||
if runtime.GOOS == "linux" {
|
||||
// TODO(jwhited): add Windows support https://github.com/tailscale/corp/issues/21874
|
||||
linkEP = newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "", enableGRO)
|
||||
linkEP.SupportedGSOKind = stack.HostGSOSupported
|
||||
} else {
|
||||
linkEP = newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "", disableGRO)
|
||||
}
|
||||
if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
|
||||
return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
|
||||
@@ -336,6 +379,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
|
||||
ns.ctx, ns.ctxCancel = context.WithCancel(context.Background())
|
||||
ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc())
|
||||
ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound
|
||||
ns.tundev.EndPacketVectorInboundFromWireGuardFlush = linkEP.flushGRO
|
||||
ns.tundev.PreFilterPacketOutboundToWireGuardNetstackIntercept = ns.handleLocalPackets
|
||||
stacksForMetrics.Store(ns, struct{}{})
|
||||
return ns, nil
|
||||
@@ -512,9 +556,7 @@ func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error {
|
||||
panic("nil LocalBackend")
|
||||
}
|
||||
ns.lb = lb
|
||||
// size = 0 means use default buffer size
|
||||
const tcpReceiveBufferSize = 0
|
||||
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts(), ns.acceptTCP)
|
||||
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpRXBufDefSize, maxInFlightConnectionAttempts(), ns.acceptTCP)
|
||||
udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
|
||||
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapTCPProtocolHandler(tcpFwd.HandlePacket))
|
||||
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapUDPProtocolHandler(udpFwd.HandlePacket))
|
||||
@@ -737,23 +779,11 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Re
|
||||
// care about the packet; resume processing.
|
||||
return filter.Accept
|
||||
}
|
||||
|
||||
var pn tcpip.NetworkProtocolNumber
|
||||
switch p.IPVersion {
|
||||
case 4:
|
||||
pn = header.IPv4ProtocolNumber
|
||||
case 6:
|
||||
pn = header.IPv6ProtocolNumber
|
||||
}
|
||||
if debugPackets {
|
||||
ns.logf("[v2] service packet in (from %v): % x", p.Src, p.Buffer())
|
||||
}
|
||||
|
||||
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(bytes.Clone(p.Buffer())),
|
||||
})
|
||||
ns.linkEP.InjectInbound(pn, packetBuf)
|
||||
packetBuf.DecRef()
|
||||
ns.linkEP.injectInbound(p)
|
||||
return filter.DropSilently
|
||||
}
|
||||
|
||||
@@ -794,7 +824,7 @@ func (ns *Impl) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet.
|
||||
func (ns *Impl) inject() {
|
||||
for {
|
||||
pkt := ns.linkEP.ReadContext(ns.ctx)
|
||||
if pkt.IsNil() {
|
||||
if pkt == nil {
|
||||
if ns.ctx.Err() != nil {
|
||||
// Return without logging.
|
||||
return
|
||||
@@ -1038,21 +1068,10 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons
|
||||
return filter.DropSilently
|
||||
}
|
||||
|
||||
var pn tcpip.NetworkProtocolNumber
|
||||
switch p.IPVersion {
|
||||
case 4:
|
||||
pn = header.IPv4ProtocolNumber
|
||||
case 6:
|
||||
pn = header.IPv6ProtocolNumber
|
||||
}
|
||||
if debugPackets {
|
||||
ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer())
|
||||
}
|
||||
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(bytes.Clone(p.Buffer())),
|
||||
})
|
||||
ns.linkEP.InjectInbound(pn, packetBuf)
|
||||
packetBuf.DecRef()
|
||||
ns.linkEP.enqueueGRO(p)
|
||||
|
||||
// We've now delivered this to netstack, so we're done.
|
||||
// Instead of returning a filter.Accept here (which would also
|
||||
|
||||
20
wgengine/netstack/netstack_tcpbuf_default.go
Normal file
20
wgengine/netstack/netstack_tcpbuf_default.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !ios
|
||||
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
)
|
||||
|
||||
const (
|
||||
tcpRXBufMinSize = tcp.MinBufferSize
|
||||
tcpRXBufDefSize = tcp.DefaultSendBufferSize
|
||||
tcpRXBufMaxSize = 8 << 20 // 8MiB
|
||||
|
||||
tcpTXBufMinSize = tcp.MinBufferSize
|
||||
tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
|
||||
tcpTXBufMaxSize = 6 << 20 // 6MiB
|
||||
)
|
||||
24
wgengine/netstack/netstack_tcpbuf_ios.go
Normal file
24
wgengine/netstack/netstack_tcpbuf_ios.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build ios
|
||||
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
)
|
||||
|
||||
const (
|
||||
// tcp{RX,TX}Buf{Min,Def,Max}Size mirror gVisor defaults. We leave these
|
||||
// unchanged on iOS for now as to not increase pressure towards the
|
||||
// NetworkExtension memory limit.
|
||||
// TODO(jwhited): test memory/throughput impact of collapsing to values in _default.go
|
||||
tcpRXBufMinSize = tcp.MinBufferSize
|
||||
tcpRXBufDefSize = tcp.DefaultSendBufferSize
|
||||
tcpRXBufMaxSize = tcp.MaxBufferSize
|
||||
|
||||
tcpTXBufMinSize = tcp.MinBufferSize
|
||||
tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
|
||||
tcpTXBufMaxSize = tcp.MaxBufferSize
|
||||
)
|
||||
@@ -374,7 +374,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
|
||||
e.logf("onPortUpdate(port=%v, network=%s)", port, network)
|
||||
|
||||
if err := e.router.UpdateMagicsockPort(port, network); err != nil {
|
||||
e.logf("UpdateMagicsockPort(port=%v, network=%s) failed: %w", port, network, err)
|
||||
e.logf("UpdateMagicsockPort(port=%v, network=%s) failed: %v", port, network, err)
|
||||
}
|
||||
}
|
||||
magicsockOpts := magicsock.Options{
|
||||
|
||||
@@ -692,3 +692,5 @@ azules
|
||||
tabby
|
||||
ussuri
|
||||
kitty
|
||||
tanuki
|
||||
neko
|
||||
|
||||
Reference in New Issue
Block a user