Compare commits
54 Commits
tom/tka2
...
dsnet/tuns
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b63618452c | ||
|
|
91794f6498 | ||
|
|
2c447de6cc | ||
|
|
021bedfb89 | ||
|
|
d988c9f098 | ||
|
|
0607832397 | ||
|
|
565dbc599a | ||
|
|
aadf63da1d | ||
|
|
d5781f61a9 | ||
|
|
a7a0baf6b9 | ||
|
|
e9b98dd2e1 | ||
|
|
b9b0bf65a0 | ||
|
|
c6162c2a94 | ||
|
|
aa5e494aba | ||
|
|
ff13c66f55 | ||
|
|
ed248b04a7 | ||
|
|
8158dd2edc | ||
|
|
6632504f45 | ||
|
|
054ef4de56 | ||
|
|
d045462dfb | ||
|
|
d8eb111ac8 | ||
|
|
832031d54b | ||
|
|
42f1d92ae0 | ||
|
|
41bb47de0e | ||
|
|
3562b5bdfa | ||
|
|
5c42990c2f | ||
|
|
65c24b6334 | ||
|
|
4bda41e701 | ||
|
|
9b71008ef2 | ||
|
|
5623ef0271 | ||
|
|
486eecc063 | ||
|
|
b830c9975f | ||
|
|
4a82b317b7 | ||
|
|
f0347e841f | ||
|
|
027111fb5a | ||
|
|
1ce0e558a7 | ||
|
|
74674b110d | ||
|
|
33ee2c058e | ||
|
|
d34dd43562 | ||
|
|
cf61070e26 | ||
|
|
81574a5c8d | ||
|
|
9c6bdae556 | ||
|
|
82e82d9b7a | ||
|
|
0f16640546 | ||
|
|
aa0064db4d | ||
|
|
45a3de14a6 | ||
|
|
f6da2220d3 | ||
|
|
b22b565947 | ||
|
|
7c49db02a2 | ||
|
|
c312e0d264 | ||
|
|
11fcc3a7b0 | ||
|
|
f03a63910d | ||
|
|
024257ef5a | ||
|
|
eb5939289c |
54
.github/workflows/cross-android.yml
vendored
Normal file
54
.github/workflows/cross-android.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
name: Android-Cross
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Android smoke build
|
||||
# Super minimal Android build that doesn't even use CGO and doesn't build everything that's needed
|
||||
# and is only arm64. But it's a smoke build: it's not meant to catch everything. But it'll catch
|
||||
# some Android breakages early.
|
||||
# TODO(bradfitz): better; see https://github.com/tailscale/tailscale/issues/4482
|
||||
env:
|
||||
GOOS: android
|
||||
GOARCH: arm64
|
||||
run: go install ./net/netns ./ipn/ipnlocal ./wgengine/magicsock/ ./wgengine/ ./wgengine/router/ ./wgengine/netstack ./util/dnsname/ ./ipn/ ./net/interfaces ./wgengine/router/ ./tailcfg/ ./types/logger/ ./net/dns ./hostinfo ./version
|
||||
|
||||
- uses: k0kubun/action-slack@v2.0.0
|
||||
with:
|
||||
payload: |
|
||||
{
|
||||
"attachments": [{
|
||||
"text": "${{ job.status }}: ${{ github.workflow }} <https://github.com/${{ github.repository }}/commit/${{ github.sha }}/checks|${{ env.COMMIT_DATE }} #${{ env.COMMIT_NUMBER_OF_DAY }}> " +
|
||||
"(<https://github.com/${{ github.repository }}/commit/${{ github.sha }}|" + "${{ github.sha }}".substring(0, 10) + ">) " +
|
||||
"of ${{ github.repository }}@" + "${{ github.ref }}".split('/').reverse()[0] + " by ${{ github.event.head_commit.committer.name }}",
|
||||
"color": "danger"
|
||||
}]
|
||||
}
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
|
||||
if: failure() && github.event_name == 'push'
|
||||
@@ -32,7 +32,7 @@
|
||||
# $ docker exec tailscaled tailscale status
|
||||
|
||||
|
||||
FROM golang:1.18-alpine AS build-env
|
||||
FROM golang:1.19-alpine AS build-env
|
||||
|
||||
WORKDIR /go/src/tailscale
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
package atomicfile // import "tailscale.com/atomicfile"
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -18,7 +17,7 @@ import (
|
||||
// WriteFile writes data to filename+some suffix, then renames it
|
||||
// into filename. The perm argument is ignored on Windows.
|
||||
func WriteFile(filename string, data []byte, perm os.FileMode) (err error) {
|
||||
f, err := ioutil.TempFile(filepath.Dir(filename), filepath.Base(filename)+".tmp")
|
||||
f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
@@ -137,7 +136,7 @@ func (lc *LocalClient) doLocalRequestNiceError(req *http.Request) (*http.Respons
|
||||
onVersionMismatch(ipn.IPCVersion(), server)
|
||||
}
|
||||
if res.StatusCode == 403 {
|
||||
all, _ := ioutil.ReadAll(res.Body)
|
||||
all, _ := io.ReadAll(res.Body)
|
||||
return nil, &AccessDeniedError{errors.New(errorMessageFromBody(all))}
|
||||
}
|
||||
return res, nil
|
||||
@@ -207,7 +206,7 @@ func (lc *LocalClient) send(ctx context.Context, method, path string, wantStatus
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
slurp, err := ioutil.ReadAll(res.Body)
|
||||
slurp, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -365,7 +364,7 @@ func (lc *LocalClient) GetWaitingFile(ctx context.Context, baseName string) (rc
|
||||
return nil, 0, fmt.Errorf("unexpected chunking")
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
body, _ := ioutil.ReadAll(res.Body)
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
return nil, 0, fmt.Errorf("HTTP %s: %s", res.Status, body)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -131,7 +130,7 @@ func (c *Client) sendRequest(req *http.Request) ([]byte, *http.Response, error)
|
||||
|
||||
// Read response. Limit the response to 10MB.
|
||||
body := io.LimitReader(resp.Body, maxReadSize+1)
|
||||
b, err := ioutil.ReadAll(body)
|
||||
b, err := io.ReadAll(body)
|
||||
if len(b) > maxReadSize {
|
||||
err = errors.New("API response too large")
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
tailscale.com/safesocket from tailscale.com/client/tailscale
|
||||
tailscale.com/syncs from tailscale.com/cmd/derper+
|
||||
tailscale.com/tailcfg from tailscale.com/client/tailscale+
|
||||
tailscale.com/tka from tailscale.com/client/tailscale
|
||||
tailscale.com/tka from tailscale.com/client/tailscale+
|
||||
W tailscale.com/tsconst from tailscale.com/net/interfaces
|
||||
💣 tailscale.com/tstime/mono from tailscale.com/tstime/rate
|
||||
tailscale.com/tstime/rate from tailscale.com/wgengine/filter
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
@@ -99,7 +98,7 @@ func loadConfig() config {
|
||||
}
|
||||
log.Printf("no config path specified; using %s", *configPath)
|
||||
}
|
||||
b, err := ioutil.ReadFile(*configPath)
|
||||
b, err := os.ReadFile(*configPath)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
return writeNewConfig()
|
||||
@@ -155,7 +154,7 @@ func main() {
|
||||
s.SetVerifyClient(*verifyClients)
|
||||
|
||||
if *meshPSKFile != "" {
|
||||
b, err := ioutil.ReadFile(*meshPSKFile)
|
||||
b, err := os.ReadFile(*meshPSKFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -206,6 +205,7 @@ func main() {
|
||||
mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, "User-agent: *\nDisallow: /\n")
|
||||
}))
|
||||
mux.Handle("/generate_204", http.HandlerFunc(serveNoContent))
|
||||
debug := tsweb.Debugger(mux)
|
||||
debug.KV("TLS hostname", *hostname)
|
||||
debug.KV("Mesh key", s.HasMeshKey())
|
||||
@@ -293,9 +293,12 @@ func main() {
|
||||
})
|
||||
if *httpPort > -1 {
|
||||
go func() {
|
||||
port80mux := http.NewServeMux()
|
||||
port80mux.HandleFunc("/generate_204", serveNoContent)
|
||||
port80mux.Handle("/", certManager.HTTPHandler(tsweb.Port80Handler{Main: mux}))
|
||||
port80srv := &http.Server{
|
||||
Addr: net.JoinHostPort(listenHost, fmt.Sprintf("%d", *httpPort)),
|
||||
Handler: certManager.HTTPHandler(tsweb.Port80Handler{Main: mux}),
|
||||
Handler: port80mux,
|
||||
ErrorLog: quietLogger,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
// Crank up WriteTimeout a bit more than usually
|
||||
@@ -322,6 +325,11 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// For captive portal detection
|
||||
func serveNoContent(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// probeHandler is the endpoint that js/wasm clients hit to measure
|
||||
// DERP latency, since they can't do UDP STUN queries.
|
||||
func probeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -33,6 +33,12 @@ func addWebSocketSupport(s *derp.Server, base http.Handler) http.Handler {
|
||||
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
Subprotocols: []string{"derp"},
|
||||
OriginPatterns: []string{"*"},
|
||||
// Disable compression because we transmit WireGuard messages that
|
||||
// are not compressible.
|
||||
// Additionally, Safari has a broken implementation of compression
|
||||
// (see https://github.com/nhooyr/websocket/issues/218) that makes
|
||||
// enabling it actively harmful.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("websocket.Accept: %v", err)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"html/template"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -106,7 +105,7 @@ func devMode() bool { return *httpsAddr == "" && *httpAddr != "" }
|
||||
|
||||
func getTmpl() (*template.Template, error) {
|
||||
if devMode() {
|
||||
tmplData, err := ioutil.ReadFile("hello.tmpl.html")
|
||||
tmplData, err := os.ReadFile("hello.tmpl.html")
|
||||
if os.IsNotExist(err) {
|
||||
log.Printf("using baked-in template in dev mode; can't find hello.tmpl.html in current directory")
|
||||
return tmpl, nil
|
||||
|
||||
@@ -789,6 +789,10 @@ func TestUpdatePrefs(t *testing.T) {
|
||||
curPrefs *ipn.Prefs
|
||||
env upCheckEnv // empty goos means "linux"
|
||||
|
||||
// sshOverTailscale specifies if the cmd being run over SSH over Tailscale.
|
||||
// It is used to test the --accept-risks flag.
|
||||
sshOverTailscale bool
|
||||
|
||||
// checkUpdatePrefsMutations, if non-nil, is run with the new prefs after
|
||||
// updatePrefs might've mutated them (from applyImplicitPrefs).
|
||||
checkUpdatePrefsMutations func(t *testing.T, newPrefs *ipn.Prefs)
|
||||
@@ -916,15 +920,159 @@ func TestUpdatePrefs(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enable_ssh",
|
||||
flags: []string{"--ssh"},
|
||||
curPrefs: &ipn.Prefs{
|
||||
ControlURL: "https://login.tailscale.com",
|
||||
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||
AllowSingleHosts: true,
|
||||
CorpDNS: true,
|
||||
NetfilterMode: preftype.NetfilterOn,
|
||||
},
|
||||
wantJustEditMP: &ipn.MaskedPrefs{
|
||||
RunSSHSet: true,
|
||||
WantRunningSet: true,
|
||||
},
|
||||
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||
if !newPrefs.RunSSH {
|
||||
t.Errorf("RunSSH not set to true")
|
||||
}
|
||||
},
|
||||
env: upCheckEnv{backendState: "Running"},
|
||||
},
|
||||
{
|
||||
name: "disable_ssh",
|
||||
flags: []string{"--ssh=false"},
|
||||
curPrefs: &ipn.Prefs{
|
||||
ControlURL: "https://login.tailscale.com",
|
||||
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||
AllowSingleHosts: true,
|
||||
CorpDNS: true,
|
||||
RunSSH: true,
|
||||
NetfilterMode: preftype.NetfilterOn,
|
||||
},
|
||||
wantJustEditMP: &ipn.MaskedPrefs{
|
||||
RunSSHSet: true,
|
||||
WantRunningSet: true,
|
||||
},
|
||||
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||
if newPrefs.RunSSH {
|
||||
t.Errorf("RunSSH not set to false")
|
||||
}
|
||||
},
|
||||
env: upCheckEnv{backendState: "Running", upArgs: upArgsT{
|
||||
runSSH: true,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "disable_ssh_over_ssh_no_risk",
|
||||
flags: []string{"--ssh=false"},
|
||||
sshOverTailscale: true,
|
||||
curPrefs: &ipn.Prefs{
|
||||
ControlURL: "https://login.tailscale.com",
|
||||
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||
AllowSingleHosts: true,
|
||||
CorpDNS: true,
|
||||
NetfilterMode: preftype.NetfilterOn,
|
||||
RunSSH: true,
|
||||
},
|
||||
wantJustEditMP: &ipn.MaskedPrefs{
|
||||
RunSSHSet: true,
|
||||
WantRunningSet: true,
|
||||
},
|
||||
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||
if !newPrefs.RunSSH {
|
||||
t.Errorf("RunSSH not set to true")
|
||||
}
|
||||
},
|
||||
env: upCheckEnv{backendState: "Running"},
|
||||
wantErrSubtr: "aborted, no changes made",
|
||||
},
|
||||
{
|
||||
name: "enable_ssh_over_ssh_no_risk",
|
||||
flags: []string{"--ssh=true"},
|
||||
sshOverTailscale: true,
|
||||
curPrefs: &ipn.Prefs{
|
||||
ControlURL: "https://login.tailscale.com",
|
||||
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||
AllowSingleHosts: true,
|
||||
CorpDNS: true,
|
||||
NetfilterMode: preftype.NetfilterOn,
|
||||
},
|
||||
wantJustEditMP: &ipn.MaskedPrefs{
|
||||
RunSSHSet: true,
|
||||
WantRunningSet: true,
|
||||
},
|
||||
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||
if !newPrefs.RunSSH {
|
||||
t.Errorf("RunSSH not set to true")
|
||||
}
|
||||
},
|
||||
env: upCheckEnv{backendState: "Running"},
|
||||
wantErrSubtr: "aborted, no changes made",
|
||||
},
|
||||
{
|
||||
name: "enable_ssh_over_ssh",
|
||||
flags: []string{"--ssh=true", "--accept-risk=lose-ssh"},
|
||||
sshOverTailscale: true,
|
||||
curPrefs: &ipn.Prefs{
|
||||
ControlURL: "https://login.tailscale.com",
|
||||
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||
AllowSingleHosts: true,
|
||||
CorpDNS: true,
|
||||
NetfilterMode: preftype.NetfilterOn,
|
||||
},
|
||||
wantJustEditMP: &ipn.MaskedPrefs{
|
||||
RunSSHSet: true,
|
||||
WantRunningSet: true,
|
||||
},
|
||||
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||
if !newPrefs.RunSSH {
|
||||
t.Errorf("RunSSH not set to true")
|
||||
}
|
||||
},
|
||||
env: upCheckEnv{backendState: "Running"},
|
||||
},
|
||||
{
|
||||
name: "disable_ssh_over_ssh",
|
||||
flags: []string{"--ssh=false", "--accept-risk=lose-ssh"},
|
||||
sshOverTailscale: true,
|
||||
curPrefs: &ipn.Prefs{
|
||||
ControlURL: "https://login.tailscale.com",
|
||||
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||
AllowSingleHosts: true,
|
||||
CorpDNS: true,
|
||||
RunSSH: true,
|
||||
NetfilterMode: preftype.NetfilterOn,
|
||||
},
|
||||
wantJustEditMP: &ipn.MaskedPrefs{
|
||||
RunSSHSet: true,
|
||||
WantRunningSet: true,
|
||||
},
|
||||
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||
if newPrefs.RunSSH {
|
||||
t.Errorf("RunSSH not set to false")
|
||||
}
|
||||
},
|
||||
env: upCheckEnv{backendState: "Running"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.sshOverTailscale {
|
||||
old := getSSHClientEnvVar
|
||||
getSSHClientEnvVar = func() string { return "100.100.100.100 1 1" }
|
||||
t.Cleanup(func() { getSSHClientEnvVar = old })
|
||||
}
|
||||
if tt.env.goos == "" {
|
||||
tt.env.goos = "linux"
|
||||
}
|
||||
tt.env.flagSet = newUpFlagSet(tt.env.goos, &tt.env.upArgs)
|
||||
flags := CleanUpArgs(tt.flags)
|
||||
tt.env.flagSet.Parse(flags)
|
||||
if err := tt.env.flagSet.Parse(flags); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newPrefs, err := prefsFromUpArgs(tt.env.upArgs, t.Logf, new(ipnstate.Status), tt.env.goos)
|
||||
if err != nil {
|
||||
@@ -939,6 +1087,8 @@ func TestUpdatePrefs(t *testing.T) {
|
||||
return
|
||||
}
|
||||
t.Fatal(err)
|
||||
} else if tt.wantErrSubtr != "" {
|
||||
t.Fatalf("want error %q, got nil", tt.wantErrSubtr)
|
||||
}
|
||||
if tt.checkUpdatePrefsMutations != nil {
|
||||
tt.checkUpdatePrefsMutations(t, newPrefs)
|
||||
@@ -952,13 +1102,18 @@ func TestUpdatePrefs(t *testing.T) {
|
||||
justEditMP.Prefs = ipn.Prefs{} // uninteresting
|
||||
}
|
||||
if !reflect.DeepEqual(justEditMP, tt.wantJustEditMP) {
|
||||
t.Logf("justEditMP != wantJustEditMP; following diff omits the Prefs field, which was %+v", oldEditPrefs)
|
||||
t.Logf("justEditMP != wantJustEditMP; following diff omits the Prefs field, which was \n%v", asJSON(oldEditPrefs))
|
||||
t.Fatalf("justEditMP: %v\n\n: ", cmp.Diff(justEditMP, tt.wantJustEditMP, cmpIP))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func asJSON(v any) string {
|
||||
b, _ := json.MarshalIndent(v, "", "\t")
|
||||
return string(b)
|
||||
}
|
||||
|
||||
var cmpIP = cmp.Comparer(func(a, b netip.Addr) bool {
|
||||
return a == b
|
||||
})
|
||||
|
||||
@@ -48,11 +48,11 @@ func runConfigureHost(ctx context.Context, args []string) error {
|
||||
if uid := os.Getuid(); uid != 0 {
|
||||
return fmt.Errorf("must be run as root, not %q (%v)", os.Getenv("USER"), uid)
|
||||
}
|
||||
osVer := hostinfo.GetOSVersion()
|
||||
isDSM6 := strings.HasPrefix(osVer, "Synology 6")
|
||||
isDSM7 := strings.HasPrefix(osVer, "Synology 7")
|
||||
hi:= hostinfo.New()
|
||||
isDSM6 := strings.HasPrefix(hi.DistroVersion, "6.")
|
||||
isDSM7 := strings.HasPrefix(hi.DistroVersion, "7.")
|
||||
if !isDSM6 && !isDSM7 {
|
||||
return fmt.Errorf("unsupported DSM version %q", osVer)
|
||||
return fmt.Errorf("unsupported DSM version %q", hi.DistroVersion)
|
||||
}
|
||||
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll("/dev/net", 0755); err != nil {
|
||||
|
||||
@@ -489,7 +489,15 @@ func runTS2021(ctx context.Context, args []string) error {
|
||||
return c, err
|
||||
}
|
||||
|
||||
conn, err := controlhttp.Dial(ctx, ts2021Args.host, "80", "443", machinePrivate, keys.PublicKey, uint16(ts2021Args.version), dialFunc)
|
||||
conn, err := (&controlhttp.Dialer{
|
||||
Hostname: ts2021Args.host,
|
||||
HTTPPort: "80",
|
||||
HTTPSPort: "443",
|
||||
MachineKey: machinePrivate,
|
||||
ControlKey: keys.PublicKey,
|
||||
ProtocolVersion: uint16(ts2021Args.version),
|
||||
Dialer: dialFunc,
|
||||
}).Dial(ctx)
|
||||
log.Printf("controlhttp.Dial = %p, %v", conn, err)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -22,9 +22,13 @@ var downCmd = &ffcli.Command{
|
||||
FlagSet: newDownFlagSet(),
|
||||
}
|
||||
|
||||
var downArgs struct {
|
||||
acceptedRisks string
|
||||
}
|
||||
|
||||
func newDownFlagSet() *flag.FlagSet {
|
||||
downf := newFlagSet("down")
|
||||
registerAcceptRiskFlag(downf)
|
||||
registerAcceptRiskFlag(downf, &downArgs.acceptedRisks)
|
||||
return downf
|
||||
}
|
||||
|
||||
@@ -34,7 +38,7 @@ func runDown(ctx context.Context, args []string) error {
|
||||
}
|
||||
|
||||
if isSSHOverTailscale() {
|
||||
if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`); err != nil {
|
||||
if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`, downArgs.acceptedRisks); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,24 +19,27 @@ var licensesCmd = &ffcli.Command{
|
||||
Exec: runLicenses,
|
||||
}
|
||||
|
||||
func runLicenses(ctx context.Context, args []string) error {
|
||||
var licenseURL string
|
||||
// licensesURL returns the absolute URL containing open source license information for the current platform.
|
||||
func licensesURL() string {
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
licenseURL = "https://tailscale.com/licenses/android"
|
||||
return "https://tailscale.com/licenses/android"
|
||||
case "darwin", "ios":
|
||||
licenseURL = "https://tailscale.com/licenses/apple"
|
||||
return "https://tailscale.com/licenses/apple"
|
||||
case "windows":
|
||||
licenseURL = "https://tailscale.com/licenses/windows"
|
||||
return "https://tailscale.com/licenses/windows"
|
||||
default:
|
||||
licenseURL = "https://tailscale.com/licenses/tailscale"
|
||||
return "https://tailscale.com/licenses/tailscale"
|
||||
}
|
||||
}
|
||||
|
||||
func runLicenses(ctx context.Context, args []string) error {
|
||||
licenses := licensesURL()
|
||||
outln(`
|
||||
Tailscale wouldn't be possible without the contributions of thousands of open
|
||||
source developers. To see the open source packages included in Tailscale and
|
||||
their respective license information, visit:
|
||||
|
||||
` + licenseURL)
|
||||
` + licenses)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"sort"
|
||||
@@ -134,6 +133,9 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error {
|
||||
printf("\t* MappingVariesByDestIP: %v\n", report.MappingVariesByDestIP)
|
||||
printf("\t* HairPinning: %v\n", report.HairPinning)
|
||||
printf("\t* PortMapping: %v\n", portMapping(report))
|
||||
if report.CaptivePortal != "" {
|
||||
printf("\t* CaptivePortal: %v\n", report.CaptivePortal)
|
||||
}
|
||||
|
||||
// When DERP latency checking failed,
|
||||
// magicsock will try to pick the DERP server that
|
||||
@@ -202,7 +204,7 @@ func prodDERPMap(ctx context.Context, httpc *http.Client) (*tailcfg.DERPMap, err
|
||||
return nil, fmt.Errorf("fetch prodDERPMap failed: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
b, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20))
|
||||
b, err := io.ReadAll(io.LimitReader(res.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch prodDERPMap failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -16,9 +16,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
riskTypes []string
|
||||
acceptedRisks string
|
||||
riskLoseSSH = registerRiskType("lose-ssh")
|
||||
riskTypes []string
|
||||
riskLoseSSH = registerRiskType("lose-ssh")
|
||||
)
|
||||
|
||||
func registerRiskType(riskType string) string {
|
||||
@@ -28,12 +27,13 @@ func registerRiskType(riskType string) string {
|
||||
|
||||
// registerAcceptRiskFlag registers the --accept-risk flag. Accepted risks are accounted for
|
||||
// in presentRiskToUser.
|
||||
func registerAcceptRiskFlag(f *flag.FlagSet) {
|
||||
f.StringVar(&acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ","))
|
||||
func registerAcceptRiskFlag(f *flag.FlagSet, acceptedRisks *string) {
|
||||
f.StringVar(acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ","))
|
||||
}
|
||||
|
||||
// riskAccepted reports whether riskType is in acceptedRisks.
|
||||
func riskAccepted(riskType string) bool {
|
||||
// isRiskAccepted reports whether riskType is in the comma-separated list of
|
||||
// risks in acceptedRisks.
|
||||
func isRiskAccepted(riskType, acceptedRisks string) bool {
|
||||
for _, r := range strings.Split(acceptedRisks, ",") {
|
||||
if r == riskType {
|
||||
return true
|
||||
@@ -49,12 +49,16 @@ var errAborted = errors.New("aborted, no changes made")
|
||||
// It is used by the presentRiskToUser function below.
|
||||
const riskAbortTimeSeconds = 5
|
||||
|
||||
// presentRiskToUser displays the risk message and waits for the user to
|
||||
// cancel. It returns errorAborted if the user aborts.
|
||||
func presentRiskToUser(riskType, riskMessage string) error {
|
||||
if riskAccepted(riskType) {
|
||||
// presentRiskToUser displays the risk message and waits for the user to cancel.
|
||||
// It returns errorAborted if the user aborts. In tests it returns errAborted
|
||||
// immediately unless the risk has been explicitly accepted.
|
||||
func presentRiskToUser(riskType, riskMessage, acceptedRisks string) error {
|
||||
if isRiskAccepted(riskType, acceptedRisks) {
|
||||
return nil
|
||||
}
|
||||
if inTest() {
|
||||
return errAborted
|
||||
}
|
||||
outln(riskMessage)
|
||||
printf("To skip this warning, use --accept-risk=%s\n", riskType)
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT) *flag.FlagSet {
|
||||
upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)")
|
||||
}
|
||||
upf.DurationVar(&upArgs.timeout, "timeout", 0, "maximum amount of time to wait for tailscaled to enter a Running state; default (0s) blocks forever")
|
||||
registerAcceptRiskFlag(upf)
|
||||
registerAcceptRiskFlag(upf, &upArgs.acceptedRisks)
|
||||
return upf
|
||||
}
|
||||
|
||||
@@ -150,6 +150,7 @@ type upArgsT struct {
|
||||
opUser string
|
||||
json bool
|
||||
timeout time.Duration
|
||||
acceptedRisks string
|
||||
}
|
||||
|
||||
func (a upArgsT) getAuthKey() (string, error) {
|
||||
@@ -376,6 +377,20 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus
|
||||
return false, nil, fmt.Errorf("can't change --login-server without --force-reauth")
|
||||
}
|
||||
|
||||
// Do this after validations to avoid the 5s delay if we're going to error
|
||||
// out anyway.
|
||||
wantSSH, haveSSH := env.upArgs.runSSH, curPrefs.RunSSH
|
||||
if wantSSH != haveSSH && isSSHOverTailscale() {
|
||||
if wantSSH {
|
||||
err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`, env.upArgs.acceptedRisks)
|
||||
} else {
|
||||
err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`, env.upArgs.acceptedRisks)
|
||||
}
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
tagsChanged := !reflect.DeepEqual(curPrefs.AdvertiseTags, prefs.AdvertiseTags)
|
||||
|
||||
simpleUp = env.flagSet.NFlag() == 0 &&
|
||||
@@ -475,17 +490,6 @@ func runUp(ctx context.Context, args []string) (retErr error) {
|
||||
curExitNodeIP: exitNodeIP(curPrefs, st),
|
||||
}
|
||||
|
||||
if upArgs.runSSH != curPrefs.RunSSH && isSSHOverTailscale() {
|
||||
if upArgs.runSSH {
|
||||
err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`)
|
||||
} else {
|
||||
err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if retErr == nil {
|
||||
checkSSHUpWarnings(ctx)
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -59,6 +58,7 @@ type tmplData struct {
|
||||
IP string
|
||||
AdvertiseExitNode bool
|
||||
AdvertiseRoutes string
|
||||
LicensesURL string
|
||||
}
|
||||
|
||||
var webCmd = &ffcli.Command{
|
||||
@@ -253,7 +253,7 @@ func qnapAuthnFinish(user, url string) (string, *qnapAuthResponse, error) {
|
||||
return "", nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, err := ioutil.ReadAll(resp.Body)
|
||||
out, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@@ -392,6 +392,7 @@ func webHandler(w http.ResponseWriter, r *http.Request) {
|
||||
Profile: profile,
|
||||
Status: st.BackendState,
|
||||
DeviceName: deviceName,
|
||||
LicensesURL: licensesURL(),
|
||||
}
|
||||
exitNodeRouteV4 := netip.MustParsePrefix("0.0.0.0/0")
|
||||
exitNodeRouteV6 := netip.MustParsePrefix("::/0")
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
</head>
|
||||
|
||||
<body class="py-14">
|
||||
<main class="container max-w-lg mx-auto py-6 px-8 bg-white rounded-md shadow-2xl" style="width: 95%">
|
||||
<main class="container max-w-lg mx-auto mb-8 py-6 px-8 bg-white rounded-md shadow-2xl" style="width: 95%">
|
||||
<header class="flex justify-between items-center min-width-0 py-2 mb-8">
|
||||
<svg width="26" height="26" viewBox="0 0 23 23" title="Tailscale" fill="none" xmlns="http://www.w3.org/2000/svg"
|
||||
class="flex-shrink-0 mr-4">
|
||||
@@ -100,6 +100,9 @@
|
||||
</div>
|
||||
{{ end }}
|
||||
</main>
|
||||
<footer class="container max-w-lg mx-auto text-center">
|
||||
<a class="text-xs text-gray-500 hover:text-gray-600" href="{{ .LicensesURL }}">Open Source Licenses</a>
|
||||
</footer>
|
||||
<script>(function () {
|
||||
const advertiseExitNode = {{.AdvertiseExitNode}};
|
||||
let fetchingUrl = false;
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -173,7 +172,7 @@ func checkDerp(ctx context.Context, derpRegion string) error {
|
||||
return fmt.Errorf("fetch derp map failed: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
b, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20))
|
||||
b, err := io.ReadAll(io.LimitReader(res.Body, 1<<20))
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch derp map failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -142,7 +141,7 @@ func installSystemDaemonDarwin(args []string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(sysPlist, []byte(darwinLaunchdPlist), 0700); err != nil {
|
||||
if err := os.WriteFile(sysPlist, []byte(darwinLaunchdPlist), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -97,6 +98,20 @@ func defaultTunName() string {
|
||||
return "tailscale0"
|
||||
}
|
||||
|
||||
// defaultPort returns the default UDP port to listen on for disco+wireguard.
|
||||
// By default it returns 0, to pick one randomly from the kernel.
|
||||
// If the environment variable PORT is set, that's used instead.
|
||||
// The PORT environment variable is chosen to match what the Linux systemd
|
||||
// unit uses, to make documentation more consistent.
|
||||
func defaultPort() uint16 {
|
||||
if s := envknob.String("PORT"); s != "" {
|
||||
if p, err := strconv.ParseUint(s, 10, 16); err == nil {
|
||||
return uint16(p)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
var args struct {
|
||||
// tunname is a /dev/net/tun tunnel name ("tailscale0"), the
|
||||
// string "userspace-networking", "tap:TAPNAME[:BRIDGENAME]"
|
||||
@@ -113,6 +128,7 @@ var args struct {
|
||||
verbose int
|
||||
socksAddr string // listen address for SOCKS5 server
|
||||
httpProxyAddr string // listen address for HTTP proxy server
|
||||
disableLogs bool
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -131,6 +147,9 @@ var subCommands = map[string]*func([]string) error{
|
||||
var beCLI func() // non-nil if CLI is linked in
|
||||
|
||||
func main() {
|
||||
envknob.PanicIfAnyEnvCheckedInInit()
|
||||
envknob.ApplyDiskConfig()
|
||||
|
||||
printVersion := false
|
||||
flag.IntVar(&args.verbose, "verbose", 0, "log verbosity level; 0 is default, 1 or higher are increasingly verbose")
|
||||
flag.BoolVar(&args.cleanup, "cleanup", false, "clean up system state and exit")
|
||||
@@ -138,12 +157,13 @@ func main() {
|
||||
flag.StringVar(&args.socksAddr, "socks5-server", "", `optional [ip]:port to run a SOCK5 server (e.g. "localhost:1080")`)
|
||||
flag.StringVar(&args.httpProxyAddr, "outbound-http-proxy-listen", "", `optional [ip]:port to run an outbound HTTP proxy (e.g. "localhost:8080")`)
|
||||
flag.StringVar(&args.tunname, "tun", defaultTunName(), `tunnel interface name; use "userspace-networking" (beta) to not use TUN`)
|
||||
flag.Var(flagtype.PortValue(&args.port, 0), "port", "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
|
||||
flag.Var(flagtype.PortValue(&args.port, defaultPort()), "port", "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
|
||||
flag.StringVar(&args.statepath, "state", "", "absolute path of state file; use 'kube:<secret-name>' to use Kubernetes secrets or 'arn:aws:ssm:...' to store in AWS SSM; use 'mem:' to not store state and register as an emphemeral node. If empty and --statedir is provided, the default is <statedir>/tailscaled.state. Default: "+paths.DefaultTailscaledStateFile())
|
||||
flag.StringVar(&args.statedir, "statedir", "", "path to directory for storage of config state, TLS certs, temporary incoming Taildrop files, etc. If empty, it's derived from --state when possible.")
|
||||
flag.StringVar(&args.socketpath, "socket", paths.DefaultTailscaledSocket(), "path of the service unix socket")
|
||||
flag.StringVar(&args.birdSocketPath, "bird-socket", "", "path of the bird unix socket")
|
||||
flag.BoolVar(&printVersion, "version", false, "print version information and exit")
|
||||
flag.BoolVar(&args.disableLogs, "no-logs-no-support", false, "disable log uploads; this also disables any technical support")
|
||||
|
||||
if len(os.Args) > 0 && filepath.Base(os.Args[0]) == "tailscale" && beCLI != nil {
|
||||
beCLI()
|
||||
@@ -199,6 +219,10 @@ func main() {
|
||||
args.statepath = paths.DefaultTailscaledStateFile()
|
||||
}
|
||||
|
||||
if args.disableLogs {
|
||||
envknob.SetNoLogsNoSupport()
|
||||
}
|
||||
|
||||
if beWindowsSubprocess() {
|
||||
return
|
||||
}
|
||||
@@ -302,6 +326,10 @@ func run() error {
|
||||
pol.Shutdown(ctx)
|
||||
}()
|
||||
|
||||
if err := envknob.ApplyDiskConfigError(); err != nil {
|
||||
log.Printf("Error reading environment config: %v", err)
|
||||
}
|
||||
|
||||
if isWindowsService() {
|
||||
// Run the IPN server from the Windows service manager.
|
||||
log.Printf("Running service...")
|
||||
@@ -370,7 +398,7 @@ func run() error {
|
||||
return fmt.Errorf("newNetstack: %w", err)
|
||||
}
|
||||
ns.ProcessLocalIPs = useNetstack
|
||||
ns.ProcessSubnets = useNetstack || wrapNetstack
|
||||
ns.ProcessSubnets = useNetstack || shouldWrapNetstack()
|
||||
|
||||
if useNetstack {
|
||||
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
|
||||
@@ -471,8 +499,6 @@ func createEngine(logf logger.Logf, linkMon *monitor.Mon, dialer *tsdial.Dialer)
|
||||
return nil, false, multierr.New(errs...)
|
||||
}
|
||||
|
||||
var wrapNetstack = shouldWrapNetstack()
|
||||
|
||||
func shouldWrapNetstack() bool {
|
||||
if v, ok := envknob.LookupBool("TS_DEBUG_WRAP_NETSTACK"); ok {
|
||||
return v
|
||||
@@ -543,7 +569,7 @@ func tryEngine(logf logger.Logf, linkMon *monitor.Mon, dialer *tsdial.Dialer, na
|
||||
}
|
||||
conf.DNS = d
|
||||
conf.Router = r
|
||||
if wrapNetstack {
|
||||
if shouldWrapNetstack() {
|
||||
conf.Router = netstack.NewSubnetRouterWrapper(conf.Router)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ After=network-pre.target NetworkManager.service systemd-resolved.service
|
||||
[Service]
|
||||
EnvironmentFile=/etc/default/tailscaled
|
||||
ExecStartPre=/usr/sbin/tailscaled --cleanup
|
||||
ExecStart=/usr/sbin/tailscaled --state=/var/lib/tailscale/tailscaled.state --socket=/run/tailscale/tailscaled.sock --port $PORT $FLAGS
|
||||
ExecStart=/usr/sbin/tailscaled --state=/var/lib/tailscale/tailscaled.state --socket=/run/tailscale/tailscaled.sock --port=${PORT} $FLAGS
|
||||
ExecStopPost=/usr/sbin/tailscaled --cleanup
|
||||
|
||||
Restart=on-failure
|
||||
|
||||
@@ -197,6 +197,9 @@ func beWindowsSubprocess() bool {
|
||||
|
||||
log.Printf("Program starting: v%v: %#v", version.Long, os.Args)
|
||||
log.Printf("subproc mode: logid=%v", logid)
|
||||
if err := envknob.ApplyDiskConfigError(); err != nil {
|
||||
log.Printf("Error reading environment config: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
b := make([]byte, 16)
|
||||
@@ -274,7 +277,7 @@ func startIPNServer(ctx context.Context, logid string) error {
|
||||
dev.Close()
|
||||
return nil, nil, fmt.Errorf("router: %w", err)
|
||||
}
|
||||
if wrapNetstack {
|
||||
if shouldWrapNetstack() {
|
||||
r = netstack.NewSubnetRouterWrapper(r)
|
||||
}
|
||||
d, err := dns.NewOSConfigurator(logf, devName)
|
||||
@@ -301,7 +304,7 @@ func startIPNServer(ctx context.Context, logid string) error {
|
||||
return nil, nil, fmt.Errorf("newNetstack: %w", err)
|
||||
}
|
||||
ns.ProcessLocalIPs = false
|
||||
ns.ProcessSubnets = wrapNetstack
|
||||
ns.ProcessSubnets = shouldWrapNetstack()
|
||||
if err := ns.Start(); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to start netstack: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
@@ -47,7 +46,7 @@ func runBuild() {
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot fix esbuild metadata paths: %v", err)
|
||||
}
|
||||
if err := ioutil.WriteFile(path.Join(*distDir, "/esbuild-metadata.json"), metadataBytes, 0666); err != nil {
|
||||
if err := os.WriteFile(path.Join(*distDir, "/esbuild-metadata.json"), metadataBytes, 0666); err != nil {
|
||||
log.Fatalf("Cannot write metadata: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
@@ -183,7 +182,7 @@ func setupEsbuildWasm(build esbuild.PluginBuild, dev bool) {
|
||||
|
||||
func buildWasm(dev bool) ([]byte, error) {
|
||||
start := time.Now()
|
||||
outputFile, err := ioutil.TempFile("", "main.*.wasm")
|
||||
outputFile, err := os.CreateTemp("", "main.*.wasm")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Cannot create main.wasm output file: %w", err)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -75,7 +74,7 @@ func generateServeIndex(distFS fs.FS) ([]byte, error) {
|
||||
return nil, fmt.Errorf("Could not open esbuild-metadata.json: %w", err)
|
||||
}
|
||||
defer esbuildMetadataFile.Close()
|
||||
esbuildMetadataBytes, err := ioutil.ReadAll(esbuildMetadataFile)
|
||||
esbuildMetadataBytes, err := io.ReadAll(esbuildMetadataFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not read esbuild-metadata.json: %w", err)
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ function SSHSession({
|
||||
const ref = useRef<HTMLDivElement>(null)
|
||||
useEffect(() => {
|
||||
if (ref.current) {
|
||||
runSSHSession(ref.current, def, ipn, onDone)
|
||||
runSSHSession(ref.current, def, ipn, onDone, (err) => console.error(err))
|
||||
}
|
||||
}, [ref])
|
||||
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
import { Terminal } from "xterm"
|
||||
import { Terminal, ITerminalOptions } from "xterm"
|
||||
import { FitAddon } from "xterm-addon-fit"
|
||||
import { WebLinksAddon } from "xterm-addon-web-links"
|
||||
|
||||
export type SSHSessionDef = {
|
||||
username: string
|
||||
hostname: string
|
||||
/** Defaults to 5 seconds */
|
||||
timeoutSeconds?: number
|
||||
}
|
||||
|
||||
export function runSSHSession(
|
||||
termContainerNode: HTMLDivElement,
|
||||
def: SSHSessionDef,
|
||||
ipn: IPN,
|
||||
onDone: () => void
|
||||
onDone: () => void,
|
||||
onError?: (err: string) => void,
|
||||
terminalOptions?: ITerminalOptions
|
||||
) {
|
||||
const parentWindow = termContainerNode.ownerDocument.defaultView ?? window
|
||||
const term = new Terminal({
|
||||
cursorBlink: true,
|
||||
allowProposedApi: true,
|
||||
...terminalOptions,
|
||||
})
|
||||
|
||||
const fitAddon = new FitAddon()
|
||||
@@ -23,7 +29,9 @@ export function runSSHSession(
|
||||
term.open(termContainerNode)
|
||||
fitAddon.fit()
|
||||
|
||||
const webLinksAddon = new WebLinksAddon()
|
||||
const webLinksAddon = new WebLinksAddon((event, uri) =>
|
||||
event.view?.open(uri, "_blank", "noopener")
|
||||
)
|
||||
term.loadAddon(webLinksAddon)
|
||||
|
||||
let onDataHook: ((data: string) => void) | undefined
|
||||
@@ -41,7 +49,7 @@ export function runSSHSession(
|
||||
term.write(input)
|
||||
},
|
||||
writeErrorFn(err) {
|
||||
console.error(err)
|
||||
onError?.(err)
|
||||
term.write(err)
|
||||
},
|
||||
setReadFn(hook) {
|
||||
@@ -53,22 +61,20 @@ export function runSSHSession(
|
||||
resizeObserver?.disconnect()
|
||||
term.dispose()
|
||||
if (handleBeforeUnload) {
|
||||
window.removeEventListener("beforeunload", handleBeforeUnload)
|
||||
parentWindow.removeEventListener("beforeunload", handleBeforeUnload)
|
||||
}
|
||||
onDone()
|
||||
},
|
||||
timeoutSeconds: def.timeoutSeconds,
|
||||
})
|
||||
|
||||
// Make terminal and SSH session track the size of the containing DOM node.
|
||||
resizeObserver =
|
||||
new termContainerNode.ownerDocument.defaultView!.ResizeObserver(() =>
|
||||
fitAddon.fit()
|
||||
)
|
||||
resizeObserver = new parentWindow.ResizeObserver(() => fitAddon.fit())
|
||||
resizeObserver.observe(termContainerNode)
|
||||
term.onResize(({ rows, cols }) => sshSession.resize(rows, cols))
|
||||
|
||||
// Close the session if the user closes the window without an explicit
|
||||
// exit.
|
||||
handleBeforeUnload = () => sshSession.close()
|
||||
window.addEventListener("beforeunload", handleBeforeUnload)
|
||||
parentWindow.addEventListener("beforeunload", handleBeforeUnload)
|
||||
}
|
||||
|
||||
3
cmd/tsconnect/src/types/wasm_js.d.ts
vendored
3
cmd/tsconnect/src/types/wasm_js.d.ts
vendored
@@ -23,6 +23,8 @@ declare global {
|
||||
setReadFn: (readFn: (data: string) => void) => void
|
||||
rows: number
|
||||
cols: number
|
||||
/** Defaults to 5 seconds */
|
||||
timeoutSeconds?: number
|
||||
onDone: () => void
|
||||
}
|
||||
): IPNSSHSession
|
||||
@@ -47,6 +49,7 @@ declare global {
|
||||
stateStorage?: IPNStateStorage
|
||||
authKey?: string
|
||||
controlURL?: string
|
||||
hostname?: string
|
||||
}
|
||||
|
||||
type IPNCallbacks = {
|
||||
|
||||
@@ -61,26 +61,30 @@ func main() {
|
||||
func newIPN(jsConfig js.Value) map[string]any {
|
||||
netns.SetEnabled(false)
|
||||
|
||||
jsStateStorage := jsConfig.Get("stateStorage")
|
||||
var store ipn.StateStore
|
||||
if jsStateStorage.IsUndefined() {
|
||||
store = new(mem.Store)
|
||||
} else {
|
||||
if jsStateStorage := jsConfig.Get("stateStorage"); !jsStateStorage.IsUndefined() {
|
||||
store = &jsStateStore{jsStateStorage}
|
||||
} else {
|
||||
store = new(mem.Store)
|
||||
}
|
||||
|
||||
jsControlURL := jsConfig.Get("controlURL")
|
||||
controlURL := ControlURL
|
||||
if jsControlURL.Type() == js.TypeString {
|
||||
if jsControlURL := jsConfig.Get("controlURL"); jsControlURL.Type() == js.TypeString {
|
||||
controlURL = jsControlURL.String()
|
||||
}
|
||||
|
||||
jsAuthKey := jsConfig.Get("authKey")
|
||||
var authKey string
|
||||
if jsAuthKey.Type() == js.TypeString {
|
||||
if jsAuthKey := jsConfig.Get("authKey"); jsAuthKey.Type() == js.TypeString {
|
||||
authKey = jsAuthKey.String()
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if jsHostname := jsConfig.Get("hostname"); jsHostname.Type() == js.TypeString {
|
||||
hostname = jsHostname.String()
|
||||
} else {
|
||||
hostname = generateHostname()
|
||||
}
|
||||
|
||||
lpc := getOrCreateLogPolicyConfig(store)
|
||||
c := logtail.Config{
|
||||
Collection: lpc.Collection,
|
||||
@@ -136,6 +140,7 @@ func newIPN(jsConfig js.Value) map[string]any {
|
||||
lb: lb,
|
||||
controlURL: controlURL,
|
||||
authKey: authKey,
|
||||
hostname: hostname,
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
@@ -196,6 +201,7 @@ type jsIPN struct {
|
||||
lb *ipnlocal.LocalBackend
|
||||
controlURL string
|
||||
authKey string
|
||||
hostname string
|
||||
}
|
||||
|
||||
var jsIPNState = map[ipn.State]string{
|
||||
@@ -284,7 +290,7 @@ func (i *jsIPN) run(jsCallbacks js.Value) {
|
||||
RouteAll: false,
|
||||
AllowSingleHosts: true,
|
||||
WantRunning: true,
|
||||
Hostname: generateHostname(),
|
||||
Hostname: i.hostname,
|
||||
},
|
||||
AuthKey: i.authKey,
|
||||
})
|
||||
@@ -354,17 +360,18 @@ func (s *jsSSHSession) Run() {
|
||||
setReadFn := s.termConfig.Get("setReadFn")
|
||||
rows := s.termConfig.Get("rows").Int()
|
||||
cols := s.termConfig.Get("cols").Int()
|
||||
timeoutSeconds := 5.0
|
||||
if jsTimeoutSeconds := s.termConfig.Get("timeoutSeconds"); jsTimeoutSeconds.Type() == js.TypeNumber {
|
||||
timeoutSeconds = jsTimeoutSeconds.Float()
|
||||
}
|
||||
onDone := s.termConfig.Get("onDone")
|
||||
defer onDone.Invoke()
|
||||
|
||||
write := func(s string) {
|
||||
writeFn.Invoke(s)
|
||||
}
|
||||
writeError := func(label string, err error) {
|
||||
writeErrorFn.Invoke(fmt.Sprintf("%s Error: %v\r\n", label, err))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds*float64(time.Second)))
|
||||
defer cancel()
|
||||
c, err := s.jsIPN.dialer.UserDial(ctx, "tcp", net.JoinHostPort(s.host, "22"))
|
||||
if err != nil {
|
||||
@@ -384,7 +391,6 @@ func (s *jsSSHSession) Run() {
|
||||
return
|
||||
}
|
||||
defer sshConn.Close()
|
||||
write("SSH Connected\r\n")
|
||||
|
||||
sshClient := ssh.NewClient(sshConn, nil, nil)
|
||||
defer sshClient.Close()
|
||||
@@ -395,7 +401,6 @@ func (s *jsSSHSession) Run() {
|
||||
return
|
||||
}
|
||||
s.session = session
|
||||
write("Session Established\r\n")
|
||||
defer session.Close()
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -107,6 +106,7 @@ type Options struct {
|
||||
KeepAlive bool
|
||||
Logf logger.Logf
|
||||
HTTPTestClient *http.Client // optional HTTP client to use (for tests only)
|
||||
NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only)
|
||||
DebugFlags []string // debug settings to send to control
|
||||
LinkMonitor *monitor.Mon // optional link monitor
|
||||
PopBrowserURL func(url string) // optional func to open browser
|
||||
@@ -227,6 +227,12 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
c.SetNetInfo(ni)
|
||||
}
|
||||
}
|
||||
if opts.NoiseTestClient != nil {
|
||||
c.noiseClient = &noiseClient{
|
||||
Client: opts.NoiseTestClient,
|
||||
}
|
||||
c.serverNoiseKey = key.NewMachine().Public() // prevent early error before hitting test client
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -490,7 +496,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
|
||||
c.logf("RegisterReq sign error: %v", err)
|
||||
}
|
||||
}
|
||||
if debugRegister {
|
||||
if debugRegister() {
|
||||
j, _ := json.MarshalIndent(request, "", "\t")
|
||||
c.logf("RegisterRequest: %s", j)
|
||||
}
|
||||
@@ -523,7 +529,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
|
||||
return regen, opt.URL, fmt.Errorf("register request: %w", err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
msg, _ := ioutil.ReadAll(res.Body)
|
||||
msg, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
return regen, opt.URL, fmt.Errorf("register request: http %d: %.200s",
|
||||
res.StatusCode, strings.TrimSpace(string(msg)))
|
||||
@@ -533,7 +539,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
|
||||
c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err)
|
||||
return regen, opt.URL, fmt.Errorf("register request: %v", err)
|
||||
}
|
||||
if debugRegister {
|
||||
if debugRegister() {
|
||||
j, _ := json.MarshalIndent(resp, "", "\t")
|
||||
c.logf("RegisterResponse: %s", j)
|
||||
}
|
||||
@@ -715,7 +721,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
|
||||
c.logf("[v1] PollNetMap: stream=%v ep=%v", allowStream, epStrs)
|
||||
|
||||
vlogf := logger.Discard
|
||||
if DevKnob.DumpNetMaps {
|
||||
if DevKnob.DumpNetMaps() {
|
||||
// TODO(bradfitz): update this to use "[v2]" prefix perhaps? but we don't
|
||||
// want to upload it always.
|
||||
vlogf = c.logf
|
||||
@@ -804,7 +810,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
|
||||
}
|
||||
vlogf("netmap: Do = %v after %v", res.StatusCode, time.Since(t0).Round(time.Millisecond))
|
||||
if res.StatusCode != 200 {
|
||||
msg, _ := ioutil.ReadAll(res.Body)
|
||||
msg, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
return fmt.Errorf("initial fetch failed %d: %.200s",
|
||||
res.StatusCode, strings.TrimSpace(string(msg)))
|
||||
@@ -814,7 +820,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
|
||||
health.NoteMapRequestHeard(request)
|
||||
|
||||
if cb == nil {
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
io.Copy(io.Discard, res.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -937,6 +943,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
|
||||
}
|
||||
if resp.Debug.DisableLogTail {
|
||||
logtail.Disable()
|
||||
envknob.SetNoLogsNoSupport()
|
||||
}
|
||||
if resp.Debug.LogHeapPprof {
|
||||
go logheap.LogHeap(resp.Debug.LogHeapURL)
|
||||
@@ -962,12 +969,12 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
|
||||
controlTrimWGConfig.Store(d.TrimWGConfig)
|
||||
}
|
||||
|
||||
if DevKnob.StripEndpoints {
|
||||
if DevKnob.StripEndpoints() {
|
||||
for _, p := range resp.Peers {
|
||||
p.Endpoints = nil
|
||||
}
|
||||
}
|
||||
if DevKnob.StripCaps {
|
||||
if DevKnob.StripCaps() {
|
||||
nm.SelfNode.Capabilities = nil
|
||||
}
|
||||
|
||||
@@ -997,7 +1004,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
|
||||
// it uses the serverKey and mkey to decode the message from the NaCl-crypto-box.
|
||||
func decode(res *http.Response, v any, serverKey, serverNoiseKey key.MachinePublic, mkey key.MachinePrivate) error {
|
||||
defer res.Body.Close()
|
||||
msg, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20))
|
||||
msg, err := io.ReadAll(io.LimitReader(res.Body, 1<<20))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1011,8 +1018,8 @@ func decode(res *http.Response, v any, serverKey, serverNoiseKey key.MachinePubl
|
||||
}
|
||||
|
||||
var (
|
||||
debugMap = envknob.Bool("TS_DEBUG_MAP")
|
||||
debugRegister = envknob.Bool("TS_DEBUG_REGISTER")
|
||||
debugMap = envknob.RegisterBool("TS_DEBUG_MAP")
|
||||
debugRegister = envknob.RegisterBool("TS_DEBUG_REGISTER")
|
||||
)
|
||||
|
||||
var jsonEscapedZero = []byte(`\u0000`)
|
||||
@@ -1050,7 +1057,7 @@ func (c *Direct) decodeMsg(msg []byte, v any, mkey key.MachinePrivate) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if debugMap {
|
||||
if debugMap() {
|
||||
var buf bytes.Buffer
|
||||
json.Indent(&buf, b, "", " ")
|
||||
log.Printf("MapResponse: %s", buf.Bytes())
|
||||
@@ -1087,7 +1094,7 @@ func encode(v any, serverKey, serverNoiseKey key.MachinePublic, mkey key.Machine
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if debugMap {
|
||||
if debugMap() {
|
||||
if _, ok := v.(*tailcfg.MapRequest); ok {
|
||||
log.Printf("MapRequest: %s", b)
|
||||
}
|
||||
@@ -1109,7 +1116,7 @@ func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string
|
||||
return nil, fmt.Errorf("fetch control key: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
b, err := ioutil.ReadAll(io.LimitReader(res.Body, 64<<10))
|
||||
b, err := io.ReadAll(io.LimitReader(res.Body, 64<<10))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch control key response: %v", err)
|
||||
}
|
||||
@@ -1138,18 +1145,18 @@ func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string
|
||||
var DevKnob = initDevKnob()
|
||||
|
||||
type devKnobs struct {
|
||||
DumpNetMaps bool
|
||||
ForceProxyDNS bool
|
||||
StripEndpoints bool // strip endpoints from control (only use disco messages)
|
||||
StripCaps bool // strip all local node's control-provided capabilities
|
||||
DumpNetMaps func() bool
|
||||
ForceProxyDNS func() bool
|
||||
StripEndpoints func() bool // strip endpoints from control (only use disco messages)
|
||||
StripCaps func() bool // strip all local node's control-provided capabilities
|
||||
}
|
||||
|
||||
func initDevKnob() devKnobs {
|
||||
return devKnobs{
|
||||
DumpNetMaps: envknob.Bool("TS_DEBUG_NETMAP"),
|
||||
ForceProxyDNS: envknob.Bool("TS_DEBUG_PROXY_DNS"),
|
||||
StripEndpoints: envknob.Bool("TS_DEBUG_STRIP_ENDPOINTS"),
|
||||
StripCaps: envknob.Bool("TS_DEBUG_STRIP_CAPS"),
|
||||
DumpNetMaps: envknob.RegisterBool("TS_DEBUG_NETMAP"),
|
||||
ForceProxyDNS: envknob.RegisterBool("TS_DEBUG_PROXY_DNS"),
|
||||
StripEndpoints: envknob.RegisterBool("TS_DEBUG_STRIP_ENDPOINTS"),
|
||||
StripCaps: envknob.RegisterBool("TS_DEBUG_STRIP_CAPS"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1397,7 +1404,7 @@ func (c *Direct) setDNSNoise(ctx context.Context, req *tailcfg.SetDNSRequest) er
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
msg, _ := ioutil.ReadAll(res.Body)
|
||||
msg, _ := io.ReadAll(res.Body)
|
||||
return fmt.Errorf("set-dns response: %v, %.200s", res.Status, strings.TrimSpace(string(msg)))
|
||||
}
|
||||
var setDNSRes tailcfg.SetDNSResponse
|
||||
@@ -1463,7 +1470,7 @@ func (c *Direct) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) (err er
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
msg, _ := ioutil.ReadAll(res.Body)
|
||||
msg, _ := io.ReadAll(res.Body)
|
||||
return fmt.Errorf("set-dns response: %v, %.200s", res.Status, strings.TrimSpace(string(msg)))
|
||||
}
|
||||
var setDNSRes tailcfg.SetDNSResponse
|
||||
|
||||
@@ -48,6 +48,7 @@ type mapSession struct {
|
||||
lastHealth []string
|
||||
lastPopBrowserURL string
|
||||
stickyDebug tailcfg.Debug // accumulated opt.Bool values
|
||||
lastTKAInfo *tailcfg.TKAInfo
|
||||
|
||||
// netMapBuilding is non-nil during a netmapForResponse call,
|
||||
// containing the value to be returned, once fully populated.
|
||||
@@ -115,6 +116,9 @@ func (ms *mapSession) netmapForResponse(resp *tailcfg.MapResponse) *netmap.Netwo
|
||||
if resp.Health != nil {
|
||||
ms.lastHealth = resp.Health
|
||||
}
|
||||
if resp.TKAInfo != nil {
|
||||
ms.lastTKAInfo = resp.TKAInfo
|
||||
}
|
||||
|
||||
debug := resp.Debug
|
||||
if debug != nil {
|
||||
@@ -152,9 +156,17 @@ func (ms *mapSession) netmapForResponse(resp *tailcfg.MapResponse) *netmap.Netwo
|
||||
DERPMap: ms.lastDERPMap,
|
||||
Debug: debug,
|
||||
ControlHealth: ms.lastHealth,
|
||||
TKAEnabled: ms.lastTKAInfo != nil && !ms.lastTKAInfo.Disabled,
|
||||
}
|
||||
ms.netMapBuilding = nm
|
||||
|
||||
if ms.lastTKAInfo != nil && ms.lastTKAInfo.Head != "" {
|
||||
if err := nm.TKAHead.UnmarshalText([]byte(ms.lastTKAInfo.Head)); err != nil {
|
||||
ms.logf("error unmarshalling TKAHead: %v", err)
|
||||
nm.TKAEnabled = false
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Node != nil {
|
||||
ms.lastNode = resp.Node
|
||||
}
|
||||
@@ -190,7 +202,7 @@ func (ms *mapSession) netmapForResponse(resp *tailcfg.MapResponse) *netmap.Netwo
|
||||
}
|
||||
ms.addUserProfile(peer.User)
|
||||
}
|
||||
if DevKnob.ForceProxyDNS {
|
||||
if DevKnob.ForceProxyDNS() {
|
||||
nm.DNS.Proxied = true
|
||||
}
|
||||
ms.netMapBuilding = nil
|
||||
@@ -356,13 +368,13 @@ func cloneNodes(v1 []*tailcfg.Node) []*tailcfg.Node {
|
||||
return v2
|
||||
}
|
||||
|
||||
var debugSelfIPv6Only = envknob.Bool("TS_DEBUG_SELF_V6_ONLY")
|
||||
var debugSelfIPv6Only = envknob.RegisterBool("TS_DEBUG_SELF_V6_ONLY")
|
||||
|
||||
func filterSelfAddresses(in []netip.Prefix) (ret []netip.Prefix) {
|
||||
switch {
|
||||
default:
|
||||
return in
|
||||
case debugSelfIPv6Only:
|
||||
case debugSelfIPv6Only():
|
||||
for _, a := range in {
|
||||
if a.Addr().Is6() {
|
||||
ret = append(ret, a)
|
||||
|
||||
@@ -165,7 +165,15 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
|
||||
// thousand version numbers before getting to this point.
|
||||
panic("capability version is too high to fit in the wire protocol")
|
||||
}
|
||||
conn, err := controlhttp.Dial(ctx, nc.host, nc.httpPort, nc.httpsPort, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion), nc.dialer.SystemDial)
|
||||
conn, err := (&controlhttp.Dialer{
|
||||
Hostname: nc.host,
|
||||
HTTPPort: nc.httpPort,
|
||||
HTTPSPort: nc.httpsPort,
|
||||
MachineKey: nc.privKey,
|
||||
ControlKey: nc.serverPubKey,
|
||||
ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion),
|
||||
Dialer: nc.dialer.SystemDial,
|
||||
}).Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -40,57 +40,49 @@ import (
|
||||
"tailscale.com/net/netutil"
|
||||
"tailscale.com/net/tlsdial"
|
||||
"tailscale.com/net/tshttpproxy"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// Dial connects to the HTTP server at host:httpPort, requests to switch to the
|
||||
// Tailscale control protocol, and returns an established control
|
||||
var stdDialer net.Dialer
|
||||
|
||||
// Dial connects to the HTTP server at this Dialer's Host:HTTPPort, requests to
|
||||
// switch to the Tailscale control protocol, and returns an established control
|
||||
// protocol connection.
|
||||
//
|
||||
// If Dial fails to connect using addr, it also tries to tunnel over
|
||||
// TLS to host:httpsPort as a compatibility fallback.
|
||||
// If Dial fails to connect using HTTP, it also tries to tunnel over TLS to the
|
||||
// Dialer's Host:HTTPSPort as a compatibility fallback.
|
||||
//
|
||||
// The provided ctx is only used for the initial connection, until
|
||||
// Dial returns. It does not affect the connection once established.
|
||||
func Dial(ctx context.Context, host string, httpPort string, httpsPort string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) {
|
||||
a := &dialParams{
|
||||
host: host,
|
||||
httpPort: httpPort,
|
||||
httpsPort: httpsPort,
|
||||
machineKey: machineKey,
|
||||
controlKey: controlKey,
|
||||
version: protocolVersion,
|
||||
proxyFunc: tshttpproxy.ProxyFromEnvironment,
|
||||
dialer: dialer,
|
||||
func (a *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
if a.Hostname == "" {
|
||||
return nil, errors.New("required Dialer.Hostname empty")
|
||||
}
|
||||
return a.dial(ctx)
|
||||
}
|
||||
|
||||
type dialParams struct {
|
||||
host string
|
||||
httpPort string
|
||||
httpsPort string
|
||||
machineKey key.MachinePrivate
|
||||
controlKey key.MachinePublic
|
||||
version uint16
|
||||
proxyFunc func(*http.Request) (*url.URL, error) // or nil
|
||||
dialer dnscache.DialContextFunc
|
||||
|
||||
// For tests only
|
||||
insecureTLS bool
|
||||
testFallbackDelay time.Duration
|
||||
func (a *Dialer) logf(format string, args ...any) {
|
||||
if a.Logf != nil {
|
||||
a.Logf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// httpsFallbackDelay is how long we'll wait for a.httpPort to work before
|
||||
// starting to try a.httpsPort.
|
||||
func (a *dialParams) httpsFallbackDelay() time.Duration {
|
||||
func (a *Dialer) getProxyFunc() func(*http.Request) (*url.URL, error) {
|
||||
if a.proxyFunc != nil {
|
||||
return a.proxyFunc
|
||||
}
|
||||
return tshttpproxy.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
// httpsFallbackDelay is how long we'll wait for a.HTTPPort to work before
|
||||
// starting to try a.HTTPSPort.
|
||||
func (a *Dialer) httpsFallbackDelay() time.Duration {
|
||||
if v := a.testFallbackDelay; v != 0 {
|
||||
return v
|
||||
}
|
||||
return 500 * time.Millisecond
|
||||
}
|
||||
|
||||
func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
// Create one shared context used by both port 80 and port 443 dials.
|
||||
// If port 80 is still in flight when 443 returns, this deferred cancel
|
||||
// will stop the port 80 dial.
|
||||
@@ -102,12 +94,12 @@ func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
// we'll speak Noise.
|
||||
u80 := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort(a.host, a.httpPort),
|
||||
Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPPort, "80")),
|
||||
Path: serverUpgradePath,
|
||||
}
|
||||
u443 := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(a.host, a.httpsPort),
|
||||
Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")),
|
||||
Path: serverUpgradePath,
|
||||
}
|
||||
|
||||
@@ -169,8 +161,8 @@ func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
}
|
||||
|
||||
// dialURL attempts to connect to the given URL.
|
||||
func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
|
||||
init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
|
||||
func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
|
||||
init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,26 +181,34 @@ func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn
|
||||
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
|
||||
//
|
||||
// Only the provided ctx is used, not a.ctx.
|
||||
func (a *dialParams) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
|
||||
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
|
||||
dns := &dnscache.Resolver{
|
||||
Forward: dnscache.Get().Forward,
|
||||
LookupIPFallback: dnsfallback.Lookup,
|
||||
UseLastGood: true,
|
||||
}
|
||||
|
||||
var dialer dnscache.DialContextFunc
|
||||
if a.Dialer != nil {
|
||||
dialer = a.Dialer
|
||||
} else {
|
||||
dialer = stdDialer.DialContext
|
||||
}
|
||||
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defer tr.CloseIdleConnections()
|
||||
tr.Proxy = a.proxyFunc
|
||||
tr.Proxy = a.getProxyFunc()
|
||||
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
|
||||
tr.DialContext = dnscache.Dialer(a.dialer, dns)
|
||||
tr.DialContext = dnscache.Dialer(dialer, dns)
|
||||
// Disable HTTP2, since h2 can't do protocol switching.
|
||||
tr.TLSClientConfig.NextProtos = []string{}
|
||||
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
|
||||
tr.TLSClientConfig = tlsdial.Config(a.host, tr.TLSClientConfig)
|
||||
tr.TLSClientConfig = tlsdial.Config(a.Hostname, tr.TLSClientConfig)
|
||||
if a.insecureTLS {
|
||||
tr.TLSClientConfig.InsecureSkipVerify = true
|
||||
tr.TLSClientConfig.VerifyConnection = nil
|
||||
}
|
||||
tr.DialTLSContext = dnscache.TLSDialer(a.dialer, dns, tr.TLSClientConfig)
|
||||
tr.DialTLSContext = dnscache.TLSDialer(dialer, dns, tr.TLSClientConfig)
|
||||
tr.DisableCompression = true
|
||||
|
||||
// (mis)use httptrace to extract the underlying net.Conn from the
|
||||
|
||||
@@ -7,27 +7,31 @@ package controlhttp
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// Variant of Dial that tunnels the request over WebSockets, since we cannot do
|
||||
// bi-directional communication over an HTTP connection when in JS.
|
||||
func Dial(ctx context.Context, host string, httpPort string, httpsPort string, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16, dialer dnscache.DialContextFunc) (*controlbase.Conn, error) {
|
||||
init, cont, err := controlbase.ClientDeferred(machineKey, controlKey, protocolVersion)
|
||||
func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
if d.Hostname == "" {
|
||||
return nil, errors.New("required Dialer.Hostname empty")
|
||||
}
|
||||
|
||||
init, cont, err := controlbase.ClientDeferred(d.MachineKey, d.ControlKey, d.ProtocolVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wsScheme := "wss"
|
||||
host := d.Hostname
|
||||
if host == "localhost" {
|
||||
wsScheme = "ws"
|
||||
host = net.JoinHostPort(host, httpPort)
|
||||
host = net.JoinHostPort(host, strDef(d.HTTPPort, "80"))
|
||||
}
|
||||
wsURL := &url.URL{
|
||||
Scheme: wsScheme,
|
||||
@@ -52,5 +56,4 @@ func Dial(ctx context.Context, host string, httpPort string, httpsPort string, m
|
||||
return nil, err
|
||||
}
|
||||
return cbConn, nil
|
||||
|
||||
}
|
||||
|
||||
@@ -4,6 +4,16 @@
|
||||
|
||||
package controlhttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// upgradeHeader is the value of the Upgrade HTTP header used to
|
||||
// indicate the Tailscale control protocol.
|
||||
@@ -18,3 +28,58 @@ const (
|
||||
// to do the protocol switch is located.
|
||||
serverUpgradePath = "/ts2021"
|
||||
)
|
||||
|
||||
// Dialer contains configuration on how to dial the Tailscale control server.
|
||||
type Dialer struct {
|
||||
// Hostname is the hostname to connect to, with no port number.
|
||||
//
|
||||
// This field is required.
|
||||
Hostname string
|
||||
|
||||
// MachineKey contains the current machine's private key.
|
||||
//
|
||||
// This field is required.
|
||||
MachineKey key.MachinePrivate
|
||||
|
||||
// ControlKey contains the expected public key for the control server.
|
||||
//
|
||||
// This field is required.
|
||||
ControlKey key.MachinePublic
|
||||
|
||||
// ProtocolVersion is the expected protocol version to negotiate.
|
||||
//
|
||||
// This field is required.
|
||||
ProtocolVersion uint16
|
||||
|
||||
// HTTPPort is the port number to use when making a HTTP connection.
|
||||
//
|
||||
// If not specified, this defaults to port 80.
|
||||
HTTPPort string
|
||||
|
||||
// HTTPSPort is the port number to use when making a HTTPS connection.
|
||||
//
|
||||
// If not specified, this defaults to port 443.
|
||||
HTTPSPort string
|
||||
|
||||
// Dialer is the dialer used to make outbound connections.
|
||||
//
|
||||
// If not specified, this defaults to net.Dialer.DialContext.
|
||||
Dialer dnscache.DialContextFunc
|
||||
|
||||
// Logf, if set, is a logging function to use; if unset, logs are
|
||||
// dropped.
|
||||
Logf logger.Logf
|
||||
|
||||
proxyFunc func(*http.Request) (*url.URL, error) // or nil
|
||||
|
||||
// For tests only
|
||||
insecureTLS bool
|
||||
testFallbackDelay time.Duration
|
||||
}
|
||||
|
||||
func strDef(v1, v2 string) string {
|
||||
if v1 != "" {
|
||||
return v1
|
||||
}
|
||||
return v2
|
||||
}
|
||||
|
||||
@@ -170,15 +170,16 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
a := dialParams{
|
||||
host: "localhost",
|
||||
httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
|
||||
httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
|
||||
machineKey: client,
|
||||
controlKey: server.Public(),
|
||||
version: testProtocolVersion,
|
||||
a := &Dialer{
|
||||
Hostname: "localhost",
|
||||
HTTPPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
|
||||
HTTPSPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
|
||||
MachineKey: client,
|
||||
ControlKey: server.Public(),
|
||||
ProtocolVersion: testProtocolVersion,
|
||||
Dialer: new(tsdial.Dialer).SystemDial,
|
||||
Logf: t.Logf,
|
||||
insecureTLS: true,
|
||||
dialer: new(tsdial.Dialer).SystemDial,
|
||||
testFallbackDelay: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
|
||||
@@ -82,6 +82,12 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request
|
||||
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
Subprotocols: []string{upgradeHeaderValue},
|
||||
OriginPatterns: []string{"*"},
|
||||
// Disable compression because we transmit Noise messages that are not
|
||||
// compressible.
|
||||
// Additionally, Safari has a broken implementation of compression
|
||||
// (see https://github.com/nhooyr/websocket/issues/218) that makes
|
||||
// enabling it actively harmful.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not accept WebSocket connection %v", err)
|
||||
|
||||
@@ -13,20 +13,18 @@ import (
|
||||
)
|
||||
|
||||
// disableUPnP indicates whether to attempt UPnP mapping.
|
||||
var disableUPnP atomic.Bool
|
||||
var disableUPnPControl atomic.Bool
|
||||
|
||||
func init() {
|
||||
SetDisableUPnP(envknob.Bool("TS_DISABLE_UPNP"))
|
||||
}
|
||||
var disableUPnpEnv = envknob.RegisterBool("TS_DISABLE_UPNP")
|
||||
|
||||
// DisableUPnP reports the last reported value from control
|
||||
// whether UPnP portmapping should be disabled.
|
||||
func DisableUPnP() bool {
|
||||
return disableUPnP.Load()
|
||||
return disableUPnPControl.Load() || disableUPnpEnv()
|
||||
}
|
||||
|
||||
// SetDisableUPnP sets whether control says that UPnP should be
|
||||
// disabled.
|
||||
func SetDisableUPnP(v bool) {
|
||||
disableUPnP.Store(v)
|
||||
disableUPnPControl.Store(v)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package derp implements DERP, the Detour Encrypted Routing Protocol.
|
||||
// Package derp implements the Designated Encrypted Relay for Packets (DERP)
|
||||
// protocol.
|
||||
//
|
||||
// DERP routes packets to clients using curve25519 keys as addresses.
|
||||
//
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -194,7 +194,7 @@ func readFrame(br *bufio.Reader, maxSize uint32, b []byte) (t frameType, frameLe
|
||||
}
|
||||
remain := frameLen - uint32(n)
|
||||
if remain > 0 {
|
||||
if _, err := io.CopyN(ioutil.Discard, br, int64(remain)); err != nil {
|
||||
if _, err := io.CopyN(io.Discard, br, int64(remain)); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
err = io.ErrShortBuffer
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math"
|
||||
"math/big"
|
||||
@@ -47,8 +46,6 @@ import (
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
var debug = envknob.Bool("DERP_DEBUG_LOGS")
|
||||
|
||||
// verboseDropKeys is the set of destination public keys that should
|
||||
// verbosely log whenever DERP drops a packet.
|
||||
var verboseDropKeys = map[key.NodePublic]bool{}
|
||||
@@ -106,6 +103,7 @@ type Server struct {
|
||||
limitedLogf logger.Logf
|
||||
metaCert []byte // the encoded x509 cert to send after LetsEncrypt cert+intermediate
|
||||
dupPolicy dupPolicy
|
||||
debug bool
|
||||
|
||||
// Counters:
|
||||
packetsSent, bytesSent expvar.Int
|
||||
@@ -299,6 +297,7 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server {
|
||||
runtime.ReadMemStats(&ms)
|
||||
|
||||
s := &Server{
|
||||
debug: envknob.Bool("DERP_DEBUG_LOGS"),
|
||||
privateKey: privateKey,
|
||||
publicKey: privateKey.Public(),
|
||||
logf: logf,
|
||||
@@ -758,7 +757,7 @@ func (c *sclient) run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c *sclient) handleUnknownFrame(ft frameType, fl uint32) error {
|
||||
_, err := io.CopyN(ioutil.Discard, c.br, int64(fl))
|
||||
_, err := io.CopyN(io.Discard, c.br, int64(fl))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -801,7 +800,7 @@ func (c *sclient) handleFramePing(ft frameType, fl uint32) error {
|
||||
return err
|
||||
}
|
||||
if extra := int64(fl) - int64(len(m)); extra > 0 {
|
||||
_, err = io.CopyN(ioutil.Discard, c.br, extra)
|
||||
_, err = io.CopyN(io.Discard, c.br, extra)
|
||||
}
|
||||
select {
|
||||
case c.sendPongCh <- [8]byte(m):
|
||||
@@ -980,7 +979,7 @@ func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, r
|
||||
msg := fmt.Sprintf("drop (%s) %s -> %s", srcKey.ShortString(), reason, dstKey.ShortString())
|
||||
s.limitedLogf(msg)
|
||||
}
|
||||
if debug {
|
||||
if s.debug {
|
||||
s.logf("dropping packet reason=%s dst=%s disco=%v", reason, dstKey, disco.LooksLikeDiscoWrapper(packetBytes))
|
||||
}
|
||||
}
|
||||
@@ -1828,7 +1827,7 @@ func (s *Server) ServeDebugTraffic(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var bufioWriterPool = &sync.Pool{
|
||||
New: func() any {
|
||||
return bufio.NewWriterSize(ioutil.Discard, 2<<10)
|
||||
return bufio.NewWriterSize(io.Discard, 2<<10)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1861,7 +1860,7 @@ func (w *lazyBufioWriter) Flush() error {
|
||||
}
|
||||
err := w.lbw.Flush()
|
||||
|
||||
w.lbw.Reset(ioutil.Discard)
|
||||
w.lbw.Reset(io.Discard)
|
||||
bufioWriterPool.Put(w.lbw)
|
||||
w.lbw = nil
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -1240,7 +1240,7 @@ func benchmarkSendRecvSize(b *testing.B, packetSize int) {
|
||||
}
|
||||
|
||||
func BenchmarkWriteUint32(b *testing.B) {
|
||||
w := bufio.NewWriter(ioutil.Discard)
|
||||
w := bufio.NewWriter(io.Discard)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -1279,9 +1279,9 @@ func waitConnect(t testing.TB, c *Client) {
|
||||
}
|
||||
|
||||
func TestParseSSOutput(t *testing.T) {
|
||||
contents, err := ioutil.ReadFile("testdata/example_ss.txt")
|
||||
contents, err := os.ReadFile("testdata/example_ss.txt")
|
||||
if err != nil {
|
||||
t.Errorf("ioutil.Readfile(example_ss.txt) failed: %v", err)
|
||||
t.Errorf("os.ReadFile(example_ss.txt) failed: %v", err)
|
||||
}
|
||||
seen := parseSSOutput(string(contents))
|
||||
if len(seen) == 0 {
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -432,7 +431,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
|
||||
return nil, 0, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, 0, fmt.Errorf("GET failed: %v: %s", err, b)
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ spec:
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: tailscale-auth
|
||||
key: AUTH_KEY
|
||||
key: TS_AUTH_KEY
|
||||
optional: true
|
||||
- name: TS_DEST_IP
|
||||
value: "{{TS_DEST_IP}}"
|
||||
|
||||
@@ -17,10 +17,11 @@ TS_KUBE_SECRET="${TS_KUBE_SECRET:-tailscale}"
|
||||
TS_SOCKS5_SERVER="${TS_SOCKS5_SERVER:-}"
|
||||
TS_OUTBOUND_HTTP_PROXY_LISTEN="${TS_OUTBOUND_HTTP_PROXY_LISTEN:-}"
|
||||
TS_TAILSCALED_EXTRA_ARGS="${TS_TAILSCALED_EXTRA_ARGS:-}"
|
||||
TS_SOCKET="${TS_SOCKET:-/tmp/tailscaled.sock}"
|
||||
|
||||
set -e
|
||||
|
||||
TAILSCALED_ARGS="--socket=/tmp/tailscaled.sock"
|
||||
TAILSCALED_ARGS="--socket=${TS_SOCKET}"
|
||||
|
||||
if [[ ! -z "${KUBERNETES_SERVICE_HOST}" ]]; then
|
||||
TAILSCALED_ARGS="${TAILSCALED_ARGS} --state=kube:${TS_KUBE_SECRET} --statedir=${TS_STATE_DIR:-/tmp}"
|
||||
@@ -81,11 +82,11 @@ if [[ ! -z "${TS_EXTRA_ARGS}" ]]; then
|
||||
fi
|
||||
|
||||
echo "Running tailscale up"
|
||||
tailscale --socket=/tmp/tailscaled.sock up ${UP_ARGS}
|
||||
tailscale --socket="${TS_SOCKET}" up ${UP_ARGS}
|
||||
|
||||
if [[ ! -z "${TS_DEST_IP}" ]]; then
|
||||
echo "Adding iptables rule for DNAT"
|
||||
iptables -t nat -I PREROUTING -d "$(tailscale --socket=/tmp/tailscaled.sock ip -4)" -j DNAT --to-destination "${TS_DEST_IP}"
|
||||
iptables -t nat -I PREROUTING -d "$(tailscale --socket=${TS_SOCKET} ip -4)" -j DNAT --to-destination "${TS_DEST_IP}"
|
||||
fi
|
||||
|
||||
echo "Waiting for tailscaled to exit"
|
||||
|
||||
@@ -23,7 +23,7 @@ spec:
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: tailscale-auth
|
||||
key: AUTH_KEY
|
||||
key: TS_AUTH_KEY
|
||||
optional: true
|
||||
securityContext:
|
||||
capabilities:
|
||||
|
||||
@@ -23,7 +23,7 @@ spec:
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: tailscale-auth
|
||||
key: AUTH_KEY
|
||||
key: TS_AUTH_KEY
|
||||
optional: true
|
||||
- name: TS_ROUTES
|
||||
value: "{{TS_ROUTES}}"
|
||||
|
||||
@@ -26,5 +26,5 @@ spec:
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: tailscale-auth
|
||||
key: AUTH_KEY
|
||||
key: TS_AUTH_KEY
|
||||
optional: true
|
||||
|
||||
@@ -17,30 +17,43 @@
|
||||
package envknob
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"tailscale.com/types/opt"
|
||||
"tailscale.com/version/distro"
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
set = map[string]string{}
|
||||
list []string
|
||||
mu sync.Mutex
|
||||
set = map[string]string{}
|
||||
regStr = map[string]*string{}
|
||||
regBool = map[string]*bool{}
|
||||
regOptBool = map[string]*opt.Bool{}
|
||||
)
|
||||
|
||||
func noteEnv(k, v string) {
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if _, ok := set[k]; !ok {
|
||||
list = append(list, k)
|
||||
noteEnvLocked(k, v)
|
||||
}
|
||||
|
||||
func noteEnvLocked(k, v string) {
|
||||
if v != "" {
|
||||
set[k] = v
|
||||
} else {
|
||||
delete(set, k)
|
||||
}
|
||||
set[k] = v
|
||||
}
|
||||
|
||||
// logf is logger.Logf, but logger depends on envknob, so for circular
|
||||
@@ -52,11 +65,39 @@ type logf = func(format string, args ...any)
|
||||
func LogCurrent(logf logf) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
list := make([]string, 0, len(set))
|
||||
for k := range set {
|
||||
list = append(list, k)
|
||||
}
|
||||
sort.Strings(list)
|
||||
for _, k := range list {
|
||||
logf("envknob: %s=%q", k, set[k])
|
||||
}
|
||||
}
|
||||
|
||||
// Setenv changes an environment variable.
|
||||
//
|
||||
// It is not safe for concurrent reading of environment variables via the
|
||||
// Register functions. All Setenv calls are meant to happen early in main before
|
||||
// any goroutines are started.
|
||||
func Setenv(envVar, val string) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
os.Setenv(envVar, val)
|
||||
noteEnvLocked(envVar, val)
|
||||
|
||||
if p := regStr[envVar]; p != nil {
|
||||
*p = val
|
||||
}
|
||||
if p := regBool[envVar]; p != nil {
|
||||
setBoolLocked(p, envVar, val)
|
||||
}
|
||||
if p := regOptBool[envVar]; p != nil {
|
||||
setOptBoolLocked(p, envVar, val)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the named environment variable, using os.Getenv.
|
||||
//
|
||||
// If the variable is non-empty, it's also tracked & logged as being
|
||||
@@ -67,6 +108,82 @@ func String(envVar string) string {
|
||||
return v
|
||||
}
|
||||
|
||||
// RegisterString returns a func that gets the named environment variable,
|
||||
// without a map lookup per call. It assumes that mutations happen via
|
||||
// envknob.Setenv.
|
||||
func RegisterString(envVar string) func() string {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
p, ok := regStr[envVar]
|
||||
if !ok {
|
||||
val := os.Getenv(envVar)
|
||||
if val != "" {
|
||||
noteEnvLocked(envVar, val)
|
||||
}
|
||||
p = &val
|
||||
regStr[envVar] = p
|
||||
}
|
||||
return func() string { return *p }
|
||||
}
|
||||
|
||||
// RegisterBool returns a func that gets the named environment variable,
|
||||
// without a map lookup per call. It assumes that mutations happen via
|
||||
// envknob.Setenv.
|
||||
func RegisterBool(envVar string) func() bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
p, ok := regBool[envVar]
|
||||
if !ok {
|
||||
var b bool
|
||||
p = &b
|
||||
setBoolLocked(p, envVar, os.Getenv(envVar))
|
||||
regBool[envVar] = p
|
||||
}
|
||||
return func() bool { return *p }
|
||||
}
|
||||
|
||||
// RegisterOptBool returns a func that gets the named environment variable,
|
||||
// without a map lookup per call. It assumes that mutations happen via
|
||||
// envknob.Setenv.
|
||||
func RegisterOptBool(envVar string) func() opt.Bool {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
p, ok := regOptBool[envVar]
|
||||
if !ok {
|
||||
var b opt.Bool
|
||||
p = &b
|
||||
setOptBoolLocked(p, envVar, os.Getenv(envVar))
|
||||
regOptBool[envVar] = p
|
||||
}
|
||||
return func() opt.Bool { return *p }
|
||||
}
|
||||
|
||||
func setBoolLocked(p *bool, envVar, val string) {
|
||||
noteEnvLocked(envVar, val)
|
||||
if val == "" {
|
||||
*p = false
|
||||
return
|
||||
}
|
||||
var err error
|
||||
*p, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid boolean environment variable %s value %q", envVar, val)
|
||||
}
|
||||
}
|
||||
|
||||
func setOptBoolLocked(p *opt.Bool, envVar, val string) {
|
||||
noteEnvLocked(envVar, val)
|
||||
if val == "" {
|
||||
*p = ""
|
||||
return
|
||||
}
|
||||
b, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid boolean environment variable %s value %q", envVar, val)
|
||||
}
|
||||
p.Set(b)
|
||||
}
|
||||
|
||||
// Bool returns the boolean value of the named environment variable.
|
||||
// If the variable is not set, it returns false.
|
||||
// An invalid value exits the binary with a failure.
|
||||
@@ -81,6 +198,7 @@ func BoolDefaultTrue(envVar string) bool {
|
||||
}
|
||||
|
||||
func boolOr(envVar string, implicitValue bool) bool {
|
||||
assertNotInInit()
|
||||
val := os.Getenv(envVar)
|
||||
if val == "" {
|
||||
return implicitValue
|
||||
@@ -98,6 +216,7 @@ func boolOr(envVar string, implicitValue bool) bool {
|
||||
// The ok result is whether a value was set.
|
||||
// If the value isn't a valid int, it exits the program with a failure.
|
||||
func LookupBool(envVar string) (v bool, ok bool) {
|
||||
assertNotInInit()
|
||||
val := os.Getenv(envVar)
|
||||
if val == "" {
|
||||
return false, false
|
||||
@@ -113,6 +232,7 @@ func LookupBool(envVar string) (v bool, ok bool) {
|
||||
// OptBool is like Bool, but returns an opt.Bool, so the caller can
|
||||
// distinguish between implicitly and explicitly false.
|
||||
func OptBool(envVar string) opt.Bool {
|
||||
assertNotInInit()
|
||||
b, ok := LookupBool(envVar)
|
||||
if !ok {
|
||||
return ""
|
||||
@@ -126,6 +246,7 @@ func OptBool(envVar string) opt.Bool {
|
||||
// The ok result is whether a value was set.
|
||||
// If the value isn't a valid int, it exits the program with a failure.
|
||||
func LookupInt(envVar string) (v int, ok bool) {
|
||||
assertNotInInit()
|
||||
val := os.Getenv(envVar)
|
||||
if val == "" {
|
||||
return 0, false
|
||||
@@ -155,3 +276,151 @@ func SSHPolicyFile() string { return String("TS_DEBUG_SSH_POLICY_FILE") }
|
||||
|
||||
// SSHIgnoreTailnetPolicy is whether to ignore the Tailnet SSH policy for development.
|
||||
func SSHIgnoreTailnetPolicy() bool { return Bool("TS_DEBUG_SSH_IGNORE_TAILNET_POLICY") }
|
||||
|
||||
// NoLogsNoSupport reports whether the client's opted out of log uploads and
|
||||
// technical support.
|
||||
func NoLogsNoSupport() bool {
|
||||
return Bool("TS_NO_LOGS_NO_SUPPORT")
|
||||
}
|
||||
|
||||
// SetNoLogsNoSupport enables no-logs-no-support mode.
|
||||
func SetNoLogsNoSupport() {
|
||||
Setenv("TS_NO_LOGS_NO_SUPPORT", "true")
|
||||
}
|
||||
|
||||
// notInInit is set true the first time we've seen a non-init stack trace.
|
||||
var notInInit atomic.Bool
|
||||
|
||||
func assertNotInInit() {
|
||||
if notInInit.Load() {
|
||||
return
|
||||
}
|
||||
skip := 0
|
||||
for {
|
||||
pc, _, _, ok := runtime.Caller(skip)
|
||||
if !ok {
|
||||
notInInit.Store(true)
|
||||
return
|
||||
}
|
||||
fu := runtime.FuncForPC(pc)
|
||||
if fu == nil {
|
||||
return
|
||||
}
|
||||
name := fu.Name()
|
||||
name = strings.TrimRightFunc(name, func(r rune) bool { return r >= '0' && r <= '9' })
|
||||
if strings.HasSuffix(name, ".init") || strings.HasSuffix(name, ".init.") {
|
||||
stack := make([]byte, 1<<10)
|
||||
stack = stack[:runtime.Stack(stack, false)]
|
||||
envCheckedInInitStack = stack
|
||||
}
|
||||
skip++
|
||||
}
|
||||
}
|
||||
|
||||
var envCheckedInInitStack []byte
|
||||
|
||||
// PanicIfAnyEnvCheckedInInit panics if environment variables were read during
|
||||
// init.
|
||||
func PanicIfAnyEnvCheckedInInit() {
|
||||
if envCheckedInInitStack != nil {
|
||||
panic("envknob check of called from init function: " + string(envCheckedInInitStack))
|
||||
}
|
||||
}
|
||||
|
||||
var applyDiskConfigErr error
|
||||
|
||||
// ApplyDiskConfigError returns the most recent result of ApplyDiskConfig.
|
||||
func ApplyDiskConfigError() error { return applyDiskConfigErr }
|
||||
|
||||
// ApplyDiskConfig returns a platform-specific config file of environment keys/values and
|
||||
// applies them. On Linux and Unix operating systems, it's a no-op and always returns nil.
|
||||
// If no platform-specific config file is found, it also returns nil.
|
||||
//
|
||||
// It exists primarily for Windows to make it easy to apply environment variables to
|
||||
// a running service in a way similar to modifying /etc/default/tailscaled on Linux.
|
||||
// On Windows, you use %ProgramData%\Tailscale\tailscaled-env.txt instead.
|
||||
func ApplyDiskConfig() (err error) {
|
||||
var f *os.File
|
||||
defer func() {
|
||||
if err != nil {
|
||||
// Stash away our return error for the healthcheck package to use.
|
||||
applyDiskConfigErr = fmt.Errorf("error parsing %s: %w", f.Name(), err)
|
||||
}
|
||||
}()
|
||||
|
||||
// First try the explicitly-provided value for development testing. Not
|
||||
// useful for users to use on their own. (if they can set this, they can set
|
||||
// any environment variable anyway)
|
||||
if name := os.Getenv("TS_DEBUG_ENV_FILE"); name != "" {
|
||||
f, err = os.Open(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening explicitly configured TS_DEBUG_ENV_FILE: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
return applyKeyValueEnv(f)
|
||||
}
|
||||
|
||||
name := getPlatformEnvFile()
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
f, err = os.Open(name)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
return applyKeyValueEnv(f)
|
||||
}
|
||||
|
||||
// getPlatformEnvFile returns the current platform's path to an optional
|
||||
// tailscaled-env.txt file. It returns an empty string if none is defined
|
||||
// for the platform.
|
||||
func getPlatformEnvFile() string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "tailscaled-env.txt")
|
||||
case "linux":
|
||||
if distro.Get() == distro.Synology {
|
||||
return "/etc/tailscale/tailscaled-env.txt"
|
||||
}
|
||||
case "darwin":
|
||||
// TODO(bradfitz): figure this out. There are three ways to run
|
||||
// Tailscale on macOS (tailscaled, GUI App Store, GUI System Extension)
|
||||
// and we should deal with all three.
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// applyKeyValueEnv reads key=value lines r and calls Setenv for each.
|
||||
//
|
||||
// Empty lines and lines beginning with '#' are skipped.
|
||||
//
|
||||
// Values can be double quoted, in which case they're unquoted using
|
||||
// strconv.Unquote.
|
||||
func applyKeyValueEnv(r io.Reader) error {
|
||||
bs := bufio.NewScanner(r)
|
||||
for bs.Scan() {
|
||||
line := strings.TrimSpace(bs.Text())
|
||||
if line == "" || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
k, v, ok := strings.Cut(line, "=")
|
||||
k = strings.TrimSpace(k)
|
||||
if !ok || k == "" {
|
||||
continue
|
||||
}
|
||||
v = strings.TrimSpace(v)
|
||||
if strings.HasPrefix(v, `"`) {
|
||||
var err error
|
||||
v, err = strconv.Unquote(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value in line %q: %v", line, err)
|
||||
}
|
||||
}
|
||||
Setenv(k, v)
|
||||
}
|
||||
return bs.Err()
|
||||
}
|
||||
|
||||
@@ -325,7 +325,7 @@ func OverallError() error {
|
||||
return overallErrorLocked()
|
||||
}
|
||||
|
||||
var fakeErrForTesting = envknob.String("TS_DEBUG_FAKE_HEALTH_ERROR")
|
||||
var fakeErrForTesting = envknob.RegisterString("TS_DEBUG_FAKE_HEALTH_ERROR")
|
||||
|
||||
func overallErrorLocked() error {
|
||||
if !anyInterfaceUp {
|
||||
@@ -383,7 +383,10 @@ func overallErrorLocked() error {
|
||||
for _, s := range controlHealth {
|
||||
errs = append(errs, errors.New(s))
|
||||
}
|
||||
if e := fakeErrForTesting; len(errs) == 0 && e != "" {
|
||||
if err := envknob.ApplyDiskConfigError(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if e := fakeErrForTesting(); len(errs) == 0 && e != "" {
|
||||
return errors.New(e)
|
||||
}
|
||||
sort.Slice(errs, func(i, j int) bool {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/opt"
|
||||
"tailscale.com/util/cloudenv"
|
||||
@@ -32,21 +33,22 @@ func New() *tailcfg.Hostinfo {
|
||||
hostname, _ := os.Hostname()
|
||||
hostname = dnsname.FirstLabel(hostname)
|
||||
return &tailcfg.Hostinfo{
|
||||
IPNVersion: version.Long,
|
||||
Hostname: hostname,
|
||||
OS: version.OS(),
|
||||
OSVersion: GetOSVersion(),
|
||||
Container: lazyInContainer.Get(),
|
||||
Distro: condCall(distroName),
|
||||
DistroVersion: condCall(distroVersion),
|
||||
DistroCodeName: condCall(distroCodeName),
|
||||
Env: string(GetEnvType()),
|
||||
Desktop: desktop(),
|
||||
Package: packageTypeCached(),
|
||||
GoArch: runtime.GOARCH,
|
||||
GoVersion: runtime.Version(),
|
||||
DeviceModel: deviceModel(),
|
||||
Cloud: string(cloudenv.Get()),
|
||||
IPNVersion: version.Long,
|
||||
Hostname: hostname,
|
||||
OS: version.OS(),
|
||||
OSVersion: GetOSVersion(),
|
||||
Container: lazyInContainer.Get(),
|
||||
Distro: condCall(distroName),
|
||||
DistroVersion: condCall(distroVersion),
|
||||
DistroCodeName: condCall(distroCodeName),
|
||||
Env: string(GetEnvType()),
|
||||
Desktop: desktop(),
|
||||
Package: packageTypeCached(),
|
||||
GoArch: runtime.GOARCH,
|
||||
GoVersion: runtime.Version(),
|
||||
DeviceModel: deviceModel(),
|
||||
Cloud: string(cloudenv.Get()),
|
||||
NoLogsNoSupport: envknob.NoLogsNoSupport(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ package hostinfo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@@ -99,11 +98,11 @@ func linuxVersionMeta() (meta versionMeta) {
|
||||
case distro.OpenWrt:
|
||||
propFile = "/etc/openwrt_release"
|
||||
case distro.WDMyCloud:
|
||||
slurp, _ := ioutil.ReadFile("/etc/version")
|
||||
slurp, _ := os.ReadFile("/etc/version")
|
||||
meta.DistroVersion = string(bytes.TrimSpace(slurp))
|
||||
return
|
||||
case distro.QNAP:
|
||||
slurp, _ := ioutil.ReadFile("/etc/version_info")
|
||||
slurp, _ := os.ReadFile("/etc/version_info")
|
||||
meta.DistroVersion = getQnapQtsVersion(string(slurp))
|
||||
return
|
||||
}
|
||||
@@ -133,7 +132,7 @@ func linuxVersionMeta() (meta versionMeta) {
|
||||
case "debian":
|
||||
// Debian's VERSION_ID is just like "11". But /etc/debian_version has "11.5" normally.
|
||||
// Or "bookworm/sid" on sid/testing.
|
||||
slurp, _ := ioutil.ReadFile("/etc/debian_version")
|
||||
slurp, _ := os.ReadFile("/etc/debian_version")
|
||||
if v := string(bytes.TrimSpace(slurp)); v != "" {
|
||||
if '0' <= v[0] && v[0] <= '9' {
|
||||
meta.DistroVersion = v
|
||||
@@ -143,7 +142,7 @@ func linuxVersionMeta() (meta versionMeta) {
|
||||
}
|
||||
case "", "centos": // CentOS 6 has no /etc/os-release, so its id is ""
|
||||
if meta.DistroVersion == "" {
|
||||
if cr, _ := ioutil.ReadFile("/etc/centos-release"); len(cr) > 0 { // "CentOS release 6.10 (Final)
|
||||
if cr, _ := os.ReadFile("/etc/centos-release"); len(cr) > 0 { // "CentOS release 6.10 (Final)
|
||||
meta.DistroVersion = string(bytes.TrimSpace(cr))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
package ipnlocal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -15,6 +18,21 @@ func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) {
|
||||
// Test handler.
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Write(body)
|
||||
case "/ssh/usernames":
|
||||
var req tailcfg.C2NSSHUsernamesRequest
|
||||
if r.Method == "POST" {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
res, err := b.getSSHUsernames(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(res)
|
||||
default:
|
||||
http.Error(w, "unknown c2n path", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"time"
|
||||
|
||||
"go4.org/netipx"
|
||||
"golang.org/x/exp/slices"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/envknob"
|
||||
@@ -68,7 +67,6 @@ import (
|
||||
)
|
||||
|
||||
var controlDebugFlags = getControlDebugFlags()
|
||||
var canSSH = envknob.CanSSHD()
|
||||
|
||||
func getControlDebugFlags() []string {
|
||||
if e := envknob.String("TS_DEBUG_CONTROL_FLAGS"); e != "" {
|
||||
@@ -575,6 +573,10 @@ func (b *LocalBackend) WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.Use
|
||||
func (b *LocalBackend) PeerCaps(src netip.Addr) []string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.peerCapsLocked(src)
|
||||
}
|
||||
|
||||
func (b *LocalBackend) peerCapsLocked(src netip.Addr) []string {
|
||||
if b.netMap == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -586,9 +588,9 @@ func (b *LocalBackend) PeerCaps(src netip.Addr) []string {
|
||||
if !a.IsSingleIP() {
|
||||
continue
|
||||
}
|
||||
dstIP := a.Addr()
|
||||
if dstIP.BitLen() == src.BitLen() {
|
||||
return filt.AppendCaps(nil, src, a.Addr())
|
||||
dst := a.Addr()
|
||||
if dst.BitLen() == src.BitLen() { // match on family
|
||||
return filt.AppendCaps(nil, src, dst)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -682,6 +684,9 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) {
|
||||
}
|
||||
}
|
||||
if st.NetMap != nil {
|
||||
if err := b.tkaSyncIfNeededLocked(st.NetMap); err != nil {
|
||||
b.logf("[v1] TKA sync error: %v", err)
|
||||
}
|
||||
if b.findExitNodeIDLocked(st.NetMap) {
|
||||
prefsChanged = true
|
||||
}
|
||||
@@ -1510,12 +1515,12 @@ func (b *LocalBackend) tellClientToBrowseToURL(url string) {
|
||||
}
|
||||
|
||||
// For testing lazy machine key generation.
|
||||
var panicOnMachineKeyGeneration = envknob.Bool("TS_DEBUG_PANIC_MACHINE_KEY")
|
||||
var panicOnMachineKeyGeneration = envknob.RegisterBool("TS_DEBUG_PANIC_MACHINE_KEY")
|
||||
|
||||
func (b *LocalBackend) createGetMachinePrivateKeyFunc() func() (key.MachinePrivate, error) {
|
||||
var cache syncs.AtomicValue[key.MachinePrivate]
|
||||
return func() (key.MachinePrivate, error) {
|
||||
if panicOnMachineKeyGeneration {
|
||||
if panicOnMachineKeyGeneration() {
|
||||
panic("machine key generated")
|
||||
}
|
||||
if v, ok := cache.LoadOk(); ok {
|
||||
@@ -1752,7 +1757,7 @@ func (b *LocalBackend) loadStateLocked(key ipn.StateKey, prefs *ipn.Prefs) (err
|
||||
// setAtomicValuesFromPrefs populates sshAtomicBool and containsViaIPFuncAtomic
|
||||
// from the prefs p, which may be nil.
|
||||
func (b *LocalBackend) setAtomicValuesFromPrefs(p *ipn.Prefs) {
|
||||
b.sshAtomicBool.Store(p != nil && p.RunSSH && canSSH)
|
||||
b.sshAtomicBool.Store(p != nil && p.RunSSH && envknob.CanSSHD())
|
||||
|
||||
if p == nil {
|
||||
b.containsViaIPFuncAtomic.Store(tsaddr.NewContainsIPFunc(nil))
|
||||
@@ -1967,7 +1972,7 @@ func (b *LocalBackend) checkSSHPrefsLocked(p *ipn.Prefs) error {
|
||||
default:
|
||||
return errors.New("The Tailscale SSH server is not supported on " + runtime.GOOS)
|
||||
}
|
||||
if !canSSH {
|
||||
if !envknob.CanSSHD() {
|
||||
return errors.New("The Tailscale SSH server has been administratively disabled.")
|
||||
}
|
||||
if envknob.SSHIgnoreTailnetPolicy() || envknob.SSHPolicyFile() != "" {
|
||||
@@ -2032,7 +2037,7 @@ func (b *LocalBackend) EditPrefs(mp *ipn.MaskedPrefs) (*ipn.Prefs, error) {
|
||||
b.logf("EditPrefs check error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if p1.RunSSH && !canSSH {
|
||||
if p1.RunSSH && !envknob.CanSSHD() {
|
||||
b.mu.Unlock()
|
||||
b.logf("EditPrefs requests SSH, but disabled by envknob; returning error")
|
||||
return nil, errors.New("Tailscale SSH server administratively disabled.")
|
||||
@@ -2854,7 +2859,7 @@ func (b *LocalBackend) applyPrefsToHostinfo(hi *tailcfg.Hostinfo, prefs *ipn.Pre
|
||||
hi.ShieldsUp = prefs.ShieldsUp
|
||||
|
||||
var sshHostKeys []string
|
||||
if prefs.RunSSH && canSSH {
|
||||
if prefs.RunSSH && envknob.CanSSHD() {
|
||||
// TODO(bradfitz): this is called with b.mu held. Not ideal.
|
||||
// If the filesystem gets wedged or something we could block for
|
||||
// a long time. But probably fine.
|
||||
@@ -3073,7 +3078,7 @@ func (b *LocalBackend) ResetForClientDisconnect() {
|
||||
b.setAtomicValuesFromPrefs(nil)
|
||||
}
|
||||
|
||||
func (b *LocalBackend) ShouldRunSSH() bool { return b.sshAtomicBool.Load() && canSSH }
|
||||
func (b *LocalBackend) ShouldRunSSH() bool { return b.sshAtomicBool.Load() && envknob.CanSSHD() }
|
||||
|
||||
// ShouldHandleViaIP reports whether whether ip is an IPv6 address in the
|
||||
// Tailscale ULA's v6 "via" range embedding an IPv4 address to be forwarded to
|
||||
@@ -3223,6 +3228,17 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) {
|
||||
}
|
||||
}
|
||||
|
||||
// operatorUserName returns the current pref's OperatorUser's name, or the
|
||||
// empty string if none.
|
||||
func (b *LocalBackend) operatorUserName() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.prefs == nil {
|
||||
return ""
|
||||
}
|
||||
return b.prefs.OperatorUser
|
||||
}
|
||||
|
||||
// OperatorUserID returns the current pref's OperatorUser's ID (in
|
||||
// os/user.User.Uid string form), or the empty string if none.
|
||||
func (b *LocalBackend) OperatorUserID() string {
|
||||
@@ -3305,13 +3321,15 @@ func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) {
|
||||
return nil, errors.New("file sharing not enabled by Tailscale admin")
|
||||
}
|
||||
for _, p := range nm.Peers {
|
||||
if p.User != nm.User && !slices.Contains(p.Capabilities, tailcfg.CapabilityFileSharingTarget) {
|
||||
if len(p.Addresses) == 0 {
|
||||
continue
|
||||
}
|
||||
if p.User != nm.User && b.peerHasCapLocked(p.Addresses[0].Addr(), tailcfg.CapabilityFileSharing) {
|
||||
continue
|
||||
}
|
||||
peerAPI := peerAPIBase(b.netMap, p)
|
||||
if peerAPI == "" {
|
||||
continue
|
||||
|
||||
}
|
||||
ret = append(ret, &apitype.FileTarget{
|
||||
Node: p,
|
||||
@@ -3322,6 +3340,15 @@ func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (b *LocalBackend) peerHasCapLocked(addr netip.Addr, wantCap string) bool {
|
||||
for _, hasCap := range b.peerCapsLocked(addr) {
|
||||
if hasCap == wantCap {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetDNS adds a DNS record for the given domain name & TXT record
|
||||
// value.
|
||||
//
|
||||
@@ -3580,6 +3607,17 @@ func (b *LocalBackend) DoNoiseRequest(req *http.Request) (*http.Response, error)
|
||||
return cc.DoNoiseRequest(req)
|
||||
}
|
||||
|
||||
// tailscaleSSHEnabled reports whether Tailscale SSH is currently enabled based
|
||||
// on prefs. It returns false if there are no prefs set.
|
||||
func (b *LocalBackend) tailscaleSSHEnabled() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.prefs == nil {
|
||||
return false
|
||||
}
|
||||
return b.prefs.RunSSH
|
||||
}
|
||||
|
||||
func (b *LocalBackend) sshServerOrInit() (_ SSHServer, err error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
@@ -478,8 +478,8 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
|
||||
// Issue 1573: don't generate a machine key if we don't want to be running.
|
||||
func TestLazyMachineKeyGeneration(t *testing.T) {
|
||||
defer func(old bool) { panicOnMachineKeyGeneration = old }(panicOnMachineKeyGeneration)
|
||||
panicOnMachineKeyGeneration = true
|
||||
defer func(old func() bool) { panicOnMachineKeyGeneration = old }(panicOnMachineKeyGeneration)
|
||||
panicOnMachineKeyGeneration = func() bool { return true }
|
||||
|
||||
var logf logger.Logf = logger.Discard
|
||||
store := new(mem.Store)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"tailscale.com/envknob"
|
||||
@@ -24,13 +26,125 @@ import (
|
||||
"tailscale.com/types/tkatype"
|
||||
)
|
||||
|
||||
var networkLockAvailable = envknob.Bool("TS_EXPERIMENTAL_NETWORK_LOCK")
|
||||
var networkLockAvailable = envknob.RegisterBool("TS_EXPERIMENTAL_NETWORK_LOCK")
|
||||
|
||||
type tkaState struct {
|
||||
authority *tka.Authority
|
||||
storage *tka.FS
|
||||
}
|
||||
|
||||
// tkaSyncIfNeededLocked examines TKA info reported from the control plane,
|
||||
// performing the steps necessary to synchronize local tka state.
|
||||
//
|
||||
// There are 4 scenarios handled here:
|
||||
// - Enablement: nm.TKAEnabled but b.tka == nil
|
||||
// ∴ reach out to /machine/tka/boostrap to get the genesis AUM, then
|
||||
// initialize TKA.
|
||||
// - Disablement: !nm.TKAEnabled but b.tka != nil
|
||||
// ∴ reach out to /machine/tka/boostrap to read the disablement secret,
|
||||
// then verify and clear tka local state.
|
||||
// - Sync needed: b.tka.Head != nm.TKAHead
|
||||
// ∴ complete multi-step synchronization flow.
|
||||
// - Everything up to date: All other cases.
|
||||
// ∴ no action necessary.
|
||||
//
|
||||
// b.mu must be held. b.mu will be stepped out of (and back in) during network
|
||||
// RPCs.
|
||||
func (b *LocalBackend) tkaSyncIfNeededLocked(nm *netmap.NetworkMap) error {
|
||||
if !networkLockAvailable() {
|
||||
// If the feature flag is not enabled, pretend we don't exist.
|
||||
return nil
|
||||
}
|
||||
if nm.SelfNode == nil {
|
||||
return errors.New("SelfNode missing")
|
||||
}
|
||||
|
||||
isEnabled := b.tka != nil
|
||||
wantEnabled := nm.TKAEnabled
|
||||
if isEnabled != wantEnabled {
|
||||
var ourHead tka.AUMHash
|
||||
if b.tka != nil {
|
||||
ourHead = b.tka.authority.Head()
|
||||
}
|
||||
|
||||
// Regardless of whether we are moving to disabled or enabled, we
|
||||
// need information from the tka bootstrap endpoint.
|
||||
b.mu.Unlock()
|
||||
bs, err := b.tkaFetchBootstrap(nm.SelfNode.ID, ourHead)
|
||||
b.mu.Lock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching bootstrap: %v", err)
|
||||
}
|
||||
|
||||
if wantEnabled && !isEnabled {
|
||||
if err := b.tkaBootstrapFromGenesisLocked(bs.GenesisAUM); err != nil {
|
||||
return fmt.Errorf("bootstrap: %v", err)
|
||||
}
|
||||
isEnabled = true
|
||||
} else if !wantEnabled && isEnabled {
|
||||
if b.tka.authority.ValidDisablement(bs.DisablementSecret) {
|
||||
b.tka = nil
|
||||
isEnabled = false
|
||||
|
||||
if err := os.RemoveAll(b.chonkPath()); err != nil {
|
||||
return fmt.Errorf("os.RemoveAll: %v", err)
|
||||
}
|
||||
} else {
|
||||
b.logf("Disablement secret did not verify, leaving TKA enabled.")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("[bug] unreachable invariant of wantEnabled /w isEnabled")
|
||||
}
|
||||
}
|
||||
|
||||
if isEnabled && b.tka.authority.Head() != nm.TKAHead {
|
||||
// TODO(tom): Implement sync
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// chonkPath returns the absolute path to the directory in which TKA
|
||||
// state (the 'tailchonk') is stored.
|
||||
func (b *LocalBackend) chonkPath() string {
|
||||
return filepath.Join(b.TailscaleVarRoot(), "tka")
|
||||
}
|
||||
|
||||
// tkaBootstrapFromGenesisLocked initializes the local (on-disk) state of the
|
||||
// tailnet key authority, based on the given genesis AUM.
|
||||
//
|
||||
// b.mu must be held.
|
||||
func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM) error {
|
||||
if !b.CanSupportNetworkLock() {
|
||||
return errors.New("network lock not supported in this configuration")
|
||||
}
|
||||
|
||||
var genesis tka.AUM
|
||||
if err := genesis.Unserialize(g); err != nil {
|
||||
return fmt.Errorf("reading genesis: %v", err)
|
||||
}
|
||||
|
||||
chonkDir := b.chonkPath()
|
||||
if err := os.Mkdir(chonkDir, 0755); err != nil && !os.IsExist(err) {
|
||||
return fmt.Errorf("mkdir: %v", err)
|
||||
}
|
||||
|
||||
chonk, err := tka.ChonkDir(chonkDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("chonk: %v", err)
|
||||
}
|
||||
authority, err := tka.Bootstrap(chonk, genesis)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tka bootstrap: %v", err)
|
||||
}
|
||||
|
||||
b.tka = &tkaState{
|
||||
authority: authority,
|
||||
storage: chonk,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanSupportNetworkLock returns true if tailscaled is able to operate
|
||||
// a local tailnet key authority (and hence enforce network lock).
|
||||
func (b *LocalBackend) CanSupportNetworkLock() bool {
|
||||
@@ -82,7 +196,7 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key) error {
|
||||
if b.tka != nil {
|
||||
return errors.New("network-lock is already initialized")
|
||||
}
|
||||
if !networkLockAvailable {
|
||||
if !networkLockAvailable() {
|
||||
return errors.New("this is an experimental feature in your version of tailscale - Please upgrade to the latest to use this.")
|
||||
}
|
||||
if !b.CanSupportNetworkLock() {
|
||||
@@ -237,3 +351,50 @@ func (b *LocalBackend) tkaInitFinish(nm *netmap.NetworkMap, nks map[tailcfg.Node
|
||||
return a, nil
|
||||
}
|
||||
}
|
||||
|
||||
// tkaFetchBootstrap sends a /machine/tka/bootstrap RPC to the control plane
|
||||
// over noise. This is used to get values necessary to enable or disable TKA.
|
||||
func (b *LocalBackend) tkaFetchBootstrap(nodeID tailcfg.NodeID, head tka.AUMHash) (*tailcfg.TKABootstrapResponse, error) {
|
||||
bootstrapReq := tailcfg.TKABootstrapRequest{
|
||||
NodeID: nodeID,
|
||||
}
|
||||
if !head.IsZero() {
|
||||
head, err := head.MarshalText()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("head.MarshalText failed: %v", err)
|
||||
}
|
||||
bootstrapReq.Head = string(head)
|
||||
}
|
||||
|
||||
var req bytes.Buffer
|
||||
if err := json.NewEncoder(&req).Encode(bootstrapReq); err != nil {
|
||||
return nil, fmt.Errorf("encoding request: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, fmt.Errorf("ctx: %w", err)
|
||||
}
|
||||
req2, err := http.NewRequestWithContext(ctx, "GET", "https://unused/machine/tka/bootstrap", &req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("req: %w", err)
|
||||
}
|
||||
res, err := b.DoNoiseRequest(req2)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resp: %w", err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
return nil, fmt.Errorf("request returned (%d): %s", res.StatusCode, string(body))
|
||||
}
|
||||
a := new(tailcfg.TKABootstrapResponse)
|
||||
err = json.NewDecoder(res.Body).Decode(a)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding JSON: %w", err)
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
243
ipn/ipnlocal/network-lock_test.go
Normal file
243
ipn/ipnlocal/network-lock_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ipnlocal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/hostinfo"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tka"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/netmap"
|
||||
)
|
||||
|
||||
func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto {
|
||||
hi := hostinfo.New()
|
||||
ni := tailcfg.NetInfo{LinkType: "wired"}
|
||||
hi.NetInfo = &ni
|
||||
|
||||
k := key.NewMachine()
|
||||
opts := controlclient.Options{
|
||||
ServerURL: "https://example.com",
|
||||
Hostinfo: hi,
|
||||
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
|
||||
return k, nil
|
||||
},
|
||||
HTTPTestClient: c,
|
||||
NoiseTestClient: c,
|
||||
Status: func(controlclient.Status) {},
|
||||
}
|
||||
|
||||
cc, err := controlclient.NewNoStart(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return cc
|
||||
}
|
||||
|
||||
// NOTE: URLs must have a https scheme and example.com domain to work with the underlying
|
||||
// httptest plumbing, despite the domain being unused in the actual noise request transport.
|
||||
func fakeNoiseServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, *http.Client) {
|
||||
ts := httptest.NewUnstartedServer(handler)
|
||||
ts.StartTLS()
|
||||
client := ts.Client()
|
||||
client.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify = true
|
||||
client.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return (&net.Dialer{}).DialContext(ctx, network, ts.Listener.Addr().String())
|
||||
}
|
||||
return ts, client
|
||||
}
|
||||
|
||||
func TestTKAEnablementFlow(t *testing.T) {
|
||||
networkLockAvailable = func() bool { return true } // Enable the feature flag
|
||||
|
||||
// Make a fake TKA authority, getting a usable genesis AUM which
|
||||
// our mock server can communicate.
|
||||
nlPriv := key.NewNLPrivate()
|
||||
key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2}
|
||||
a1, genesisAUM, err := tka.Create(&tka.Mem{}, tka.State{
|
||||
Keys: []tka.Key{key},
|
||||
DisablementSecrets: [][]byte{bytes.Repeat([]byte{0xa5}, 32)},
|
||||
}, nlPriv)
|
||||
if err != nil {
|
||||
t.Fatalf("tka.Create() failed: %v", err)
|
||||
}
|
||||
|
||||
ts, client := fakeNoiseServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
switch r.URL.Path {
|
||||
case "/machine/tka/bootstrap":
|
||||
body := new(tailcfg.TKABootstrapRequest)
|
||||
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if body.NodeID != 420 {
|
||||
t.Errorf("bootstrap nodeID=%v, want 420", body.NodeID)
|
||||
}
|
||||
if body.Head != "" {
|
||||
t.Errorf("bootstrap head=%s, want empty hash", body.Head)
|
||||
}
|
||||
|
||||
w.WriteHeader(200)
|
||||
out := tailcfg.TKABootstrapResponse{
|
||||
GenesisAUM: genesisAUM.Serialize(),
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(out); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
default:
|
||||
t.Errorf("unhandled endpoint path: %v", r.URL.Path)
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
temp := t.TempDir()
|
||||
|
||||
cc := fakeControlClient(t, client)
|
||||
b := LocalBackend{
|
||||
varRoot: temp,
|
||||
cc: cc,
|
||||
ccAuto: cc,
|
||||
logf: t.Logf,
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
err = b.tkaSyncIfNeededLocked(&netmap.NetworkMap{
|
||||
SelfNode: &tailcfg.Node{ID: 420},
|
||||
TKAEnabled: true,
|
||||
TKAHead: tka.AUMHash{},
|
||||
})
|
||||
b.mu.Unlock()
|
||||
if err != nil {
|
||||
t.Errorf("tkaSyncIfNeededLocked() failed: %v", err)
|
||||
}
|
||||
if b.tka == nil {
|
||||
t.Fatal("tka was not initialized")
|
||||
}
|
||||
if b.tka.authority.Head() != a1.Head() {
|
||||
t.Errorf("authority.Head() = %x, want %x", b.tka.authority.Head(), a1.Head())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTKADisablementFlow(t *testing.T) {
|
||||
networkLockAvailable = func() bool { return true } // Enable the feature flag
|
||||
temp := t.TempDir()
|
||||
os.Mkdir(filepath.Join(temp, "tka"), 0755)
|
||||
|
||||
// Make a fake TKA authority, to seed local state.
|
||||
disablementSecret := bytes.Repeat([]byte{0xa5}, 32)
|
||||
nlPriv := key.NewNLPrivate()
|
||||
key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2}
|
||||
chonk, err := tka.ChonkDir(filepath.Join(temp, "tka"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
authority, _, err := tka.Create(chonk, tka.State{
|
||||
Keys: []tka.Key{key},
|
||||
DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)},
|
||||
}, nlPriv)
|
||||
if err != nil {
|
||||
t.Fatalf("tka.Create() failed: %v", err)
|
||||
}
|
||||
|
||||
ts, client := fakeNoiseServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
switch r.URL.Path {
|
||||
case "/machine/tka/bootstrap":
|
||||
body := new(tailcfg.TKABootstrapRequest)
|
||||
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var disablement []byte
|
||||
switch body.NodeID {
|
||||
case 42:
|
||||
disablement = bytes.Repeat([]byte{0x42}, 32) // wrong secret
|
||||
case 420:
|
||||
disablement = disablementSecret
|
||||
default:
|
||||
t.Errorf("bootstrap nodeID=%v, wanted 42 or 420", body.NodeID)
|
||||
}
|
||||
var head tka.AUMHash
|
||||
if err := head.UnmarshalText([]byte(body.Head)); err != nil {
|
||||
t.Fatalf("failed unmarshal of body.Head: %v", err)
|
||||
}
|
||||
if head != authority.Head() {
|
||||
t.Errorf("reported head = %x, want %x", head, authority.Head())
|
||||
}
|
||||
|
||||
w.WriteHeader(200)
|
||||
out := tailcfg.TKABootstrapResponse{
|
||||
DisablementSecret: disablement,
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(out); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
default:
|
||||
t.Errorf("unhandled endpoint path: %v", r.URL.Path)
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
cc := fakeControlClient(t, client)
|
||||
b := LocalBackend{
|
||||
varRoot: temp,
|
||||
cc: cc,
|
||||
ccAuto: cc,
|
||||
logf: t.Logf,
|
||||
tka: &tkaState{
|
||||
authority: authority,
|
||||
storage: chonk,
|
||||
},
|
||||
}
|
||||
|
||||
// Test that the wrong disablement secret does not shut down the authority.
|
||||
// NodeID == 42 indicates this scenario to our mock server.
|
||||
b.mu.Lock()
|
||||
err = b.tkaSyncIfNeededLocked(&netmap.NetworkMap{
|
||||
SelfNode: &tailcfg.Node{ID: 42},
|
||||
TKAEnabled: false,
|
||||
TKAHead: authority.Head(),
|
||||
})
|
||||
b.mu.Unlock()
|
||||
if err != nil {
|
||||
t.Errorf("tkaSyncIfNeededLocked() failed: %v", err)
|
||||
}
|
||||
if b.tka == nil {
|
||||
t.Error("TKA was disabled despite incorrect disablement secret")
|
||||
}
|
||||
|
||||
// Test the correct disablement secret shuts down the authority.
|
||||
// NodeID == 420 indicates this scenario to our mock server.
|
||||
b.mu.Lock()
|
||||
err = b.tkaSyncIfNeededLocked(&netmap.NetworkMap{
|
||||
SelfNode: &tailcfg.Node{ID: 420},
|
||||
TKAEnabled: false,
|
||||
TKAHead: authority.Head(),
|
||||
})
|
||||
b.mu.Unlock()
|
||||
if err != nil {
|
||||
t.Errorf("tkaSyncIfNeededLocked() failed: %v", err)
|
||||
}
|
||||
|
||||
if b.tka != nil {
|
||||
t.Fatal("tka was not shut down")
|
||||
}
|
||||
if _, err := os.Stat(b.chonkPath()); err == nil || !os.IsNotExist(err) {
|
||||
t.Errorf("os.Stat(chonkDir) = %v, want ErrNotExist", err)
|
||||
}
|
||||
}
|
||||
@@ -44,6 +44,7 @@ import (
|
||||
"tailscale.com/net/netutil"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/strs"
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
@@ -720,8 +721,8 @@ func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
rawPath := r.URL.EscapedPath()
|
||||
suffix := strings.TrimPrefix(rawPath, "/v0/put/")
|
||||
if suffix == rawPath {
|
||||
suffix, ok := strs.CutPrefix(rawPath, "/v0/put/")
|
||||
if !ok {
|
||||
http.Error(w, "misconfigured internals", 500)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -87,7 +86,7 @@ func fileHasContents(name string, want string) check {
|
||||
return
|
||||
}
|
||||
path := filepath.Join(root, name)
|
||||
got, err := ioutil.ReadFile(path)
|
||||
got, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Errorf("fileHasContents: %v", err)
|
||||
return
|
||||
@@ -517,7 +516,7 @@ func TestDeletedMarkers(t *testing.T) {
|
||||
}
|
||||
wantEmptyTempDir := func() {
|
||||
t.Helper()
|
||||
if fis, err := ioutil.ReadDir(dir); err != nil {
|
||||
if fis, err := os.ReadDir(dir); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(fis) > 0 && runtime.GOOS != "windows" {
|
||||
for _, fi := range fis {
|
||||
|
||||
@@ -18,24 +18,98 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tailscale/golang-x-crypto/ssh"
|
||||
"tailscale.com/envknob"
|
||||
"go4.org/mem"
|
||||
"golang.org/x/exp/slices"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/lineread"
|
||||
"tailscale.com/util/mak"
|
||||
)
|
||||
|
||||
var useHostKeys = envknob.Bool("TS_USE_SYSTEM_SSH_HOST_KEYS")
|
||||
|
||||
// keyTypes are the SSH key types that we either try to read from the
|
||||
// system's OpenSSH keys or try to generate for ourselves when not
|
||||
// running as root.
|
||||
var keyTypes = []string{"rsa", "ecdsa", "ed25519"}
|
||||
|
||||
// getSSHUsernames discovers and returns the list of usernames that are
|
||||
// potential Tailscale SSH user targets.
|
||||
//
|
||||
// Invariant: must not be called with b.mu held.
|
||||
func (b *LocalBackend) getSSHUsernames(req *tailcfg.C2NSSHUsernamesRequest) (*tailcfg.C2NSSHUsernamesResponse, error) {
|
||||
res := new(tailcfg.C2NSSHUsernamesResponse)
|
||||
if !b.tailscaleSSHEnabled() {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
max := 10
|
||||
if req != nil && req.Max != 0 {
|
||||
max = req.Max
|
||||
}
|
||||
|
||||
add := func(u string) {
|
||||
if req != nil && req.Exclude[u] {
|
||||
return
|
||||
}
|
||||
switch u {
|
||||
case "nobody", "daemon", "sync":
|
||||
return
|
||||
}
|
||||
if slices.Contains(res.Usernames, u) {
|
||||
return
|
||||
}
|
||||
if len(res.Usernames) > max {
|
||||
// Enough for a hint.
|
||||
return
|
||||
}
|
||||
res.Usernames = append(res.Usernames, u)
|
||||
}
|
||||
|
||||
if opUser := b.operatorUserName(); opUser != "" {
|
||||
add(opUser)
|
||||
}
|
||||
|
||||
// Check popular usernames and see if they exist with a real shell.
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
out, err := exec.Command("dscl", ".", "list", "/Users").Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lineread.Reader(bytes.NewReader(out), func(line []byte) error {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || line[0] == '_' {
|
||||
return nil
|
||||
}
|
||||
add(string(line))
|
||||
return nil
|
||||
})
|
||||
default:
|
||||
lineread.File("/etc/passwd", func(line []byte) error {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || line[0] == '#' || line[0] == '_' {
|
||||
return nil
|
||||
}
|
||||
if mem.HasSuffix(mem.B(line), mem.S("/nologin")) ||
|
||||
mem.HasSuffix(mem.B(line), mem.S("/false")) {
|
||||
return nil
|
||||
}
|
||||
colon := bytes.IndexByte(line, ':')
|
||||
if colon != -1 {
|
||||
add(string(line[:colon]))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (b *LocalBackend) GetSSH_HostKeys() (keys []ssh.Signer, err error) {
|
||||
var existing map[string]ssh.Signer
|
||||
if os.Geteuid() == 0 {
|
||||
@@ -83,7 +157,7 @@ func (b *LocalBackend) hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) {
|
||||
defer keyGenMu.Unlock()
|
||||
|
||||
path := filepath.Join(keyDir, "ssh_host_"+typ+"_key")
|
||||
v, err := ioutil.ReadFile(path)
|
||||
v, err := os.ReadFile(path)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
@@ -124,7 +198,7 @@ func (b *LocalBackend) hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) {
|
||||
func (b *LocalBackend) getSystemSSH_HostKeys() (ret map[string]ssh.Signer) {
|
||||
for _, typ := range keyTypes {
|
||||
filename := "/etc/ssh/ssh_host_" + typ + "_key"
|
||||
hostKey, err := ioutil.ReadFile(filename)
|
||||
hostKey, err := os.ReadFile(filename)
|
||||
if err != nil || len(bytes.TrimSpace(hostKey)) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -6,6 +6,16 @@
|
||||
|
||||
package ipnlocal
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func (b *LocalBackend) getSSHHostKeyPublicStrings() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LocalBackend) getSSHUsernames(*tailcfg.C2NSSHUsernamesRequest) (*tailcfg.C2NSSHUsernamesResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -2,14 +2,18 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
//go:build linux || (darwin && !ios)
|
||||
// +build linux darwin,!ios
|
||||
|
||||
package ipnlocal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
func TestSSHKeyGen(t *testing.T) {
|
||||
@@ -40,3 +44,17 @@ func TestSSHKeyGen(t *testing.T) {
|
||||
t.Errorf("got different keys on second call")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeSSHServer struct {
|
||||
SSHServer
|
||||
}
|
||||
|
||||
func TestGetSSHUsernames(t *testing.T) {
|
||||
b := new(LocalBackend)
|
||||
b.sshServer = fakeSSHServer{}
|
||||
res, err := b.getSSHUsernames(new(tailcfg.C2NSSHUsernamesRequest))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Got: %s", must.Get(json.Marshal(res)))
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -215,7 +214,7 @@ func (s *Server) blockWhileInUse(conn io.Reader, ci connIdentity) {
|
||||
s.logf("blocking client while server in use; connIdentity=%v", ci)
|
||||
connDone := make(chan struct{})
|
||||
go func() {
|
||||
io.Copy(ioutil.Discard, conn)
|
||||
io.Copy(io.Discard, conn)
|
||||
close(connDone)
|
||||
}()
|
||||
ch := make(chan struct{}, 1)
|
||||
@@ -773,7 +772,7 @@ func New(logf logger.Logf, logid string, store ipn.StateStore, eng wgengine.Engi
|
||||
})
|
||||
|
||||
if root := b.TailscaleVarRoot(); root != "" {
|
||||
chonkDir := filepath.Join(root, "chonk")
|
||||
chonkDir := filepath.Join(root, "tka")
|
||||
if _, err := os.Stat(chonkDir); err == nil {
|
||||
// The directory exists, which means network-lock has been initialized.
|
||||
storage, err := tka.ChonkDir(chonkDir)
|
||||
@@ -933,14 +932,6 @@ func BabysitProc(ctx context.Context, args []string, logf logger.Logf) {
|
||||
startTime := time.Now()
|
||||
log.Printf("exec: %#v %v", executable, args)
|
||||
cmd := exec.Command(executable, args...)
|
||||
if runtime.GOOS == "windows" {
|
||||
extraEnv, err := loadExtraEnv()
|
||||
if err != nil {
|
||||
logf("errors loading extra env file; ignoring: %v", err)
|
||||
} else {
|
||||
cmd.Env = append(os.Environ(), extraEnv...)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a pipe object to use as the subproc's stdin.
|
||||
// When the writer goes away, the reader gets EOF.
|
||||
@@ -1175,7 +1166,7 @@ func findTrueNASTaildropDir(name string) (dir string, err error) {
|
||||
}
|
||||
|
||||
// but if running on the host, it may be something like /mnt/Primary/Taildrop
|
||||
fis, err := ioutil.ReadDir("/mnt")
|
||||
fis, err := os.ReadDir("/mnt")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading /mnt: %w", err)
|
||||
}
|
||||
@@ -1209,38 +1200,3 @@ func findQnapTaildropDir(name string) (string, error) {
|
||||
}
|
||||
return "", fmt.Errorf("shared folder %q not found", name)
|
||||
}
|
||||
|
||||
func loadExtraEnv() (env []string, err error) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil, nil
|
||||
}
|
||||
name := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "tailscaled-env.txt")
|
||||
contents, err := os.ReadFile(name)
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, line := range strings.Split(string(contents), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
k, v, ok := strings.Cut(line, "=")
|
||||
if !ok || k == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(v, `"`) {
|
||||
var err error
|
||||
v, err = strconv.Unquote(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid value in line %q: %v", line, err)
|
||||
}
|
||||
env = append(env, k+"="+v)
|
||||
} else {
|
||||
env = append(env, line)
|
||||
}
|
||||
}
|
||||
return env, nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -38,6 +37,7 @@ import (
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/strs"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/version/distro"
|
||||
)
|
||||
@@ -73,7 +73,7 @@ func (h *Handler) certDir() (string, error) {
|
||||
return full, nil
|
||||
}
|
||||
|
||||
var acmeDebug = envknob.Bool("TS_DEBUG_ACME")
|
||||
var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
|
||||
|
||||
func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.PermitWrite && !h.PermitCert {
|
||||
@@ -87,16 +87,19 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
domain := strings.TrimPrefix(r.URL.Path, "/localapi/v0/cert/")
|
||||
if domain == r.URL.Path {
|
||||
domain, ok := strs.CutPrefix(r.URL.Path, "/localapi/v0/cert/")
|
||||
if !ok {
|
||||
http.Error(w, "internal handler config wired wrong", 500)
|
||||
return
|
||||
}
|
||||
|
||||
if !validLookingCertDomain(domain) {
|
||||
http.Error(w, "invalid domain", 400)
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
logf := logger.WithPrefix(h.logf, fmt.Sprintf("cert(%q): ", domain))
|
||||
traceACME := func(v any) {
|
||||
if !acmeDebug {
|
||||
if !acmeDebug() {
|
||||
return
|
||||
}
|
||||
j, _ := json.MarshalIndent(v, "", "\t")
|
||||
@@ -165,6 +168,11 @@ func certFile(dir, domain string) string { return filepath.Join(dir, domain+".cr
|
||||
// keypair for domain exists on disk in dir that is valid at the
|
||||
// provided now time.
|
||||
func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) {
|
||||
if !validLookingCertDomain(domain) {
|
||||
// Before we read files from disk using it, validate it's halfway
|
||||
// reasonable looking.
|
||||
return nil, false
|
||||
}
|
||||
if keyPEM, err := os.ReadFile(keyFile(dir, domain)); err == nil {
|
||||
certPEM, _ := os.ReadFile(certFile(dir, domain))
|
||||
if validCertPEM(domain, keyPEM, certPEM, now) {
|
||||
@@ -293,7 +301,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
|
||||
if err := encodeECDSAKey(&privPEM, certPrivKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ioutil.WriteFile(keyFile(dir, domain), privPEM.Bytes(), 0600); err != nil {
|
||||
if err := os.WriteFile(keyFile(dir, domain), privPEM.Bytes(), 0600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -316,7 +324,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := ioutil.WriteFile(certFile(dir, domain), certPEM.Bytes(), 0644); err != nil {
|
||||
if err := os.WriteFile(certFile(dir, domain), certPEM.Bytes(), 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -372,7 +380,7 @@ func parsePrivateKey(der []byte) (crypto.Signer, error) {
|
||||
|
||||
func acmeKey(dir string) (crypto.Signer, error) {
|
||||
pemName := filepath.Join(dir, "acme-account.key.pem")
|
||||
if v, err := ioutil.ReadFile(pemName); err == nil {
|
||||
if v, err := os.ReadFile(pemName); err == nil {
|
||||
priv, _ := pem.Decode(v)
|
||||
if priv == nil || !strings.Contains(priv.Type, "PRIVATE") {
|
||||
return nil, errors.New("acme/autocert: invalid account key found in cache")
|
||||
@@ -388,7 +396,7 @@ func acmeKey(dir string) (crypto.Signer, error) {
|
||||
if err := encodeECDSAKey(&pemBuf, privKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ioutil.WriteFile(pemName, pemBuf.Bytes(), 0600); err != nil {
|
||||
if err := os.WriteFile(pemName, pemBuf.Bytes(), 0600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return privKey, nil
|
||||
@@ -426,6 +434,21 @@ func validCertPEM(domain string, keyPEM, certPEM []byte, now time.Time) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// validLookingCertDomain reports whether name looks like a valid domain name that
|
||||
// we might be able to get a cert for.
|
||||
//
|
||||
// It's a light check primarily for double checking before it's used
|
||||
// as part of a filesystem path. The actual validation happens in checkCertDomain.
|
||||
func validLookingCertDomain(name string) bool {
|
||||
if name == "" ||
|
||||
strings.Contains(name, "..") ||
|
||||
strings.ContainsAny(name, ":/\\\x00") ||
|
||||
!strings.Contains(name, ".") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func checkCertDomain(st *ipnstate.Status, domain string) error {
|
||||
if domain == "" {
|
||||
return errors.New("missing domain name")
|
||||
|
||||
30
ipn/localapi/cert_test.go
Normal file
30
ipn/localapi/cert_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !ios && !android && !js
|
||||
// +build !ios,!android,!js
|
||||
|
||||
package localapi
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidLookingCertDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"foo.com", true},
|
||||
{"foo..com", false},
|
||||
{"foo/com.com", false},
|
||||
{"NUL", false},
|
||||
{"", false},
|
||||
{"foo\\bar.com", false},
|
||||
{"foo\x00bar.com", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := validLookingCertDomain(tt.in); got != tt.want {
|
||||
t.Errorf("validLookingCertDomain(%q) = %v, want %v", tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"time"
|
||||
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
@@ -213,6 +214,9 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
logMarker := fmt.Sprintf("BUG-%v-%v-%v", h.backendLogID, time.Now().UTC().Format("20060102150405Z"), randHex(8))
|
||||
if envknob.NoLogsNoSupport() {
|
||||
logMarker = "BUG-NO-LOGS-NO-SUPPORT-this-node-has-had-its-logging-disabled"
|
||||
}
|
||||
h.logf("user bugreport: %s", logMarker)
|
||||
if note := r.FormValue("note"); len(note) > 0 {
|
||||
h.logf("user bugreport note: %s", note)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -618,7 +617,7 @@ func PrefsFromBytes(b []byte) (*Prefs, error) {
|
||||
// LoadPrefs loads a legacy relaynode config file into Prefs
|
||||
// with sensible migration defaults set.
|
||||
func LoadPrefs(filename string) (*Prefs, error) {
|
||||
data, err := ioutil.ReadFile(filename)
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LoadPrefs open: %w", err) // err includes path
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/netip"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -474,7 +473,7 @@ func TestLoadPrefsNotExist(t *testing.T) {
|
||||
// TestLoadPrefsFileWithZeroInIt verifies that LoadPrefs hanldes corrupted input files.
|
||||
// See issue #954 for details.
|
||||
func TestLoadPrefsFileWithZeroInIt(t *testing.T) {
|
||||
f, err := ioutil.TempFile("", "TestLoadPrefsFileWithZeroInIt")
|
||||
f, err := os.CreateTemp("", "TestLoadPrefsFileWithZeroInIt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -128,7 +127,7 @@ func NewFileStore(logf logger.Logf, path string) (ipn.StateStore, error) {
|
||||
return nil, fmt.Errorf("creating state directory: %w", err)
|
||||
}
|
||||
|
||||
bs, err := ioutil.ReadFile(path)
|
||||
bs, err := os.ReadFile(path)
|
||||
|
||||
// Treat an empty file as a missing file.
|
||||
// (https://github.com/tailscale/tailscale/issues/895#issuecomment-723255589)
|
||||
|
||||
@@ -9,7 +9,6 @@ package filelogger
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -186,12 +185,18 @@ func (w *logFileWriter) startNewFileLocked() {
|
||||
//
|
||||
// w.mu must be held.
|
||||
func (w *logFileWriter) cleanLocked() {
|
||||
fis, _ := ioutil.ReadDir(w.dir)
|
||||
entries, _ := os.ReadDir(w.dir)
|
||||
prefix := w.fileBasePrefix + "-"
|
||||
fileSize := map[string]int64{}
|
||||
var files []string
|
||||
var sumSize int64
|
||||
for _, fi := range fis {
|
||||
for _, entry := range entries {
|
||||
fi, err := entry.Info()
|
||||
if err != nil {
|
||||
w.wrappedLogf("error getting log file info: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
baseName := filepath.Base(fi.Name())
|
||||
if !strings.HasPrefix(baseName, prefix) {
|
||||
continue
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -248,7 +247,7 @@ func logsDir(logf logger.Logf) string {
|
||||
// No idea where to put stuff. Try to create a temp dir. It'll
|
||||
// mean we might lose some logs and rotate through log IDs, but
|
||||
// it's something.
|
||||
tmp, err := ioutil.TempDir("", "tailscaled-log-*")
|
||||
tmp, err := os.MkdirTemp("", "tailscaled-log-*")
|
||||
if err != nil {
|
||||
panic("no safe place found to store log state")
|
||||
}
|
||||
@@ -259,7 +258,7 @@ func logsDir(logf logger.Logf) string {
|
||||
// runningUnderSystemd reports whether we're running under systemd.
|
||||
func runningUnderSystemd() bool {
|
||||
if runtime.GOOS == "linux" && os.Getppid() == 1 {
|
||||
slurp, _ := ioutil.ReadFile("/proc/1/stat")
|
||||
slurp, _ := os.ReadFile("/proc/1/stat")
|
||||
return bytes.HasPrefix(slurp, []byte("1 (systemd) "))
|
||||
}
|
||||
return false
|
||||
@@ -438,6 +437,13 @@ func tryFixLogStateLocation(dir, cmdname string) {
|
||||
// New returns a new log policy (a logger and its instance ID) for a
|
||||
// given collection name.
|
||||
func New(collection string) *Policy {
|
||||
return NewWithConfigPath(collection, "", "")
|
||||
}
|
||||
|
||||
// NewWithConfigPath is identical to New,
|
||||
// but uses the specified directory and command name.
|
||||
// If either is empty, it derives them automatically.
|
||||
func NewWithConfigPath(collection, dir, cmdName string) *Policy {
|
||||
var lflags int
|
||||
if term.IsTerminal(2) || runtime.GOOS == "windows" {
|
||||
lflags = 0
|
||||
@@ -460,9 +466,12 @@ func New(collection string) *Policy {
|
||||
earlyErrBuf.WriteByte('\n')
|
||||
}
|
||||
|
||||
dir := logsDir(earlyLogf)
|
||||
|
||||
cmdName := version.CmdName()
|
||||
if dir == "" {
|
||||
dir = logsDir(earlyLogf)
|
||||
}
|
||||
if cmdName == "" {
|
||||
cmdName = version.CmdName()
|
||||
}
|
||||
tryFixLogStateLocation(dir, cmdName)
|
||||
|
||||
cfgPath := filepath.Join(dir, fmt.Sprintf("%s.log.conf", cmdName))
|
||||
@@ -539,7 +548,10 @@ func New(collection string) *Policy {
|
||||
conf.IncludeProcSequence = true
|
||||
}
|
||||
|
||||
if val := getLogTarget(); val != "" {
|
||||
if envknob.NoLogsNoSupport() {
|
||||
log.Println("You have disabled logging. Tailscale will not be able to provide support.")
|
||||
conf.HTTPC = &http.Client{Transport: noopPretendSuccessTransport{}}
|
||||
} else if val := getLogTarget(); val != "" {
|
||||
log.Println("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.")
|
||||
conf.BaseURL = val
|
||||
u, _ := url.Parse(val)
|
||||
@@ -735,3 +747,14 @@ func goVersion() string {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
type noopPretendSuccessTransport struct{}
|
||||
|
||||
func (noopPretendSuccessTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
io.ReadAll(req.Body)
|
||||
req.Body.Close()
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Status: "200 OK",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -39,7 +39,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Fatalf("logadopt: response read failed %d: %v", resp.StatusCode, err)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -50,7 +50,7 @@ func main() {
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Fatalf("logreprocess: read error %d: %v", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ package filch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -195,7 +195,7 @@ func TestFilchStderr(t *testing.T) {
|
||||
f.close(t)
|
||||
|
||||
pipeW.Close()
|
||||
b, err := ioutil.ReadAll(pipeR)
|
||||
b, err := io.ReadAll(pipeR)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -430,7 +429,7 @@ func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (uploaded
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
uploaded = resp.StatusCode == 400 // the server saved the logs anyway
|
||||
b, _ := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
b, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
return uploaded, fmt.Errorf("log upload of %d bytes %s failed %d: %q", len(body), compressedNote, resp.StatusCode, b)
|
||||
}
|
||||
|
||||
@@ -654,7 +653,7 @@ func (l *Logger) Write(buf []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
level, buf := parseAndRemoveLogLevel(buf)
|
||||
if l.stderr != nil && l.stderr != ioutil.Discard && int64(level) <= atomic.LoadInt64(&l.stderrLevel) {
|
||||
if l.stderr != nil && l.stderr != io.Discard && int64(level) <= atomic.LoadInt64(&l.stderrLevel) {
|
||||
if buf[len(buf)-1] == '\n' {
|
||||
l.stderr.Write(buf)
|
||||
} else {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -52,7 +51,7 @@ func NewLogtailTestHarness(t *testing.T) (*LogtailTestServer, *Logger) {
|
||||
|
||||
ts.srv = httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Error("failed to read HTTP request")
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -158,7 +157,7 @@ func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) {
|
||||
if sc.Text() == resolvconfConfigName {
|
||||
continue
|
||||
}
|
||||
bs, err := ioutil.ReadFile(filepath.Join(m.interfacesDir, sc.Text()))
|
||||
bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text()))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Probably raced with a deletion, that's okay.
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"io/ioutil"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -452,7 +451,7 @@ func (fs directFS) Rename(oldName, newName string) error {
|
||||
func (fs directFS) Remove(name string) error { return os.Remove(fs.path(name)) }
|
||||
|
||||
func (fs directFS) ReadFile(name string) ([]byte, error) {
|
||||
return ioutil.ReadFile(fs.path(name))
|
||||
return os.ReadFile(fs.path(name))
|
||||
}
|
||||
|
||||
func (fs directFS) Truncate(name string) error {
|
||||
@@ -460,7 +459,7 @@ func (fs directFS) Truncate(name string) error {
|
||||
}
|
||||
|
||||
func (fs directFS) WriteFile(name string, contents []byte, perm os.FileMode) error {
|
||||
return ioutil.WriteFile(fs.path(name), contents, perm)
|
||||
return os.WriteFile(fs.path(name), contents, perm)
|
||||
}
|
||||
|
||||
// runningAsGUIDesktopUser reports whether it seems that this code is
|
||||
|
||||
@@ -6,14 +6,13 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
func NewOSConfigurator(logf logger.Logf, _ string) (OSConfigurator, error) {
|
||||
bs, err := ioutil.ReadFile("/etc/resolv.conf")
|
||||
bs, err := os.ReadFile("/etc/resolv.conf")
|
||||
if os.IsNotExist(err) {
|
||||
return newDirectManager(logf), nil
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ const (
|
||||
versionKey = `SOFTWARE\Microsoft\Windows NT\CurrentVersion`
|
||||
)
|
||||
|
||||
var configureWSL = envknob.Bool("TS_DEBUG_CONFIGURE_WSL")
|
||||
var configureWSL = envknob.RegisterBool("TS_DEBUG_CONFIGURE_WSL")
|
||||
|
||||
type windowsManager struct {
|
||||
logf logger.Logf
|
||||
@@ -359,7 +359,7 @@ func (m windowsManager) SetDNS(cfg OSConfig) error {
|
||||
|
||||
// On initial setup of WSL, the restart caused by --shutdown is slow,
|
||||
// so we do it out-of-line.
|
||||
if configureWSL {
|
||||
if configureWSL() {
|
||||
go func() {
|
||||
if err := m.wslManager.SetDNS(cfg); err != nil {
|
||||
m.logf("WSL SetDNS: %v", err) // continue
|
||||
|
||||
@@ -205,7 +205,7 @@ func (m *resolvedManager) run(ctx context.Context) {
|
||||
// When ctx goes away systemd-resolved auto reverts.
|
||||
// Keeping for potential use in future refactor.
|
||||
if call := rManager.CallWithContext(ctx, dbusResolvedInterface+".RevertLink", 0, m.ifidx); call.Err != nil {
|
||||
m.logf("[v1] RevertLink: %w", call.Err)
|
||||
m.logf("[v1] RevertLink: %v", call.Err)
|
||||
return
|
||||
}
|
||||
return
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -58,9 +57,6 @@ func truncatedFlagSet(pkt []byte) bool {
|
||||
}
|
||||
|
||||
const (
|
||||
// responseTimeout is the maximal amount of time to wait for a DNS response.
|
||||
responseTimeout = 5 * time.Second
|
||||
|
||||
// dohTransportTimeout is how long to keep idle HTTP
|
||||
// connections open to DNS-over-HTTPs servers. This is pretty
|
||||
// arbitrary.
|
||||
@@ -477,7 +473,7 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
|
||||
metricDNSFwdDoHErrorCT.Add(1)
|
||||
return nil, fmt.Errorf("unexpected response Content-Type %q", ct)
|
||||
}
|
||||
res, err := ioutil.ReadAll(hres.Body)
|
||||
res, err := io.ReadAll(hres.Body)
|
||||
if err != nil {
|
||||
metricDNSFwdDoHErrorBody.Add(1)
|
||||
}
|
||||
@@ -487,13 +483,13 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
|
||||
return res, err
|
||||
}
|
||||
|
||||
var verboseDNSForward = envknob.Bool("TS_DEBUG_DNS_FORWARD_SEND")
|
||||
var verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
|
||||
|
||||
// send sends packet to dst. It is best effort.
|
||||
//
|
||||
// send expects the reply to have the same txid as txidOut.
|
||||
func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
||||
if verboseDNSForward {
|
||||
if verboseDNSForward() {
|
||||
f.logf("forwarder.send(%q) ...", rr.name.Addr)
|
||||
defer func() {
|
||||
f.logf("forwarder.send(%q) = %v, %v", rr.name.Addr, len(ret), err)
|
||||
|
||||
@@ -141,7 +141,7 @@ func (r *Resolver) ttl() time.Duration {
|
||||
return 10 * time.Minute
|
||||
}
|
||||
|
||||
var debug = envknob.Bool("TS_DEBUG_DNS_CACHE")
|
||||
var debug = envknob.RegisterBool("TS_DEBUG_DNS_CACHE")
|
||||
|
||||
// LookupIP returns the host's primary IP address (either IPv4 or
|
||||
// IPv6, but preferring IPv4) and optionally its IPv6 address, if
|
||||
@@ -167,14 +167,14 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr
|
||||
}
|
||||
if ip, err := netip.ParseAddr(host); err == nil {
|
||||
ip = ip.Unmap()
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: %q is an IP", host)
|
||||
}
|
||||
return ip, zaddr, []netip.Addr{ip}, nil
|
||||
}
|
||||
|
||||
if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok {
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: %q = %v (cached)", host, ip)
|
||||
}
|
||||
return ip, ip6, allIPs, nil
|
||||
@@ -192,13 +192,13 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr
|
||||
if res.Err != nil {
|
||||
if r.UseLastGood {
|
||||
if ip, ip6, allIPs, ok := r.lookupIPCacheExpired(host); ok {
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: %q using %v after error", host, ip)
|
||||
}
|
||||
return ip, ip6, allIPs, nil
|
||||
}
|
||||
}
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: error resolving %q: %v", host, res.Err)
|
||||
}
|
||||
return zaddr, zaddr, nil, res.Err
|
||||
@@ -206,7 +206,7 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr
|
||||
r := res.Val
|
||||
return r.ip, r.ip6, r.allIPs, nil
|
||||
case <-ctx.Done():
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err())
|
||||
}
|
||||
return zaddr, zaddr, nil, ctx.Err()
|
||||
@@ -250,7 +250,7 @@ func (r *Resolver) lookupTimeoutForHost(host string) time.Duration {
|
||||
|
||||
func (r *Resolver) lookupIP(host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, err error) {
|
||||
if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok {
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: %q found in cache as %v", host, ip)
|
||||
}
|
||||
return ip, ip6, allIPs, nil
|
||||
@@ -300,13 +300,13 @@ func (r *Resolver) addIPCache(host string, ip, ip6 netip.Addr, allIPs []netip.Ad
|
||||
if ip.IsPrivate() {
|
||||
// Don't cache obviously wrong entries from captive portals.
|
||||
// TODO: use DoH or DoT for the forwarding resolver?
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: %q resolved to private IP %v; using but not caching", host, ip)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: %q resolved to IP %v; caching", host, ip)
|
||||
}
|
||||
|
||||
@@ -361,7 +361,7 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC
|
||||
defer func() {
|
||||
// On failure, consider that our DNS might be wrong and ask the DNS fallback mechanism for
|
||||
// some other IPs to try.
|
||||
if ret == nil || ctx.Err() != nil || d.dnsCache.LookupIPFallback == nil || dc.dnsWasTrustworthy() {
|
||||
if !d.shouldTryBootstrap(ctx, ret, dc) {
|
||||
return
|
||||
}
|
||||
ips, err := d.dnsCache.LookupIPFallback(ctx, host)
|
||||
@@ -382,7 +382,7 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC
|
||||
}
|
||||
i4s := v4addrs(allIPs)
|
||||
if len(i4s) < 2 {
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("dnscache: dialing %s, %s for %s", network, ip, address)
|
||||
}
|
||||
c, err := dc.dialOne(ctx, ip.Unmap())
|
||||
@@ -398,6 +398,40 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC
|
||||
return dc.raceDial(ctx, ipsToTry)
|
||||
}
|
||||
|
||||
func (d *dialer) shouldTryBootstrap(ctx context.Context, err error, dc *dialCall) bool {
|
||||
// No need to do anything when we succeeded.
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Can't try bootstrap DNS if we don't have a fallback function
|
||||
if d.dnsCache.LookupIPFallback == nil {
|
||||
if debug() {
|
||||
log.Printf("dnscache: not using bootstrap DNS: no fallback")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// We can't retry if the context is canceled, since any further
|
||||
// operations with this context will fail.
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
if debug() {
|
||||
log.Printf("dnscache: not using bootstrap DNS: context error: %v", ctxErr)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
wasTrustworthy := dc.dnsWasTrustworthy()
|
||||
if wasTrustworthy {
|
||||
if debug() {
|
||||
log.Printf("dnscache: not using bootstrap DNS: DNS was trustworthy")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// dialCall is the state around a single call to dial.
|
||||
type dialCall struct {
|
||||
d *dialer
|
||||
|
||||
@@ -164,3 +164,102 @@ func TestInterleaveSlices(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldTryBootstrap(t *testing.T) {
|
||||
oldDebug := debug
|
||||
t.Cleanup(func() { debug = oldDebug })
|
||||
debug = func() bool { return true }
|
||||
|
||||
type step struct {
|
||||
ip netip.Addr // IP we pretended to dial
|
||||
err error // the dial error or nil for success
|
||||
}
|
||||
|
||||
canceled, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0)
|
||||
defer cancel()
|
||||
|
||||
ctx := context.Background()
|
||||
errFailed := errors.New("some failure")
|
||||
|
||||
cacheWithFallback := &Resolver{
|
||||
LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) {
|
||||
panic("unimplemented")
|
||||
},
|
||||
}
|
||||
cacheNoFallback := &Resolver{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
steps []step
|
||||
ctx context.Context
|
||||
err error
|
||||
noFallback bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no-error",
|
||||
ctx: ctx,
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "canceled",
|
||||
ctx: canceled,
|
||||
err: errFailed,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "deadline-exceeded",
|
||||
ctx: deadlineExceeded,
|
||||
err: errFailed,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no-fallback",
|
||||
ctx: ctx,
|
||||
err: errFailed,
|
||||
noFallback: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "dns-was-trustworthy",
|
||||
ctx: ctx,
|
||||
err: errFailed,
|
||||
steps: []step{
|
||||
{netip.MustParseAddr("2003::1"), nil},
|
||||
{netip.MustParseAddr("2003::1"), errFailed},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "should-bootstrap",
|
||||
ctx: ctx,
|
||||
err: errFailed,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &dialer{
|
||||
pastConnect: map[netip.Addr]time.Time{},
|
||||
}
|
||||
if tt.noFallback {
|
||||
d.dnsCache = cacheNoFallback
|
||||
} else {
|
||||
d.dnsCache = cacheWithFallback
|
||||
}
|
||||
dc := &dialCall{d: d}
|
||||
for _, st := range tt.steps {
|
||||
dc.noteDialResult(st.ip, st.err)
|
||||
}
|
||||
got := d.shouldTryBootstrap(tt.ctx, tt.err, dc)
|
||||
if got != tt.want {
|
||||
t.Errorf("got %v; want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -42,7 +41,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := ioutil.WriteFile("dns-fallback-servers.json", out, 0644); err != nil {
|
||||
if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -23,7 +22,7 @@ func TestGoogleCloudRunDefaultRouteInterface(t *testing.T) {
|
||||
buf := []byte("Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT\n" +
|
||||
"eth0\t8008FEA9\t00000000\t0001\t0\t0\t0\t01FFFFFF\t0\t0\t0\n" +
|
||||
"eth1\t00000000\t00000000\t0001\t0\t0\t0\t00000000\t0\t0\t0\n")
|
||||
err := ioutil.WriteFile(procNetRoutePath, buf, 0644)
|
||||
err := os.WriteFile(procNetRoutePath, buf, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -87,7 +86,7 @@ func TestAwsAppRunnerDefaultRouteInterface(t *testing.T) {
|
||||
"ecs-eth0\t02AAFEA9\t01ACFEA9\t0007\t0\t0\t0\tFFFFFFFF\t0\t0\t0\n" +
|
||||
"ecs-eth0\t00ACFEA9\t00000000\t0001\t0\t0\t0\t00FFFFFF\t0\t0\t0\n" +
|
||||
"eth0\t00AFFEA9\t00000000\t0001\t0\t0\t0\t00FFFFFF\t0\t0\t0\n")
|
||||
err := ioutil.WriteFile(procNetRoutePath, buf, 0644)
|
||||
err := os.WriteFile(procNetRoutePath, buf, 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -43,7 +43,7 @@ import (
|
||||
|
||||
// Debugging and experimentation tweakables.
|
||||
var (
|
||||
debugNetcheck = envknob.Bool("TS_DEBUG_NETCHECK")
|
||||
debugNetcheck = envknob.RegisterBool("TS_DEBUG_NETCHECK")
|
||||
)
|
||||
|
||||
// The various default timeouts for things.
|
||||
@@ -113,6 +113,10 @@ type Report struct {
|
||||
GlobalV4 string // ip:port of global IPv4
|
||||
GlobalV6 string // [ip]:port of global IPv6
|
||||
|
||||
// CaptivePortal is set when we think there's a captive portal that is
|
||||
// intercepting HTTP traffic.
|
||||
CaptivePortal opt.Bool
|
||||
|
||||
// TODO: update Clone when adding new fields
|
||||
}
|
||||
|
||||
@@ -176,6 +180,10 @@ type Client struct {
|
||||
// If nil, portmap discovery is not done.
|
||||
PortMapper *portmapper.Client // lazily initialized on first use
|
||||
|
||||
// For tests
|
||||
testEnoughRegions int
|
||||
testCaptivePortalDelay time.Duration
|
||||
|
||||
mu sync.Mutex // guards following
|
||||
nextFull bool // do a full region scan, even if last != nil
|
||||
prev map[time.Time]*Report // some previous reports
|
||||
@@ -193,6 +201,9 @@ type STUNConn interface {
|
||||
}
|
||||
|
||||
func (c *Client) enoughRegions() int {
|
||||
if c.testEnoughRegions > 0 {
|
||||
return c.testEnoughRegions
|
||||
}
|
||||
if c.Verbose {
|
||||
// Abuse verbose a bit here so netcheck can show all region latencies
|
||||
// in verbose mode.
|
||||
@@ -201,6 +212,14 @@ func (c *Client) enoughRegions() int {
|
||||
return 3
|
||||
}
|
||||
|
||||
func (c *Client) captivePortalDelay() time.Duration {
|
||||
if c.testCaptivePortalDelay > 0 {
|
||||
return c.testCaptivePortalDelay
|
||||
}
|
||||
// Chosen semi-arbitrarily
|
||||
return 200 * time.Millisecond
|
||||
}
|
||||
|
||||
func (c *Client) logf(format string, a ...any) {
|
||||
if c.Logf != nil {
|
||||
c.Logf(format, a...)
|
||||
@@ -210,7 +229,7 @@ func (c *Client) logf(format string, a ...any) {
|
||||
}
|
||||
|
||||
func (c *Client) vlogf(format string, a ...any) {
|
||||
if c.Verbose || debugNetcheck {
|
||||
if c.Verbose || debugNetcheck() {
|
||||
c.logf(format, a...)
|
||||
}
|
||||
}
|
||||
@@ -784,13 +803,35 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (_ *Report,
|
||||
}
|
||||
c.curState = rs
|
||||
last := c.last
|
||||
|
||||
// Even if we're doing a non-incremental update, we may want to try our
|
||||
// preferred DERP region for captive portal detection. Save that, if we
|
||||
// have it.
|
||||
var preferredDERP int
|
||||
if last != nil {
|
||||
preferredDERP = last.PreferredDERP
|
||||
}
|
||||
|
||||
now := c.timeNow()
|
||||
|
||||
doFull := false
|
||||
if c.nextFull || now.Sub(c.lastFull) > 5*time.Minute {
|
||||
doFull = true
|
||||
}
|
||||
// If the last report had a captive portal and reported no UDP access,
|
||||
// it's possible that we didn't get a useful netcheck due to the
|
||||
// captive portal blocking us. If so, make this report a full
|
||||
// (non-incremental) one.
|
||||
if !doFull && last != nil {
|
||||
doFull = !last.UDP && last.CaptivePortal.EqualBool(true)
|
||||
}
|
||||
if doFull {
|
||||
last = nil // causes makeProbePlan below to do a full (initial) plan
|
||||
c.nextFull = false
|
||||
c.lastFull = now
|
||||
metricNumGetReportFull.Add(1)
|
||||
}
|
||||
|
||||
rs.incremental = last != nil
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -875,6 +916,48 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (_ *Report,
|
||||
|
||||
plan := makeProbePlan(dm, ifState, last)
|
||||
|
||||
// If we're doing a full probe, also check for a captive portal. We
|
||||
// delay by a bit to wait for UDP STUN to finish, to avoid the probe if
|
||||
// it's unnecessary.
|
||||
captivePortalDone := syncs.ClosedChan()
|
||||
captivePortalStop := func() {}
|
||||
if !rs.incremental {
|
||||
// NOTE(andrew): we can't simply add this goroutine to the
|
||||
// `NewWaitGroupChan` below, since we don't wait for that
|
||||
// waitgroup to finish when exiting this function and thus get
|
||||
// a data race.
|
||||
ch := make(chan struct{})
|
||||
captivePortalDone = ch
|
||||
|
||||
tmr := time.AfterFunc(c.captivePortalDelay(), func() {
|
||||
defer close(ch)
|
||||
found, err := c.checkCaptivePortal(ctx, dm, preferredDERP)
|
||||
if err != nil {
|
||||
c.logf("[v1] checkCaptivePortal: %v", err)
|
||||
return
|
||||
}
|
||||
rs.report.CaptivePortal.Set(found)
|
||||
})
|
||||
|
||||
captivePortalStop = func() {
|
||||
// Don't cancel our captive portal check if we're
|
||||
// explicitly doing a verbose netcheck.
|
||||
if c.Verbose {
|
||||
return
|
||||
}
|
||||
|
||||
if tmr.Stop() {
|
||||
// Stopped successfully; need to close the
|
||||
// signal channel ourselves.
|
||||
close(ch)
|
||||
return
|
||||
}
|
||||
|
||||
// Did not stop; do nothing and it'll finish by itself
|
||||
// and close the signal channel.
|
||||
}
|
||||
}
|
||||
|
||||
wg := syncs.NewWaitGroupChan()
|
||||
wg.Add(len(plan))
|
||||
for _, probeSet := range plan {
|
||||
@@ -895,9 +978,17 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (_ *Report,
|
||||
case <-stunTimer.C:
|
||||
case <-ctx.Done():
|
||||
case <-wg.DoneChan():
|
||||
// All of our probes finished, so if we have >0 responses, we
|
||||
// stop our captive portal check.
|
||||
if rs.anyUDP() {
|
||||
captivePortalStop()
|
||||
}
|
||||
case <-rs.stopProbeCh:
|
||||
// Saw enough regions.
|
||||
c.vlogf("saw enough regions; not waiting for rest")
|
||||
// We can stop the captive portal check since we know that we
|
||||
// got a bunch of STUN responses.
|
||||
captivePortalStop()
|
||||
}
|
||||
|
||||
rs.waitHairCheck(ctx)
|
||||
@@ -966,6 +1057,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (_ *Report,
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Wait for captive portal check before finishing the report.
|
||||
<-captivePortalDone
|
||||
|
||||
return c.finishAndStoreReport(rs, dm), nil
|
||||
}
|
||||
|
||||
@@ -980,6 +1074,54 @@ func (c *Client) finishAndStoreReport(rs *reportState, dm *tailcfg.DERPMap) *Rep
|
||||
return report
|
||||
}
|
||||
|
||||
var noRedirectClient = &http.Client{
|
||||
// No redirects allowed
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
|
||||
// Remaining fields are the same as the default client.
|
||||
Transport: http.DefaultClient.Transport,
|
||||
Jar: http.DefaultClient.Jar,
|
||||
Timeout: http.DefaultClient.Timeout,
|
||||
}
|
||||
|
||||
// checkCaptivePortal reports whether or not we think the system is behind a
|
||||
// captive portal, detected by making a request to a URL that we know should
|
||||
// return a "204 No Content" response and checking if that's what we get.
|
||||
//
|
||||
// The boolean return is whether we think we have a captive portal.
|
||||
func (c *Client) checkCaptivePortal(ctx context.Context, dm *tailcfg.DERPMap, preferredDERP int) (bool, error) {
|
||||
defer noRedirectClient.CloseIdleConnections()
|
||||
|
||||
// If we have a preferred DERP region with more than one node, try
|
||||
// that; otherwise, pick a random one not marked as "Avoid".
|
||||
if preferredDERP == 0 || dm.Regions[preferredDERP] == nil ||
|
||||
(preferredDERP != 0 && len(dm.Regions[preferredDERP].Nodes) == 0) {
|
||||
rids := make([]int, 0, len(dm.Regions))
|
||||
for id, reg := range dm.Regions {
|
||||
if reg == nil || reg.Avoid || len(reg.Nodes) == 0 {
|
||||
continue
|
||||
}
|
||||
rids = append(rids, id)
|
||||
}
|
||||
preferredDERP = rids[rand.Intn(len(rids))]
|
||||
}
|
||||
|
||||
node := dm.Regions[preferredDERP].Nodes[0]
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://"+node.HostName+"/generate_204", nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
r, err := noRedirectClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
c.logf("[v2] checkCaptivePortal url=%q status_code=%d", req.URL.String(), r.StatusCode)
|
||||
|
||||
return r.StatusCode != 204, nil
|
||||
}
|
||||
|
||||
// runHTTPOnlyChecks is the netcheck done by environments that can
|
||||
// only do HTTP requests, such as ws/wasm.
|
||||
func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *reportState, dm *tailcfg.DERPMap) error {
|
||||
@@ -1096,7 +1238,7 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio
|
||||
return 0, ip, fmt.Errorf("unexpected status code: %d (%s)", resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
_, err = io.Copy(ioutil.Discard, io.LimitReader(resp.Body, 8<<10))
|
||||
_, err = io.Copy(io.Discard, io.LimitReader(resp.Body, 8<<10))
|
||||
if err != nil {
|
||||
return 0, ip, err
|
||||
}
|
||||
@@ -1201,6 +1343,9 @@ func (c *Client) logConciseReport(r *Report, dm *tailcfg.DERPMap) {
|
||||
if r.GlobalV6 != "" {
|
||||
fmt.Fprintf(w, " v6a=%v", r.GlobalV6)
|
||||
}
|
||||
if r.CaptivePortal != "" {
|
||||
fmt.Fprintf(w, " captiveportal=%v", r.CaptivePortal)
|
||||
}
|
||||
fmt.Fprintf(w, " derp=%v", r.PreferredDERP)
|
||||
if r.PreferredDERP != 0 {
|
||||
fmt.Fprintf(w, " derpdist=")
|
||||
|
||||
@@ -9,11 +9,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -115,6 +117,9 @@ func TestWorksWhenUDPBlocked(t *testing.T) {
|
||||
// OS IPv6 test is irrelevant here, accept whatever the current
|
||||
// machine has.
|
||||
want.OSHasIPv6 = r.OSHasIPv6
|
||||
// Captive portal test is irrelevant; accept what the current report
|
||||
// has.
|
||||
want.CaptivePortal = r.CaptivePortal
|
||||
|
||||
if !reflect.DeepEqual(r, want) {
|
||||
t.Errorf("mismatch\n got: %+v\nwant: %+v\n", r, want)
|
||||
@@ -661,3 +666,57 @@ func TestSortRegions(t *testing.T) {
|
||||
t.Errorf("got %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoCaptivePortalWhenUDP(t *testing.T) {
|
||||
// Override noRedirectClient to handle the /generate_204 endpoint
|
||||
var generate204Called atomic.Bool
|
||||
tr := RoundTripFunc(func(req *http.Request) *http.Response {
|
||||
if !strings.HasSuffix(req.URL.String(), "/generate_204") {
|
||||
panic("bad URL: " + req.URL.String())
|
||||
}
|
||||
generate204Called.Store(true)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusNoContent,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
})
|
||||
|
||||
oldTransport := noRedirectClient.Transport
|
||||
t.Cleanup(func() { noRedirectClient.Transport = oldTransport })
|
||||
noRedirectClient.Transport = tr
|
||||
|
||||
stunAddr, cleanup := stuntest.Serve(t)
|
||||
defer cleanup()
|
||||
|
||||
c := &Client{
|
||||
Logf: t.Logf,
|
||||
UDPBindAddr: "127.0.0.1:0",
|
||||
testEnoughRegions: 1,
|
||||
|
||||
// Set the delay long enough that we have time to cancel it
|
||||
// when our STUN probe succeeds.
|
||||
testCaptivePortalDelay: 10 * time.Second,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
r, err := c.GetReport(ctx, stuntest.DERPMapOf(stunAddr.String()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should not have called our captive portal function.
|
||||
if generate204Called.Load() {
|
||||
t.Errorf("captive portal check called; expected no call")
|
||||
}
|
||||
if r.CaptivePortal != "" {
|
||||
t.Errorf("got CaptivePortal=%q, want empty", r.CaptivePortal)
|
||||
}
|
||||
}
|
||||
|
||||
type RoundTripFunc func(req *http.Request) *http.Response
|
||||
|
||||
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req), nil
|
||||
}
|
||||
|
||||
@@ -20,6 +20,11 @@ var (
|
||||
androidProtectFunc func(fd int) error
|
||||
)
|
||||
|
||||
// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK.
|
||||
func UseSocketMark() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetAndroidProtectFunc register a func that Android provides that JNI calls into
|
||||
// https://developer.android.com/reference/android/net/VpnService#protect(int)
|
||||
// which is documented as:
|
||||
|
||||
@@ -63,12 +63,12 @@ func socketMarkWorks() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
var forceBindToDevice = envknob.Bool("TS_FORCE_LINUX_BIND_TO_DEVICE")
|
||||
var forceBindToDevice = envknob.RegisterBool("TS_FORCE_LINUX_BIND_TO_DEVICE")
|
||||
|
||||
// useSocketMark reports whether SO_MARK works.
|
||||
// UseSocketMark reports whether SO_MARK is in use.
|
||||
// If it doesn't, we have to use SO_BINDTODEVICE on our sockets instead.
|
||||
func useSocketMark() bool {
|
||||
if forceBindToDevice {
|
||||
func UseSocketMark() bool {
|
||||
if forceBindToDevice() {
|
||||
return false
|
||||
}
|
||||
socketMarkWorksOnce.Do(func() {
|
||||
@@ -103,7 +103,7 @@ func controlC(network, address string, c syscall.RawConn) error {
|
||||
|
||||
var sockErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
if useSocketMark() {
|
||||
if UseSocketMark() {
|
||||
sockErr = setBypassMark(fd)
|
||||
} else {
|
||||
sockErr = bindToDevice(fd)
|
||||
|
||||
@@ -208,7 +208,7 @@ func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) {
|
||||
b = b[:attrsLen] // trim trailing packet bytes
|
||||
}
|
||||
|
||||
var addr6, fallbackAddr, fallbackAddr6 netip.AddrPort
|
||||
var fallbackAddr netip.AddrPort
|
||||
|
||||
// Read through the attributes.
|
||||
// The the addr+port reported by XOR-MAPPED-ADDRESS
|
||||
@@ -218,24 +218,20 @@ func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) {
|
||||
if err := foreachAttr(b, func(attrType uint16, attr []byte) error {
|
||||
switch attrType {
|
||||
case attrXorMappedAddress, attrXorMappedAddressAlt:
|
||||
a, p, err := xorMappedAddress(tID, attr)
|
||||
ipSlice, port, err := xorMappedAddress(tID, attr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(a) == 16 {
|
||||
addr6 = netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)([]byte(a))), p)
|
||||
} else {
|
||||
addr = netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)([]byte(a))), p)
|
||||
if ip, ok := netip.AddrFromSlice(ipSlice); ok {
|
||||
addr = netip.AddrPortFrom(ip.Unmap(), port)
|
||||
}
|
||||
case attrMappedAddress:
|
||||
a, p, err := mappedAddress(attr)
|
||||
ipSlice, port, err := mappedAddress(attr)
|
||||
if err != nil {
|
||||
return ErrMalformedAttrs
|
||||
}
|
||||
if len(a) == 16 {
|
||||
fallbackAddr6 = netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)([]byte(a))), p)
|
||||
} else {
|
||||
fallbackAddr = netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)([]byte(a))), p)
|
||||
if ip, ok := netip.AddrFromSlice(ipSlice); ok {
|
||||
fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -250,12 +246,6 @@ func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) {
|
||||
if fallbackAddr.IsValid() {
|
||||
return tID, fallbackAddr, nil
|
||||
}
|
||||
if addr6.IsValid() {
|
||||
return tID, addr6, nil
|
||||
}
|
||||
if fallbackAddr6.IsValid() {
|
||||
return tID, fallbackAddr6, nil
|
||||
}
|
||||
return tID, netip.AddrPort{}, ErrMalformedAttrs
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,13 @@ package stun_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/net/stun"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
// TODO(bradfitz): fuzz this.
|
||||
@@ -175,6 +177,13 @@ var responseTests = []struct {
|
||||
wantAddr: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
|
||||
wantPort: 61300,
|
||||
},
|
||||
{
|
||||
name: "no-4in6",
|
||||
data: must.Get(hex.DecodeString("010100182112a4424fd5d202dcb37d31fc773306002000140002cd3d2112a4424fd5d202dcb382ce2dc3fcc7")),
|
||||
wantTID: []byte{79, 213, 210, 2, 220, 179, 125, 49, 252, 119, 51, 6},
|
||||
wantAddr: netip.AddrFrom4([4]byte{209, 180, 207, 193}),
|
||||
wantPort: 60463,
|
||||
},
|
||||
}
|
||||
|
||||
func TestParseResponse(t *testing.T) {
|
||||
|
||||
@@ -32,7 +32,7 @@ var counterFallbackOK int32 // atomic
|
||||
// See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
|
||||
var sslKeyLogFile = os.Getenv("SSLKEYLOGFILE")
|
||||
|
||||
var debug = envknob.Bool("TS_DEBUG_TLS_DIAL")
|
||||
var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL")
|
||||
|
||||
// Config returns a tls.Config for connecting to a server.
|
||||
// If base is non-nil, it's cloned as the base config before
|
||||
@@ -77,7 +77,7 @@ func Config(host string, base *tls.Config) *tls.Config {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, errSys := cs.PeerCertificates[0].Verify(opts)
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("tlsdial(sys %q): %v", host, errSys)
|
||||
}
|
||||
if errSys == nil {
|
||||
@@ -88,7 +88,7 @@ func Config(host string, base *tls.Config) *tls.Config {
|
||||
// or broken, fall back to trying LetsEncrypt at least.
|
||||
opts.Roots = bakedInRoots()
|
||||
_, err := cs.PeerCertificates[0].Verify(opts)
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("tlsdial(bake %q): %v", host, err)
|
||||
}
|
||||
if err == nil {
|
||||
@@ -142,7 +142,7 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, errSys := certs[0].Verify(opts)
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("tlsdial(sys %q/%q): %v", c.ServerName, certDNSName, errSys)
|
||||
}
|
||||
if errSys == nil {
|
||||
@@ -150,7 +150,7 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) {
|
||||
}
|
||||
opts.Roots = bakedInRoots()
|
||||
_, err := certs[0].Verify(opts)
|
||||
if debug {
|
||||
if debug() {
|
||||
log.Printf("tlsdial(bake %q/%q): %v", c.ServerName, certDNSName, err)
|
||||
}
|
||||
if err == nil {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -60,7 +59,7 @@ func TestSynologyProxyFromConfigCached(t *testing.T) {
|
||||
cache.httpProxy = nil
|
||||
cache.httpsProxy = nil
|
||||
|
||||
if err := ioutil.WriteFile(synologyProxyConfigPath, []byte(`
|
||||
if err := os.WriteFile(synologyProxyConfigPath, []byte(`
|
||||
proxy_enabled=yes
|
||||
http_host=10.0.0.55
|
||||
http_port=80
|
||||
@@ -116,7 +115,7 @@ https_port=443
|
||||
cache.httpProxy = nil
|
||||
cache.httpsProxy = nil
|
||||
|
||||
if err := ioutil.WriteFile(synologyProxyConfigPath, []byte(`
|
||||
if err := os.WriteFile(synologyProxyConfigPath, []byte(`
|
||||
proxy_enabled=yes
|
||||
http_host=10.0.0.55
|
||||
http_port=80
|
||||
|
||||
@@ -20,14 +20,6 @@ import (
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
var tunMTU = DefaultMTU
|
||||
|
||||
func init() {
|
||||
if mtu, ok := envknob.LookupInt("TS_DEBUG_MTU"); ok {
|
||||
tunMTU = mtu
|
||||
}
|
||||
}
|
||||
|
||||
// createTAP is non-nil on Linux.
|
||||
var createTAP func(tapName, bridgeName string) (tun.Device, error)
|
||||
|
||||
@@ -52,6 +44,10 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) {
|
||||
}
|
||||
dev, err = createTAP(tapName, bridgeName)
|
||||
} else {
|
||||
tunMTU := DefaultMTU
|
||||
if mtu, ok := envknob.LookupInt("TS_DEBUG_MTU"); ok {
|
||||
tunMTU = mtu
|
||||
}
|
||||
dev, err = tun.CreateTUN(tunName, tunMTU)
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
367
net/tunstats/stats.go
Normal file
367
net/tunstats/stats.go
Normal file
@@ -0,0 +1,367 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tunstats maintains statistics about connections
|
||||
// flowing through a TUN device (which operate at the IP layer).
|
||||
package tunstats
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/maphash"
|
||||
"math/bits"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"tailscale.com/net/flowtrack"
|
||||
"tailscale.com/types/ipproto"
|
||||
)
|
||||
|
||||
// Statistics maintains counters for every connection.
|
||||
// All methods are safe for concurrent use.
|
||||
// The zero value is ready for use.
|
||||
type Statistics struct {
|
||||
v4 hashTable[addrsPortsV4]
|
||||
v6 hashTable[addrsPortsV6]
|
||||
}
|
||||
|
||||
// Counts are statistics about a particular connection.
|
||||
type Counts struct {
|
||||
TxPackets uint64 `json:"txPkts,omitempty"`
|
||||
TxBytes uint64 `json:"txBytes,omitempty"`
|
||||
RxPackets uint64 `json:"rxPkts,omitempty"`
|
||||
RxBytes uint64 `json:"rxBytes,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
minTableLen = 8
|
||||
maxProbeLen = 64
|
||||
)
|
||||
|
||||
// hashTable is a hash table that uses open addressing with probing.
|
||||
// See https://en.wikipedia.org/wiki/Hash_table#Open_addressing.
|
||||
// The primary table is in the active field and can be retrieved atomically.
|
||||
// In the common case, this data structure is mostly lock free.
|
||||
//
|
||||
// If the current table is too small, a new table is allocated that
|
||||
// replaces the current active table. The contents of the older table are
|
||||
// NOT copied to the new table, but rather the older table is appended
|
||||
// to a list of outgrown tables. Re-growth happens under a lock,
|
||||
// but is expected to happen rarely as the table size grows exponentially.
|
||||
//
|
||||
// To reduce memory usage, the counters uses 32-bit unsigned integers,
|
||||
// which carry the risk of overflowing. If an overflow is detected,
|
||||
// we add the amount overflowed to the overflow map. This is a naive Go map
|
||||
// protected by a sync.Mutex. Overflow is rare that contention is not a concern.
|
||||
//
|
||||
// To extract all counters, we replace the active table with a zeroed table,
|
||||
// and clear out the outgrown and overflow tables.
|
||||
// We take advantage of the fact that all the tables can be merged together
|
||||
// by simply adding up all the counters for each connection.
|
||||
type hashTable[AddrsPorts addrsPorts] struct {
|
||||
// TODO: Get rid of this. It is just an atomic update in the common case,
|
||||
// but contention updating the same word still incurs a 25% performance hit.
|
||||
mu sync.RWMutex // RLock held while updating, Lock held while extracting
|
||||
|
||||
active atomic.Pointer[countsTable[AddrsPorts]]
|
||||
inserts atomic.Uint32 // heuristic for next active table to allocate
|
||||
|
||||
muGrow sync.Mutex // muGrow.Lock implies that mu.RLock held
|
||||
outgrown []countsTable[AddrsPorts]
|
||||
|
||||
muOverflow sync.Mutex // muOverflow.Lock implies that mu.RLock held
|
||||
overflow map[flowtrack.Tuple]Counts
|
||||
}
|
||||
|
||||
type countsTable[AddrsPorts addrsPorts] []counts[AddrsPorts]
|
||||
|
||||
func (t *countsTable[AddrsPorts]) len() int {
|
||||
if t == nil {
|
||||
return 0
|
||||
}
|
||||
return len(*t)
|
||||
}
|
||||
|
||||
type counts[AddrsPorts addrsPorts] struct {
|
||||
// initProto is both an initialization flag and the IP protocol.
|
||||
// It is 0 if uninitialized, 1 if initializing, and
|
||||
// 2+ipproto.Proto if initialized.
|
||||
initProto atomic.Uint32
|
||||
|
||||
addrsPorts AddrsPorts // only valid if initProto is initialized
|
||||
|
||||
txPackets atomic.Uint32
|
||||
txBytes atomic.Uint32
|
||||
rxPackets atomic.Uint32
|
||||
rxBytes atomic.Uint32
|
||||
}
|
||||
|
||||
// NOTE: There is some degree of duplicated code.
|
||||
// For example, the functionality to swap the addrsPorts and compute the hash
|
||||
// should be performed by hashTable.update rather than Statistics.update.
|
||||
// However, Go generics cannot invoke pointer methods on addressable values.
|
||||
// See https://go.googlesource.com/proposal/+/refs/heads/master/design/43651-type-parameters.md#no-way-to-require-pointer-methods
|
||||
|
||||
type addrsPorts interface {
|
||||
comparable
|
||||
asTuple(ipproto.Proto) flowtrack.Tuple
|
||||
}
|
||||
|
||||
type addrsPortsV4 [4 + 4 + 2 + 2]byte
|
||||
|
||||
func (x *addrsPortsV4) addrs() *[8]byte { return (*[8]byte)(x[:]) }
|
||||
func (x *addrsPortsV4) ports() *[4]byte { return (*[4]byte)(x[8:]) }
|
||||
func (x *addrsPortsV4) swap() {
|
||||
*(*[4]byte)(x[0:]), *(*[4]byte)(x[4:]) = *(*[4]byte)(x[4:]), *(*[4]byte)(x[0:])
|
||||
*(*[2]byte)(x[8:]), *(*[2]byte)(x[10:]) = *(*[2]byte)(x[10:]), *(*[2]byte)(x[8:])
|
||||
}
|
||||
func (x addrsPortsV4) asTuple(proto ipproto.Proto) flowtrack.Tuple {
|
||||
return flowtrack.Tuple{Proto: proto,
|
||||
Src: netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(x[0:])), binary.BigEndian.Uint16(x[8:])),
|
||||
Dst: netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(x[4:])), binary.BigEndian.Uint16(x[10:])),
|
||||
}
|
||||
}
|
||||
|
||||
type addrsPortsV6 [16 + 16 + 2 + 2]byte
|
||||
|
||||
func (x *addrsPortsV6) addrs() *[32]byte { return (*[32]byte)(x[:]) }
|
||||
func (x *addrsPortsV6) ports() *[4]byte { return (*[4]byte)(x[32:]) }
|
||||
func (x *addrsPortsV6) swap() {
|
||||
*(*[16]byte)(x[0:]), *(*[16]byte)(x[16:]) = *(*[16]byte)(x[16:]), *(*[16]byte)(x[0:])
|
||||
*(*[2]byte)(x[32:]), *(*[2]byte)(x[34:]) = *(*[2]byte)(x[34:]), *(*[2]byte)(x[32:])
|
||||
}
|
||||
func (x addrsPortsV6) asTuple(proto ipproto.Proto) flowtrack.Tuple {
|
||||
return flowtrack.Tuple{Proto: proto,
|
||||
Src: netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(x[0:])), binary.BigEndian.Uint16(x[32:])),
|
||||
Dst: netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(x[16:])), binary.BigEndian.Uint16(x[34:])),
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTx updates the statistics for a transmitted IP packet.
|
||||
func (s *Statistics) UpdateTx(b []byte) {
|
||||
s.update(b, false)
|
||||
}
|
||||
|
||||
// UpdateRx updates the statistics for a received IP packet.
|
||||
func (s *Statistics) UpdateRx(b []byte) {
|
||||
s.update(b, true)
|
||||
}
|
||||
|
||||
var seed = maphash.MakeSeed()
|
||||
|
||||
func (s *Statistics) update(b []byte, receive bool) {
|
||||
switch {
|
||||
case len(b) >= 20 && b[0]>>4 == 4: // IPv4
|
||||
proto := ipproto.Proto(b[9])
|
||||
hasPorts := proto == ipproto.TCP || proto == ipproto.UDP
|
||||
var addrsPorts addrsPortsV4
|
||||
if hdrLen := int(4 * (b[0] & 0xf)); hdrLen == 20 && len(b) >= 24 && hasPorts {
|
||||
addrsPorts = *(*addrsPortsV4)(b[12:]) // addresses and ports are contiguous
|
||||
} else {
|
||||
*addrsPorts.addrs() = *(*[8]byte)(b[12:])
|
||||
// May have IPv4 options in-between address and ports.
|
||||
if len(b) >= hdrLen+4 && hasPorts {
|
||||
*addrsPorts.ports() = *(*[4]byte)(b[hdrLen:])
|
||||
}
|
||||
}
|
||||
if receive {
|
||||
addrsPorts.swap()
|
||||
}
|
||||
hash := maphash.Bytes(seed, addrsPorts[:]) ^ uint64(proto) // TODO: Hash proto better?
|
||||
s.v4.update(receive, proto, &addrsPorts, hash, uint32(len(b)))
|
||||
return
|
||||
case len(b) >= 40 && b[0]>>4 == 6: // IPv6
|
||||
proto := ipproto.Proto(b[6])
|
||||
hasPorts := proto == ipproto.TCP || proto == ipproto.UDP
|
||||
var addrsPorts addrsPortsV6
|
||||
if len(b) >= 44 && hasPorts {
|
||||
addrsPorts = *(*addrsPortsV6)(b[8:]) // addresses and ports are contiguous
|
||||
} else {
|
||||
*addrsPorts.addrs() = *(*[32]byte)(b[8:])
|
||||
// TODO: Support IPv6 extension headers?
|
||||
if hdrLen := 40; len(b) > hdrLen+4 && hasPorts {
|
||||
*addrsPorts.ports() = *(*[4]byte)(b[hdrLen:])
|
||||
}
|
||||
}
|
||||
if receive {
|
||||
addrsPorts.swap()
|
||||
}
|
||||
hash := maphash.Bytes(seed, addrsPorts[:]) ^ uint64(proto) // TODO: Hash proto better?
|
||||
s.v6.update(receive, proto, &addrsPorts, hash, uint32(len(b)))
|
||||
return
|
||||
}
|
||||
// TODO: Track malformed packets?
|
||||
}
|
||||
|
||||
func (h *hashTable[AddrsPorts]) update(receive bool, proto ipproto.Proto, addrsPorts *AddrsPorts, hash uint64, size uint32) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
table := h.active.Load()
|
||||
for {
|
||||
// Start with an initialized table.
|
||||
if table.len() == 0 {
|
||||
table = h.grow(table)
|
||||
}
|
||||
|
||||
// Try to update an entry in the currently active table.
|
||||
for i := 0; i < len(*table) && i < maxProbeLen; i++ {
|
||||
probe := uint64(i) // linear probing for small tables
|
||||
if len(*table) > 2*maxProbeLen {
|
||||
probe *= probe // quadratic probing for large tables
|
||||
}
|
||||
entry := &(*table)[(hash+probe)%uint64(len(*table))]
|
||||
|
||||
// Spin-lock waiting for the entry to be initialized,
|
||||
// which should be quick as it only stores the AddrsPort.
|
||||
retry:
|
||||
switch initProto := entry.initProto.Load(); initProto {
|
||||
case 0: // uninitialized
|
||||
if !entry.initProto.CompareAndSwap(0, 1) {
|
||||
goto retry // raced with another initialization attempt
|
||||
}
|
||||
entry.addrsPorts = *addrsPorts
|
||||
entry.initProto.Store(uint32(proto) + 2) // initialization done
|
||||
h.inserts.Add(1)
|
||||
case 1: // initializing
|
||||
goto retry
|
||||
default: // initialized
|
||||
if ipproto.Proto(initProto-2) != proto || entry.addrsPorts != *addrsPorts {
|
||||
continue // this entry is for a different connection; try next entry
|
||||
}
|
||||
}
|
||||
|
||||
// Atomically update the counters for the connection entry.
|
||||
var overflowPackets, overflowBytes bool
|
||||
if receive {
|
||||
overflowPackets = entry.rxPackets.Add(1) < 1
|
||||
overflowBytes = entry.rxBytes.Add(size) < size
|
||||
} else {
|
||||
overflowPackets = entry.txPackets.Add(1) < 1
|
||||
overflowBytes = entry.txBytes.Add(size) < size
|
||||
}
|
||||
if overflowPackets || overflowBytes {
|
||||
h.updateOverflow(receive, proto, addrsPorts, overflowPackets, overflowBytes)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Unable to update, so grow the table and try again.
|
||||
// TODO: Use overflow map instead if table utilization is too low.
|
||||
table = h.grow(table)
|
||||
}
|
||||
}
|
||||
|
||||
// grow grows the table unless the active table is larger than oldTable.
|
||||
func (h *hashTable[AddrsPorts]) grow(oldTable *countsTable[AddrsPorts]) (newTable *countsTable[AddrsPorts]) {
|
||||
h.muGrow.Lock()
|
||||
defer h.muGrow.Unlock()
|
||||
|
||||
if newTable = h.active.Load(); newTable.len() > oldTable.len() {
|
||||
return newTable // raced with another grow
|
||||
}
|
||||
newTable = new(countsTable[AddrsPorts])
|
||||
if oldTable.len() == 0 {
|
||||
*newTable = make(countsTable[AddrsPorts], minTableLen)
|
||||
} else {
|
||||
*newTable = make(countsTable[AddrsPorts], 2*len(*oldTable))
|
||||
h.outgrown = append(h.outgrown, *oldTable)
|
||||
}
|
||||
h.active.Store(newTable)
|
||||
return newTable
|
||||
}
|
||||
|
||||
// updateOverflow updates the overflow map for counters that overflowed.
|
||||
// Using 32-bit counters, this condition happens rarely as it only triggers
|
||||
// after every 4 GiB of unidirectional network traffic on the same connection.
|
||||
func (h *hashTable[AddrsPorts]) updateOverflow(receive bool, proto ipproto.Proto, addrsPorts *AddrsPorts, overflowPackets, overflowBytes bool) {
|
||||
h.muOverflow.Lock()
|
||||
defer h.muOverflow.Unlock()
|
||||
if h.overflow == nil {
|
||||
h.overflow = make(map[flowtrack.Tuple]Counts)
|
||||
}
|
||||
tuple := (*addrsPorts).asTuple(proto)
|
||||
cnts := h.overflow[tuple]
|
||||
if overflowPackets {
|
||||
if receive {
|
||||
cnts.RxPackets += 1 << 32
|
||||
} else {
|
||||
cnts.TxPackets += 1 << 32
|
||||
}
|
||||
}
|
||||
if overflowBytes {
|
||||
if receive {
|
||||
cnts.RxBytes += 1 << 32
|
||||
} else {
|
||||
cnts.TxBytes += 1 << 32
|
||||
}
|
||||
}
|
||||
h.overflow[tuple] = cnts
|
||||
}
|
||||
|
||||
func (h *hashTable[AddrsPorts]) extractInto(out map[flowtrack.Tuple]Counts) {
|
||||
// Allocate a new table based on previous usage.
|
||||
var newTable *countsTable[AddrsPorts]
|
||||
if numInserts := h.inserts.Load(); numInserts > 0 {
|
||||
newLen := 1 << bits.Len(uint(4*numInserts/3)|uint(minTableLen-1))
|
||||
newTable = new(countsTable[AddrsPorts])
|
||||
*newTable = make(countsTable[AddrsPorts], newLen)
|
||||
}
|
||||
|
||||
// Swap out the old tables for new tables.
|
||||
// We do not need to lock h.muGrow or h.muOverflow since holding h.mu
|
||||
// implies that nothing else could be holding those locks.
|
||||
h.mu.Lock()
|
||||
oldTable := h.active.Swap(newTable)
|
||||
oldOutgrown := h.outgrown
|
||||
oldOverflow := h.overflow
|
||||
h.outgrown = nil
|
||||
h.overflow = nil
|
||||
h.inserts.Store(0)
|
||||
h.mu.Unlock()
|
||||
|
||||
// Merge tables into output.
|
||||
if oldTable != nil {
|
||||
mergeTable(out, *oldTable)
|
||||
}
|
||||
for _, table := range oldOutgrown {
|
||||
mergeTable(out, table)
|
||||
}
|
||||
mergeMap(out, oldOverflow)
|
||||
}
|
||||
|
||||
// Extract extracts and resets the counters for all active connections.
|
||||
// It must be called periodically otherwise the memory used is unbounded.
|
||||
func (s *Statistics) Extract() map[flowtrack.Tuple]Counts {
|
||||
out := make(map[flowtrack.Tuple]Counts)
|
||||
s.v4.extractInto(out)
|
||||
s.v6.extractInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func mergeTable[AddrsPorts addrsPorts](dst map[flowtrack.Tuple]Counts, src countsTable[AddrsPorts]) {
|
||||
for i := range src {
|
||||
entry := &src[i]
|
||||
if initProto := entry.initProto.Load(); initProto > 0 {
|
||||
tuple := entry.addrsPorts.asTuple(ipproto.Proto(initProto - 2))
|
||||
cnts := dst[tuple]
|
||||
cnts.TxPackets += uint64(entry.txPackets.Load())
|
||||
cnts.TxBytes += uint64(entry.txBytes.Load())
|
||||
cnts.RxPackets += uint64(entry.rxPackets.Load())
|
||||
cnts.RxBytes += uint64(entry.rxBytes.Load())
|
||||
dst[tuple] = cnts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeMap(dst, src map[flowtrack.Tuple]Counts) {
|
||||
for tuple, cntsSrc := range src {
|
||||
cntsDst := dst[tuple]
|
||||
cntsDst.TxPackets += cntsSrc.TxPackets
|
||||
cntsDst.TxBytes += cntsSrc.TxBytes
|
||||
cntsDst.RxPackets += cntsSrc.RxPackets
|
||||
cntsDst.RxBytes += cntsSrc.RxBytes
|
||||
dst[tuple] = cntsDst
|
||||
}
|
||||
}
|
||||
325
net/tunstats/stats_test.go
Normal file
325
net/tunstats/stats_test.go
Normal file
@@ -0,0 +1,325 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tunstats
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"math"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
qt "github.com/frankban/quicktest"
|
||||
"tailscale.com/net/flowtrack"
|
||||
"tailscale.com/types/ipproto"
|
||||
)
|
||||
|
||||
type SimpleStatistics struct {
|
||||
mu sync.Mutex
|
||||
m map[flowtrack.Tuple]Counts
|
||||
}
|
||||
|
||||
func (s *SimpleStatistics) UpdateTx(b []byte) {
|
||||
s.update(b, false)
|
||||
}
|
||||
func (s *SimpleStatistics) UpdateRx(b []byte) {
|
||||
s.update(b, true)
|
||||
}
|
||||
func (s *SimpleStatistics) update(b []byte, receive bool) {
|
||||
var tuple flowtrack.Tuple
|
||||
var size uint64
|
||||
if len(b) >= 1 {
|
||||
// This logic is mostly copied from Statistics.update.
|
||||
switch v := b[0] >> 4; {
|
||||
case v == 4 && len(b) >= 20: // IPv4
|
||||
proto := ipproto.Proto(b[9])
|
||||
size = uint64(binary.BigEndian.Uint16(b[2:]))
|
||||
var addrsPorts addrsPortsV4
|
||||
*(*[8]byte)(addrsPorts[0:]) = *(*[8]byte)(b[12:])
|
||||
if hdrLen := int(4 * (b[0] & 0xf)); len(b) >= hdrLen+4 && (proto == ipproto.TCP || proto == ipproto.UDP) {
|
||||
*(*[4]byte)(addrsPorts[8:]) = *(*[4]byte)(b[hdrLen:])
|
||||
}
|
||||
if receive {
|
||||
addrsPorts.swap()
|
||||
}
|
||||
tuple = addrsPorts.asTuple(proto)
|
||||
case v == 6 && len(b) >= 40: // IPv6
|
||||
proto := ipproto.Proto(b[6])
|
||||
size = uint64(binary.BigEndian.Uint16(b[4:]))
|
||||
var addrsPorts addrsPortsV6
|
||||
*(*[32]byte)(addrsPorts[0:]) = *(*[32]byte)(b[8:])
|
||||
if hdrLen := 40; len(b) > hdrLen+4 && (proto == ipproto.TCP || proto == ipproto.UDP) {
|
||||
*(*[4]byte)(addrsPorts[32:]) = *(*[4]byte)(b[hdrLen:])
|
||||
}
|
||||
if receive {
|
||||
addrsPorts.swap()
|
||||
}
|
||||
tuple = addrsPorts.asTuple(proto)
|
||||
default:
|
||||
return // non-IP packet
|
||||
}
|
||||
} else {
|
||||
return // invalid packet
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.m == nil {
|
||||
s.m = make(map[flowtrack.Tuple]Counts)
|
||||
}
|
||||
cnts := s.m[tuple]
|
||||
if receive {
|
||||
cnts.RxPackets++
|
||||
cnts.RxBytes += size
|
||||
} else {
|
||||
cnts.TxPackets++
|
||||
cnts.TxBytes += size
|
||||
}
|
||||
s.m[tuple] = cnts
|
||||
}
|
||||
|
||||
func TestEmpty(t *testing.T) {
|
||||
c := qt.New(t)
|
||||
var s Statistics
|
||||
c.Assert(s.Extract(), qt.DeepEquals, map[flowtrack.Tuple]Counts{})
|
||||
c.Assert(s.Extract(), qt.DeepEquals, map[flowtrack.Tuple]Counts{})
|
||||
}
|
||||
|
||||
func TestOverflow(t *testing.T) {
|
||||
c := qt.New(t)
|
||||
var s Statistics
|
||||
var cnts Counts
|
||||
|
||||
a := &addrsPortsV4{192, 168, 0, 1, 192, 168, 0, 2, 12, 34, 56, 78}
|
||||
h := maphash.Bytes(seed, a[:])
|
||||
|
||||
cnts.TxPackets++
|
||||
cnts.TxBytes += math.MaxUint32
|
||||
s.v4.update(false, ipproto.UDP, a, h, math.MaxUint32)
|
||||
for i := 0; i < 1e6; i++ {
|
||||
cnts.TxPackets++
|
||||
cnts.TxBytes += uint64(i)
|
||||
s.v4.update(false, ipproto.UDP, a, h, uint32(i))
|
||||
}
|
||||
c.Assert(s.Extract(), qt.DeepEquals, map[flowtrack.Tuple]Counts{a.asTuple(ipproto.UDP): cnts})
|
||||
c.Assert(s.Extract(), qt.DeepEquals, map[flowtrack.Tuple]Counts{})
|
||||
}
|
||||
|
||||
func FuzzParse(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
var s Statistics
|
||||
s.UpdateRx(b) // must not panic
|
||||
s.UpdateTx(b) // must not panic
|
||||
s.Extract() // must not panic
|
||||
})
|
||||
}
|
||||
|
||||
var testV4 = func() (b [24]byte) {
|
||||
b[0] = 4<<4 | 5 // version and header length
|
||||
binary.BigEndian.PutUint16(b[2:], 1234) // size
|
||||
b[9] = byte(ipproto.UDP) // protocol
|
||||
*(*[4]byte)(b[12:]) = [4]byte{192, 168, 0, 1} // src addr
|
||||
*(*[4]byte)(b[16:]) = [4]byte{192, 168, 0, 2} // dst addr
|
||||
binary.BigEndian.PutUint16(b[20:], 456) // src port
|
||||
binary.BigEndian.PutUint16(b[22:], 789) // dst port
|
||||
return b
|
||||
}()
|
||||
|
||||
/*
|
||||
func BenchmarkA(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
for j := 0; j < 1e3; j++ {
|
||||
s.UpdateTx(testV4[:])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkB(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s SimpleStatistics
|
||||
for j := 0; j < 1e3; j++ {
|
||||
s.UpdateTx(testV4[:])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkC(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
var group sync.WaitGroup
|
||||
for k := 0; k < runtime.NumCPU(); k++ {
|
||||
group.Add(1)
|
||||
go func(k int) {
|
||||
defer group.Done()
|
||||
b := testV4
|
||||
for j := 0; j < 1e3; j++ {
|
||||
binary.LittleEndian.PutUint32(b[12:], uint32(k))
|
||||
binary.LittleEndian.PutUint32(b[16:], uint32(j))
|
||||
s.UpdateTx(b[:])
|
||||
}
|
||||
}(k)
|
||||
}
|
||||
group.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkD(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s SimpleStatistics
|
||||
var group sync.WaitGroup
|
||||
for k := 0; k < runtime.NumCPU(); k++ {
|
||||
group.Add(1)
|
||||
go func(k int) {
|
||||
defer group.Done()
|
||||
b := testV4
|
||||
for j := 0; j < 1e3; j++ {
|
||||
binary.LittleEndian.PutUint32(b[12:], uint32(k))
|
||||
binary.LittleEndian.PutUint32(b[16:], uint32(j))
|
||||
s.UpdateTx(b[:])
|
||||
}
|
||||
}(k)
|
||||
}
|
||||
group.Wait()
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// FUZZ
|
||||
// Benchmark:
|
||||
// IPv4 vs IPv6
|
||||
// single vs all cores
|
||||
// same vs unique addresses
|
||||
|
||||
/*
|
||||
linear probing
|
||||
|
||||
1 => 115595714 ns/op 859003746 B/op
|
||||
2 => 9355585 ns/op 46454947 B/op
|
||||
4 => 3301663 ns/op 8706967 B/op
|
||||
8 => 2775162 ns/op 4176433 B/op
|
||||
16 => 2517899 ns/op 2099434 B/op
|
||||
32 => 2397939 ns/op 2098986 B/op
|
||||
64 => 2118390 ns/op 1197352 B/op
|
||||
128 => 2029255 ns/op 1046729 B/op
|
||||
256 => 2069939 ns/op 1042577 B/op
|
||||
|
||||
quadratic probing
|
||||
|
||||
1 => 111134367 ns/op 825962200 B/op
|
||||
2 => 8061189 ns/op 45106117 B/op
|
||||
4 => 3216728 ns/op 8079556 B/op
|
||||
8 => 2576443 ns/op 2355890 B/op
|
||||
16 => 2471713 ns/op 2097196 B/op
|
||||
32 => 2108294 ns/op 1050225 B/op
|
||||
64 => 1964441 ns/op 1048736 B/op
|
||||
128 => 2118538 ns/op 1046663 B/op
|
||||
256 => 1968353 ns/op 1042568 B/op
|
||||
512 => 2049336 ns/op 1034306 B/op
|
||||
1024 => 2001605 ns/op 1017786 B/op
|
||||
2048 => 2046972 ns/op 984988 B/op
|
||||
4096 => 2108753 ns/op 919105 B/op
|
||||
*/
|
||||
|
||||
func testPacketV4(proto ipproto.Proto, srcAddr, dstAddr [4]byte, srcPort, dstPort, size uint16) (out []byte) {
|
||||
var ipHdr [20]byte
|
||||
ipHdr[0] = 4<<4 | 5
|
||||
binary.BigEndian.PutUint16(ipHdr[2:], size)
|
||||
ipHdr[9] = byte(proto)
|
||||
*(*[4]byte)(ipHdr[12:]) = srcAddr
|
||||
*(*[4]byte)(ipHdr[16:]) = dstAddr
|
||||
out = append(out, ipHdr[:]...)
|
||||
switch proto {
|
||||
case ipproto.TCP:
|
||||
var tcpHdr [20]byte
|
||||
binary.BigEndian.PutUint16(tcpHdr[0:], srcPort)
|
||||
binary.BigEndian.PutUint16(tcpHdr[2:], dstPort)
|
||||
out = append(out, tcpHdr[:]...)
|
||||
case ipproto.UDP:
|
||||
var udpHdr [8]byte
|
||||
binary.BigEndian.PutUint16(udpHdr[0:], srcPort)
|
||||
binary.BigEndian.PutUint16(udpHdr[2:], dstPort)
|
||||
out = append(out, udpHdr[:]...)
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown proto: %d", proto))
|
||||
}
|
||||
return append(out, make([]byte, int(size)-len(out))...)
|
||||
}
|
||||
|
||||
func Benchmark(b *testing.B) {
|
||||
b.Run("SingleRoutine/SameConn", func(b *testing.B) {
|
||||
p := testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 123, 456, 789)
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
for j := 0; j < 1e3; j++ {
|
||||
s.UpdateTx(p)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.Run("SingleRoutine/UniqueConns", func(b *testing.B) {
|
||||
p := testPacketV4(ipproto.UDP, [4]byte{}, [4]byte{}, 0, 0, 789)
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
for j := 0; j < 1e3; j++ {
|
||||
binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination
|
||||
s.UpdateTx(p)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.Run("MultiRoutine/SameConn", func(b *testing.B) {
|
||||
p := testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 123, 456, 789)
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
var group sync.WaitGroup
|
||||
for j := 0; j < runtime.NumCPU(); j++ {
|
||||
group.Add(1)
|
||||
go func() {
|
||||
defer group.Done()
|
||||
for k := 0; k < 1e3; k++ {
|
||||
s.UpdateTx(p)
|
||||
}
|
||||
}()
|
||||
}
|
||||
group.Wait()
|
||||
}
|
||||
})
|
||||
b.Run("MultiRoutine/UniqueConns", func(b *testing.B) {
|
||||
ps := make([][]byte, runtime.NumCPU())
|
||||
for i := range ps {
|
||||
ps[i] = testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 0, 0, 789)
|
||||
}
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
var group sync.WaitGroup
|
||||
for j := 0; j < runtime.NumCPU(); j++ {
|
||||
group.Add(1)
|
||||
go func(j int) {
|
||||
defer group.Done()
|
||||
p := ps[j]
|
||||
j *= 1e3
|
||||
for k := 0; k < 1e3; k++ {
|
||||
binary.BigEndian.PutUint32(p[20:], uint32(j+k)) // unique port combination
|
||||
s.UpdateTx(p)
|
||||
}
|
||||
}(j)
|
||||
}
|
||||
group.Wait()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -74,7 +73,7 @@ func Read(r io.Reader) (*Info, error) {
|
||||
}
|
||||
|
||||
// Exhaust the remainder of r, so that the summers see the entire file.
|
||||
if _, err := io.Copy(ioutil.Discard, r); err != nil {
|
||||
if _, err := io.Copy(io.Discard, r); err != nil {
|
||||
return nil, fmt.Errorf("hashing file: %w", err)
|
||||
}
|
||||
|
||||
@@ -117,7 +116,7 @@ func findControlTar(r io.Reader) (tarReader io.Reader, err error) {
|
||||
if size%2 == 1 {
|
||||
size++
|
||||
}
|
||||
if _, err := io.CopyN(ioutil.Discard, r, size); err != nil {
|
||||
if _, err := io.CopyN(io.Discard, r, size); err != nil {
|
||||
return nil, fmt.Errorf("seeking past file %q: %w", filename, err)
|
||||
}
|
||||
}
|
||||
@@ -150,7 +149,7 @@ func findControlFile(r io.Reader) (control []byte, err error) {
|
||||
break
|
||||
}
|
||||
|
||||
bs, err := ioutil.ReadAll(tr)
|
||||
bs, err := io.ReadAll(tr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading control file: %w", err)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,9 @@ func argvSubject(argv ...string) string {
|
||||
ret = filepath.Base(argv[1])
|
||||
}
|
||||
|
||||
// Handle space separated argv
|
||||
ret, _, _ = strings.Cut(ret, " ")
|
||||
|
||||
// Remove common noise.
|
||||
ret = strings.TrimSpace(ret)
|
||||
ret = strings.TrimSuffix(ret, ".exe")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user