Compare commits
49 Commits
ip6tables
...
andrew/con
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39b45bb031 | ||
|
|
265b008e49 | ||
|
|
a5ad57472a | ||
|
|
3564fd61b5 | ||
|
|
cfbbcf6d07 | ||
|
|
9c66dce8e0 | ||
|
|
e470893ba0 | ||
|
|
c72caa6672 | ||
|
|
58f35261d0 | ||
|
|
be95aebabd | ||
|
|
490acdefb6 | ||
|
|
84b74825f0 | ||
|
|
9bd9f37d29 | ||
|
|
185f2e4768 | ||
|
|
53e08bd7ea | ||
|
|
70ed22ccf9 | ||
|
|
7ca17b6bdb | ||
|
|
e945d87d76 | ||
|
|
1ac4a26fee | ||
|
|
761163815c | ||
|
|
9f6c8517e0 | ||
|
|
27f36f77c3 | ||
|
|
122bd667dc | ||
|
|
21cd402204 | ||
|
|
0ae0439668 | ||
|
|
6dcc6313a6 | ||
|
|
78dbb59a00 | ||
|
|
7e40071571 | ||
|
|
90dc0e1702 | ||
|
|
2c18517121 | ||
|
|
d6c3588ed3 | ||
|
|
81dba3738e | ||
|
|
ad1cc6cff9 | ||
|
|
68d9d161f4 | ||
|
|
c66f99fcdc | ||
|
|
08b3f5f070 | ||
|
|
66d7d2549f | ||
|
|
d20392d413 | ||
|
|
58cc049a9f | ||
|
|
9b77ac128a | ||
|
|
e1738ea78e | ||
|
|
9bf13fc3d1 | ||
|
|
ab7e6f3f11 | ||
|
|
c5b1565337 | ||
|
|
d2e2d8438b | ||
|
|
23c3831ff9 | ||
|
|
296b008b9f | ||
|
|
31bf3874d6 | ||
|
|
e0c5ac1f02 |
7
.github/workflows/cross-darwin.yml
vendored
7
.github/workflows/cross-darwin.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: macOS build cmd
|
||||
env:
|
||||
GOOS: darwin
|
||||
|
||||
7
.github/workflows/cross-freebsd.yml
vendored
7
.github/workflows/cross-freebsd.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: FreeBSD build cmd
|
||||
env:
|
||||
GOOS: freebsd
|
||||
|
||||
7
.github/workflows/cross-openbsd.yml
vendored
7
.github/workflows/cross-openbsd.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: OpenBSD build cmd
|
||||
env:
|
||||
GOOS: openbsd
|
||||
|
||||
7
.github/workflows/cross-wasm.yml
vendored
7
.github/workflows/cross-wasm.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Wasm client build
|
||||
env:
|
||||
GOOS: js
|
||||
|
||||
7
.github/workflows/cross-windows.yml
vendored
7
.github/workflows/cross-windows.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Windows build cmd
|
||||
env:
|
||||
GOOS: windows
|
||||
|
||||
8
.github/workflows/depaware.yml
vendored
8
.github/workflows/depaware.yml
vendored
@@ -17,13 +17,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: depaware
|
||||
run: go run github.com/tailscale/depaware --check
|
||||
|
||||
10
.github/workflows/go_generate.yml
vendored
10
.github/workflows/go_generate.yml
vendored
@@ -18,16 +18,16 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: check 'go generate' is clean
|
||||
run: |
|
||||
if [[ "${{github.ref}}" == release-branch/* ]]
|
||||
|
||||
35
.github/workflows/go_mod_tidy.yml
vendored
Normal file
35
.github/workflows/go_mod_tidy.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: go mod tidy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- "*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: check 'go mod tidy' is clean
|
||||
run: |
|
||||
go mod tidy
|
||||
echo
|
||||
echo
|
||||
git diff --name-only --exit-code || (echo "Please run 'go mod tidy'."; exit 1)
|
||||
8
.github/workflows/license.yml
vendored
8
.github/workflows/license.yml
vendored
@@ -17,13 +17,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Run license checker
|
||||
run: ./scripts/check_license_headers.sh .
|
||||
|
||||
7
.github/workflows/linux-race.yml
vendored
7
.github/workflows/linux-race.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Basic build
|
||||
run: go build ./cmd/...
|
||||
|
||||
|
||||
7
.github/workflows/linux.yml
vendored
7
.github/workflows/linux.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Basic build
|
||||
run: go build ./cmd/...
|
||||
|
||||
|
||||
7
.github/workflows/linux32.yml
vendored
7
.github/workflows/linux32.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Basic build
|
||||
run: GOARCH=386 go build ./cmd/...
|
||||
|
||||
|
||||
6
.github/workflows/static-analysis.yml
vendored
6
.github/workflows/static-analysis.yml
vendored
@@ -16,12 +16,12 @@ jobs:
|
||||
gofmt:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
- name: Run gofmt (goimports)
|
||||
run: go run golang.org/x/tools/cmd/goimports -d --format-only .
|
||||
- uses: k0kubun/action-slack@v2.0.0
|
||||
|
||||
30
.github/workflows/tsconnect-pkg-publish.yml
vendored
Normal file
30
.github/workflows/tsconnect-pkg-publish.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: "@tailscale/connect npm publish"
|
||||
|
||||
on: workflow_dispatch
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up node
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "16.x"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
|
||||
- name: Build package
|
||||
# Build with build_dist.sh to ensure that version information is embedded.
|
||||
# GOROOT is specified so that the Go/Wasm that is trigged by build-pk
|
||||
# also picks up our custom Go toolchain.
|
||||
run: |
|
||||
./build_dist.sh tailscale.com/cmd/tsconnect
|
||||
GOROOT="${HOME}/.cache/tailscale-go" ./tsconnect build-pkg
|
||||
|
||||
- name: Publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.TSCONNECT_NPM_PUBLISH_AUTH_TOKEN }}
|
||||
run: ./tool/yarn --cwd ./cmd/tsconnect/pkg publish --access public
|
||||
8
.github/workflows/vm.yml
vendored
8
.github/workflows/vm.yml
vendored
@@ -19,13 +19,13 @@ jobs:
|
||||
- name: Set GOPATH
|
||||
run: echo "GOPATH=$HOME/go" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Run VM tests
|
||||
run: go test ./tstest/integration/vms -v -no-s3 -run-vm-tests -run=TestRunUbuntu2004
|
||||
|
||||
7
.github/workflows/windows.yml
vendored
7
.github/workflows/windows.yml
vendored
@@ -19,14 +19,13 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19.x
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Restore Cache
|
||||
uses: actions/cache@v3
|
||||
|
||||
@@ -1 +1 @@
|
||||
1.29.0
|
||||
1.31.0
|
||||
|
||||
@@ -11,15 +11,31 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum amount of time we should wait when reading a response from BIRD.
|
||||
responseTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// New creates a BIRDClient.
|
||||
func New(socket string) (*BIRDClient, error) {
|
||||
return newWithTimeout(socket, responseTimeout)
|
||||
}
|
||||
|
||||
func newWithTimeout(socket string, timeout time.Duration) (*BIRDClient, error) {
|
||||
conn, err := net.Dial("unix", socket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to BIRD: %w", err)
|
||||
}
|
||||
b := &BIRDClient{socket: socket, conn: conn, scanner: bufio.NewScanner(conn)}
|
||||
b := &BIRDClient{
|
||||
socket: socket,
|
||||
conn: conn,
|
||||
scanner: bufio.NewScanner(conn),
|
||||
timeNow: time.Now,
|
||||
timeout: timeout,
|
||||
}
|
||||
// Read and discard the first line as that is the welcome message.
|
||||
if _, err := b.readResponse(); err != nil {
|
||||
return nil, err
|
||||
@@ -32,6 +48,8 @@ type BIRDClient struct {
|
||||
socket string
|
||||
conn net.Conn
|
||||
scanner *bufio.Scanner
|
||||
timeNow func() time.Time
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// Close closes the underlying connection to BIRD.
|
||||
@@ -81,10 +99,15 @@ func (b *BIRDClient) EnableProtocol(protocol string) error {
|
||||
// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’.
|
||||
|
||||
func (b *BIRDClient) exec(cmd string, args ...any) (string, error) {
|
||||
if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fmt.Fprintln(b.conn)
|
||||
if _, err := fmt.Fprintln(b.conn); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return b.readResponse()
|
||||
}
|
||||
|
||||
@@ -105,14 +128,20 @@ func hasResponseCode(s []byte) bool {
|
||||
}
|
||||
|
||||
func (b *BIRDClient) readResponse() (string, error) {
|
||||
// Set the read timeout before we start reading anything.
|
||||
if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var resp strings.Builder
|
||||
var done bool
|
||||
for !done {
|
||||
if !b.scanner.Scan() {
|
||||
return "", fmt.Errorf("reading response from bird failed: %q", resp.String())
|
||||
}
|
||||
if err := b.scanner.Err(); err != nil {
|
||||
return "", err
|
||||
if err := b.scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String())
|
||||
}
|
||||
out := b.scanner.Bytes()
|
||||
if _, err := resp.Write(out); err != nil {
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fakeBIRD struct {
|
||||
@@ -109,3 +112,82 @@ func TestChirp(t *testing.T) {
|
||||
t.Fatalf("disabling %q succeded", "rando")
|
||||
}
|
||||
}
|
||||
|
||||
type hangingListener struct {
|
||||
net.Listener
|
||||
t *testing.T
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
sock string
|
||||
}
|
||||
|
||||
func newHangingListener(t *testing.T) *hangingListener {
|
||||
sock := filepath.Join(t.TempDir(), "sock")
|
||||
l, err := net.Listen("unix", sock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &hangingListener{
|
||||
Listener: l,
|
||||
t: t,
|
||||
done: make(chan struct{}),
|
||||
sock: sock,
|
||||
}
|
||||
}
|
||||
|
||||
func (hl *hangingListener) Stop() {
|
||||
hl.Close()
|
||||
close(hl.done)
|
||||
hl.wg.Wait()
|
||||
}
|
||||
|
||||
func (hl *hangingListener) listen() error {
|
||||
for {
|
||||
c, err := hl.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
hl.wg.Add(1)
|
||||
go hl.handle(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (hl *hangingListener) handle(c net.Conn) {
|
||||
defer hl.wg.Done()
|
||||
|
||||
// Write our fake first line of response so that we get into the read loop
|
||||
fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.")
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hl.t.Logf("connection still hanging")
|
||||
case <-hl.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChirpTimeout(t *testing.T) {
|
||||
fb := newHangingListener(t)
|
||||
defer fb.Stop()
|
||||
go fb.listen()
|
||||
|
||||
c, err := newWithTimeout(fb.sock, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = c.EnableProtocol("tailscale")
|
||||
if err == nil {
|
||||
t.Fatal("got err=nil, want timeout")
|
||||
}
|
||||
if !os.IsTimeout(err) {
|
||||
t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
@@ -30,17 +31,14 @@ var (
|
||||
cacheFname = rootFlagSet.String("cache-file", "./version-cache.json", "filename for the previous known version hash")
|
||||
timeout = rootFlagSet.Duration("timeout", 5*time.Minute, "timeout for the entire CI run")
|
||||
githubSyntax = rootFlagSet.Bool("github-syntax", true, "use GitHub Action error syntax (https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#setting-an-error-message)")
|
||||
|
||||
modifiedExternallyFailure = make(chan struct{}, 1)
|
||||
)
|
||||
|
||||
func modifiedExternallyError() {
|
||||
if *githubSyntax {
|
||||
fmt.Printf("::error file=%s,line=1,col=1,title=Policy File Modified Externally::The policy file was modified externally in the admin console.\n", *policyFname)
|
||||
fmt.Printf("::warning file=%s,line=1,col=1,title=Policy File Modified Externally::The policy file was modified externally in the admin console.\n", *policyFname)
|
||||
} else {
|
||||
fmt.Printf("The policy file was modified externally in the admin console.\n")
|
||||
}
|
||||
modifiedExternallyFailure <- struct{}{}
|
||||
}
|
||||
|
||||
func apply(cache *Cache, tailnet, apiKey string) func(context.Context, []string) error {
|
||||
@@ -207,10 +205,6 @@ func main() {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(modifiedExternallyFailure) != 0 {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func sumFile(fname string) (string, error) {
|
||||
@@ -271,13 +265,16 @@ func applyNewACL(ctx context.Context, tailnet, apiKey, policyFname, oldEtag stri
|
||||
}
|
||||
|
||||
func testNewACLs(ctx context.Context, tailnet, apiKey, policyFname string) error {
|
||||
fin, err := os.Open(policyFname)
|
||||
data, err := os.ReadFile(policyFname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err = hujson.Standardize(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fin.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://api.tailscale.com/api/v2/tailnet/%s/acl/validate", tailnet), fin)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://api.tailscale.com/api/v2/tailnet/%s/acl/validate", tailnet), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -75,12 +75,7 @@ func main() {
|
||||
log.Printf("can't extract tailnet name from hostname %q", info.Node.Name)
|
||||
return
|
||||
}
|
||||
tailnet, _, ok = strings.Cut(tailnet, ".beta.tailscale.net")
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
log.Printf("can't extract tailnet name from hostname %q", info.Node.Name)
|
||||
return
|
||||
}
|
||||
tailnet = strings.TrimSuffix(tailnet, ".beta.tailscale.net")
|
||||
}
|
||||
|
||||
if expectedTailnet := r.Header.Get("Expected-Tailnet"); expectedTailnet != "" && expectedTailnet != tailnet {
|
||||
|
||||
@@ -29,7 +29,7 @@ var certCmd = &ffcli.Command{
|
||||
FlagSet: (func() *flag.FlagSet {
|
||||
fs := newFlagSet("cert")
|
||||
fs.StringVar(&certArgs.certFile, "cert-file", "", "output cert file or \"-\" for stdout; defaults to DOMAIN.crt if --cert-file and --key-file are both unset")
|
||||
fs.StringVar(&certArgs.keyFile, "key-file", "", "output cert file or \"-\" for stdout; defaults to DOMAIN.key if --cert-file and --key-file are both unset")
|
||||
fs.StringVar(&certArgs.keyFile, "key-file", "", "output key file or \"-\" for stdout; defaults to DOMAIN.key if --cert-file and --key-file are both unset")
|
||||
fs.BoolVar(&certArgs.serve, "serve-demo", false, "if true, serve on port :443 using the cert as a demo, instead of writing out the files to disk")
|
||||
return fs
|
||||
})(),
|
||||
|
||||
@@ -290,12 +290,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
W tailscale.com/wf from tailscale.com/cmd/tailscaled
|
||||
tailscale.com/wgengine from tailscale.com/ipn/ipnlocal+
|
||||
tailscale.com/wgengine/filter from tailscale.com/control/controlclient+
|
||||
tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+
|
||||
💣 tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+
|
||||
tailscale.com/wgengine/monitor from tailscale.com/control/controlclient+
|
||||
tailscale.com/wgengine/netstack from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/wgengine/router from tailscale.com/ipn/ipnlocal+
|
||||
tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+
|
||||
tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal
|
||||
💣 tailscale.com/wgengine/wgint from tailscale.com/wgengine
|
||||
tailscale.com/wgengine/wglog from tailscale.com/wgengine
|
||||
W 💣 tailscale.com/wgengine/winnet from tailscale.com/wgengine/router
|
||||
golang.org/x/crypto/acme from tailscale.com/ipn/localapi
|
||||
@@ -404,6 +405,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
mime/quotedprintable from mime/multipart
|
||||
net from crypto/tls+
|
||||
net/http from expvar+
|
||||
net/http/httptest from tailscale.com/control/controlclient
|
||||
net/http/httptrace from github.com/tcnksm/go-httpstat+
|
||||
net/http/httputil from github.com/aws/smithy-go/transport/http+
|
||||
net/http/internal from net/http+
|
||||
|
||||
@@ -5,9 +5,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
esbuild "github.com/evanw/esbuild/pkg/api"
|
||||
"github.com/tailscale/hujson"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
func runBuildPkg() {
|
||||
@@ -41,4 +47,33 @@ func runBuildPkg() {
|
||||
log.Fatalf("Type generation failed: %v", err)
|
||||
}
|
||||
|
||||
if err := updateVersion(); err != nil {
|
||||
log.Fatalf("Cannot update version: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("Built package version %s", version.Long)
|
||||
}
|
||||
|
||||
func updateVersion() error {
|
||||
packageJSONBytes, err := os.ReadFile("package.json.tmpl")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not read package.json: %w", err)
|
||||
}
|
||||
|
||||
var packageJSON map[string]any
|
||||
packageJSONBytes, err = hujson.Standardize(packageJSONBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not standardize template package.json: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(packageJSONBytes, &packageJSON); err != nil {
|
||||
return fmt.Errorf("Could not unmarshal package.json: %w", err)
|
||||
}
|
||||
packageJSON["version"] = version.Long
|
||||
|
||||
packageJSONBytes, err = json.MarshalIndent(packageJSON, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not marshal package.json: %w", err)
|
||||
}
|
||||
|
||||
return os.WriteFile(path.Join(*pkgDir, "package.json"), packageJSONBytes, 0644)
|
||||
}
|
||||
|
||||
17
cmd/tsconnect/package.json.tmpl
Normal file
17
cmd/tsconnect/package.json.tmpl
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Template for the package.json that is generated by the build-pkg command.
|
||||
// The version number will be replaced by the current Tailscale client version
|
||||
// number.
|
||||
{
|
||||
"author": "Tailscale Inc.",
|
||||
"description": "Tailscale Connect SDK",
|
||||
"license": "BSD-3-Clause",
|
||||
"name": "tailscale-connect",
|
||||
"type": "module",
|
||||
"main": "./pkg.js",
|
||||
"types": "./pkg.d.ts",
|
||||
"version": "AUTO_GENERATED"
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"author": "Tailscale Inc.",
|
||||
"description": "Tailscale Connect SDK",
|
||||
"license": "BSD-3-Clause",
|
||||
"name": "@tailscale/connect",
|
||||
"type": "module",
|
||||
"main": "./pkg.js",
|
||||
"types": "./pkg.d.ts",
|
||||
"version": "0.0.5"
|
||||
}
|
||||
@@ -27,26 +27,39 @@ export function runSSHSession(
|
||||
|
||||
term.focus()
|
||||
|
||||
let resizeObserver: ResizeObserver | undefined
|
||||
let handleBeforeUnload: ((e: BeforeUnloadEvent) => void) | undefined
|
||||
|
||||
const sshSession = ipn.ssh(def.hostname, def.username, {
|
||||
writeFn: (input) => term.write(input),
|
||||
setReadFn: (hook) => (onDataHook = hook),
|
||||
writeFn(input) {
|
||||
term.write(input)
|
||||
},
|
||||
writeErrorFn(err) {
|
||||
console.error(err)
|
||||
term.write(err)
|
||||
},
|
||||
setReadFn(hook) {
|
||||
onDataHook = hook
|
||||
},
|
||||
rows: term.rows,
|
||||
cols: term.cols,
|
||||
onDone: () => {
|
||||
resizeObserver.disconnect()
|
||||
onDone() {
|
||||
resizeObserver?.disconnect()
|
||||
term.dispose()
|
||||
window.removeEventListener("beforeunload", handleBeforeUnload)
|
||||
if (handleBeforeUnload) {
|
||||
window.removeEventListener("beforeunload", handleBeforeUnload)
|
||||
}
|
||||
onDone()
|
||||
},
|
||||
})
|
||||
|
||||
// Make terminal and SSH session track the size of the containing DOM node.
|
||||
const resizeObserver = new ResizeObserver(() => fitAddon.fit())
|
||||
resizeObserver = new ResizeObserver(() => fitAddon.fit())
|
||||
resizeObserver.observe(termContainerNode)
|
||||
term.onResize(({ rows, cols }) => sshSession.resize(rows, cols))
|
||||
|
||||
// Close the session if the user closes the window without an explicit
|
||||
// exit.
|
||||
const handleBeforeUnload = () => sshSession.close()
|
||||
handleBeforeUnload = () => sshSession.close()
|
||||
window.addEventListener("beforeunload", handleBeforeUnload)
|
||||
}
|
||||
|
||||
1
cmd/tsconnect/src/types/wasm_js.d.ts
vendored
1
cmd/tsconnect/src/types/wasm_js.d.ts
vendored
@@ -19,6 +19,7 @@ declare global {
|
||||
username: string,
|
||||
termConfig: {
|
||||
writeFn: (data: string) => void
|
||||
writeErrorFn: (err: string) => void
|
||||
setReadFn: (readFn: (data: string) => void) => void
|
||||
rows: number
|
||||
cols: number
|
||||
|
||||
@@ -347,6 +347,7 @@ type jsSSHSession struct {
|
||||
|
||||
func (s *jsSSHSession) Run() {
|
||||
writeFn := s.termConfig.Get("writeFn")
|
||||
writeErrorFn := s.termConfig.Get("writeErrorFn")
|
||||
setReadFn := s.termConfig.Get("setReadFn")
|
||||
rows := s.termConfig.Get("rows").Int()
|
||||
cols := s.termConfig.Get("cols").Int()
|
||||
@@ -357,7 +358,7 @@ func (s *jsSSHSession) Run() {
|
||||
writeFn.Invoke(s)
|
||||
}
|
||||
writeError := func(label string, err error) {
|
||||
write(fmt.Sprintf("%s Error: %v\r\n", label, err))
|
||||
writeErrorFn.Invoke(fmt.Sprintf("%s Error: %v\r\n", label, err))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -73,6 +75,7 @@ type Direct struct {
|
||||
skipIPForwardingCheck bool
|
||||
pinger Pinger
|
||||
popBrowser func(url string) // or nil
|
||||
c2nHandler http.Handler // or nil
|
||||
|
||||
mu sync.Mutex // mutex guards the following fields
|
||||
serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key
|
||||
@@ -108,6 +111,7 @@ type Options struct {
|
||||
LinkMonitor *monitor.Mon // optional link monitor
|
||||
PopBrowserURL func(url string) // optional func to open browser
|
||||
Dialer *tsdial.Dialer // non-nil
|
||||
C2NHandler http.Handler // or nil
|
||||
|
||||
// GetNLPublicKey specifies an optional function to use
|
||||
// Network Lock. If nil, it's not used.
|
||||
@@ -129,6 +133,12 @@ type Options struct {
|
||||
// MapResponse.PingRequest queries from the control plane.
|
||||
// If nil, PingRequest queries are not answered.
|
||||
Pinger Pinger
|
||||
|
||||
// GetTailscaleRoutes is a function that should return any Tailscale
|
||||
// routes that are currently known; if any are returned, we test the IP
|
||||
// address of the control server against these routes and use our
|
||||
// fallback DNS server in those cases.
|
||||
GetTailscaleRoutes func() []netip.Prefix
|
||||
}
|
||||
|
||||
// Pinger is the LocalBackend.Ping method.
|
||||
@@ -210,6 +220,7 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
skipIPForwardingCheck: opts.SkipIPForwardingCheck,
|
||||
pinger: opts.Pinger,
|
||||
popBrowser: opts.PopBrowserURL,
|
||||
c2nHandler: opts.C2NHandler,
|
||||
dialer: opts.Dialer,
|
||||
}
|
||||
if opts.Hostinfo == nil {
|
||||
@@ -1205,7 +1216,8 @@ func (c *Direct) isUniquePingRequest(pr *tailcfg.PingRequest) bool {
|
||||
|
||||
func (c *Direct) answerPing(pr *tailcfg.PingRequest) {
|
||||
httpc := c.httpc
|
||||
if pr.URLIsNoise {
|
||||
useNoise := pr.URLIsNoise || pr.Types == "c2n" && c.noiseConfigured()
|
||||
if useNoise {
|
||||
nc, err := c.getNoiseClient()
|
||||
if err != nil {
|
||||
c.logf("failed to get noise client for ping request: %v", err)
|
||||
@@ -1217,9 +1229,17 @@ func (c *Direct) answerPing(pr *tailcfg.PingRequest) {
|
||||
c.logf("invalid PingRequest with no URL")
|
||||
return
|
||||
}
|
||||
if pr.Types == "" {
|
||||
switch pr.Types {
|
||||
case "":
|
||||
answerHeadPing(c.logf, httpc, pr)
|
||||
return
|
||||
case "c2n":
|
||||
if !useNoise && !envknob.Bool("TS_DEBUG_PERMIT_HTTP_C2N") {
|
||||
c.logf("refusing to answer c2n ping without noise")
|
||||
return
|
||||
}
|
||||
answerC2NPing(c.logf, c.c2nHandler, httpc, pr)
|
||||
return
|
||||
}
|
||||
for _, t := range strings.Split(pr.Types, ",") {
|
||||
switch pt := tailcfg.PingType(t); pt {
|
||||
@@ -1253,6 +1273,54 @@ func answerHeadPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr *tailcfg.PingRequest) {
|
||||
if c2nHandler == nil {
|
||||
logf("answerC2NPing: c2nHandler not defined")
|
||||
return
|
||||
}
|
||||
hreq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload)))
|
||||
if err != nil {
|
||||
logf("answerC2NPing: ReadRequest: %v", err)
|
||||
return
|
||||
}
|
||||
if pr.Log {
|
||||
logf("answerC2NPing: got c2n request for %v ...", hreq.RequestURI)
|
||||
}
|
||||
handlerTimeout := time.Minute
|
||||
if v := hreq.Header.Get("C2n-Handler-Timeout"); v != "" {
|
||||
handlerTimeout, _ = time.ParseDuration(v)
|
||||
}
|
||||
handlerCtx, cancel := context.WithTimeout(context.Background(), handlerTimeout)
|
||||
defer cancel()
|
||||
hreq = hreq.WithContext(handlerCtx)
|
||||
rec := httptest.NewRecorder()
|
||||
c2nHandler.ServeHTTP(rec, hreq)
|
||||
cancel()
|
||||
|
||||
c2nResBuf := new(bytes.Buffer)
|
||||
rec.Result().Write(c2nResBuf)
|
||||
|
||||
replyCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(replyCtx, "POST", pr.URL, c2nResBuf)
|
||||
if err != nil {
|
||||
logf("answerC2NPing: NewRequestWithContext: %v", err)
|
||||
return
|
||||
}
|
||||
if pr.Log {
|
||||
logf("answerC2NPing: sending POST ping to %v ...", pr.URL)
|
||||
}
|
||||
t0 := time.Now()
|
||||
_, err = c.Do(req)
|
||||
d := time.Since(t0).Round(time.Millisecond)
|
||||
if err != nil {
|
||||
logf("answerC2NPing error: %v to %v (after %v)", err, pr.URL, d)
|
||||
} else if pr.Log {
|
||||
logf("answerC2NPing complete to %v (after %v)", pr.URL, d)
|
||||
}
|
||||
}
|
||||
|
||||
func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- struct{}, d time.Duration) error {
|
||||
const maxSleep = 5 * time.Minute
|
||||
if d > maxSleep {
|
||||
|
||||
@@ -18,7 +18,7 @@ spec:
|
||||
command: ["/bin/sh"]
|
||||
args:
|
||||
- -c
|
||||
- sysctl -w net.ipv4.ip_forward=1
|
||||
- sysctl -w net.ipv4.ip_forward=1 -w net.ipv6.conf.all.forwarding=1
|
||||
resources:
|
||||
requests:
|
||||
cpu: 1m
|
||||
|
||||
4
go.mod
4
go.mod
@@ -64,7 +64,7 @@ require (
|
||||
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11
|
||||
golang.org/x/tools v0.1.11
|
||||
golang.zx2c4.com/wireguard v0.0.0-20220703234212-c31a7b1ab478
|
||||
golang.zx2c4.com/wireguard/windows v0.4.10
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
gvisor.dev/gvisor v0.0.0-20220801230058-850e42eb4444
|
||||
honnef.co/go/tools v0.4.0-0.dev.0.20220404092545-59d7a2877f83
|
||||
inet.af/peercred v0.0.0-20210906144145-0893ea02156a
|
||||
@@ -266,7 +266,7 @@ require (
|
||||
github.com/yeya24/promlinter v0.1.0 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20220328175248-053ad81199eb // indirect
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
|
||||
google.golang.org/protobuf v1.28.0 // indirect
|
||||
gopkg.in/ini.v1 v1.66.2 // indirect
|
||||
|
||||
12
go.sum
12
go.sum
@@ -729,8 +729,6 @@ github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/logrusorgru/aurora v0.0.0-20181002194514-a7b3b318ed4e/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/lxn/walk v0.0.0-20210112085537-c389da54e794/go.mod h1:E23UucZGqpuUANJooIbHWCufXvOcT6E7Stq81gU+CSQ=
|
||||
github.com/lxn/win v0.0.0-20210218163916-a377121e959e/go.mod h1:KxxjdtRkfNoYDCUP5ryK7XJJNTnpC8atvtmTheChOtk=
|
||||
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
github.com/magiconair/properties v1.8.4/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
|
||||
@@ -1352,7 +1350,6 @@ golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qx
|
||||
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20210903162142-ad29c8ab022f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
@@ -1449,7 +1446,6 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201101102859-da207088b7d1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201109165425-215b40eba54c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -1508,8 +1504,9 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2 h1:GLw7MR8AfAG2GmGcmVgObFOHXYypgGjnGno25RDwn3Y=
|
||||
golang.org/x/text v0.3.8-0.20211105212822-18b340fc7af2/go.mod h1:EFNZuWvGYxIRUEX+K8UmCFwYmZjqcrnq15ZuVldZkZ0=
|
||||
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
@@ -1636,11 +1633,10 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
|
||||
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20210905140043-2ef39d47540c/go.mod h1:laHzsbfMhGSobUmruXWAyMKKHSqvIcrqZJMyHD+/3O8=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20220703234212-c31a7b1ab478 h1:vDy//hdR+GnROE3OdYbQKt9rdtNdHkDtONvpRwmls/0=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20220703234212-c31a7b1ab478/go.mod h1:bVQfyl2sCM/QIIGHpWbFGfHPuDvqnCNkT6MQLTCjO/U=
|
||||
golang.zx2c4.com/wireguard/windows v0.4.10 h1:HmjzJnb+G4NCdX+sfjsQlsxGPuYaThxRbZUZFLyR0/s=
|
||||
golang.zx2c4.com/wireguard/windows v0.4.10/go.mod h1:v7w/8FC48tTBm1IzScDVPEEb0/GjLta+T0ybpP9UWRg=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
|
||||
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
|
||||
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
|
||||
|
||||
21
ipn/ipnlocal/c2n.go
Normal file
21
ipn/ipnlocal/c2n.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ipnlocal
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/echo":
|
||||
// Test handler.
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Write(body)
|
||||
default:
|
||||
http.Error(w, "unknown c2n path", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
@@ -132,6 +132,7 @@ type LocalBackend struct {
|
||||
|
||||
filterAtomic atomic.Pointer[filter.Filter]
|
||||
containsViaIPFuncAtomic syncs.AtomicValue[func(netip.Addr) bool]
|
||||
tailscaleRoutesAtomic syncs.AtomicValue[[]netip.Prefix]
|
||||
|
||||
// The mutex protects the following elements.
|
||||
mu sync.Mutex
|
||||
@@ -884,6 +885,10 @@ func (b *LocalBackend) getNewControlClientFunc() clientGen {
|
||||
return b.ccGen
|
||||
}
|
||||
|
||||
func (b *LocalBackend) getTailscaleRoutes() []netip.Prefix {
|
||||
return b.tailscaleRoutesAtomic.Load()
|
||||
}
|
||||
|
||||
// startIsNoopLocked reports whether a Start call on this LocalBackend
|
||||
// with the provided Start Options would be a useless no-op.
|
||||
//
|
||||
@@ -955,6 +960,8 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
|
||||
hostinfo := hostinfo.New()
|
||||
hostinfo.BackendLogID = b.backendLogID
|
||||
hostinfo.FrontendLogID = opts.FrontendLogID
|
||||
hostinfo.Userspace.Set(wgengine.IsNetstack(b.e))
|
||||
hostinfo.UserspaceRouter.Set(wgengine.IsNetstackRouter(b.e))
|
||||
|
||||
if b.cc != nil {
|
||||
// TODO(apenwarr): avoid the need to reinit controlclient.
|
||||
@@ -1075,6 +1082,8 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
|
||||
PopBrowserURL: b.tellClientToBrowseToURL,
|
||||
Dialer: b.Dialer(),
|
||||
Status: b.setClientStatus,
|
||||
C2NHandler: http.HandlerFunc(b.handleC2N),
|
||||
GetTailscaleRoutes: b.getTailscaleRoutes,
|
||||
|
||||
// Don't warn about broken Linux IP forwarding when
|
||||
// netstack is being used.
|
||||
@@ -2312,6 +2321,7 @@ func (b *LocalBackend) authReconfig() {
|
||||
}
|
||||
b.logf("[v1] authReconfig: ra=%v dns=%v 0x%02x: %v", prefs.RouteAll, prefs.CorpDNS, flags, err)
|
||||
|
||||
b.tailscaleRoutesAtomic.Store(rcfg.Routes)
|
||||
b.initPeerAPIListener()
|
||||
}
|
||||
|
||||
|
||||
@@ -147,7 +147,7 @@ func signNodeKey(nodeInfo tailcfg.TKASignInfo, signer key.NLPrivate) (*tka.NodeK
|
||||
SigKind: tka.SigDirect,
|
||||
KeyID: signer.KeyID(),
|
||||
Pubkey: p,
|
||||
RotationPubkey: nodeInfo.RotationPubkey,
|
||||
WrappingPubkey: nodeInfo.RotationPubkey,
|
||||
}
|
||||
sig.Signature, err = signer.SignNKS(sig.SigHash())
|
||||
if err != nil {
|
||||
|
||||
@@ -505,6 +505,8 @@ func osEmoji(os string) string {
|
||||
return "👿"
|
||||
case "openbsd":
|
||||
return "🐡"
|
||||
case "illumos":
|
||||
return "☀️"
|
||||
}
|
||||
return "👽"
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/version/distro"
|
||||
)
|
||||
|
||||
@@ -185,7 +186,10 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acmeKey: %w", err)
|
||||
}
|
||||
ac := &acme.Client{Key: key}
|
||||
ac := &acme.Client{
|
||||
Key: key,
|
||||
UserAgent: "tailscaled/" + version.Long,
|
||||
}
|
||||
|
||||
a, err := ac.GetReg(ctx, "" /* pre-RFC param */)
|
||||
switch {
|
||||
|
||||
@@ -59,7 +59,7 @@ Client][]. See also the dependencies in the [Tailscale CLI][].
|
||||
- [golang.org/x/sync/errgroup](https://pkg.go.dev/golang.org/x/sync/errgroup) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/0de741cf:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/c0bba94a:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/03fcf44c:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.3.7:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/18b340fc:LICENSE))
|
||||
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/f0f3c7e8:LICENSE))
|
||||
- [golang.zx2c4.com/wireguard](https://pkg.go.dev/golang.zx2c4.com/wireguard) ([MIT](https://git.zx2c4.com/wireguard-go/tree/LICENSE?id=c31a7b1ab478))
|
||||
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/850e42eb4444/LICENSE))
|
||||
|
||||
@@ -42,7 +42,7 @@ and [iOS][]. See also the dependencies in the [Tailscale CLI][].
|
||||
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/c690dde0:LICENSE))
|
||||
- [golang.org/x/sync/errgroup](https://pkg.go.dev/golang.org/x/sync/errgroup) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/0de741cf:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/c0bba94a:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.3.7:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/18b340fc:LICENSE))
|
||||
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/f0f3c7e8:LICENSE))
|
||||
- [golang.zx2c4.com/wireguard](https://pkg.go.dev/golang.zx2c4.com/wireguard) ([MIT](https://git.zx2c4.com/wireguard-go/tree/LICENSE?id=c31a7b1ab478))
|
||||
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/850e42eb4444/LICENSE))
|
||||
|
||||
@@ -75,11 +75,11 @@ Some packages may only be included on certain architectures or operating systems
|
||||
- [golang.org/x/sync/errgroup](https://pkg.go.dev/golang.org/x/sync/errgroup) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/0de741cf:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/c0bba94a:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/03fcf44c:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.3.7:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/18b340fc:LICENSE))
|
||||
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/f0f3c7e8:LICENSE))
|
||||
- [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=415007cec224))
|
||||
- [golang.zx2c4.com/wireguard](https://pkg.go.dev/golang.zx2c4.com/wireguard) ([MIT](https://git.zx2c4.com/wireguard-go/tree/LICENSE?id=c31a7b1ab478))
|
||||
- [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.4.10))
|
||||
- [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3))
|
||||
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/850e42eb4444/LICENSE))
|
||||
- [inet.af/peercred](https://pkg.go.dev/inet.af/peercred) ([BSD-3-Clause](https://github.com/inetaf/peercred/blob/0893ea02156a/LICENSE))
|
||||
- [inet.af/wf](https://pkg.go.dev/inet.af/wf) ([BSD-3-Clause](https://github.com/inetaf/wf/blob/50d96caab2f6/LICENSE))
|
||||
|
||||
@@ -36,7 +36,7 @@ Windows][]. See also the dependencies in the [Tailscale CLI][].
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/c0bba94a:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/03fcf44c:LICENSE))
|
||||
- [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=415007cec224))
|
||||
- [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.4.10))
|
||||
- [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3))
|
||||
- [gopkg.in/Knetic/govaluate.v3](https://pkg.go.dev/gopkg.in/Knetic/govaluate.v3) ([MIT](https://github.com/Knetic/govaluate/blob/v3.0.0/LICENSE))
|
||||
- [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE))
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -108,6 +109,10 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger {
|
||||
procID = 7
|
||||
}
|
||||
}
|
||||
|
||||
stdLogf := func(f string, a ...any) {
|
||||
fmt.Fprintf(cfg.Stderr, strings.TrimSuffix(f, "\n")+"\n", a...)
|
||||
}
|
||||
l := &Logger{
|
||||
privateID: cfg.PrivateID,
|
||||
stderr: cfg.Stderr,
|
||||
@@ -121,7 +126,7 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger {
|
||||
sentinel: make(chan int32, 16),
|
||||
drainLogs: cfg.DrainLogs,
|
||||
timeNow: cfg.TimeNow,
|
||||
bo: backoff.NewBackoff("logtail", logf, 30*time.Second),
|
||||
bo: backoff.NewBackoff("logtail", stdLogf, 30*time.Second),
|
||||
metricsDelta: cfg.MetricsDelta,
|
||||
|
||||
procID: procID,
|
||||
|
||||
@@ -91,6 +91,30 @@ func (c Config) hasDefaultIPResolversOnly() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// hasHostsWithoutSplitDNSRoutes reports whether c contains any Host entries
|
||||
// that aren't covered by a SplitDNS route suffix.
|
||||
func (c Config) hasHostsWithoutSplitDNSRoutes() bool {
|
||||
// TODO(bradfitz): this could be more efficient, but we imagine
|
||||
// the number of SplitDNS routes and/or hosts will be small.
|
||||
for host := range c.Hosts {
|
||||
if !c.hasSplitDNSRouteForHost(host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasSplitDNSRouteForHost reports whether c contains a SplitDNS route
|
||||
// that contains hosts.
|
||||
func (c Config) hasSplitDNSRouteForHost(host dnsname.FQDN) bool {
|
||||
for route := range c.Routes {
|
||||
if route.Contains(host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c Config) hasDefaultResolvers() bool {
|
||||
return len(c.DefaultResolvers) > 0
|
||||
}
|
||||
|
||||
@@ -207,9 +207,14 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig
|
||||
// case where cfg is entirely zero, in which case these
|
||||
// configs clear all Tailscale DNS settings.
|
||||
return rcfg, ocfg, nil
|
||||
case cfg.hasDefaultIPResolversOnly():
|
||||
// Trivial CorpDNS configuration, just override the OS
|
||||
// resolver.
|
||||
case cfg.hasDefaultIPResolversOnly() && !cfg.hasHostsWithoutSplitDNSRoutes():
|
||||
// Trivial CorpDNS configuration, just override the OS resolver.
|
||||
//
|
||||
// If there are hosts (ExtraRecords) that are not covered by an existing
|
||||
// SplitDNS route, then we don't go into this path so that we fall into
|
||||
// the next case and send the extra record hosts queries through
|
||||
// 100.100.100.100 instead where we can answer them.
|
||||
//
|
||||
// TODO: for OSes that support it, pass IP:port and DoH
|
||||
// addresses directly to OS.
|
||||
// https://github.com/tailscale/tailscale/issues/1666
|
||||
|
||||
@@ -199,6 +199,71 @@ func TestManager(t *testing.T) {
|
||||
"bradfitz.ts.com.", "2.3.4.5"),
|
||||
},
|
||||
},
|
||||
{
|
||||
// If Hosts are specified (i.e. ExtraRecords) that aren't a split
|
||||
// DNS route and a global resolver is specified, then make
|
||||
// everything go via 100.100.100.100.
|
||||
name: "hosts-with-global-dns-uses-quad100",
|
||||
split: true,
|
||||
in: Config{
|
||||
DefaultResolvers: mustRes("1.1.1.1", "9.9.9.9"),
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
os: OSConfig{
|
||||
Nameservers: mustIPs("100.100.100.100"),
|
||||
},
|
||||
rs: resolver.Config{
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
Routes: upstreams(".", "1.1.1.1", "9.9.9.9"),
|
||||
},
|
||||
},
|
||||
{
|
||||
// This is the above hosts-with-global-dns-uses-quad100 test but
|
||||
// verifying that if global DNS servers aren't set (the 1.1.1.1 and
|
||||
// 9.9.9.9 above), then we don't configure 100.100.100.100 as the
|
||||
// resolver.
|
||||
name: "hosts-without-global-dns-not-use-quad100",
|
||||
split: true,
|
||||
in: Config{
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
os: OSConfig{},
|
||||
rs: resolver.Config{
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
},
|
||||
{
|
||||
// This tests that ExtraRecords (foo.tld and bar.tld here) don't trigger forcing
|
||||
// traffic through 100.100.100.100 if there's Split DNS support and the extra
|
||||
// records are part of a split DNS route.
|
||||
name: "hosts-with-extrarecord-hosts-with-routes-no-quad100",
|
||||
split: true,
|
||||
in: Config{
|
||||
Routes: upstreams(
|
||||
"tld.", "4.4.4.4",
|
||||
),
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
os: OSConfig{
|
||||
Nameservers: mustIPs("4.4.4.4"),
|
||||
MatchDomains: fqdns("tld."),
|
||||
},
|
||||
rs: resolver.Config{
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "corp",
|
||||
in: Config{
|
||||
|
||||
@@ -388,10 +388,10 @@ func (m windowsManager) Close() error {
|
||||
// Windows DHCP client from sending dynamic DNS updates for our interface to
|
||||
// AD domain controllers.
|
||||
func (m windowsManager) disableDynamicUpdates() error {
|
||||
if err := m.setSingleDWORD(winutil.IPv4TCPIPInterfacePrefix, "EnableDNSUpdate", 0); err != nil {
|
||||
if err := m.setSingleDWORD(winutil.IPv4TCPIPInterfacePrefix, "DisableDynamicUpdate", 1); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.setSingleDWORD(winutil.IPv6TCPIPInterfacePrefix, "EnableDNSUpdate", 0); err != nil {
|
||||
if err := m.setSingleDWORD(winutil.IPv6TCPIPInterfacePrefix, "DisableDynamicUpdate", 1); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -485,11 +485,7 @@ func (m windowsManager) getBasePrimaryResolver() (resolvers []netip.Addr, err er
|
||||
}
|
||||
|
||||
ipLoop:
|
||||
for _, stdip := range ips {
|
||||
ip, ok := netip.AddrFromSlice(stdip)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, ip := range ips {
|
||||
ip = ip.Unmap()
|
||||
// Skip IPv6 site-local resolvers. These are an ancient
|
||||
// and obsolete IPv6 RFC, which Windows still faithfully
|
||||
|
||||
@@ -6,7 +6,6 @@ package interfaces
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -54,22 +53,21 @@ func likelyHomeRouterIPWindows() (ret netip.Addr, ok bool) {
|
||||
return
|
||||
}
|
||||
|
||||
unspec := net.IPv4(0, 0, 0, 0)
|
||||
v4unspec := netip.IPv4Unspecified()
|
||||
var best *winipcfg.MibIPforwardRow2 // best (lowest metric) found so far, or nil
|
||||
|
||||
for i := range rs {
|
||||
r := &rs[i]
|
||||
if r.Loopback || r.DestinationPrefix.PrefixLength != 0 || !r.DestinationPrefix.Prefix.IP().Equal(unspec) {
|
||||
if r.Loopback || r.DestinationPrefix.PrefixLength != 0 || r.DestinationPrefix.Prefix().Addr().Unmap() != v4unspec {
|
||||
// Not a default route, so skip
|
||||
continue
|
||||
}
|
||||
|
||||
ip, ok := netip.AddrFromSlice(r.NextHop.IP())
|
||||
if !ok {
|
||||
ip := r.NextHop.Addr().Unmap()
|
||||
if !ip.IsValid() {
|
||||
// Not a valid gateway, so skip (won't happen though)
|
||||
continue
|
||||
}
|
||||
ip = ip.Unmap()
|
||||
|
||||
if best == nil {
|
||||
best = r
|
||||
|
||||
@@ -77,7 +77,8 @@ type CapabilityVersion int
|
||||
// 38: 2022-08-11: added PingRequest.URLIsNoise
|
||||
// 39: 2022-08-15: clients can talk Noise over arbitrary HTTPS port
|
||||
// 40: 2022-08-22: added Node.KeySignature, PeersChangedPatch.KeySignature
|
||||
const CurrentCapabilityVersion CapabilityVersion = 40
|
||||
// 41: 2022-08-30: uses 100.100.100.100 for route-less ExtraRecords if global nameservers is set
|
||||
const CurrentCapabilityVersion CapabilityVersion = 41
|
||||
|
||||
type StableID string
|
||||
|
||||
@@ -464,25 +465,27 @@ type Service struct {
|
||||
// Because it contains pointers (slices), this type should not be used
|
||||
// as a value type.
|
||||
type Hostinfo struct {
|
||||
IPNVersion string `json:",omitempty"` // version of this code
|
||||
FrontendLogID string `json:",omitempty"` // logtail ID of frontend instance
|
||||
BackendLogID string `json:",omitempty"` // logtail ID of backend instance
|
||||
OS string `json:",omitempty"` // operating system the client runs on (a version.OS value)
|
||||
OSVersion string `json:",omitempty"` // operating system version, with optional distro prefix ("Debian 10.4", "Windows 10 Pro 10.0.19041")
|
||||
Desktop opt.Bool `json:",omitempty"` // if a desktop was detected on Linux
|
||||
Package string `json:",omitempty"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown)
|
||||
DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone12,3")
|
||||
Hostname string `json:",omitempty"` // name of the host the client runs on
|
||||
ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections
|
||||
ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user
|
||||
GoArch string `json:",omitempty"` // the host's GOARCH value (of the running binary)
|
||||
GoVersion string `json:",omitempty"` // Go version binary was built with
|
||||
RoutableIPs []netip.Prefix `json:",omitempty"` // set of IP ranges this client can route
|
||||
RequestTags []string `json:",omitempty"` // set of ACL tags this node wants to claim
|
||||
Services []Service `json:",omitempty"` // services advertised by this machine
|
||||
NetInfo *NetInfo `json:",omitempty"`
|
||||
SSH_HostKeys []string `json:"sshHostKeys,omitempty"` // if advertised
|
||||
Cloud string `json:",omitempty"`
|
||||
IPNVersion string `json:",omitempty"` // version of this code
|
||||
FrontendLogID string `json:",omitempty"` // logtail ID of frontend instance
|
||||
BackendLogID string `json:",omitempty"` // logtail ID of backend instance
|
||||
OS string `json:",omitempty"` // operating system the client runs on (a version.OS value)
|
||||
OSVersion string `json:",omitempty"` // operating system version, with optional distro prefix ("Debian 10.4", "Windows 10 Pro 10.0.19041")
|
||||
Desktop opt.Bool `json:",omitempty"` // if a desktop was detected on Linux
|
||||
Package string `json:",omitempty"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown)
|
||||
DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone12,3")
|
||||
Hostname string `json:",omitempty"` // name of the host the client runs on
|
||||
ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections
|
||||
ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user
|
||||
GoArch string `json:",omitempty"` // the host's GOARCH value (of the running binary)
|
||||
GoVersion string `json:",omitempty"` // Go version binary was built with
|
||||
RoutableIPs []netip.Prefix `json:",omitempty"` // set of IP ranges this client can route
|
||||
RequestTags []string `json:",omitempty"` // set of ACL tags this node wants to claim
|
||||
Services []Service `json:",omitempty"` // services advertised by this machine
|
||||
NetInfo *NetInfo `json:",omitempty"`
|
||||
SSH_HostKeys []string `json:"sshHostKeys,omitempty"` // if advertised
|
||||
Cloud string `json:",omitempty"`
|
||||
Userspace opt.Bool `json:",omitempty"` // if the client is running in userspace (netstack) mode
|
||||
UserspaceRouter opt.Bool `json:",omitempty"` // if the client's subnet router is running in userspace (netstack) mode
|
||||
|
||||
// NOTE: any new fields containing pointers in this type
|
||||
// require changes to Hostinfo.Equal.
|
||||
@@ -1155,12 +1158,15 @@ const (
|
||||
// PingRequest with Types and IP, will send a ping to the IP and send a POST
|
||||
// request containing a PingResponse to the URL containing results.
|
||||
type PingRequest struct {
|
||||
// URL is the URL to send a HEAD request to.
|
||||
// URL is the URL to reply to the PingRequest to.
|
||||
// It will be a unique URL each time. No auth headers are necessary.
|
||||
//
|
||||
// If the client sees multiple PingRequests with the same URL,
|
||||
// subsequent ones should be ignored.
|
||||
// If Types and IP are defined, then URL is the URL to send a POST request to.
|
||||
//
|
||||
// The HTTP method that the node should make back to URL depends on the other
|
||||
// fields of the PingRequest. If Types is defined, then URL is the URL to
|
||||
// send a POST request to. Otherwise, the node should just make a HEAD
|
||||
// request to URL.
|
||||
URL string
|
||||
|
||||
// URLIsNoise, if true, means that the client should hit URL over the Noise
|
||||
@@ -1173,11 +1179,22 @@ type PingRequest struct {
|
||||
|
||||
// Types is the types of ping that are initiated. Can be any PingType, comma
|
||||
// separated, e.g. "disco,TSMP"
|
||||
Types string
|
||||
//
|
||||
// As a special case, if Types is "c2n", then this PingRequest is a
|
||||
// client-to-node HTTP request. The HTTP request should be handled by this
|
||||
// node's c2n handler and the HTTP response sent in a POST to URL. For c2n,
|
||||
// the value of URLIsNoise is ignored and only the Noise transport (back to
|
||||
// the control plane) will be used, as if URLIsNoise were true.
|
||||
Types string `json:",omitempty"`
|
||||
|
||||
// IP is the ping target.
|
||||
// It is used in TSMP pings, if IP is invalid or empty then do a HEAD request to the URL.
|
||||
// IP is the ping target, when needed by the PingType(s) given in Types.
|
||||
IP netip.Addr
|
||||
|
||||
// Payload is the ping payload.
|
||||
//
|
||||
// It is only used for c2n requests, in which case it's an HTTP/1.0 or
|
||||
// HTTP/1.1-formatted HTTP request as parsable with http.ReadRequest.
|
||||
Payload []byte `json:",omitempty"`
|
||||
}
|
||||
|
||||
// PingResponse provides result information for a TSMP or Disco PingRequest.
|
||||
|
||||
@@ -115,25 +115,27 @@ func (src *Hostinfo) Clone() *Hostinfo {
|
||||
|
||||
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
|
||||
var _HostinfoCloneNeedsRegeneration = Hostinfo(struct {
|
||||
IPNVersion string
|
||||
FrontendLogID string
|
||||
BackendLogID string
|
||||
OS string
|
||||
OSVersion string
|
||||
Desktop opt.Bool
|
||||
Package string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
ShareeNode bool
|
||||
GoArch string
|
||||
GoVersion string
|
||||
RoutableIPs []netip.Prefix
|
||||
RequestTags []string
|
||||
Services []Service
|
||||
NetInfo *NetInfo
|
||||
SSH_HostKeys []string
|
||||
Cloud string
|
||||
IPNVersion string
|
||||
FrontendLogID string
|
||||
BackendLogID string
|
||||
OS string
|
||||
OSVersion string
|
||||
Desktop opt.Bool
|
||||
Package string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
ShareeNode bool
|
||||
GoArch string
|
||||
GoVersion string
|
||||
RoutableIPs []netip.Prefix
|
||||
RequestTags []string
|
||||
Services []Service
|
||||
NetInfo *NetInfo
|
||||
SSH_HostKeys []string
|
||||
Cloud string
|
||||
Userspace opt.Bool
|
||||
UserspaceRouter opt.Bool
|
||||
}{})
|
||||
|
||||
// Clone makes a deep copy of NetInfo.
|
||||
|
||||
@@ -37,6 +37,7 @@ func TestHostinfoEqual(t *testing.T) {
|
||||
"GoArch", "GoVersion",
|
||||
"RoutableIPs", "RequestTags",
|
||||
"Services", "NetInfo", "SSH_HostKeys", "Cloud",
|
||||
"Userspace", "UserspaceRouter",
|
||||
}
|
||||
if have := fieldsOf(reflect.TypeOf(Hostinfo{})); !reflect.DeepEqual(have, hiHandles) {
|
||||
t.Errorf("Hostinfo.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
||||
|
||||
@@ -271,29 +271,33 @@ func (v HostinfoView) Services() views.Slice[Service] { return views.SliceOf(
|
||||
func (v HostinfoView) NetInfo() NetInfoView { return v.ж.NetInfo.View() }
|
||||
func (v HostinfoView) SSH_HostKeys() views.Slice[string] { return views.SliceOf(v.ж.SSH_HostKeys) }
|
||||
func (v HostinfoView) Cloud() string { return v.ж.Cloud }
|
||||
func (v HostinfoView) Userspace() opt.Bool { return v.ж.Userspace }
|
||||
func (v HostinfoView) UserspaceRouter() opt.Bool { return v.ж.UserspaceRouter }
|
||||
func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) }
|
||||
|
||||
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
|
||||
var _HostinfoViewNeedsRegeneration = Hostinfo(struct {
|
||||
IPNVersion string
|
||||
FrontendLogID string
|
||||
BackendLogID string
|
||||
OS string
|
||||
OSVersion string
|
||||
Desktop opt.Bool
|
||||
Package string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
ShareeNode bool
|
||||
GoArch string
|
||||
GoVersion string
|
||||
RoutableIPs []netip.Prefix
|
||||
RequestTags []string
|
||||
Services []Service
|
||||
NetInfo *NetInfo
|
||||
SSH_HostKeys []string
|
||||
Cloud string
|
||||
IPNVersion string
|
||||
FrontendLogID string
|
||||
BackendLogID string
|
||||
OS string
|
||||
OSVersion string
|
||||
Desktop opt.Bool
|
||||
Package string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
ShareeNode bool
|
||||
GoArch string
|
||||
GoVersion string
|
||||
RoutableIPs []netip.Prefix
|
||||
RequestTags []string
|
||||
Services []Service
|
||||
NetInfo *NetInfo
|
||||
SSH_HostKeys []string
|
||||
Cloud string
|
||||
Userspace opt.Bool
|
||||
UserspaceRouter opt.Bool
|
||||
}{})
|
||||
|
||||
// View returns a readonly view of NetInfo.
|
||||
|
||||
80
tka/sig.go
80
tka/sig.go
@@ -33,6 +33,19 @@ const (
|
||||
// SigRotation signature and sign it again with their rotation key. That
|
||||
// way, SigRotation nesting should only be 2 deep in the common case.
|
||||
SigRotation
|
||||
// SigCredential describes a signature over a specifi public key, signed
|
||||
// by a key in the tailnet key authority referenced by the specified keyID.
|
||||
// In effect, SigCredential delegates the ability to make a signature to
|
||||
// a different public/private key pair.
|
||||
//
|
||||
// It is intended that a different public/private key pair be generated
|
||||
// for each different SigCredential that is created. Implementors must
|
||||
// take care that the private side is only known to the entity that needs
|
||||
// to generate the wrapping SigRotation signature, and it is immediately
|
||||
// discarded after use.
|
||||
//
|
||||
// SigCredential is expected to be nested in a SigRotation signature.
|
||||
SigCredential
|
||||
)
|
||||
|
||||
func (s SigKind) String() string {
|
||||
@@ -43,6 +56,8 @@ func (s SigKind) String() string {
|
||||
return "direct"
|
||||
case SigRotation:
|
||||
return "rotation"
|
||||
case SigCredential:
|
||||
return "credential"
|
||||
default:
|
||||
return fmt.Sprintf("Sig?<%d>", int(s))
|
||||
}
|
||||
@@ -53,8 +68,9 @@ func (s SigKind) String() string {
|
||||
type NodeKeySignature struct {
|
||||
// SigKind identifies the variety of signature.
|
||||
SigKind SigKind `cbor:"1,keyasint"`
|
||||
// Pubkey identifies the public key which is being authorized.
|
||||
Pubkey []byte `cbor:"2,keyasint"`
|
||||
// Pubkey identifies the key.NodePublic which is being authorized.
|
||||
// SigCredential signatures do not use this field.
|
||||
Pubkey []byte `cbor:"2,keyasint,omitempty"`
|
||||
|
||||
// KeyID identifies which key in the tailnet key authority should
|
||||
// be used to verify this signature. Only set for SigDirect and
|
||||
@@ -69,19 +85,23 @@ type NodeKeySignature struct {
|
||||
// used as Pubkey. Only used for SigRotation signatures.
|
||||
Nested *NodeKeySignature `cbor:"5,keyasint,omitempty"`
|
||||
|
||||
// RotationPubkey specifies the ed25519 public key which may sign a
|
||||
// SigRotation signature, which embeds this one.
|
||||
// WrappingPubkey specifies the ed25519 public key which must be used
|
||||
// to sign a Signature which embeds this one.
|
||||
//
|
||||
// Intermediate SigRotation signatures may omit this value to use the
|
||||
// parent one.
|
||||
RotationPubkey []byte `cbor:"6,keyasint,omitempty"`
|
||||
// For SigRotation signatures multiple levels deep, intermediate
|
||||
// signatures may omit this value, in which case the parent WrappingPubkey
|
||||
// is used.
|
||||
//
|
||||
// SigCredential signatures use this field to specify the public key
|
||||
// they are certifying, following the usual semanticsfor WrappingPubkey.
|
||||
WrappingPubkey []byte `cbor:"6,keyasint,omitempty"`
|
||||
}
|
||||
|
||||
// rotationPublic returns the public key which must sign a SigRotation
|
||||
// signature that embeds this signature, if any.
|
||||
func (s NodeKeySignature) rotationPublic() (pub ed25519.PublicKey, ok bool) {
|
||||
if len(s.RotationPubkey) > 0 {
|
||||
return ed25519.PublicKey(s.RotationPubkey), true
|
||||
// wrappingPublic returns the public key which must sign a signature which
|
||||
// embeds this one, if any.
|
||||
func (s NodeKeySignature) wrappingPublic() (pub ed25519.PublicKey, ok bool) {
|
||||
if len(s.WrappingPubkey) > 0 {
|
||||
return ed25519.PublicKey(s.WrappingPubkey), true
|
||||
}
|
||||
|
||||
switch s.SigKind {
|
||||
@@ -89,7 +109,7 @@ func (s NodeKeySignature) rotationPublic() (pub ed25519.PublicKey, ok bool) {
|
||||
if s.Nested == nil {
|
||||
return nil, false
|
||||
}
|
||||
return s.Nested.rotationPublic()
|
||||
return s.Nested.wrappingPublic()
|
||||
|
||||
default:
|
||||
return nil, false
|
||||
@@ -138,15 +158,18 @@ func (s *NodeKeySignature) Unserialize(data []byte) error {
|
||||
return dec.Unmarshal(data, s)
|
||||
}
|
||||
|
||||
// verifySignature checks that the NodeKeySignature is authentic, certified
|
||||
// by the given verificationKey, and authorizes the given nodeKey.
|
||||
// verifySignature checks that the NodeKeySignature is authentic & certified
|
||||
// by the given verificationKey. Additionally, SigDirect and SigRotation
|
||||
// signatures are checked to ensure they authorize the given nodeKey.
|
||||
func (s *NodeKeySignature) verifySignature(nodeKey key.NodePublic, verificationKey Key) error {
|
||||
nodeBytes, err := nodeKey.MarshalBinary()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling pubkey: %v", err)
|
||||
}
|
||||
if !bytes.Equal(nodeBytes, s.Pubkey) {
|
||||
return errors.New("signature does not authorize nodeKey")
|
||||
if s.SigKind != SigCredential {
|
||||
nodeBytes, err := nodeKey.MarshalBinary()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling pubkey: %v", err)
|
||||
}
|
||||
if !bytes.Equal(nodeBytes, s.Pubkey) {
|
||||
return errors.New("signature does not authorize nodeKey")
|
||||
}
|
||||
}
|
||||
|
||||
sigHash := s.SigHash()
|
||||
@@ -157,7 +180,7 @@ func (s *NodeKeySignature) verifySignature(nodeKey key.NodePublic, verificationK
|
||||
}
|
||||
|
||||
// Verify the signature using the nested rotation key.
|
||||
verifyPub, ok := s.Nested.rotationPublic()
|
||||
verifyPub, ok := s.Nested.wrappingPublic()
|
||||
if !ok {
|
||||
return errors.New("missing rotation key")
|
||||
}
|
||||
@@ -167,15 +190,22 @@ func (s *NodeKeySignature) verifySignature(nodeKey key.NodePublic, verificationK
|
||||
|
||||
// Recurse to verify the signature on the nested structure.
|
||||
var nestedPub key.NodePublic
|
||||
if err := nestedPub.UnmarshalBinary(s.Nested.Pubkey); err != nil {
|
||||
return fmt.Errorf("nested pubkey: %v", err)
|
||||
// SigCredential signatures certify an indirection key rather than a node
|
||||
// key, so theres no need to check the node key.
|
||||
if s.Nested.SigKind != SigCredential {
|
||||
if err := nestedPub.UnmarshalBinary(s.Nested.Pubkey); err != nil {
|
||||
return fmt.Errorf("nested pubkey: %v", err)
|
||||
}
|
||||
}
|
||||
if err := s.Nested.verifySignature(nestedPub, verificationKey); err != nil {
|
||||
return fmt.Errorf("nested: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
case SigDirect:
|
||||
case SigDirect, SigCredential:
|
||||
if s.Nested != nil {
|
||||
return fmt.Errorf("invalid signature: signatures of type %v cannot nest another signature", s.SigKind)
|
||||
}
|
||||
switch verificationKey.Kind {
|
||||
case Key25519:
|
||||
if ed25519consensus.Verify(ed25519.PublicKey(verificationKey.Public), sigHash[:], s.Signature) {
|
||||
|
||||
@@ -67,7 +67,7 @@ func TestSigNested(t *testing.T) {
|
||||
SigKind: SigDirect,
|
||||
KeyID: k.ID(),
|
||||
Pubkey: oldPub,
|
||||
RotationPubkey: rPub,
|
||||
WrappingPubkey: rPub,
|
||||
}
|
||||
sigHash := nestedSig.SigHash()
|
||||
nestedSig.Signature = ed25519.Sign(priv, sigHash[:])
|
||||
@@ -110,6 +110,13 @@ func TestSigNested(t *testing.T) {
|
||||
if err := sig.verifySignature(node.Public(), k); err == nil {
|
||||
t.Error("verifySignature(node) succeeded with bad outer signature")
|
||||
}
|
||||
|
||||
// Test verification fails if the outer signature is signed with a
|
||||
// different public key to whats specified in WrappingPubkey
|
||||
sig.Signature = ed25519.Sign(priv, sigHash[:])
|
||||
if err := sig.verifySignature(node.Public(), k); err == nil {
|
||||
t.Error("verifySignature(node) succeeded with different signature")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigNested_DeepNesting(t *testing.T) {
|
||||
@@ -128,7 +135,7 @@ func TestSigNested_DeepNesting(t *testing.T) {
|
||||
SigKind: SigDirect,
|
||||
KeyID: k.ID(),
|
||||
Pubkey: oldPub,
|
||||
RotationPubkey: rPub,
|
||||
WrappingPubkey: rPub,
|
||||
}
|
||||
sigHash := nestedSig.SigHash()
|
||||
nestedSig.Signature = ed25519.Sign(priv, sigHash[:])
|
||||
@@ -175,6 +182,91 @@ func TestSigNested_DeepNesting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigCredential(t *testing.T) {
|
||||
// Network-lock key (the key used to sign the nested sig)
|
||||
pub, priv := testingKey25519(t, 1)
|
||||
k := Key{Kind: Key25519, Public: pub, Votes: 2}
|
||||
// 'credential' key (the one being delegated to)
|
||||
cPub, cPriv := testingKey25519(t, 2)
|
||||
// The node key being certified
|
||||
node := key.NewNode()
|
||||
nodeKeyPub, _ := node.Public().MarshalBinary()
|
||||
|
||||
// The signature certifying delegated trust to another
|
||||
// public key.
|
||||
nestedSig := NodeKeySignature{
|
||||
SigKind: SigCredential,
|
||||
KeyID: k.ID(),
|
||||
WrappingPubkey: cPub,
|
||||
}
|
||||
sigHash := nestedSig.SigHash()
|
||||
nestedSig.Signature = ed25519.Sign(priv, sigHash[:])
|
||||
|
||||
// The signature authorizing the node key, signed by the
|
||||
// delegated key & embedding the original signature.
|
||||
sig := NodeKeySignature{
|
||||
SigKind: SigRotation,
|
||||
KeyID: k.ID(),
|
||||
Pubkey: nodeKeyPub,
|
||||
Nested: &nestedSig,
|
||||
}
|
||||
sigHash = sig.SigHash()
|
||||
sig.Signature = ed25519.Sign(cPriv, sigHash[:])
|
||||
if err := sig.verifySignature(node.Public(), k); err != nil {
|
||||
t.Fatalf("verifySignature(node) failed: %v", err)
|
||||
}
|
||||
|
||||
// Test verification fails if the wrong verification key is provided
|
||||
kBad := Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}, Votes: 2}
|
||||
if err := sig.verifySignature(node.Public(), kBad); err == nil {
|
||||
t.Error("verifySignature() did not error for wrong verification key")
|
||||
}
|
||||
|
||||
// Test someone can't misuse our public API for verifying node-keys
|
||||
a, _ := Open(newTestchain(t, "G1\nG1.template = genesis",
|
||||
optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{
|
||||
Keys: []Key{k},
|
||||
DisablementSecrets: [][]byte{disablementKDF([]byte{1, 2, 3})},
|
||||
}})).Chonk())
|
||||
if err := a.NodeKeyAuthorized(node.Public(), nestedSig.Serialize()); err == nil {
|
||||
t.Error("NodeKeyAuthorized(SigCredential, node) did not fail")
|
||||
}
|
||||
// but that they can use it properly (nested in a SigRotation)
|
||||
if err := a.NodeKeyAuthorized(node.Public(), sig.Serialize()); err != nil {
|
||||
t.Errorf("NodeKeyAuthorized(SigRotation{SigCredential}, node) failed: %v", err)
|
||||
}
|
||||
|
||||
// Test verification fails if the inner signature is invalid
|
||||
tmp := make([]byte, ed25519.SignatureSize)
|
||||
copy(tmp, nestedSig.Signature)
|
||||
copy(nestedSig.Signature, []byte{1, 2, 3, 4})
|
||||
if err := sig.verifySignature(node.Public(), k); err == nil {
|
||||
t.Error("verifySignature(node) succeeded with bad inner signature")
|
||||
}
|
||||
copy(nestedSig.Signature, tmp)
|
||||
|
||||
// Test verification fails if the outer signature is invalid
|
||||
copy(tmp, sig.Signature)
|
||||
copy(sig.Signature, []byte{1, 2, 3, 4})
|
||||
if err := sig.verifySignature(node.Public(), k); err == nil {
|
||||
t.Error("verifySignature(node) succeeded with bad outer signature")
|
||||
}
|
||||
copy(sig.Signature, tmp)
|
||||
|
||||
// Test verification fails if we attempt to check a different node-key
|
||||
otherNode := key.NewNode()
|
||||
if err := sig.verifySignature(otherNode.Public(), k); err == nil {
|
||||
t.Error("verifySignature(otherNode) succeeded with different principal")
|
||||
}
|
||||
|
||||
// Test verification fails if the outer signature is signed with a
|
||||
// different public key to whats specified in WrappingPubkey
|
||||
sig.Signature = ed25519.Sign(priv, sigHash[:])
|
||||
if err := sig.verifySignature(node.Public(), k); err == nil {
|
||||
t.Error("verifySignature(node) succeeded with different signature")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigSerializeUnserialize(t *testing.T) {
|
||||
nodeKeyPub := []byte{1, 2, 3, 4}
|
||||
pub, priv := testingKey25519(t, 1)
|
||||
|
||||
@@ -29,8 +29,6 @@ type State struct {
|
||||
|
||||
// DisablementSecrets are KDF-derived values which can be used
|
||||
// to turn off the TKA in the event of a consensus-breaking bug.
|
||||
// An AUM of type DisableNL should contain a secret when results
|
||||
// in one of these values when run through the disablement KDF.
|
||||
//
|
||||
// TODO(tom): This is an alpha feature, remove this mechanism once
|
||||
// we have confidence in our implementation.
|
||||
@@ -169,6 +167,9 @@ func (s State) applyVerifiedAUM(update AUM) (State, error) {
|
||||
if update.Meta != nil {
|
||||
k.Meta = update.Meta
|
||||
}
|
||||
if err := k.StaticValidate(); err != nil {
|
||||
return State{}, fmt.Errorf("updated key fails validation: %v", err)
|
||||
}
|
||||
out := s.cloneForUpdate(&update)
|
||||
for i := range out.Keys {
|
||||
if bytes.Equal(out.Keys[i].ID(), update.KeyID) {
|
||||
|
||||
@@ -181,6 +181,7 @@ func TestApplyUpdatesChain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApplyUpdateErrors(t *testing.T) {
|
||||
tooLargeVotes := uint(99999)
|
||||
tcs := []struct {
|
||||
Name string
|
||||
Updates []AUM
|
||||
@@ -205,6 +206,12 @@ func TestApplyUpdateErrors(t *testing.T) {
|
||||
State{},
|
||||
ErrNoSuchKey,
|
||||
},
|
||||
{
|
||||
"UpdateKey now fails validation",
|
||||
[]AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}},
|
||||
State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}},
|
||||
errors.New("updated key fails validation: excessive key weight: 99999 > 4096"),
|
||||
},
|
||||
{
|
||||
"Bad lastAUMHash",
|
||||
[]AUM{
|
||||
|
||||
@@ -673,6 +673,10 @@ func (a *Authority) NodeKeyAuthorized(nodeKey key.NodePublic, nodeKeySignature t
|
||||
if err := decoded.Unserialize(nodeKeySignature); err != nil {
|
||||
return fmt.Errorf("unserialize: %v", err)
|
||||
}
|
||||
if decoded.SigKind == SigCredential {
|
||||
return errors.New("credential signatures cannot authorize nodes on their own")
|
||||
}
|
||||
|
||||
key, err := a.state.GetKey(decoded.KeyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("key: %v", err)
|
||||
|
||||
@@ -374,6 +374,74 @@ func TestAddPingRequest(t *testing.T) {
|
||||
t.Error("all ping attempts failed")
|
||||
}
|
||||
|
||||
func TestC2NPingRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
env := newTestEnv(t)
|
||||
n1 := newTestNode(t, env)
|
||||
n1.StartDaemon()
|
||||
|
||||
n1.AwaitListening()
|
||||
n1.MustUp()
|
||||
n1.AwaitRunning()
|
||||
|
||||
gotPing := make(chan bool, 1)
|
||||
waitPing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("unexpected ping method %q", r.Method)
|
||||
}
|
||||
got, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("ping body read error: %v", err)
|
||||
}
|
||||
const want = "HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nabc"
|
||||
if string(got) != want {
|
||||
t.Errorf("body error\n got: %q\nwant: %q", got, want)
|
||||
}
|
||||
gotPing <- true
|
||||
}))
|
||||
defer waitPing.Close()
|
||||
|
||||
nodes := env.Control.AllNodes()
|
||||
if len(nodes) != 1 {
|
||||
t.Fatalf("expected 1 node, got %d nodes", len(nodes))
|
||||
}
|
||||
|
||||
nodeKey := nodes[0].Key
|
||||
|
||||
// Check that we get at least one ping reply after 10 tries.
|
||||
for try := 1; try <= 10; try++ {
|
||||
t.Logf("ping %v ...", try)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if err := env.Control.AwaitNodeInMapRequest(ctx, nodeKey); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cancel()
|
||||
|
||||
pr := &tailcfg.PingRequest{
|
||||
URL: fmt.Sprintf("%s/ping-%d", waitPing.URL, try),
|
||||
Log: true,
|
||||
Types: "c2n",
|
||||
Payload: []byte("POST /echo HTTP/1.0\r\nContent-Length: 3\r\n\r\nabc"),
|
||||
}
|
||||
if !env.Control.AddPingRequest(nodeKey, pr) {
|
||||
t.Logf("failed to AddPingRequest")
|
||||
continue
|
||||
}
|
||||
|
||||
// Wait for PingRequest to come back
|
||||
pingTimeout := time.NewTimer(2 * time.Second)
|
||||
defer pingTimeout.Stop()
|
||||
select {
|
||||
case <-gotPing:
|
||||
t.Logf("got ping; success")
|
||||
return
|
||||
case <-pingTimeout.C:
|
||||
// Try again.
|
||||
}
|
||||
}
|
||||
t.Error("all ping attempts failed")
|
||||
}
|
||||
|
||||
// Issue 2434: when "down" (WantRunning false), tailscaled shouldn't
|
||||
// be connected to control.
|
||||
func TestNoControlConnWhenDown(t *testing.T) {
|
||||
@@ -737,6 +805,7 @@ func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon {
|
||||
cmd.Args = append(cmd.Args, "-verbose=2")
|
||||
}
|
||||
cmd.Env = append(os.Environ(),
|
||||
"TS_DEBUG_PERMIT_HTTP_C2N=1",
|
||||
"TS_LOG_TARGET="+n.env.LogCatcherServer.URL,
|
||||
"HTTP_PROXY="+n.env.TrafficTrapServer.URL,
|
||||
"HTTPS_PROXY="+n.env.TrafficTrapServer.URL,
|
||||
|
||||
@@ -9,12 +9,13 @@ package logger
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func rusageMaxRSS() float64 {
|
||||
var ru syscall.Rusage
|
||||
err := syscall.Getrusage(syscall.RUSAGE_SELF, &ru)
|
||||
var ru unix.Rusage
|
||||
err := unix.Getrusage(unix.RUSAGE_SELF, &ru)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
179
util/cstruct/cstruct.go
Normal file
179
util/cstruct/cstruct.go
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package cstruct provides a helper for decoding binary data that is in the
|
||||
// form of a padded C structure.
|
||||
package cstruct
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"tailscale.com/util/endian"
|
||||
)
|
||||
|
||||
// Size of a pointer-typed value, in bits
|
||||
const pointerSize = 32 << (^uintptr(0) >> 63)
|
||||
|
||||
// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on
|
||||
// a 16- or 8-bit architecture any time soon.
|
||||
const is64Bit = pointerSize == 64
|
||||
|
||||
// Decoder reads and decodes padded fields from a slice of bytes. All fields
|
||||
// are decoded with native endianness.
|
||||
//
|
||||
// Methods of a Decoder do not return errors, but rather store any error within
|
||||
// the Decoder. The first error can be obtained via the Err method; after the
|
||||
// first error, methods will return the zero value for their type.
|
||||
type Decoder struct {
|
||||
b []byte
|
||||
off int
|
||||
err error
|
||||
dbuf [8]byte // for decoding
|
||||
}
|
||||
|
||||
// NewDecoder creates a Decoder from a byte slice.
|
||||
func NewDecoder(b []byte) *Decoder {
|
||||
return &Decoder{b: b}
|
||||
}
|
||||
|
||||
var errUnsupportedSize = errors.New("unsupported size")
|
||||
|
||||
func padBytes(offset, size int) int {
|
||||
if offset == 0 || size == 1 {
|
||||
return 0
|
||||
}
|
||||
remainder := offset % size
|
||||
return size - remainder
|
||||
}
|
||||
|
||||
func (d *Decoder) getField(b []byte) error {
|
||||
size := len(b)
|
||||
|
||||
// We only support fields that are multiples of 2 (or 1-sized)
|
||||
if size != 1 && size&1 == 1 {
|
||||
return errUnsupportedSize
|
||||
}
|
||||
|
||||
// Fields are aligned to their size
|
||||
padBytes := padBytes(d.off, size)
|
||||
if d.off+size+padBytes > len(d.b) {
|
||||
return io.EOF
|
||||
}
|
||||
d.off += padBytes
|
||||
|
||||
copy(b, d.b[d.off:d.off+size])
|
||||
d.off += size
|
||||
return nil
|
||||
}
|
||||
|
||||
// Err returns the first error that was encountered by this Decoder.
|
||||
func (d *Decoder) Err() error {
|
||||
return d.err
|
||||
}
|
||||
|
||||
// Offset returns the current read offset for data in the buffer.
|
||||
func (d *Decoder) Offset() int {
|
||||
return d.off
|
||||
}
|
||||
|
||||
// Byte returns a single byte from the buffer.
|
||||
func (d *Decoder) Byte() byte {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := d.getField(d.dbuf[0:1]); err != nil {
|
||||
d.err = err
|
||||
return 0
|
||||
}
|
||||
return d.dbuf[0]
|
||||
}
|
||||
|
||||
// Byte returns a number of bytes from the buffer based on the size of the
|
||||
// input slice. No padding is applied.
|
||||
//
|
||||
// If an error is encountered or this Decoder has previously encountered an
|
||||
// error, no changes are made to the provided buffer.
|
||||
func (d *Decoder) Bytes(b []byte) {
|
||||
if d.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// No padding for byte slices
|
||||
size := len(b)
|
||||
if d.off+size >= len(d.b) {
|
||||
d.err = io.EOF
|
||||
return
|
||||
}
|
||||
copy(b, d.b[d.off:d.off+size])
|
||||
d.off += size
|
||||
}
|
||||
|
||||
// Uint16 returns a uint16 decoded from the buffer.
|
||||
func (d *Decoder) Uint16() uint16 {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := d.getField(d.dbuf[0:2]); err != nil {
|
||||
d.err = err
|
||||
return 0
|
||||
}
|
||||
return endian.Native.Uint16(d.dbuf[0:2])
|
||||
}
|
||||
|
||||
// Uint32 returns a uint32 decoded from the buffer.
|
||||
func (d *Decoder) Uint32() uint32 {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := d.getField(d.dbuf[0:4]); err != nil {
|
||||
d.err = err
|
||||
return 0
|
||||
}
|
||||
return endian.Native.Uint32(d.dbuf[0:4])
|
||||
}
|
||||
|
||||
// Uint64 returns a uint64 decoded from the buffer.
|
||||
func (d *Decoder) Uint64() uint64 {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := d.getField(d.dbuf[0:8]); err != nil {
|
||||
d.err = err
|
||||
return 0
|
||||
}
|
||||
return endian.Native.Uint64(d.dbuf[0:8])
|
||||
}
|
||||
|
||||
// Uintptr returns a uintptr decoded from the buffer.
|
||||
func (d *Decoder) Uintptr() uintptr {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if is64Bit {
|
||||
return uintptr(d.Uint64())
|
||||
} else {
|
||||
return uintptr(d.Uint32())
|
||||
}
|
||||
}
|
||||
|
||||
// Int16 returns a int16 decoded from the buffer.
|
||||
func (d *Decoder) Int16() int16 {
|
||||
return int16(d.Uint16())
|
||||
}
|
||||
|
||||
// Int32 returns a int32 decoded from the buffer.
|
||||
func (d *Decoder) Int32() int32 {
|
||||
return int32(d.Uint32())
|
||||
}
|
||||
|
||||
// Int64 returns a int64 decoded from the buffer.
|
||||
func (d *Decoder) Int64() int64 {
|
||||
return int64(d.Uint64())
|
||||
}
|
||||
75
util/cstruct/cstruct_example_test.go
Normal file
75
util/cstruct/cstruct_example_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Only built on 64-bit platforms to avoid complexity
|
||||
|
||||
//go:build amd64 || arm64 || mips64le || ppc64le || riscv64
|
||||
// +build amd64 arm64 mips64le ppc64le riscv64
|
||||
|
||||
package cstruct
|
||||
|
||||
import "fmt"
|
||||
|
||||
// This test provides a semi-realistic example of how you can
|
||||
// use this package to decode a C structure.
|
||||
func ExampleDecoder() {
|
||||
// Our example C structure:
|
||||
// struct mystruct {
|
||||
// char *p;
|
||||
// char c;
|
||||
// /* implicit: char _pad[3]; */
|
||||
// int x;
|
||||
// };
|
||||
//
|
||||
// The Go structure definition:
|
||||
type myStruct struct {
|
||||
Ptr uintptr
|
||||
Ch byte
|
||||
Intval uint32
|
||||
}
|
||||
|
||||
// Our "in-memory" version of the above structure
|
||||
buf := []byte{
|
||||
1, 2, 3, 4, 0, 0, 0, 0, // ptr
|
||||
5, // ch
|
||||
99, 99, 99, // padding
|
||||
78, 6, 0, 0, // x
|
||||
}
|
||||
d := NewDecoder(buf)
|
||||
|
||||
// Decode the structure; if one of these function returns an error,
|
||||
// then subsequent decoder functions will return the zero value.
|
||||
var x myStruct
|
||||
x.Ptr = d.Uintptr()
|
||||
x.Ch = d.Byte()
|
||||
x.Intval = d.Uint32()
|
||||
|
||||
// Note that per the Go language spec:
|
||||
// [...] when evaluating the operands of an expression, assignment,
|
||||
// or return statement, all function calls, method calls, and
|
||||
// (channel) communication operations are evaluated in lexical
|
||||
// left-to-right order
|
||||
//
|
||||
// Since each field is assigned via a function call, one could use the
|
||||
// following snippet to decode the struct.
|
||||
// x := myStruct{
|
||||
// Ptr: d.Uintptr(),
|
||||
// Ch: d.Byte(),
|
||||
// Intval: d.Uint32(),
|
||||
// }
|
||||
//
|
||||
// However, this means that reordering the fields in the initialization
|
||||
// statement–normally a semantically identical operation–would change
|
||||
// the way the structure is parsed. Thus we do it as above with
|
||||
// explicit ordering.
|
||||
|
||||
// After finishing with the decoder, check errors
|
||||
if err := d.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Print the decoder offset and structure
|
||||
fmt.Printf("off=%d struct=%#v\n", d.Offset(), x)
|
||||
// Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e}
|
||||
}
|
||||
151
util/cstruct/cstruct_test.go
Normal file
151
util/cstruct/cstruct_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cstruct
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPadBytes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
offset int
|
||||
size int
|
||||
want int
|
||||
}{
|
||||
// No padding at beginning of structure
|
||||
{0, 1, 0},
|
||||
{0, 2, 0},
|
||||
{0, 4, 0},
|
||||
{0, 8, 0},
|
||||
|
||||
// No padding for single bytes
|
||||
{1, 1, 0},
|
||||
|
||||
// Single byte padding
|
||||
{1, 2, 1},
|
||||
{3, 4, 1},
|
||||
|
||||
// Multi-byte padding
|
||||
{1, 4, 3},
|
||||
{2, 8, 6},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%d_%d_%d", tc.offset, tc.size, tc.want), func(t *testing.T) {
|
||||
got := padBytes(tc.offset, tc.size)
|
||||
if got != tc.want {
|
||||
t.Errorf("got=%d; want=%d", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
t.Run("UnsignedTypes", func(t *testing.T) {
|
||||
dec := func(n int) *Decoder {
|
||||
buf := make([]byte, n)
|
||||
buf[0] = 1
|
||||
|
||||
d := NewDecoder(buf)
|
||||
|
||||
// Use t.Cleanup to perform an assertion on this
|
||||
// decoder after the test code is finished with it.
|
||||
t.Cleanup(func() {
|
||||
if err := d.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
return d
|
||||
}
|
||||
if got := dec(2).Uint16(); got != 1 {
|
||||
t.Errorf("uint16: got=%d; want=1", got)
|
||||
}
|
||||
if got := dec(4).Uint32(); got != 1 {
|
||||
t.Errorf("uint32: got=%d; want=1", got)
|
||||
}
|
||||
if got := dec(8).Uint64(); got != 1 {
|
||||
t.Errorf("uint64: got=%d; want=1", got)
|
||||
}
|
||||
if got := dec(pointerSize / 8).Uintptr(); got != 1 {
|
||||
t.Errorf("uintptr: got=%d; want=1", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SignedTypes", func(t *testing.T) {
|
||||
dec := func(n int) *Decoder {
|
||||
// Make a buffer of the exact size that consists of 0xff bytes
|
||||
buf := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
buf[i] = 0xff
|
||||
}
|
||||
|
||||
d := NewDecoder(buf)
|
||||
|
||||
// Use t.Cleanup to perform an assertion on this
|
||||
// decoder after the test code is finished with it.
|
||||
t.Cleanup(func() {
|
||||
if err := d.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
return d
|
||||
}
|
||||
if got := dec(2).Int16(); got != -1 {
|
||||
t.Errorf("int16: got=%d; want=-1", got)
|
||||
}
|
||||
if got := dec(4).Int32(); got != -1 {
|
||||
t.Errorf("int32: got=%d; want=-1", got)
|
||||
}
|
||||
if got := dec(8).Int64(); got != -1 {
|
||||
t.Errorf("int64: got=%d; want=-1", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsufficientData", func(t *testing.T) {
|
||||
dec := func(n int) *Decoder {
|
||||
// Make a buffer that's too small and contains arbitrary bytes
|
||||
buf := make([]byte, n-1)
|
||||
for i := 0; i < n-1; i++ {
|
||||
buf[i] = 0xAD
|
||||
}
|
||||
|
||||
// Use t.Cleanup to perform an assertion on this
|
||||
// decoder after the test code is finished with it.
|
||||
d := NewDecoder(buf)
|
||||
t.Cleanup(func() {
|
||||
if err := d.Err(); err == nil || !errors.Is(err, io.EOF) {
|
||||
t.Errorf("(n=%d) expected io.EOF; got=%v", n, err)
|
||||
}
|
||||
})
|
||||
return d
|
||||
}
|
||||
|
||||
dec(2).Uint16()
|
||||
dec(4).Uint32()
|
||||
dec(8).Uint64()
|
||||
dec(pointerSize / 8).Uintptr()
|
||||
|
||||
dec(2).Int16()
|
||||
dec(4).Int32()
|
||||
dec(8).Int64()
|
||||
})
|
||||
|
||||
t.Run("Bytes", func(t *testing.T) {
|
||||
d := NewDecoder([]byte("hello worldasdf"))
|
||||
t.Cleanup(func() {
|
||||
if err := d.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
buf := make([]byte, 11)
|
||||
d.Bytes(buf)
|
||||
if got := string(buf); got != "hello world" {
|
||||
t.Errorf("bytes: got=%q; want=%q", got, "hello world")
|
||||
}
|
||||
})
|
||||
}
|
||||
38
util/deephash/debug.go
Normal file
38
util/deephash/debug.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build deephash_debug
|
||||
|
||||
package deephash
|
||||
|
||||
import "fmt"
|
||||
|
||||
func (h *hasher) HashBytes(b []byte) {
|
||||
fmt.Printf("B(%q)+", b)
|
||||
h.Block512.HashBytes(b)
|
||||
}
|
||||
func (h *hasher) HashString(s string) {
|
||||
fmt.Printf("S(%q)+", s)
|
||||
h.Block512.HashString(s)
|
||||
}
|
||||
func (h *hasher) HashUint8(n uint8) {
|
||||
fmt.Printf("U8(%d)+", n)
|
||||
h.Block512.HashUint8(n)
|
||||
}
|
||||
func (h *hasher) HashUint16(n uint16) {
|
||||
fmt.Printf("U16(%d)+", n)
|
||||
h.Block512.HashUint16(n)
|
||||
}
|
||||
func (h *hasher) HashUint32(n uint32) {
|
||||
fmt.Printf("U32(%d)+", n)
|
||||
h.Block512.HashUint32(n)
|
||||
}
|
||||
func (h *hasher) HashUint64(n uint64) {
|
||||
fmt.Printf("U64(%d)+", n)
|
||||
h.Block512.HashUint64(n)
|
||||
}
|
||||
func (h *hasher) Sum(b []byte) []byte {
|
||||
fmt.Println("FIN")
|
||||
return h.Block512.Sum(b)
|
||||
}
|
||||
@@ -6,7 +6,7 @@
|
||||
// without looping. The hash is only valid within the lifetime of a program.
|
||||
// Users should not store the hash on disk or send it over the network.
|
||||
// The hash is sufficiently strong and unique such that
|
||||
// Hash(x) == Hash(y) is an appropriate replacement for x == y.
|
||||
// Hash(&x) == Hash(&y) is an appropriate replacement for x == y.
|
||||
//
|
||||
// The definition of equality is identical to reflect.DeepEqual except:
|
||||
// - Floating-point values are compared based on the raw bits,
|
||||
@@ -24,11 +24,9 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"tailscale.com/util/hashx"
|
||||
)
|
||||
@@ -41,9 +39,10 @@ import (
|
||||
//
|
||||
// The logic below hashes a value by printing it to a hash.Hash.
|
||||
// To be parsable, it assumes that we know the Go type of each value:
|
||||
// * scalar types (e.g., bool or int32) are printed as fixed-width fields.
|
||||
// * list types (e.g., strings, slices, and AppendTo buffers) are prefixed
|
||||
// by a fixed-width length field, followed by the contents of the list.
|
||||
// * scalar types (e.g., bool or int32) are directly printed as their
|
||||
// underlying memory representation.
|
||||
// * list types (e.g., strings and slices) are prefixed by a
|
||||
// fixed-width length field, followed by the contents of the list.
|
||||
// * slices, arrays, and structs print each element/field consecutively.
|
||||
// * interfaces print with a 1-byte prefix indicating whether it is nil.
|
||||
// If non-nil, it is followed by a fixed-width field of the type index,
|
||||
@@ -55,34 +54,46 @@ import (
|
||||
// * maps print with a 1-byte prefix indicating whether the map pointer is
|
||||
// 1) nil, 2) previously seen, or 3) newly seen. Previously seen pointers
|
||||
// are followed by a fixed-width field of the index of the previous pointer.
|
||||
// Newly seen maps are printed as a fixed-width field with the XOR of the
|
||||
// hash of every map entry. With a sufficiently strong hash, this value is
|
||||
// theoretically "parsable" by looking up the hash in a magical map that
|
||||
// returns the set of entries for that given hash.
|
||||
|
||||
// addressableValue is a reflect.Value that is guaranteed to be addressable
|
||||
// such that calling the Addr and Set methods do not panic.
|
||||
//
|
||||
// There is no compile magic that enforces this property,
|
||||
// but rather the need to construct this type makes it easier to examine each
|
||||
// construction site to ensure that this property is upheld.
|
||||
type addressableValue struct{ reflect.Value }
|
||||
|
||||
// newAddressableValue constructs a new addressable value of type t.
|
||||
func newAddressableValue(t reflect.Type) addressableValue {
|
||||
return addressableValue{reflect.New(t).Elem()} // dereferenced pointer is always addressable
|
||||
}
|
||||
|
||||
const scratchSize = 128
|
||||
// Newly seen maps are printed with a fixed-width length field, followed by
|
||||
// a fixed-width field with the XOR of the hash of every map entry.
|
||||
// With a sufficiently strong hash, this value is theoretically "parsable"
|
||||
// by looking up the hash in a magical map that returns the set of entries
|
||||
// for that given hash.
|
||||
|
||||
// hasher is reusable state for hashing a value.
|
||||
// Get one via hasherPool.
|
||||
type hasher struct {
|
||||
hashx.Block512
|
||||
scratch [scratchSize]byte
|
||||
visitStack visitStack
|
||||
}
|
||||
|
||||
var hasherPool = &sync.Pool{
|
||||
New: func() any { return new(hasher) },
|
||||
}
|
||||
|
||||
func (h *hasher) reset() {
|
||||
if h.Block512.Hash == nil {
|
||||
h.Block512.Hash = sha256.New()
|
||||
}
|
||||
h.Block512.Reset()
|
||||
}
|
||||
|
||||
// hashType hashes a reflect.Type.
|
||||
// The hash is only consistent within the lifetime of a program.
|
||||
func (h *hasher) hashType(t reflect.Type) {
|
||||
// This approach relies on reflect.Type always being backed by a unique
|
||||
// *reflect.rtype pointer. A safer approach is to use a global sync.Map
|
||||
// that maps reflect.Type to some arbitrary and unique index.
|
||||
// While safer, it requires global state with memory that can never be GC'd.
|
||||
rtypeAddr := reflect.ValueOf(t).Pointer() // address of *reflect.rtype
|
||||
h.HashUint64(uint64(rtypeAddr))
|
||||
}
|
||||
|
||||
func (h *hasher) sum() (s Sum) {
|
||||
h.Sum(s.sum[:0])
|
||||
return s
|
||||
}
|
||||
|
||||
// Sum is an opaque checksum type that is comparable.
|
||||
type Sum struct {
|
||||
sum [sha256.Size]byte
|
||||
@@ -107,92 +118,57 @@ func initSeed() {
|
||||
seed = uint64(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (h *hasher) Reset() {
|
||||
if h.Block512.Hash == nil {
|
||||
h.Block512.Hash = sha256.New()
|
||||
}
|
||||
h.Block512.Reset()
|
||||
}
|
||||
|
||||
func (h *hasher) sum() (s Sum) {
|
||||
h.Sum(s.sum[:0])
|
||||
return s
|
||||
}
|
||||
|
||||
var hasherPool = &sync.Pool{
|
||||
New: func() any { return new(hasher) },
|
||||
}
|
||||
|
||||
// Hash returns the hash of v.
|
||||
// For performance, this should be a non-nil pointer.
|
||||
func Hash(v any) (s Sum) {
|
||||
func Hash[T any](v *T) Sum {
|
||||
h := hasherPool.Get().(*hasher)
|
||||
defer hasherPool.Put(h)
|
||||
h.Reset()
|
||||
h.reset()
|
||||
seedOnce.Do(initSeed)
|
||||
h.HashUint64(seed)
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.IsValid() {
|
||||
var va addressableValue
|
||||
if rv.Kind() == reflect.Pointer && !rv.IsNil() {
|
||||
va = addressableValue{rv.Elem()} // dereferenced pointer is always addressable
|
||||
} else {
|
||||
va = newAddressableValue(rv.Type())
|
||||
va.Set(rv)
|
||||
}
|
||||
|
||||
// Always treat the Hash input as an interface (it is), including hashing
|
||||
// its type, otherwise two Hash calls of different types could hash to the
|
||||
// same bytes off the different types and get equivalent Sum values. This is
|
||||
// the same thing that we do for reflect.Kind Interface in hashValue, but
|
||||
// the initial reflect.ValueOf from an interface value effectively strips
|
||||
// the interface box off so we have to do it at the top level by hand.
|
||||
h.hashType(va.Type())
|
||||
ti := getTypeInfo(va.Type())
|
||||
ti.hasher()(h, va)
|
||||
// Always treat the Hash input as if it were an interface by including
|
||||
// a hash of the type. This ensures that hashing of two different types
|
||||
// but with the same value structure produces different hashes.
|
||||
t := reflect.TypeOf(v).Elem()
|
||||
h.hashType(t)
|
||||
if v == nil {
|
||||
h.HashUint8(0) // indicates nil
|
||||
} else {
|
||||
h.HashUint8(1) // indicates visiting pointer element
|
||||
p := pointerOf(reflect.ValueOf(v))
|
||||
hash := lookupTypeHasher(t)
|
||||
hash(h, p)
|
||||
}
|
||||
return h.sum()
|
||||
}
|
||||
|
||||
// HasherForType is like Hash, but it returns a Hash func that's specialized for
|
||||
// the provided reflect type, avoiding a map lookup per value.
|
||||
func HasherForType[T any]() func(T) Sum {
|
||||
var zeroT T
|
||||
t := reflect.TypeOf(zeroT)
|
||||
ti := getTypeInfo(t)
|
||||
var tiElem *typeInfo
|
||||
if t.Kind() == reflect.Pointer {
|
||||
tiElem = getTypeInfo(t.Elem())
|
||||
}
|
||||
// HasherForType returns a hash that is specialized for the provided type.
|
||||
func HasherForType[T any]() func(*T) Sum {
|
||||
var v *T
|
||||
seedOnce.Do(initSeed)
|
||||
|
||||
return func(v T) (s Sum) {
|
||||
t := reflect.TypeOf(v).Elem()
|
||||
hash := lookupTypeHasher(t)
|
||||
return func(v *T) (s Sum) {
|
||||
// This logic is identical to Hash, but pull out a few statements.
|
||||
h := hasherPool.Get().(*hasher)
|
||||
defer hasherPool.Put(h)
|
||||
h.Reset()
|
||||
h.reset()
|
||||
h.HashUint64(seed)
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
if rv.IsValid() {
|
||||
if rv.Kind() == reflect.Pointer && !rv.IsNil() {
|
||||
va := addressableValue{rv.Elem()} // dereferenced pointer is always addressable
|
||||
h.hashType(va.Type())
|
||||
tiElem.hasher()(h, va)
|
||||
} else {
|
||||
va := newAddressableValue(rv.Type())
|
||||
va.Set(rv)
|
||||
h.hashType(va.Type())
|
||||
ti.hasher()(h, va)
|
||||
}
|
||||
h.hashType(t)
|
||||
if v == nil {
|
||||
h.HashUint8(0) // indicates nil
|
||||
} else {
|
||||
h.HashUint8(1) // indicates visiting pointer element
|
||||
p := pointerOf(reflect.ValueOf(v))
|
||||
hash(h, p)
|
||||
}
|
||||
return h.sum()
|
||||
}
|
||||
}
|
||||
|
||||
// Update sets last to the hash of v and reports whether its value changed.
|
||||
func Update(last *Sum, v any) (changed bool) {
|
||||
func Update[T any](last *Sum, v *T) (changed bool) {
|
||||
sum := Hash(v)
|
||||
changed = sum != *last
|
||||
if changed {
|
||||
@@ -201,129 +177,29 @@ func Update(last *Sum, v any) (changed bool) {
|
||||
return changed
|
||||
}
|
||||
|
||||
// typeInfo describes properties of a type.
|
||||
//
|
||||
// A non-nil typeInfo is populated into the typeHasher map
|
||||
// when its type is first requested, before its func is created.
|
||||
// Its func field fn is only populated once the type has been created.
|
||||
// This is used for recursive types.
|
||||
type typeInfo struct {
|
||||
rtype reflect.Type
|
||||
isRecursive bool
|
||||
// typeHasherFunc hashes the value pointed at by p for a given type.
|
||||
// For example, if t is a bool, then p is a *bool.
|
||||
// The provided pointer must always be non-nil.
|
||||
type typeHasherFunc func(h *hasher, p pointer)
|
||||
|
||||
// elemTypeInfo is the element type's typeInfo.
|
||||
// It's set when rtype is of Kind Ptr, Slice, Array, Map.
|
||||
elemTypeInfo *typeInfo
|
||||
var typeHasherCache sync.Map // map[reflect.Type]typeHasherFunc
|
||||
|
||||
// keyTypeInfo is the map key type's typeInfo.
|
||||
// It's set when rtype is of Kind Map.
|
||||
keyTypeInfo *typeInfo
|
||||
|
||||
hashFuncOnce sync.Once
|
||||
hashFuncLazy typeHasherFunc // nil until created
|
||||
}
|
||||
|
||||
type typeHasherFunc func(h *hasher, v addressableValue)
|
||||
|
||||
var typeInfoMap sync.Map // map[reflect.Type]*typeInfo
|
||||
var typeInfoMapPopulate sync.Mutex // just for adding to typeInfoMap
|
||||
|
||||
func (ti *typeInfo) hasher() typeHasherFunc {
|
||||
ti.hashFuncOnce.Do(ti.buildHashFuncOnce)
|
||||
return ti.hashFuncLazy
|
||||
}
|
||||
|
||||
func (ti *typeInfo) buildHashFuncOnce() {
|
||||
ti.hashFuncLazy = genTypeHasher(ti)
|
||||
}
|
||||
|
||||
// fieldInfo describes a struct field.
|
||||
type fieldInfo struct {
|
||||
index int // index of field for reflect.Value.Field(n); -1 if invalid
|
||||
typeInfo *typeInfo
|
||||
canMemHash bool
|
||||
offset uintptr // when we can memhash the field
|
||||
size uintptr // when we can memhash the field
|
||||
}
|
||||
|
||||
// mergeContiguousFieldsCopy returns a copy of f with contiguous memhashable fields
|
||||
// merged together. Such fields get a bogus index and fu value.
|
||||
func mergeContiguousFieldsCopy(in []fieldInfo) []fieldInfo {
|
||||
ret := make([]fieldInfo, 0, len(in))
|
||||
var last *fieldInfo
|
||||
for _, f := range in {
|
||||
// Combine two fields if they're both contiguous & memhash-able.
|
||||
if f.canMemHash && last != nil && last.canMemHash && last.offset+last.size == f.offset {
|
||||
last.size += f.size
|
||||
last.index = -1
|
||||
last.typeInfo = nil
|
||||
} else {
|
||||
ret = append(ret, f)
|
||||
last = &ret[len(ret)-1]
|
||||
}
|
||||
func lookupTypeHasher(t reflect.Type) typeHasherFunc {
|
||||
if v, ok := typeHasherCache.Load(t); ok {
|
||||
return v.(typeHasherFunc)
|
||||
}
|
||||
return ret
|
||||
hash := makeTypeHasher(t)
|
||||
v, _ := typeHasherCache.LoadOrStore(t, hash)
|
||||
return v.(typeHasherFunc)
|
||||
}
|
||||
|
||||
// genHashStructFields generates a typeHasherFunc for t, which must be of kind Struct.
|
||||
func genHashStructFields(t reflect.Type) typeHasherFunc {
|
||||
fields := make([]fieldInfo, 0, t.NumField())
|
||||
for i, n := 0, t.NumField(); i < n; i++ {
|
||||
sf := t.Field(i)
|
||||
if sf.Type.Size() == 0 {
|
||||
continue
|
||||
}
|
||||
fields = append(fields, fieldInfo{
|
||||
index: i,
|
||||
typeInfo: getTypeInfo(sf.Type),
|
||||
canMemHash: typeIsMemHashable(sf.Type),
|
||||
offset: sf.Offset,
|
||||
size: sf.Type.Size(),
|
||||
})
|
||||
}
|
||||
fields = mergeContiguousFieldsCopy(fields)
|
||||
return structHasher{fields}.hash
|
||||
}
|
||||
|
||||
type structHasher struct {
|
||||
fields []fieldInfo
|
||||
}
|
||||
|
||||
func (sh structHasher) hash(h *hasher, v addressableValue) {
|
||||
base := v.Addr().UnsafePointer()
|
||||
for _, f := range sh.fields {
|
||||
if f.canMemHash {
|
||||
h.HashBytes(unsafe.Slice((*byte)(unsafe.Pointer(uintptr(base)+f.offset)), f.size))
|
||||
continue
|
||||
}
|
||||
va := addressableValue{v.Field(f.index)} // field is addressable if parent struct is addressable
|
||||
f.typeInfo.hasher()(h, va)
|
||||
}
|
||||
}
|
||||
|
||||
// genHashPtrToMemoryRange returns a hasher where the reflect.Value is a Ptr to
|
||||
// the provided eleType.
|
||||
func genHashPtrToMemoryRange(eleType reflect.Type) typeHasherFunc {
|
||||
size := eleType.Size()
|
||||
return func(h *hasher, v addressableValue) {
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
} else {
|
||||
h.HashUint8(1) // indicates visiting a pointer
|
||||
h.HashBytes(unsafe.Slice((*byte)(v.UnsafePointer()), size))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func genTypeHasher(ti *typeInfo) typeHasherFunc {
|
||||
t := ti.rtype
|
||||
|
||||
func makeTypeHasher(t reflect.Type) typeHasherFunc {
|
||||
// Types with specific hashing.
|
||||
switch t {
|
||||
case timeTimeType:
|
||||
return (*hasher).hashTimev
|
||||
return hashTime
|
||||
case netipAddrType:
|
||||
return (*hasher).hashAddrv
|
||||
return hashAddr
|
||||
}
|
||||
|
||||
// Types that can have their memory representation directly hashed.
|
||||
@@ -333,107 +209,40 @@ func genTypeHasher(ti *typeInfo) typeHasherFunc {
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
return (*hasher).hashString
|
||||
case reflect.Slice:
|
||||
et := t.Elem()
|
||||
if typeIsMemHashable(et) {
|
||||
return (*hasher).hashSliceMem
|
||||
}
|
||||
eti := getTypeInfo(et)
|
||||
return genHashSliceElements(eti)
|
||||
return hashString
|
||||
case reflect.Array:
|
||||
et := t.Elem()
|
||||
eti := getTypeInfo(et)
|
||||
return genHashArray(t, eti)
|
||||
return makeArrayHasher(t)
|
||||
case reflect.Slice:
|
||||
return makeSliceHasher(t)
|
||||
case reflect.Struct:
|
||||
return genHashStructFields(t)
|
||||
return makeStructHasher(t)
|
||||
case reflect.Map:
|
||||
return func(h *hasher, v addressableValue) {
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
if ti.isRecursive {
|
||||
ptr := pointerOf(v)
|
||||
if idx, ok := h.visitStack.seen(ptr); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.visitStack.push(ptr)
|
||||
defer h.visitStack.pop(ptr)
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting a map
|
||||
h.hashMap(v, ti, ti.isRecursive)
|
||||
}
|
||||
return makeMapHasher(t)
|
||||
case reflect.Pointer:
|
||||
et := t.Elem()
|
||||
if typeIsMemHashable(et) {
|
||||
return genHashPtrToMemoryRange(et)
|
||||
}
|
||||
eti := getTypeInfo(et)
|
||||
return func(h *hasher, v addressableValue) {
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
if ti.isRecursive {
|
||||
ptr := pointerOf(v)
|
||||
if idx, ok := h.visitStack.seen(ptr); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.visitStack.push(ptr)
|
||||
defer h.visitStack.pop(ptr)
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting a pointer
|
||||
va := addressableValue{v.Elem()} // dereferenced pointer is always addressable
|
||||
eti.hasher()(h, va)
|
||||
}
|
||||
return makePointerHasher(t)
|
||||
case reflect.Interface:
|
||||
return func(h *hasher, v addressableValue) {
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
va := newAddressableValue(v.Elem().Type())
|
||||
va.Set(v.Elem())
|
||||
|
||||
h.HashUint8(1) // indicates visiting interface value
|
||||
h.hashType(va.Type())
|
||||
ti := getTypeInfo(va.Type())
|
||||
ti.hasher()(h, va)
|
||||
}
|
||||
return makeInterfaceHasher(t)
|
||||
default: // Func, Chan, UnsafePointer
|
||||
return noopHasherFunc
|
||||
return func(*hasher, pointer) {}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hasher) hashString(v addressableValue) {
|
||||
s := v.String()
|
||||
h.HashUint64(uint64(len(s)))
|
||||
h.HashString(s)
|
||||
}
|
||||
|
||||
// hashTimev hashes v, of kind time.Time.
|
||||
func (h *hasher) hashTimev(v addressableValue) {
|
||||
func hashTime(h *hasher, p pointer) {
|
||||
// Include the zone offset (but not the name) to keep
|
||||
// Hash(t1) == Hash(t2) being semantically equivalent to
|
||||
// t1.Format(time.RFC3339Nano) == t2.Format(time.RFC3339Nano).
|
||||
t := *(*time.Time)(v.Addr().UnsafePointer())
|
||||
t := *p.asTime()
|
||||
_, offset := t.Zone()
|
||||
h.HashUint64(uint64(t.Unix()))
|
||||
h.HashUint32(uint32(t.Nanosecond()))
|
||||
h.HashUint32(uint32(offset))
|
||||
}
|
||||
|
||||
// hashAddrv hashes v, of type netip.Addr.
|
||||
func (h *hasher) hashAddrv(v addressableValue) {
|
||||
func hashAddr(h *hasher, p pointer) {
|
||||
// The formatting of netip.Addr covers the
|
||||
// IP version, the address, and the optional zone name (for v6).
|
||||
// This is equivalent to a1.MarshalBinary() == a2.MarshalBinary().
|
||||
ip := *(*netip.Addr)(v.Addr().UnsafePointer())
|
||||
ip := *p.asAddr()
|
||||
switch {
|
||||
case !ip.IsValid():
|
||||
h.HashUint64(0)
|
||||
@@ -451,121 +260,254 @@ func (h *hasher) hashAddrv(v addressableValue) {
|
||||
}
|
||||
}
|
||||
|
||||
func hashString(h *hasher, p pointer) {
|
||||
s := *p.asString()
|
||||
h.HashUint64(uint64(len(s)))
|
||||
h.HashString(s)
|
||||
}
|
||||
|
||||
func makeMemHasher(n uintptr) typeHasherFunc {
|
||||
return func(h *hasher, v addressableValue) {
|
||||
h.HashBytes(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), n))
|
||||
return func(h *hasher, p pointer) {
|
||||
h.HashBytes(p.asMemory(n))
|
||||
}
|
||||
}
|
||||
|
||||
// hashSliceMem hashes v, of kind Slice, with a memhash-able element type.
|
||||
func (h *hasher) hashSliceMem(v addressableValue) {
|
||||
vLen := v.Len()
|
||||
h.HashUint64(uint64(vLen))
|
||||
if vLen == 0 {
|
||||
return
|
||||
func makeArrayHasher(t reflect.Type) typeHasherFunc {
|
||||
var once sync.Once
|
||||
var hashElem typeHasherFunc
|
||||
init := func() {
|
||||
hashElem = lookupTypeHasher(t.Elem())
|
||||
}
|
||||
h.HashBytes(unsafe.Slice((*byte)(v.UnsafePointer()), v.Type().Elem().Size()*uintptr(vLen)))
|
||||
}
|
||||
|
||||
func genHashArrayMem(n int, arraySize uintptr, efu *typeInfo) typeHasherFunc {
|
||||
return func(h *hasher, v addressableValue) {
|
||||
h.HashBytes(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), arraySize))
|
||||
}
|
||||
}
|
||||
|
||||
func genHashArrayElements(n int, eti *typeInfo) typeHasherFunc {
|
||||
return func(h *hasher, v addressableValue) {
|
||||
n := t.Len() // number of array elements
|
||||
nb := t.Elem().Size() // byte size of each array element
|
||||
return func(h *hasher, p pointer) {
|
||||
once.Do(init)
|
||||
for i := 0; i < n; i++ {
|
||||
va := addressableValue{v.Index(i)} // element is addressable if parent array is addressable
|
||||
eti.hasher()(h, va)
|
||||
hashElem(h, p.arrayIndex(i, nb))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func noopHasherFunc(h *hasher, v addressableValue) {}
|
||||
|
||||
func genHashArray(t reflect.Type, eti *typeInfo) typeHasherFunc {
|
||||
if t.Size() == 0 {
|
||||
return noopHasherFunc
|
||||
func makeSliceHasher(t reflect.Type) typeHasherFunc {
|
||||
nb := t.Elem().Size() // byte size of each slice element
|
||||
if typeIsMemHashable(t.Elem()) {
|
||||
return func(h *hasher, p pointer) {
|
||||
pa := p.sliceArray()
|
||||
if pa.isNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting slice
|
||||
n := p.sliceLen()
|
||||
b := pa.asMemory(uintptr(n) * nb)
|
||||
h.HashUint64(uint64(n))
|
||||
h.HashBytes(b)
|
||||
}
|
||||
}
|
||||
et := t.Elem()
|
||||
if typeIsMemHashable(et) {
|
||||
return genHashArrayMem(t.Len(), t.Size(), eti)
|
||||
|
||||
var once sync.Once
|
||||
var hashElem typeHasherFunc
|
||||
init := func() {
|
||||
hashElem = lookupTypeHasher(t.Elem())
|
||||
if typeIsRecursive(t) {
|
||||
hashElemDefault := hashElem
|
||||
hashElem = func(h *hasher, p pointer) {
|
||||
if idx, ok := h.visitStack.seen(p.p); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting slice element
|
||||
h.visitStack.push(p.p)
|
||||
defer h.visitStack.pop(p.p)
|
||||
hashElemDefault(h, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
n := t.Len()
|
||||
return genHashArrayElements(n, eti)
|
||||
}
|
||||
|
||||
func genHashSliceElements(eti *typeInfo) typeHasherFunc {
|
||||
return sliceElementHasher{eti}.hash
|
||||
}
|
||||
|
||||
type sliceElementHasher struct {
|
||||
eti *typeInfo
|
||||
}
|
||||
|
||||
func (seh sliceElementHasher) hash(h *hasher, v addressableValue) {
|
||||
vLen := v.Len()
|
||||
h.HashUint64(uint64(vLen))
|
||||
for i := 0; i < vLen; i++ {
|
||||
va := addressableValue{v.Index(i)} // slice elements are always addressable
|
||||
seh.eti.hasher()(h, va)
|
||||
return func(h *hasher, p pointer) {
|
||||
pa := p.sliceArray()
|
||||
if pa.isNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
once.Do(init)
|
||||
h.HashUint8(1) // indicates visiting slice
|
||||
n := p.sliceLen()
|
||||
h.HashUint64(uint64(n))
|
||||
for i := 0; i < n; i++ {
|
||||
pe := pa.arrayIndex(i, nb)
|
||||
hashElem(h, pe)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTypeInfo(t reflect.Type) *typeInfo {
|
||||
if f, ok := typeInfoMap.Load(t); ok {
|
||||
return f.(*typeInfo)
|
||||
func makeStructHasher(t reflect.Type) typeHasherFunc {
|
||||
type fieldHasher struct {
|
||||
idx int // index of field for reflect.Type.Field(n); negative if memory is directly hashable
|
||||
hash typeHasherFunc // only valid if idx is not negative
|
||||
offset uintptr
|
||||
size uintptr
|
||||
}
|
||||
typeInfoMapPopulate.Lock()
|
||||
defer typeInfoMapPopulate.Unlock()
|
||||
newTypes := map[reflect.Type]*typeInfo{}
|
||||
ti := getTypeInfoLocked(t, newTypes)
|
||||
for t, ti := range newTypes {
|
||||
typeInfoMap.Store(t, ti)
|
||||
var once sync.Once
|
||||
var fields []fieldHasher
|
||||
init := func() {
|
||||
for i, numField := 0, t.NumField(); i < numField; i++ {
|
||||
sf := t.Field(i)
|
||||
f := fieldHasher{i, nil, sf.Offset, sf.Type.Size()}
|
||||
if typeIsMemHashable(sf.Type) {
|
||||
f.idx = -1
|
||||
}
|
||||
|
||||
// Combine with previous field if both contiguous and mem-hashable.
|
||||
if f.idx < 0 && len(fields) > 0 {
|
||||
if last := &fields[len(fields)-1]; last.idx < 0 && last.offset+last.size == f.offset {
|
||||
last.size += f.size
|
||||
continue
|
||||
}
|
||||
}
|
||||
fields = append(fields, f)
|
||||
}
|
||||
|
||||
for i, f := range fields {
|
||||
if f.idx >= 0 {
|
||||
fields[i].hash = lookupTypeHasher(t.Field(f.idx).Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return func(h *hasher, p pointer) {
|
||||
once.Do(init)
|
||||
for _, field := range fields {
|
||||
pf := p.structField(field.idx, field.offset, field.size)
|
||||
if field.idx < 0 {
|
||||
h.HashBytes(pf.asMemory(field.size))
|
||||
} else {
|
||||
field.hash(h, pf)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ti
|
||||
}
|
||||
|
||||
func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *typeInfo {
|
||||
if v, ok := typeInfoMap.Load(t); ok {
|
||||
return v.(*typeInfo)
|
||||
}
|
||||
if ti, ok := incomplete[t]; ok {
|
||||
return ti
|
||||
}
|
||||
ti := &typeInfo{
|
||||
rtype: t,
|
||||
isRecursive: typeIsRecursive(t),
|
||||
}
|
||||
incomplete[t] = ti
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Map:
|
||||
ti.keyTypeInfo = getTypeInfoLocked(t.Key(), incomplete)
|
||||
fallthrough
|
||||
case reflect.Ptr, reflect.Slice, reflect.Array:
|
||||
ti.elemTypeInfo = getTypeInfoLocked(t.Elem(), incomplete)
|
||||
func makeMapHasher(t reflect.Type) typeHasherFunc {
|
||||
var once sync.Once
|
||||
var hashKey, hashValue typeHasherFunc
|
||||
var isRecursive bool
|
||||
init := func() {
|
||||
hashKey = lookupTypeHasher(t.Key())
|
||||
hashValue = lookupTypeHasher(t.Elem())
|
||||
isRecursive = typeIsRecursive(t)
|
||||
}
|
||||
|
||||
return ti
|
||||
return func(h *hasher, p pointer) {
|
||||
v := p.asValue(t).Elem() // reflect.Map kind
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
once.Do(init)
|
||||
if isRecursive {
|
||||
pm := v.UnsafePointer() // underlying pointer of map
|
||||
if idx, ok := h.visitStack.seen(pm); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.visitStack.push(pm)
|
||||
defer h.visitStack.pop(pm)
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting map entries
|
||||
h.HashUint64(uint64(v.Len()))
|
||||
|
||||
mh := mapHasherPool.Get().(*mapHasher)
|
||||
defer mapHasherPool.Put(mh)
|
||||
|
||||
// Hash a map in a sort-free mannar.
|
||||
// It relies on a map being a an unordered set of KV entries.
|
||||
// So long as we hash each KV entry together, we can XOR all the
|
||||
// individual hashes to produce a unique hash for the entire map.
|
||||
k := mh.valKey.get(v.Type().Key())
|
||||
e := mh.valElem.get(v.Type().Elem())
|
||||
mh.sum = Sum{}
|
||||
mh.h.visitStack = h.visitStack // always use the parent's visit stack to avoid cycles
|
||||
for iter := v.MapRange(); iter.Next(); {
|
||||
k.SetIterKey(iter)
|
||||
e.SetIterValue(iter)
|
||||
mh.h.reset()
|
||||
hashKey(&mh.h, pointerOf(k.Addr()))
|
||||
hashValue(&mh.h, pointerOf(e.Addr()))
|
||||
mh.sum.xor(mh.h.sum())
|
||||
}
|
||||
h.HashBytes(mh.sum.sum[:])
|
||||
}
|
||||
}
|
||||
|
||||
func makePointerHasher(t reflect.Type) typeHasherFunc {
|
||||
var once sync.Once
|
||||
var hashElem typeHasherFunc
|
||||
var isRecursive bool
|
||||
init := func() {
|
||||
hashElem = lookupTypeHasher(t.Elem())
|
||||
isRecursive = typeIsRecursive(t)
|
||||
}
|
||||
return func(h *hasher, p pointer) {
|
||||
pe := p.pointerElem()
|
||||
if pe.isNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
once.Do(init)
|
||||
if isRecursive {
|
||||
if idx, ok := h.visitStack.seen(pe.p); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.visitStack.push(pe.p)
|
||||
defer h.visitStack.pop(pe.p)
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting a pointer element
|
||||
hashElem(h, pe)
|
||||
}
|
||||
}
|
||||
|
||||
func makeInterfaceHasher(t reflect.Type) typeHasherFunc {
|
||||
return func(h *hasher, p pointer) {
|
||||
v := p.asValue(t).Elem() // reflect.Interface kind
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting an interface value
|
||||
v = v.Elem()
|
||||
t := v.Type()
|
||||
h.hashType(t)
|
||||
va := reflect.New(t).Elem()
|
||||
va.Set(v)
|
||||
hashElem := lookupTypeHasher(t)
|
||||
hashElem(h, pointerOf(va.Addr()))
|
||||
}
|
||||
}
|
||||
|
||||
type mapHasher struct {
|
||||
h hasher
|
||||
valKey, valElem valueCache // re-usable values for map iteration
|
||||
h hasher
|
||||
valKey valueCache
|
||||
valElem valueCache
|
||||
sum Sum
|
||||
}
|
||||
|
||||
var mapHasherPool = &sync.Pool{
|
||||
New: func() any { return new(mapHasher) },
|
||||
}
|
||||
|
||||
type valueCache map[reflect.Type]addressableValue
|
||||
type valueCache map[reflect.Type]reflect.Value
|
||||
|
||||
func (c *valueCache) get(t reflect.Type) addressableValue {
|
||||
// get returns an addressable reflect.Value for the given type.
|
||||
func (c *valueCache) get(t reflect.Type) reflect.Value {
|
||||
v, ok := (*c)[t]
|
||||
if !ok {
|
||||
v = newAddressableValue(t)
|
||||
v = reflect.New(t).Elem()
|
||||
if *c == nil {
|
||||
*c = make(valueCache)
|
||||
}
|
||||
@@ -573,72 +515,3 @@ func (c *valueCache) get(t reflect.Type) addressableValue {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// hashMap hashes a map in a sort-free manner.
|
||||
// It relies on a map being a functionally an unordered set of KV entries.
|
||||
// So long as we hash each KV entry together, we can XOR all
|
||||
// of the individual hashes to produce a unique hash for the entire map.
|
||||
func (h *hasher) hashMap(v addressableValue, ti *typeInfo, checkCycles bool) {
|
||||
mh := mapHasherPool.Get().(*mapHasher)
|
||||
defer mapHasherPool.Put(mh)
|
||||
|
||||
var sum Sum
|
||||
if v.IsNil() {
|
||||
sum.sum[0] = 1 // something non-zero
|
||||
}
|
||||
|
||||
k := mh.valKey.get(v.Type().Key())
|
||||
e := mh.valElem.get(v.Type().Elem())
|
||||
mh.h.visitStack = h.visitStack // always use the parent's visit stack to avoid cycles
|
||||
for iter := v.MapRange(); iter.Next(); {
|
||||
k.SetIterKey(iter)
|
||||
e.SetIterValue(iter)
|
||||
mh.h.Reset()
|
||||
ti.keyTypeInfo.hasher()(&mh.h, k)
|
||||
ti.elemTypeInfo.hasher()(&mh.h, e)
|
||||
sum.xor(mh.h.sum())
|
||||
}
|
||||
h.HashBytes(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation
|
||||
}
|
||||
|
||||
// visitStack is a stack of pointers visited.
|
||||
// Pointers are pushed onto the stack when visited, and popped when leaving.
|
||||
// The integer value is the depth at which the pointer was visited.
|
||||
// The length of this stack should be zero after every hashing operation.
|
||||
type visitStack map[pointer]int
|
||||
|
||||
func (v visitStack) seen(p pointer) (int, bool) {
|
||||
idx, ok := v[p]
|
||||
return idx, ok
|
||||
}
|
||||
|
||||
func (v *visitStack) push(p pointer) {
|
||||
if *v == nil {
|
||||
*v = make(map[pointer]int)
|
||||
}
|
||||
(*v)[p] = len(*v)
|
||||
}
|
||||
|
||||
func (v visitStack) pop(p pointer) {
|
||||
delete(v, p)
|
||||
}
|
||||
|
||||
// pointer is a thin wrapper over unsafe.Pointer.
|
||||
// We only rely on comparability of pointers; we cannot rely on uintptr since
|
||||
// that would break if Go ever switched to a moving GC.
|
||||
type pointer struct{ p unsafe.Pointer }
|
||||
|
||||
func pointerOf(v addressableValue) pointer {
|
||||
return pointer{unsafe.Pointer(v.Value.Pointer())}
|
||||
}
|
||||
|
||||
// hashType hashes a reflect.Type.
|
||||
// The hash is only consistent within the lifetime of a program.
|
||||
func (h *hasher) hashType(t reflect.Type) {
|
||||
// This approach relies on reflect.Type always being backed by a unique
|
||||
// *reflect.rtype pointer. A safer approach is to use a global sync.Map
|
||||
// that maps reflect.Type to some arbitrary and unique index.
|
||||
// While safer, it requires global state with memory that can never be GC'd.
|
||||
rtypeAddr := reflect.ValueOf(t).Pointer() // address of *reflect.rtype
|
||||
h.HashUint64(uint64(rtypeAddr))
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
qt "github.com/frankban/quicktest"
|
||||
"go4.org/mem"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -94,6 +95,12 @@ func TestHash(t *testing.T) {
|
||||
{in: tuple{scalars{F64: float64(math.NaN())}, scalars{F64: float64(math.NaN())}}, wantEq: true},
|
||||
{in: tuple{scalars{C64: 32 + 32i}, scalars{C64: complex(math.Nextafter32(32, 0), 32)}}, wantEq: false},
|
||||
{in: tuple{scalars{C128: 64 + 64i}, scalars{C128: complex(math.Nextafter(64, 0), 64)}}, wantEq: false},
|
||||
{in: tuple{[]int(nil), []int(nil)}, wantEq: true},
|
||||
{in: tuple{[]int{}, []int(nil)}, wantEq: false},
|
||||
{in: tuple{[]int{}, []int{}}, wantEq: true},
|
||||
{in: tuple{[]string(nil), []string(nil)}, wantEq: true},
|
||||
{in: tuple{[]string{}, []string(nil)}, wantEq: false},
|
||||
{in: tuple{[]string{}, []string{}}, wantEq: true},
|
||||
{in: tuple{[]appendBytes{{}, {0, 0, 0, 0, 0, 0, 0, 1}}, []appendBytes{{}, {0, 0, 0, 0, 0, 0, 0, 1}}}, wantEq: true},
|
||||
{in: tuple{[]appendBytes{{}, {0, 0, 0, 0, 0, 0, 0, 1}}, []appendBytes{{0, 0, 0, 0, 0, 0, 0, 1}, {}}}, wantEq: false},
|
||||
{in: tuple{iface{MyBool(true)}, iface{MyBool(true)}}, wantEq: true},
|
||||
@@ -159,7 +166,7 @@ func TestHash(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
gotEq := Hash(tt.in[0]) == Hash(tt.in[1])
|
||||
gotEq := Hash(&tt.in[0]) == Hash(&tt.in[1])
|
||||
if gotEq != tt.wantEq {
|
||||
t.Errorf("(Hash(%T %v) == Hash(%T %v)) = %v, want %v", tt.in[0], tt.in[0], tt.in[1], tt.in[1], gotEq, tt.wantEq)
|
||||
}
|
||||
@@ -170,11 +177,11 @@ func TestDeepHash(t *testing.T) {
|
||||
// v contains the types of values we care about for our current callers.
|
||||
// Mostly we're just testing that we don't panic on handled types.
|
||||
v := getVal()
|
||||
|
||||
hash1 := Hash(v)
|
||||
t.Logf("hash: %v", hash1)
|
||||
for i := 0; i < 20; i++ {
|
||||
hash2 := Hash(getVal())
|
||||
v := getVal()
|
||||
hash2 := Hash(v)
|
||||
if hash1 != hash2 {
|
||||
t.Error("second hash didn't match")
|
||||
}
|
||||
@@ -185,7 +192,7 @@ func TestDeepHash(t *testing.T) {
|
||||
func TestIssue4868(t *testing.T) {
|
||||
m1 := map[int]string{1: "foo"}
|
||||
m2 := map[int]string{1: "bar"}
|
||||
if Hash(m1) == Hash(m2) {
|
||||
if Hash(&m1) == Hash(&m2) {
|
||||
t.Error("bogus")
|
||||
}
|
||||
}
|
||||
@@ -193,7 +200,7 @@ func TestIssue4868(t *testing.T) {
|
||||
func TestIssue4871(t *testing.T) {
|
||||
m1 := map[string]string{"": "", "x": "foo"}
|
||||
m2 := map[string]string{}
|
||||
if h1, h2 := Hash(m1), Hash(m2); h1 == h2 {
|
||||
if h1, h2 := Hash(&m1), Hash(&m2); h1 == h2 {
|
||||
t.Errorf("bogus: h1=%x, h2=%x", h1, h2)
|
||||
}
|
||||
}
|
||||
@@ -201,7 +208,7 @@ func TestIssue4871(t *testing.T) {
|
||||
func TestNilVsEmptymap(t *testing.T) {
|
||||
m1 := map[string]string(nil)
|
||||
m2 := map[string]string{}
|
||||
if h1, h2 := Hash(m1), Hash(m2); h1 == h2 {
|
||||
if h1, h2 := Hash(&m1), Hash(&m2); h1 == h2 {
|
||||
t.Errorf("bogus: h1=%x, h2=%x", h1, h2)
|
||||
}
|
||||
}
|
||||
@@ -209,7 +216,7 @@ func TestNilVsEmptymap(t *testing.T) {
|
||||
func TestMapFraming(t *testing.T) {
|
||||
m1 := map[string]string{"foo": "", "fo": "o"}
|
||||
m2 := map[string]string{}
|
||||
if h1, h2 := Hash(m1), Hash(m2); h1 == h2 {
|
||||
if h1, h2 := Hash(&m1), Hash(&m2); h1 == h2 {
|
||||
t.Errorf("bogus: h1=%x, h2=%x", h1, h2)
|
||||
}
|
||||
}
|
||||
@@ -217,23 +224,25 @@ func TestMapFraming(t *testing.T) {
|
||||
func TestQuick(t *testing.T) {
|
||||
initSeed()
|
||||
err := quick.Check(func(v, w map[string]string) bool {
|
||||
return (Hash(v) == Hash(w)) == reflect.DeepEqual(v, w)
|
||||
return (Hash(&v) == Hash(&w)) == reflect.DeepEqual(v, w)
|
||||
}, &quick.Config{MaxCount: 1000, Rand: rand.New(rand.NewSource(int64(seed)))})
|
||||
if err != nil {
|
||||
t.Fatalf("seed=%v, err=%v", seed, err)
|
||||
}
|
||||
}
|
||||
|
||||
func getVal() any {
|
||||
return &struct {
|
||||
WGConfig *wgcfg.Config
|
||||
RouterConfig *router.Config
|
||||
MapFQDNAddrs map[dnsname.FQDN][]netip.Addr
|
||||
MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort
|
||||
MapDiscoPublics map[key.DiscoPublic]bool
|
||||
MapResponse *tailcfg.MapResponse
|
||||
FilterMatch filter.Match
|
||||
}{
|
||||
type tailscaleTypes struct {
|
||||
WGConfig *wgcfg.Config
|
||||
RouterConfig *router.Config
|
||||
MapFQDNAddrs map[dnsname.FQDN][]netip.Addr
|
||||
MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort
|
||||
MapDiscoPublics map[key.DiscoPublic]bool
|
||||
MapResponse *tailcfg.MapResponse
|
||||
FilterMatch filter.Match
|
||||
}
|
||||
|
||||
func getVal() *tailscaleTypes {
|
||||
return &tailscaleTypes{
|
||||
&wgcfg.Config{
|
||||
Name: "foo",
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{3: 3}).Unmap(), 5)},
|
||||
@@ -410,13 +419,13 @@ func TestGetTypeHasher(t *testing.T) {
|
||||
{
|
||||
name: "string_slice",
|
||||
val: []string{"foo", "bar"},
|
||||
out: "\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x03\x00\x00\x00\x00\x00\x00\x00bar",
|
||||
out: "\x01\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x03\x00\x00\x00\x00\x00\x00\x00bar",
|
||||
},
|
||||
{
|
||||
name: "int_slice",
|
||||
val: []int{1, 0, -1},
|
||||
out: "\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff",
|
||||
out32: "\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff",
|
||||
out: "\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff",
|
||||
out32: "\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff",
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
@@ -451,8 +460,8 @@ func TestGetTypeHasher(t *testing.T) {
|
||||
{
|
||||
name: "packet_filter",
|
||||
val: filterRules,
|
||||
out: "\x04\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x00\x00\x00\x00\x01\x00\x02\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04 \x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
out32: "\x04\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x01\x00\x02\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04 \x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
out: "\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x00\x00\x00\x00\x01\x00\x02\x00\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04 \x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00",
|
||||
out32: "\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x01\x00\x02\x00\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04 \x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00",
|
||||
},
|
||||
{
|
||||
name: "netip.Addr",
|
||||
@@ -566,19 +575,19 @@ func TestGetTypeHasher(t *testing.T) {
|
||||
{
|
||||
name: "tailcfg.Node",
|
||||
val: &tailcfg.Node{},
|
||||
out: "\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + u64(uint64(time.Time{}.Unix())) + u64(0) + u32(0) + u32(0) + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + u64(uint64(time.Time{}.Unix())) + u32(0) + u32(0) + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
out: "\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\tn\x88\xf1\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\tn\x88\xf1\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rv := reflect.ValueOf(tt.val)
|
||||
va := newAddressableValue(rv.Type())
|
||||
va := reflect.New(rv.Type()).Elem()
|
||||
va.Set(rv)
|
||||
fn := getTypeInfo(va.Type()).hasher()
|
||||
fn := lookupTypeHasher(va.Type())
|
||||
hb := &hashBuffer{Hash: sha256.New()}
|
||||
h := new(hasher)
|
||||
h.Block512.Hash = hb
|
||||
fn(h, va)
|
||||
fn(h, pointerOf(va.Addr()))
|
||||
const ptrSize = 32 << uintptr(^uintptr(0)>>63)
|
||||
if tt.out32 != "" && ptrSize == 32 {
|
||||
tt.out = tt.out32
|
||||
@@ -591,6 +600,138 @@ func TestGetTypeHasher(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceCycle(t *testing.T) {
|
||||
type S []S
|
||||
c := qt.New(t)
|
||||
|
||||
a := make(S, 1) // cylic graph of 1 node
|
||||
a[0] = a
|
||||
b := make(S, 1) // cylic graph of 1 node
|
||||
b[0] = b
|
||||
ha := Hash(&a)
|
||||
hb := Hash(&b)
|
||||
c.Assert(ha, qt.Equals, hb)
|
||||
|
||||
c1 := make(S, 1) // cyclic graph of 2 nodes
|
||||
c2 := make(S, 1) // cyclic graph of 2 nodes
|
||||
c1[0] = c2
|
||||
c2[0] = c1
|
||||
hc1 := Hash(&c1)
|
||||
hc2 := Hash(&c2)
|
||||
c.Assert(hc1, qt.Equals, hc2)
|
||||
c.Assert(ha, qt.Not(qt.Equals), hc1)
|
||||
c.Assert(hb, qt.Not(qt.Equals), hc2)
|
||||
|
||||
c3 := make(S, 1) // graph of 1 node pointing to cyclic graph of 2 nodes
|
||||
c3[0] = c1
|
||||
hc3 := Hash(&c3)
|
||||
c.Assert(hc1, qt.Not(qt.Equals), hc3)
|
||||
|
||||
c4 := make(S, 2) // cyclic graph of 3 nodes
|
||||
c5 := make(S, 2) // cyclic graph of 3 nodes
|
||||
c4[0] = nil
|
||||
c4[1] = c4
|
||||
c5[0] = c5
|
||||
c5[1] = nil
|
||||
hc4 := Hash(&c4)
|
||||
hc5 := Hash(&c5)
|
||||
c.Assert(hc4, qt.Not(qt.Equals), hc5) // cycle occurs through different indexes
|
||||
}
|
||||
|
||||
func TestMapCycle(t *testing.T) {
|
||||
type M map[string]M
|
||||
c := qt.New(t)
|
||||
|
||||
a := make(M) // cylic graph of 1 node
|
||||
a["self"] = a
|
||||
b := make(M) // cylic graph of 1 node
|
||||
b["self"] = b
|
||||
ha := Hash(&a)
|
||||
hb := Hash(&b)
|
||||
c.Assert(ha, qt.Equals, hb)
|
||||
|
||||
c1 := make(M) // cyclic graph of 2 nodes
|
||||
c2 := make(M) // cyclic graph of 2 nodes
|
||||
c1["peer"] = c2
|
||||
c2["peer"] = c1
|
||||
hc1 := Hash(&c1)
|
||||
hc2 := Hash(&c2)
|
||||
c.Assert(hc1, qt.Equals, hc2)
|
||||
c.Assert(ha, qt.Not(qt.Equals), hc1)
|
||||
c.Assert(hb, qt.Not(qt.Equals), hc2)
|
||||
|
||||
c3 := make(M) // graph of 1 node pointing to cyclic graph of 2 nodes
|
||||
c3["child"] = c1
|
||||
hc3 := Hash(&c3)
|
||||
c.Assert(hc1, qt.Not(qt.Equals), hc3)
|
||||
|
||||
c4 := make(M) // cyclic graph of 3 nodes
|
||||
c5 := make(M) // cyclic graph of 3 nodes
|
||||
c4["0"] = nil
|
||||
c4["1"] = c4
|
||||
c5["0"] = c5
|
||||
c5["1"] = nil
|
||||
hc4 := Hash(&c4)
|
||||
hc5 := Hash(&c5)
|
||||
c.Assert(hc4, qt.Not(qt.Equals), hc5) // cycle occurs through different keys
|
||||
}
|
||||
|
||||
func TestPointerCycle(t *testing.T) {
|
||||
type P *P
|
||||
c := qt.New(t)
|
||||
|
||||
a := new(P) // cyclic graph of 1 node
|
||||
*a = a
|
||||
b := new(P) // cyclic graph of 1 node
|
||||
*b = b
|
||||
ha := Hash(&a)
|
||||
hb := Hash(&b)
|
||||
c.Assert(ha, qt.Equals, hb)
|
||||
|
||||
c1 := new(P) // cyclic graph of 2 nodes
|
||||
c2 := new(P) // cyclic graph of 2 nodes
|
||||
*c1 = c2
|
||||
*c2 = c1
|
||||
hc1 := Hash(&c1)
|
||||
hc2 := Hash(&c2)
|
||||
c.Assert(hc1, qt.Equals, hc2)
|
||||
c.Assert(ha, qt.Not(qt.Equals), hc1)
|
||||
c.Assert(hb, qt.Not(qt.Equals), hc2)
|
||||
|
||||
c3 := new(P) // graph of 1 node pointing to cyclic graph of 2 nodes
|
||||
*c3 = c1
|
||||
hc3 := Hash(&c3)
|
||||
c.Assert(hc1, qt.Not(qt.Equals), hc3)
|
||||
}
|
||||
|
||||
func TestInterfaceCycle(t *testing.T) {
|
||||
type I struct{ v any }
|
||||
c := qt.New(t)
|
||||
|
||||
a := new(I) // cyclic graph of 1 node
|
||||
a.v = a
|
||||
b := new(I) // cyclic graph of 1 node
|
||||
b.v = b
|
||||
ha := Hash(&a)
|
||||
hb := Hash(&b)
|
||||
c.Assert(ha, qt.Equals, hb)
|
||||
|
||||
c1 := new(I) // cyclic graph of 2 nodes
|
||||
c2 := new(I) // cyclic graph of 2 nodes
|
||||
c1.v = c2
|
||||
c2.v = c1
|
||||
hc1 := Hash(&c1)
|
||||
hc2 := Hash(&c2)
|
||||
c.Assert(hc1, qt.Equals, hc2)
|
||||
c.Assert(ha, qt.Not(qt.Equals), hc1)
|
||||
c.Assert(hb, qt.Not(qt.Equals), hc2)
|
||||
|
||||
c3 := new(I) // graph of 1 node pointing to cyclic graph of 2 nodes
|
||||
c3.v = c1
|
||||
hc3 := Hash(&c3)
|
||||
c.Assert(hc1, qt.Not(qt.Equals), hc3)
|
||||
}
|
||||
|
||||
var sink Sum
|
||||
|
||||
func BenchmarkHash(b *testing.B) {
|
||||
@@ -647,9 +788,8 @@ var filterRules = []tailcfg.FilterRule{
|
||||
func BenchmarkHashPacketFilter(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
hash := HasherForType[*[]tailcfg.FilterRule]()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sink = hash(&filterRules)
|
||||
sink = Hash(&filterRules)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -662,14 +802,13 @@ func TestHashMapAcyclic(t *testing.T) {
|
||||
|
||||
hb := &hashBuffer{Hash: sha256.New()}
|
||||
|
||||
ti := getTypeInfo(reflect.TypeOf(m))
|
||||
|
||||
hash := lookupTypeHasher(reflect.TypeOf(m))
|
||||
for i := 0; i < 20; i++ {
|
||||
v := addressableValue{reflect.ValueOf(&m).Elem()}
|
||||
va := reflect.ValueOf(&m).Elem()
|
||||
hb.Reset()
|
||||
h := new(hasher)
|
||||
h.Block512.Hash = hb
|
||||
h.hashMap(v, ti, false)
|
||||
hash(h, pointerOf(va.Addr()))
|
||||
h.sum()
|
||||
if got[string(hb.B)] {
|
||||
continue
|
||||
@@ -689,9 +828,9 @@ func TestPrintArray(t *testing.T) {
|
||||
hb := &hashBuffer{Hash: sha256.New()}
|
||||
h := new(hasher)
|
||||
h.Block512.Hash = hb
|
||||
v := addressableValue{reflect.ValueOf(&x).Elem()}
|
||||
ti := getTypeInfo(v.Type())
|
||||
ti.hasher()(h, v)
|
||||
va := reflect.ValueOf(&x).Elem()
|
||||
hash := lookupTypeHasher(va.Type())
|
||||
hash(h, pointerOf(va.Addr()))
|
||||
h.sum()
|
||||
const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f"
|
||||
if got := hb.B; string(got) != want {
|
||||
@@ -707,15 +846,15 @@ func BenchmarkHashMapAcyclic(b *testing.B) {
|
||||
}
|
||||
|
||||
hb := &hashBuffer{Hash: sha256.New()}
|
||||
v := addressableValue{reflect.ValueOf(&m).Elem()}
|
||||
ti := getTypeInfo(v.Type())
|
||||
va := reflect.ValueOf(&m).Elem()
|
||||
hash := lookupTypeHasher(va.Type())
|
||||
|
||||
h := new(hasher)
|
||||
h.Block512.Hash = hb
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h.Reset()
|
||||
h.hashMap(v, ti, false)
|
||||
hash(h, pointerOf(va.Addr()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -731,7 +870,7 @@ func BenchmarkTailcfgNode(b *testing.B) {
|
||||
func TestExhaustive(t *testing.T) {
|
||||
seen := make(map[Sum]bool)
|
||||
for i := 0; i < 100000; i++ {
|
||||
s := Hash(i)
|
||||
s := Hash(&i)
|
||||
if seen[s] {
|
||||
t.Fatalf("hash collision %v", i)
|
||||
}
|
||||
|
||||
115
util/deephash/pointer.go
Normal file
115
util/deephash/pointer.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package deephash
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// unsafePointer is an untyped pointer.
|
||||
// It is the caller's responsibility to call operations on the correct type.
|
||||
//
|
||||
// This pointer only ever points to a small set of kinds or types:
|
||||
// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface,
|
||||
// or a pointer to memory that is directly hashable.
|
||||
//
|
||||
// Arrays are represented as pointers to the first element.
|
||||
// Structs are represented as pointers to the first field.
|
||||
// Slices are represented as pointers to a slice header.
|
||||
// Pointers are represented as pointers to a pointer.
|
||||
//
|
||||
// We do not support direct operations on maps and interfaces, and instead
|
||||
// rely on pointer.asValue to convert the pointer back to a reflect.Value.
|
||||
// Conversion of an unsafe.Pointer to reflect.Value guarantees that the
|
||||
// read-only flag in the reflect.Value is unpopulated, avoiding panics that may
|
||||
// othewise have occurred since the value was obtained from an unexported field.
|
||||
type unsafePointer struct{ p unsafe.Pointer }
|
||||
|
||||
func unsafePointerOf(v reflect.Value) unsafePointer {
|
||||
return unsafePointer{v.UnsafePointer()}
|
||||
}
|
||||
func (p unsafePointer) isNil() bool {
|
||||
return p.p == nil
|
||||
}
|
||||
|
||||
// pointerElem dereferences a pointer.
|
||||
// p must point to a pointer.
|
||||
func (p unsafePointer) pointerElem() unsafePointer {
|
||||
return unsafePointer{*(*unsafe.Pointer)(p.p)}
|
||||
}
|
||||
|
||||
// sliceLen returns the slice length.
|
||||
// p must point to a slice.
|
||||
func (p unsafePointer) sliceLen() int {
|
||||
return (*reflect.SliceHeader)(p.p).Len
|
||||
}
|
||||
|
||||
// sliceArray returns a pointer to the underlying slice array.
|
||||
// p must point to a slice.
|
||||
func (p unsafePointer) sliceArray() unsafePointer {
|
||||
return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)}
|
||||
}
|
||||
|
||||
// arrayIndex returns a pointer to an element in the array.
|
||||
// p must point to an array.
|
||||
func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer {
|
||||
return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)}
|
||||
}
|
||||
|
||||
// structField returns a pointer to a field in a struct.
|
||||
// p must pointer to a struct.
|
||||
func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer {
|
||||
return unsafePointer{unsafe.Add(p.p, offset)}
|
||||
}
|
||||
|
||||
// asString casts p as a *string.
|
||||
func (p unsafePointer) asString() *string {
|
||||
return (*string)(p.p)
|
||||
}
|
||||
|
||||
// asTime casts p as a *time.Time.
|
||||
func (p unsafePointer) asTime() *time.Time {
|
||||
return (*time.Time)(p.p)
|
||||
}
|
||||
|
||||
// asAddr casts p as a *netip.Addr.
|
||||
func (p unsafePointer) asAddr() *netip.Addr {
|
||||
return (*netip.Addr)(p.p)
|
||||
}
|
||||
|
||||
// asValue casts p as a reflect.Value containing a pointer to value of t.
|
||||
func (p unsafePointer) asValue(typ reflect.Type) reflect.Value {
|
||||
return reflect.NewAt(typ, p.p)
|
||||
}
|
||||
|
||||
// asMemory returns the memory pointer at by p for a specified size.
|
||||
func (p unsafePointer) asMemory(size uintptr) []byte {
|
||||
return unsafe.Slice((*byte)(p.p), size)
|
||||
}
|
||||
|
||||
// visitStack is a stack of pointers visited.
|
||||
// Pointers are pushed onto the stack when visited, and popped when leaving.
|
||||
// The integer value is the depth at which the pointer was visited.
|
||||
// The length of this stack should be zero after every hashing operation.
|
||||
type visitStack map[unsafe.Pointer]int
|
||||
|
||||
func (v visitStack) seen(p unsafe.Pointer) (int, bool) {
|
||||
idx, ok := v[p]
|
||||
return idx, ok
|
||||
}
|
||||
|
||||
func (v *visitStack) push(p unsafe.Pointer) {
|
||||
if *v == nil {
|
||||
*v = make(map[unsafe.Pointer]int)
|
||||
}
|
||||
(*v)[p] = len(*v)
|
||||
}
|
||||
|
||||
func (v visitStack) pop(p unsafe.Pointer) {
|
||||
delete(v, p)
|
||||
}
|
||||
14
util/deephash/pointer_norace.go
Normal file
14
util/deephash/pointer_norace.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !race
|
||||
|
||||
package deephash
|
||||
|
||||
import "reflect"
|
||||
|
||||
type pointer = unsafePointer
|
||||
|
||||
// pointerOf returns a pointer from v, which must be a reflect.Pointer.
|
||||
func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) }
|
||||
100
util/deephash/pointer_race.go
Normal file
100
util/deephash/pointer_race.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build race
|
||||
|
||||
package deephash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// pointer is a typed pointer that performs safety checks for every operation.
|
||||
type pointer struct {
|
||||
unsafePointer
|
||||
t reflect.Type // type of pointed-at value; may be nil
|
||||
n uintptr // size of valid memory after p
|
||||
}
|
||||
|
||||
// pointerOf returns a pointer from v, which must be a reflect.Pointer.
|
||||
func pointerOf(v reflect.Value) pointer {
|
||||
assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind())
|
||||
te := v.Type().Elem()
|
||||
return pointer{unsafePointerOf(v), te, te.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) pointerElem() pointer {
|
||||
assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind())
|
||||
te := p.t.Elem()
|
||||
return pointer{p.unsafePointer.pointerElem(), te, te.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) sliceLen() int {
|
||||
assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind())
|
||||
return p.unsafePointer.sliceLen()
|
||||
}
|
||||
|
||||
func (p pointer) sliceArray() pointer {
|
||||
assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind())
|
||||
n := p.sliceLen()
|
||||
assert(n >= 0, "got negative slice length %d", n)
|
||||
ta := reflect.ArrayOf(n, p.t.Elem())
|
||||
return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) arrayIndex(index int, size uintptr) pointer {
|
||||
assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind())
|
||||
assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index)
|
||||
assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size)
|
||||
te := p.t.Elem()
|
||||
return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) structField(index int, offset, size uintptr) pointer {
|
||||
assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind())
|
||||
assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset)
|
||||
assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size)
|
||||
if index < 0 {
|
||||
return pointer{p.unsafePointer.structField(index, offset, size), nil, size}
|
||||
}
|
||||
sf := p.t.Field(index)
|
||||
t := sf.Type
|
||||
assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset)
|
||||
assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size)
|
||||
return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) asString() *string {
|
||||
assert(p.t.Kind() == reflect.String, "got %v, want string", p.t)
|
||||
return p.unsafePointer.asString()
|
||||
}
|
||||
|
||||
func (p pointer) asTime() *time.Time {
|
||||
assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType)
|
||||
return p.unsafePointer.asTime()
|
||||
}
|
||||
|
||||
func (p pointer) asAddr() *netip.Addr {
|
||||
assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType)
|
||||
return p.unsafePointer.asAddr()
|
||||
}
|
||||
|
||||
func (p pointer) asValue(typ reflect.Type) reflect.Value {
|
||||
assert(p.t == typ, "got %v, want %v", p.t, typ)
|
||||
return p.unsafePointer.asValue(typ)
|
||||
}
|
||||
|
||||
func (p pointer) asMemory(size uintptr) []byte {
|
||||
assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size)
|
||||
return p.unsafePointer.asMemory(size)
|
||||
}
|
||||
|
||||
func assert(b bool, f string, a ...any) {
|
||||
if !b {
|
||||
panic(fmt.Sprintf(f, a...))
|
||||
}
|
||||
}
|
||||
@@ -6,60 +6,58 @@
|
||||
// It is similar to the unix command uniq.
|
||||
package uniq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type badTypeError struct {
|
||||
typ reflect.Type
|
||||
}
|
||||
|
||||
func (e badTypeError) Error() string {
|
||||
return fmt.Sprintf("uniq.ModifySlice's first argument must have type *[]T, got %v", e.typ)
|
||||
}
|
||||
|
||||
// ModifySlice removes adjacent duplicate elements from the slice pointed to by sliceptr.
|
||||
// It adjusts the length of the slice appropriately and zeros the tail.
|
||||
// eq reports whether (*sliceptr)[i] and (*sliceptr)[j] are equal.
|
||||
// ModifySlice does O(len(*sliceptr)) operations.
|
||||
func ModifySlice(sliceptr any, eq func(i, j int) bool) {
|
||||
rvp := reflect.ValueOf(sliceptr)
|
||||
if rvp.Type().Kind() != reflect.Ptr {
|
||||
panic(badTypeError{rvp.Type()})
|
||||
}
|
||||
rv := rvp.Elem()
|
||||
if rv.Type().Kind() != reflect.Slice {
|
||||
panic(badTypeError{rvp.Type()})
|
||||
}
|
||||
|
||||
length := rv.Len()
|
||||
// ModifySlice removes adjacent duplicate elements from the given slice. It
|
||||
// adjusts the length of the slice appropriately and zeros the tail.
|
||||
//
|
||||
// ModifySlice does O(len(*slice)) operations.
|
||||
func ModifySlice[E comparable](slice *[]E) {
|
||||
// Remove duplicates
|
||||
dst := 0
|
||||
for i := 1; i < length; i++ {
|
||||
if eq(dst, i) {
|
||||
for i := 1; i < len(*slice); i++ {
|
||||
if (*slice)[i] == (*slice)[dst] {
|
||||
continue
|
||||
}
|
||||
dst++
|
||||
// slice[dst] = slice[i]
|
||||
rv.Index(dst).Set(rv.Index(i))
|
||||
(*slice)[dst] = (*slice)[i]
|
||||
}
|
||||
|
||||
// Zero out the elements we removed at the end of the slice
|
||||
end := dst + 1
|
||||
var zero reflect.Value
|
||||
if end < length {
|
||||
zero = reflect.Zero(rv.Type().Elem())
|
||||
var zero E
|
||||
for i := end; i < len(*slice); i++ {
|
||||
(*slice)[i] = zero
|
||||
}
|
||||
|
||||
// for i := range slice[end:] {
|
||||
// size[i] = 0/nil/{}
|
||||
// }
|
||||
for i := end; i < length; i++ {
|
||||
// slice[i] = 0/nil/{}
|
||||
rv.Index(i).Set(zero)
|
||||
}
|
||||
|
||||
// slice = slice[:end]
|
||||
if end < length {
|
||||
rv.SetLen(end)
|
||||
// Truncate the slice
|
||||
if end < len(*slice) {
|
||||
*slice = (*slice)[:end]
|
||||
}
|
||||
}
|
||||
|
||||
// ModifySliceFunc is the same as ModifySlice except that it allows using a
|
||||
// custom comparison function.
|
||||
//
|
||||
// eq should report whether the two provided elements are equal.
|
||||
func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) {
|
||||
// Remove duplicates
|
||||
dst := 0
|
||||
for i := 1; i < len(*slice); i++ {
|
||||
if eq((*slice)[dst], (*slice)[i]) {
|
||||
continue
|
||||
}
|
||||
dst++
|
||||
(*slice)[dst] = (*slice)[i]
|
||||
}
|
||||
|
||||
// Zero out the elements we removed at the end of the slice
|
||||
end := dst + 1
|
||||
var zero E
|
||||
for i := end; i < len(*slice); i++ {
|
||||
(*slice)[i] = zero
|
||||
}
|
||||
|
||||
// Truncate the slice
|
||||
if end < len(*slice) {
|
||||
*slice = (*slice)[:end]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,24 +12,25 @@ import (
|
||||
"tailscale.com/util/uniq"
|
||||
)
|
||||
|
||||
func TestModifySlice(t *testing.T) {
|
||||
func runTests(t *testing.T, cb func(*[]uint32)) {
|
||||
tests := []struct {
|
||||
in []int
|
||||
want []int
|
||||
// Use uint32 to be different from an int-typed slice index
|
||||
in []uint32
|
||||
want []uint32
|
||||
}{
|
||||
{in: []int{0, 1, 2}, want: []int{0, 1, 2}},
|
||||
{in: []int{0, 1, 2, 2}, want: []int{0, 1, 2}},
|
||||
{in: []int{0, 0, 1, 2}, want: []int{0, 1, 2}},
|
||||
{in: []int{0, 1, 0, 2}, want: []int{0, 1, 0, 2}},
|
||||
{in: []int{0}, want: []int{0}},
|
||||
{in: []int{0, 0}, want: []int{0}},
|
||||
{in: []int{}, want: []int{}},
|
||||
{in: []uint32{0, 1, 2}, want: []uint32{0, 1, 2}},
|
||||
{in: []uint32{0, 1, 2, 2}, want: []uint32{0, 1, 2}},
|
||||
{in: []uint32{0, 0, 1, 2}, want: []uint32{0, 1, 2}},
|
||||
{in: []uint32{0, 1, 0, 2}, want: []uint32{0, 1, 0, 2}},
|
||||
{in: []uint32{0}, want: []uint32{0}},
|
||||
{in: []uint32{0, 0}, want: []uint32{0}},
|
||||
{in: []uint32{}, want: []uint32{}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
in := make([]int, len(test.in))
|
||||
in := make([]uint32, len(test.in))
|
||||
copy(in, test.in)
|
||||
uniq.ModifySlice(&test.in, func(i, j int) bool { return test.in[i] == test.in[j] })
|
||||
cb(&test.in)
|
||||
if !reflect.DeepEqual(test.in, test.want) {
|
||||
t.Errorf("uniq.Slice(%v) = %v, want %v", in, test.in, test.want)
|
||||
}
|
||||
@@ -43,6 +44,20 @@ func TestModifySlice(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifySlice(t *testing.T) {
|
||||
runTests(t, func(slice *[]uint32) {
|
||||
uniq.ModifySlice(slice)
|
||||
})
|
||||
}
|
||||
|
||||
func TestModifySliceFunc(t *testing.T) {
|
||||
runTests(t, func(slice *[]uint32) {
|
||||
uniq.ModifySliceFunc(slice, func(i, j uint32) bool {
|
||||
return i == j
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func Benchmark(b *testing.B) {
|
||||
benches := []struct {
|
||||
name string
|
||||
@@ -83,6 +98,6 @@ func benchmark(b *testing.B, size int64, reset func(s []byte)) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
s = s[:size]
|
||||
reset(s)
|
||||
uniq.ModifySlice(&s, func(i, j int) bool { return s[i] == s[j] })
|
||||
uniq.ModifySlice(&s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,8 @@ var Long = ""
|
||||
|
||||
// Short is a short version number for this build, of the form
|
||||
// "x.y.z" for builds stamped in the usual way (see
|
||||
// build_dist.sh in the root) or, for binaries built by hand with the // go tool, it's like Long's dev form, but ending at the date part,
|
||||
// build_dist.sh in the root) or, for binaries built by hand with the
|
||||
// go tool, it's like Long's dev form, but ending at the date part,
|
||||
// of the form "1.23.0-dev20220316".
|
||||
var Short = ""
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -60,6 +61,16 @@ import (
|
||||
"tailscale.com/wgengine/monitor"
|
||||
)
|
||||
|
||||
const (
|
||||
// These are disco.Magic in big-endian form, 4 then 2 bytes. The
|
||||
// BPF filters need the magic in this format to match on it. Used
|
||||
// only in magicsock_linux.go, but defined here so that the test
|
||||
// which verifies this is the correct magic doesn't also need a
|
||||
// _linux variant.
|
||||
discoMagic1 = 0x5453f09f
|
||||
discoMagic2 = 0x92ac
|
||||
)
|
||||
|
||||
// useDerpRoute reports whether magicsock should enable the DERP
|
||||
// return path optimization (Issue 150).
|
||||
func useDerpRoute() bool {
|
||||
@@ -254,6 +265,12 @@ type Conn struct {
|
||||
pconn4 *RebindingUDPConn
|
||||
pconn6 *RebindingUDPConn
|
||||
|
||||
// closeDisco4 and closeDisco6 are io.Closers to shut down the raw
|
||||
// disco packet receivers. If nil, no raw disco receiver is
|
||||
// running for the given family.
|
||||
closeDisco4 io.Closer
|
||||
closeDisco6 io.Closer
|
||||
|
||||
// netChecker is the prober that discovers local network
|
||||
// conditions, including the closest DERP relay and NAT mappings.
|
||||
netChecker *netcheck.Client
|
||||
@@ -572,6 +589,19 @@ func NewConn(opts Options) (*Conn, error) {
|
||||
|
||||
c.ignoreSTUNPackets()
|
||||
|
||||
if d4, err := c.listenRawDisco("ip4"); err == nil {
|
||||
c.logf("[v1] using BPF disco receiver for IPv4")
|
||||
c.closeDisco4 = d4
|
||||
} else {
|
||||
c.logf("[v1] couldn't create raw v4 disco listener, using regular listener instead: %v", err)
|
||||
}
|
||||
if d6, err := c.listenRawDisco("ip6"); err == nil {
|
||||
c.logf("[v1] using BPF disco receiver for IPv6")
|
||||
c.closeDisco6 = d6
|
||||
} else {
|
||||
c.logf("[v1] couldn't create raw v6 disco listener, using regular listener instead: %v", err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -1638,7 +1668,7 @@ func (c *Conn) receiveIPv6(b []byte) (int, conn.Endpoint, error) {
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint6); ok {
|
||||
if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint6, c.closeDisco6 == nil); ok {
|
||||
metricRecvDataIPv6.Add(1)
|
||||
return n, ep, nil
|
||||
}
|
||||
@@ -1654,7 +1684,7 @@ func (c *Conn) receiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint4); ok {
|
||||
if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint4, c.closeDisco4 == nil); ok {
|
||||
metricRecvDataIPv4.Add(1)
|
||||
return n, ep, nil
|
||||
}
|
||||
@@ -1665,12 +1695,18 @@ func (c *Conn) receiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
//
|
||||
// ok is whether this read should be reported up to wireguard-go (our
|
||||
// caller).
|
||||
func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) (ep *endpoint, ok bool) {
|
||||
func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache, checkDisco bool) (ep *endpoint, ok bool) {
|
||||
if stun.Is(b) {
|
||||
c.stunReceiveFunc.Load()(b, ipp)
|
||||
return nil, false
|
||||
}
|
||||
if c.handleDiscoMessage(b, ipp, key.NodePublic{}) {
|
||||
if checkDisco {
|
||||
if c.handleDiscoMessage(b, ipp, key.NodePublic{}) {
|
||||
return nil, false
|
||||
}
|
||||
} else if disco.LooksLikeDiscoWrapper(b) {
|
||||
// Caller told us to ignore disco traffic, don't let it fall
|
||||
// through to wireguard-go.
|
||||
return nil, false
|
||||
}
|
||||
if !c.havePrivateKey.Load() {
|
||||
@@ -2094,13 +2130,11 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netip.AddrPort, de *endpoint) {
|
||||
|
||||
if !c.lastEndpointsTime.After(time.Now().Add(-endpointsFreshEnoughDuration)) {
|
||||
c.logf("[v1] magicsock: want call-me-maybe but endpoints stale; restunning")
|
||||
if c.onEndpointRefreshed == nil {
|
||||
c.onEndpointRefreshed = map[*endpoint]func(){}
|
||||
}
|
||||
c.onEndpointRefreshed[de] = func() {
|
||||
|
||||
mak.Set(&c.onEndpointRefreshed, de, func() {
|
||||
c.logf("[v1] magicsock: STUN done; sending call-me-maybe to %v %v", de.discoShort, de.publicKey.ShortString())
|
||||
c.enqueueCallMeMaybe(derpAddr, de)
|
||||
}
|
||||
})
|
||||
// TODO(bradfitz): make a new 'reSTUNQuickly' method
|
||||
// that passes down a do-a-lite-netcheck flag down to
|
||||
// netcheck that does 1 (or 2 max) STUN queries
|
||||
@@ -2632,6 +2666,12 @@ func (c *connBind) Close() error {
|
||||
if c.pconn6 != nil {
|
||||
c.pconn6.Close()
|
||||
}
|
||||
if c.closeDisco4 != nil {
|
||||
c.closeDisco4.Close()
|
||||
}
|
||||
if c.closeDisco6 != nil {
|
||||
c.closeDisco6.Close()
|
||||
}
|
||||
// Send an empty read result to unblock receiveDERP,
|
||||
// which will then check connBind.Closed.
|
||||
// connBind.Closed takes c.mu, but c.derpRecvCh is buffered.
|
||||
@@ -2843,7 +2883,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
|
||||
}
|
||||
ports = append(ports, 0)
|
||||
// Remove duplicates. (All duplicates are consecutive.)
|
||||
uniq.ModifySlice(&ports, func(i, j int) bool { return ports[i] == ports[j] })
|
||||
uniq.ModifySlice(&ports)
|
||||
|
||||
var pconn nettype.PacketConn
|
||||
for _, port := range ports {
|
||||
@@ -4192,4 +4232,8 @@ var (
|
||||
// metricDERPHomeChange is how many times our DERP home region DI has
|
||||
// changed from non-zero to a different non-zero.
|
||||
metricDERPHomeChange = clientmetric.NewCounter("derp_home_change")
|
||||
|
||||
// Disco packets received bpf read path
|
||||
metricRecvDiscoPacketIPv4 = clientmetric.NewCounter("magicsock_disco_recv_bpf_ipv4")
|
||||
metricRecvDiscoPacketIPv6 = clientmetric.NewCounter("magicsock_disco_recv_bpf_ipv6")
|
||||
)
|
||||
|
||||
17
wgengine/magicsock/magicsock_default.go
Normal file
17
wgengine/magicsock/magicsock_default.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
|
||||
return nil, errors.New("raw disco listening not supported on this OS")
|
||||
}
|
||||
260
wgengine/magicsock/magicsock_linux.go
Normal file
260
wgengine/magicsock/magicsock_linux.go
Normal file
@@ -0,0 +1,260 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/bpf"
|
||||
"golang.org/x/sys/unix"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const (
|
||||
udpHeaderSize = 8
|
||||
ipv6FragmentHeaderSize = 8
|
||||
)
|
||||
|
||||
// Enable/disable using raw sockets to receive disco traffic.
|
||||
var debugDisableRawDisco = envknob.Bool("TS_DEBUG_DISABLE_RAW_DISCO")
|
||||
|
||||
// These are our BPF filters that we use for testing packets.
|
||||
var (
|
||||
magicsockFilterV4 = []bpf.Instruction{
|
||||
// For raw UDPv4 sockets, BPF receives the entire IP packet to
|
||||
// inspect.
|
||||
|
||||
// Disco packets are so small they should never get
|
||||
// fragmented, and we don't want to handle reassembly.
|
||||
bpf.LoadAbsolute{Off: 6, Size: 2},
|
||||
// More Fragments bit set means this is part of a fragmented packet.
|
||||
bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x2000, SkipTrue: 7, SkipFalse: 0},
|
||||
// Non-zero fragment offset with MF=0 means this is the last
|
||||
// fragment of packet.
|
||||
bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x1fff, SkipTrue: 6, SkipFalse: 0},
|
||||
|
||||
// Load IP header length into X register.
|
||||
bpf.LoadMemShift{Off: 0},
|
||||
|
||||
// Get the first 4 bytes of the UDP packet, compare with our magic number
|
||||
bpf.LoadIndirect{Off: udpHeaderSize, Size: 4},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3},
|
||||
|
||||
// Compare the next 2 bytes
|
||||
bpf.LoadIndirect{Off: udpHeaderSize + 4, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(discoMagic2), SkipTrue: 0, SkipFalse: 1},
|
||||
|
||||
// Accept the whole packet
|
||||
bpf.RetConstant{Val: 0xFFFFFFFF},
|
||||
|
||||
// Skip the packet
|
||||
bpf.RetConstant{Val: 0x0},
|
||||
}
|
||||
|
||||
// IPv6 is more complicated to filter, since we can have 0-to-N
|
||||
// extension headers following the IPv6 header. Since BPF can't
|
||||
// loop, we can't really parse these in a general way; instead, we
|
||||
// simply handle the case where we have no extension headers; any
|
||||
// packets with headers will be skipped. IPv6 extension headers
|
||||
// are sufficiently uncommon that we're willing to accept false
|
||||
// negatives here.
|
||||
//
|
||||
// The "proper" way to handle this would be to do minimal parsing in
|
||||
// BPF and more in-depth parsing of all IPv6 packets in userspace, but
|
||||
// on systems with a high volume of UDP that would be unacceptably slow
|
||||
// and thus we'd rather be conservative here and possibly not receive
|
||||
// disco packets rather than slow down the system.
|
||||
magicsockFilterV6 = []bpf.Instruction{
|
||||
// For raw UDPv6 sockets, BPF receives _only_ the UDP header onwards, not an entire IP packet.
|
||||
//
|
||||
// https://stackoverflow.com/questions/24514333/using-bpf-with-sock-dgram-on-linux-machine
|
||||
// https://blog.cloudflare.com/epbf_sockets_hop_distance/
|
||||
//
|
||||
// This is especially confusing because this *isn't* true for
|
||||
// IPv4; see the following code from the 'ping' utility that
|
||||
// corroborates this:
|
||||
//
|
||||
// https://github.com/iputils/iputils/blob/1ab5fa/ping/ping.c#L1667-L1676
|
||||
// https://github.com/iputils/iputils/blob/1ab5fa/ping/ping6_common.c#L933-L941
|
||||
|
||||
// Compare with our magic number. Start by loading and
|
||||
// comparing the first 4 bytes of the UDP payload.
|
||||
bpf.LoadAbsolute{Off: udpHeaderSize, Size: 4},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3},
|
||||
|
||||
// Compare the next 2 bytes
|
||||
bpf.LoadAbsolute{Off: udpHeaderSize + 4, Size: 2},
|
||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic2, SkipTrue: 0, SkipFalse: 1},
|
||||
|
||||
// Accept the whole packet
|
||||
bpf.RetConstant{Val: 0xFFFFFFFF},
|
||||
|
||||
// Skip the packet
|
||||
bpf.RetConstant{Val: 0x0},
|
||||
}
|
||||
|
||||
testDiscoPacket = []byte{
|
||||
// Disco magic
|
||||
0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac,
|
||||
// Sender key
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
// Nonce
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
}
|
||||
)
|
||||
|
||||
// listenRawDisco starts listening for disco packets on the given
|
||||
// address family, which must be "ip4" or "ip6", using a raw socket
|
||||
// and BPF filter.
|
||||
// https://github.com/tailscale/tailscale/issues/3824
|
||||
func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
|
||||
if debugDisableRawDisco {
|
||||
return nil, errors.New("raw disco listening disabled by debug flag")
|
||||
}
|
||||
|
||||
var (
|
||||
network string
|
||||
addr string
|
||||
testAddr string
|
||||
prog []bpf.Instruction
|
||||
)
|
||||
switch family {
|
||||
case "ip4":
|
||||
network = "ip4:17"
|
||||
addr = "0.0.0.0"
|
||||
testAddr = "127.0.0.1:1"
|
||||
prog = magicsockFilterV4
|
||||
case "ip6":
|
||||
network = "ip6:17"
|
||||
addr = "::"
|
||||
testAddr = "[::1]:1"
|
||||
prog = magicsockFilterV6
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported address family %q", family)
|
||||
}
|
||||
|
||||
asm, err := bpf.Assemble(prog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("assembling filter: %w", err)
|
||||
}
|
||||
|
||||
pc, err := net.ListenPacket(network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating packet conn: %w", err)
|
||||
}
|
||||
|
||||
if err := setBPF(pc, asm); err != nil {
|
||||
pc.Close()
|
||||
return nil, fmt.Errorf("installing BPF filter: %w", err)
|
||||
}
|
||||
|
||||
// If all the above succeeds, we should be ready to receive. Just
|
||||
// out of paranoia, check that we do receive a well-formed disco
|
||||
// packet.
|
||||
tc, err := net.ListenPacket("udp", net.JoinHostPort(addr, "0"))
|
||||
if err != nil {
|
||||
pc.Close()
|
||||
return nil, fmt.Errorf("creating disco test socket: %w", err)
|
||||
}
|
||||
defer tc.Close()
|
||||
if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, netip.MustParseAddrPort(testAddr)); err != nil {
|
||||
pc.Close()
|
||||
return nil, fmt.Errorf("writing disco test packet: %w", err)
|
||||
}
|
||||
pc.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
var buf [1500]byte
|
||||
for {
|
||||
n, _, err := pc.ReadFrom(buf[:])
|
||||
if err != nil {
|
||||
pc.Close()
|
||||
return nil, fmt.Errorf("reading during raw disco self-test: %w", err)
|
||||
}
|
||||
if n < udpHeaderSize {
|
||||
continue
|
||||
}
|
||||
if !bytes.Equal(buf[udpHeaderSize:n], testDiscoPacket) {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
pc.SetReadDeadline(time.Time{})
|
||||
|
||||
go c.receiveDisco(pc)
|
||||
return pc, nil
|
||||
}
|
||||
|
||||
func (c *Conn) receiveDisco(pc net.PacketConn) {
|
||||
var buf [1500]byte
|
||||
for {
|
||||
n, src, err := pc.ReadFrom(buf[:])
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
} else if err != nil {
|
||||
c.logf("disco raw reader failed: %v", err)
|
||||
return
|
||||
}
|
||||
if n < udpHeaderSize {
|
||||
// Too small to be a valid UDP datagram, drop.
|
||||
continue
|
||||
}
|
||||
srcIP, ok := netip.AddrFromSlice(src.(*net.IPAddr).IP)
|
||||
if !ok {
|
||||
c.logf("[unexpected] PacketConn.ReadFrom returned not-an-IP %v in from", src)
|
||||
continue
|
||||
}
|
||||
srcPort := binary.BigEndian.Uint16(buf[:2])
|
||||
|
||||
if srcIP.Is4() {
|
||||
metricRecvDiscoPacketIPv4.Add(1)
|
||||
} else {
|
||||
metricRecvDiscoPacketIPv6.Add(1)
|
||||
}
|
||||
|
||||
c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{})
|
||||
}
|
||||
}
|
||||
|
||||
// setBPF installs filter as the BPF filter on conn.
|
||||
// Ideally we would just use SetBPF as implemented in x/net/ipv4,
|
||||
// but x/net/ipv6 doesn't implement it. And once you've written
|
||||
// this code once, it turns out to be address family agnostic, so
|
||||
// we might as well use it on both and get to use a net.PacketConn
|
||||
// directly for both families instead of being stuck with
|
||||
// different types.
|
||||
func setBPF(conn net.PacketConn, filter []bpf.RawInstruction) error {
|
||||
sc, err := conn.(*net.IPConn).SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prog := &unix.SockFprog{
|
||||
Len: uint16(len(filter)),
|
||||
Filter: (*unix.SockFilter)(unsafe.Pointer(&filter[0])),
|
||||
}
|
||||
var setErr error
|
||||
err = sc.Control(func(fd uintptr) {
|
||||
setErr = unix.SetsockoptSockFprog(int(fd), unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, prog)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if setErr != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/netaddr"
|
||||
"tailscale.com/net/stun/stuntest"
|
||||
@@ -1799,3 +1800,21 @@ func TestBlockForeverConnUnblocks(t *testing.T) {
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoMagicMatches(t *testing.T) {
|
||||
// Convert our disco magic number into a uint32 and uint16 to test
|
||||
// against. We panic on an incorrect length here rather than try to be
|
||||
// generic with our BPF instructions below.
|
||||
//
|
||||
// Note that BPF uses network byte order (big-endian) when loading data
|
||||
// from a packet, so that is what we use to generate our magic numbers.
|
||||
if len(disco.Magic) != 6 {
|
||||
t.Fatalf("expected disco.Magic to be of length 6")
|
||||
}
|
||||
if m1 := binary.BigEndian.Uint32([]byte(disco.Magic[:4])); m1 != discoMagic1 {
|
||||
t.Errorf("first 4 bytes of disco magic don't match, got %v want %v", discoMagic1, m1)
|
||||
}
|
||||
if m2 := binary.BigEndian.Uint16([]byte(disco.Magic[4:6])); m2 != discoMagic2 {
|
||||
t.Errorf("last 2 bytes of disco magic don't match, got %v want %v", discoMagic2, m2)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,13 +7,11 @@ package monitor
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/net/netaddr"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
@@ -133,7 +131,7 @@ func (m *winMon) Receive() (message, error) {
|
||||
// unicastAddressChanged is the callback we register with Windows to call when unicast address changes.
|
||||
func (m *winMon) unicastAddressChanged(_ winipcfg.MibNotificationType, row *winipcfg.MibUnicastIPAddressRow) {
|
||||
what := "addr"
|
||||
if ip, ok := netip.AddrFromSlice(row.Address.IP()); ok && tsaddr.IsTailscaleIP(ip.Unmap()) {
|
||||
if ip := row.Address.Addr(); ip.IsValid() && tsaddr.IsTailscaleIP(ip.Unmap()) {
|
||||
what = "tsaddr"
|
||||
}
|
||||
|
||||
@@ -144,8 +142,8 @@ func (m *winMon) unicastAddressChanged(_ winipcfg.MibNotificationType, row *wini
|
||||
// routeChanged is the callback we register with Windows to call when route changes.
|
||||
func (m *winMon) routeChanged(_ winipcfg.MibNotificationType, row *winipcfg.MibIPforwardRow2) {
|
||||
what := "route"
|
||||
ipn := row.DestinationPrefix.IPNet()
|
||||
if cidr, ok := netaddr.FromStdIPNet(&ipn); ok && tsaddr.IsTailscaleIP(cidr.Addr()) {
|
||||
ip := row.DestinationPrefix.Prefix().Addr().Unmap()
|
||||
if ip.IsValid() && tsaddr.IsTailscaleIP(ip) {
|
||||
what = "tsroute"
|
||||
}
|
||||
// start a goroutine to finish our work, to return to Windows out of this callback
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/flowtrack"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@@ -157,28 +156,8 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) {
|
||||
return
|
||||
}
|
||||
|
||||
// We don't care if this information is perfectly up-to-date, since
|
||||
// we're just using it to print debug information.
|
||||
//
|
||||
// In tailscale/coral#72, we see a goroutine profile with thousands of
|
||||
// goroutines blocked on the mutex in getStatus here, so we wrap it in
|
||||
// a singleflight and accept stale information to reduce contention.
|
||||
st, err, _ := e.getStatusSf.Do(struct{}{}, e.getStatus)
|
||||
|
||||
var ps *ipnstate.PeerStatusLite
|
||||
if err == nil {
|
||||
for _, v := range st.Peers {
|
||||
if v.NodeKey == n.Key {
|
||||
v := v // copy
|
||||
ps = &v
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
e.logf("open-conn-track: timeout opening %v to node %v; failed to get engine status: %v", flow, n.Key.ShortString(), err)
|
||||
return
|
||||
}
|
||||
if ps == nil {
|
||||
ps, found := e.getPeerStatusLite(n.Key)
|
||||
if !found {
|
||||
onlyZeroRoute := true // whether peerForIP returned n only because its /0 route matched
|
||||
for _, r := range n.AllowedIPs {
|
||||
if r.Bits() != 0 && r.Contains(flow.Dst.Addr()) {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -18,12 +17,12 @@ import (
|
||||
|
||||
ole "github.com/go-ole/go-ole"
|
||||
"go4.org/netipx"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/netaddr"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/util/multierr"
|
||||
"tailscale.com/wgengine/winnet"
|
||||
@@ -324,25 +323,23 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) (retErr error) {
|
||||
// ours where the nexthop is meaningless, you're supposed to use
|
||||
// one of the local IP addresses of the interface. Find an IPv4
|
||||
// and IPv6 address we can use for this purpose.
|
||||
var firstGateway4 *net.IP
|
||||
var firstGateway6 *net.IP
|
||||
addresses := make([]*net.IPNet, 0, len(cfg.LocalAddrs))
|
||||
var firstGateway4 netip.Addr
|
||||
var firstGateway6 netip.Addr
|
||||
addresses := make([]netip.Prefix, 0, len(cfg.LocalAddrs))
|
||||
for _, addr := range cfg.LocalAddrs {
|
||||
if (addr.Addr().Is4() && ipif4 == nil) || (addr.Addr().Is6() && ipif6 == nil) {
|
||||
// Can't program addresses for disabled protocol.
|
||||
continue
|
||||
}
|
||||
ipnet := netipx.PrefixIPNet(addr)
|
||||
addresses = append(addresses, ipnet)
|
||||
gateway := ipnet.IP
|
||||
if addr.Addr().Is4() && firstGateway4 == nil {
|
||||
firstGateway4 = &gateway
|
||||
} else if addr.Addr().Is6() && firstGateway6 == nil {
|
||||
firstGateway6 = &gateway
|
||||
addresses = append(addresses, addr)
|
||||
if addr.Addr().Is4() && !firstGateway4.IsValid() {
|
||||
firstGateway4 = addr.Addr()
|
||||
} else if addr.Addr().Is6() && !firstGateway6.IsValid() {
|
||||
firstGateway6 = addr.Addr()
|
||||
}
|
||||
}
|
||||
|
||||
var routes []winipcfg.RouteData
|
||||
var routes []*winipcfg.RouteData
|
||||
foundDefault4 := false
|
||||
foundDefault6 := false
|
||||
for _, route := range cfg.Routes {
|
||||
@@ -351,37 +348,33 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) (retErr error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Addr().Is6() && firstGateway6 == nil {
|
||||
if route.Addr().Is6() && !firstGateway6.IsValid() {
|
||||
// Windows won't let us set IPv6 routes without having an
|
||||
// IPv6 local address set. However, when we've configured
|
||||
// a default route, we want to forcibly grab IPv6 traffic
|
||||
// even if the v6 overlay network isn't configured. To do
|
||||
// that, we add a dummy local IPv6 address to serve as a
|
||||
// route source.
|
||||
ipnet := &net.IPNet{tsaddr.Tailscale4To6Placeholder().AsSlice(), net.CIDRMask(128, 128)}
|
||||
addresses = append(addresses, ipnet)
|
||||
firstGateway6 = &ipnet.IP
|
||||
} else if route.Addr().Is4() && firstGateway4 == nil {
|
||||
ip := tsaddr.Tailscale4To6Placeholder()
|
||||
addresses = append(addresses, netip.PrefixFrom(ip, ip.BitLen()))
|
||||
firstGateway6 = ip
|
||||
} else if route.Addr().Is4() && !firstGateway4.IsValid() {
|
||||
// TODO: do same dummy behavior as v6?
|
||||
return errors.New("due to a Windows limitation, one cannot have interface routes without an interface address")
|
||||
}
|
||||
|
||||
ipn := netipx.PrefixIPNet(route)
|
||||
var gateway net.IP
|
||||
var gateway netip.Addr
|
||||
if route.Addr().Is4() {
|
||||
gateway = *firstGateway4
|
||||
gateway = firstGateway4
|
||||
} else if route.Addr().Is6() {
|
||||
gateway = *firstGateway6
|
||||
gateway = firstGateway6
|
||||
}
|
||||
r := winipcfg.RouteData{
|
||||
Destination: net.IPNet{
|
||||
IP: ipn.IP.Mask(ipn.Mask),
|
||||
Mask: ipn.Mask,
|
||||
},
|
||||
NextHop: gateway,
|
||||
Metric: 0,
|
||||
r := &winipcfg.RouteData{
|
||||
Destination: route,
|
||||
NextHop: gateway,
|
||||
Metric: 0,
|
||||
}
|
||||
if net.IP.Equal(r.Destination.IP, gateway) {
|
||||
if r.Destination.Addr().Unmap() == gateway {
|
||||
// no need to add a route for the interface's
|
||||
// own IP. The kernel does that for us.
|
||||
// If we try to replace it, we'll fail to
|
||||
@@ -393,12 +386,12 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) (retErr error) {
|
||||
if route.Bits() == 0 {
|
||||
foundDefault4 = true
|
||||
}
|
||||
r.NextHop = *firstGateway4
|
||||
r.NextHop = firstGateway4
|
||||
} else if route.Addr().Is6() {
|
||||
if route.Bits() == 0 {
|
||||
foundDefault6 = true
|
||||
}
|
||||
r.NextHop = *firstGateway6
|
||||
r.NextHop = firstGateway6
|
||||
}
|
||||
routes = append(routes, r)
|
||||
}
|
||||
@@ -408,18 +401,16 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) (retErr error) {
|
||||
return fmt.Errorf("syncAddresses: %w", err)
|
||||
}
|
||||
|
||||
sort.Slice(routes, func(i, j int) bool { return routeLess(&routes[i], &routes[j]) })
|
||||
slices.SortFunc(routes, routeDataLess)
|
||||
|
||||
deduplicatedRoutes := []*winipcfg.RouteData{}
|
||||
for i := 0; i < len(routes); i++ {
|
||||
// There's only one way to get to a given IP+Mask, so delete
|
||||
// all matches after the first.
|
||||
if i > 0 &&
|
||||
net.IP.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
|
||||
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
|
||||
if i > 0 && routes[i].Destination == routes[i-1].Destination {
|
||||
continue
|
||||
}
|
||||
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
|
||||
deduplicatedRoutes = append(deduplicatedRoutes, routes[i])
|
||||
}
|
||||
|
||||
// Re-read interface after syncAddresses.
|
||||
@@ -484,28 +475,6 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) (retErr error) {
|
||||
return errAcc
|
||||
}
|
||||
|
||||
// routeLess reports whether ri should sort before rj.
|
||||
// The actual sort order doesn't appear to matter. The caller just
|
||||
// wants them sorted to be able to de-dup.
|
||||
func routeLess(ri, rj *winipcfg.RouteData) bool {
|
||||
if v := bytes.Compare(ri.Destination.IP, rj.Destination.IP); v != 0 {
|
||||
return v == -1
|
||||
}
|
||||
if v := bytes.Compare(ri.Destination.Mask, rj.Destination.Mask); v != 0 {
|
||||
// Narrower masks first
|
||||
return v == 1
|
||||
}
|
||||
if ri.Metric != rj.Metric {
|
||||
// Lower metrics first
|
||||
return ri.Metric < rj.Metric
|
||||
}
|
||||
if v := bytes.Compare(ri.NextHop, rj.NextHop); v != 0 {
|
||||
// No nexthop before non-empty nexthop.
|
||||
return v == -1
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// unwrapIP returns the shortest version of ip.
|
||||
func unwrapIP(ip net.IP) net.IP {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
@@ -521,40 +490,40 @@ func v4Mask(m net.IPMask) net.IPMask {
|
||||
return m
|
||||
}
|
||||
|
||||
func netCompare(a, b net.IPNet) int {
|
||||
aip, bip := unwrapIP(a.IP), unwrapIP(b.IP)
|
||||
v := bytes.Compare(aip, bip)
|
||||
func netCompare(a, b netip.Prefix) int {
|
||||
aip, bip := a.Addr().Unmap(), b.Addr().Unmap()
|
||||
v := aip.Compare(bip)
|
||||
if v != 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
amask, bmask := a.Mask, b.Mask
|
||||
if len(aip) == 4 {
|
||||
amask = v4Mask(a.Mask)
|
||||
bmask = v4Mask(b.Mask)
|
||||
if a.Bits() == b.Bits() {
|
||||
return 0
|
||||
}
|
||||
|
||||
// narrower first
|
||||
return -bytes.Compare(amask, bmask)
|
||||
if a.Bits() > b.Bits() {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func sortNets(a []*net.IPNet) {
|
||||
sort.Slice(a, func(i, j int) bool {
|
||||
return netCompare(*a[i], *a[j]) == -1
|
||||
func sortNets(s []netip.Prefix) {
|
||||
sort.Slice(s, func(i, j int) bool {
|
||||
return netCompare(s[i], s[j]) == -1
|
||||
})
|
||||
}
|
||||
|
||||
// deltaNets returns the changes to turn a into b.
|
||||
func deltaNets(a, b []*net.IPNet) (add, del []*net.IPNet) {
|
||||
add = make([]*net.IPNet, 0, len(b))
|
||||
del = make([]*net.IPNet, 0, len(a))
|
||||
func deltaNets(a, b []netip.Prefix) (add, del []netip.Prefix) {
|
||||
add = make([]netip.Prefix, 0, len(b))
|
||||
del = make([]netip.Prefix, 0, len(a))
|
||||
sortNets(a)
|
||||
sortNets(b)
|
||||
|
||||
i := 0
|
||||
j := 0
|
||||
for i < len(a) && j < len(b) {
|
||||
switch netCompare(*a[i], *b[j]) {
|
||||
switch netCompare(a[i], b[j]) {
|
||||
case -1:
|
||||
// a < b, delete
|
||||
del = append(del, a[i])
|
||||
@@ -576,28 +545,21 @@ func deltaNets(a, b []*net.IPNet) (add, del []*net.IPNet) {
|
||||
return
|
||||
}
|
||||
|
||||
func isIPv6LinkLocal(in *net.IPNet) bool {
|
||||
return len(in.IP) == 16 && in.IP.IsLinkLocalUnicast()
|
||||
func isIPv6LinkLocal(a netip.Prefix) bool {
|
||||
return a.Addr().Is6() && a.Addr().IsLinkLocalUnicast()
|
||||
}
|
||||
|
||||
// ipAdapterUnicastAddressToIPNet converts windows.IpAdapterUnicastAddress to net.IPNet.
|
||||
func ipAdapterUnicastAddressToIPNet(u *windows.IpAdapterUnicastAddress) *net.IPNet {
|
||||
ip := u.Address.IP()
|
||||
w := 32
|
||||
if ip.To4() == nil {
|
||||
w = 128
|
||||
}
|
||||
return &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(int(u.OnLinkPrefixLength), w),
|
||||
}
|
||||
// ipAdapterUnicastAddressToPrefix converts windows.IpAdapterUnicastAddress to netip.Prefix
|
||||
func ipAdapterUnicastAddressToPrefix(u *windows.IpAdapterUnicastAddress) netip.Prefix {
|
||||
ip, _ := netip.AddrFromSlice(u.Address.IP())
|
||||
return netip.PrefixFrom(ip.Unmap(), int(u.OnLinkPrefixLength))
|
||||
}
|
||||
|
||||
// unicastIPNets returns all unicast net.IPNet for ifc interface.
|
||||
func unicastIPNets(ifc *winipcfg.IPAdapterAddresses) []*net.IPNet {
|
||||
nets := make([]*net.IPNet, 0)
|
||||
func unicastIPNets(ifc *winipcfg.IPAdapterAddresses) []netip.Prefix {
|
||||
var nets []netip.Prefix
|
||||
for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next {
|
||||
nets = append(nets, ipAdapterUnicastAddressToIPNet(addr))
|
||||
nets = append(nets, ipAdapterUnicastAddressToPrefix(addr))
|
||||
}
|
||||
return nets
|
||||
}
|
||||
@@ -612,13 +574,13 @@ func unicastIPNets(ifc *winipcfg.IPAdapterAddresses) []*net.IPNet {
|
||||
// DNS locally or remotely and from being picked as a source address for
|
||||
// outgoing packets with unspecified sources. See #4647 and
|
||||
// https://web.archive.org/web/20200912120956/https://devblogs.microsoft.com/scripting/use-powershell-to-change-ip-behavior-with-skipassource/
|
||||
func syncAddresses(ifc *winipcfg.IPAdapterAddresses, want []*net.IPNet) error {
|
||||
func syncAddresses(ifc *winipcfg.IPAdapterAddresses, want []netip.Prefix) error {
|
||||
var erracc error
|
||||
|
||||
got := unicastIPNets(ifc)
|
||||
add, del := deltaNets(got, want)
|
||||
|
||||
ll := make([]*net.IPNet, 0)
|
||||
ll := make([]netip.Prefix, 0)
|
||||
for _, a := range del {
|
||||
// do not delete link-local addresses, and collect them for later
|
||||
// applying SkipAsSource.
|
||||
@@ -627,29 +589,29 @@ func syncAddresses(ifc *winipcfg.IPAdapterAddresses, want []*net.IPNet) error {
|
||||
continue
|
||||
}
|
||||
|
||||
err := ifc.LUID.DeleteIPAddress(*a)
|
||||
err := ifc.LUID.DeleteIPAddress(a)
|
||||
if err != nil {
|
||||
erracc = fmt.Errorf("deleting IP %q: %w", *a, err)
|
||||
erracc = fmt.Errorf("deleting IP %q: %w", a, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range add {
|
||||
err := ifc.LUID.AddIPAddress(*a)
|
||||
err := ifc.LUID.AddIPAddress(a)
|
||||
if err != nil {
|
||||
erracc = fmt.Errorf("adding IP %q: %w", *a, err)
|
||||
erracc = fmt.Errorf("adding IP %q: %w", a, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range ll {
|
||||
mib, err := ifc.LUID.IPAddress(a.IP)
|
||||
mib, err := ifc.LUID.IPAddress(a.Addr())
|
||||
if err != nil {
|
||||
erracc = fmt.Errorf("setting skip-as-source on IP %q: unable to retrieve MIB: %w", *a, err)
|
||||
erracc = fmt.Errorf("setting skip-as-source on IP %q: unable to retrieve MIB: %w", a, err)
|
||||
continue
|
||||
}
|
||||
if !mib.SkipAsSource {
|
||||
mib.SkipAsSource = true
|
||||
if err := mib.Set(); err != nil {
|
||||
erracc = fmt.Errorf("setting skip-as-source on IP %q: unable to set MIB: %w", *a, err)
|
||||
erracc = fmt.Errorf("setting skip-as-source on IP %q: unable to set MIB: %w", a, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -657,20 +619,27 @@ func syncAddresses(ifc *winipcfg.IPAdapterAddresses, want []*net.IPNet) error {
|
||||
return erracc
|
||||
}
|
||||
|
||||
func routeDataLess(a, b *winipcfg.RouteData) bool {
|
||||
return routeDataCompare(a, b) < 0
|
||||
}
|
||||
|
||||
func routeDataCompare(a, b *winipcfg.RouteData) int {
|
||||
v := bytes.Compare(a.Destination.IP, b.Destination.IP)
|
||||
v := a.Destination.Addr().Compare(b.Destination.Addr())
|
||||
if v != 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
// Narrower masks first
|
||||
v = bytes.Compare(a.Destination.Mask, b.Destination.Mask)
|
||||
if v != 0 {
|
||||
return -v
|
||||
b1, b2 := a.Destination.Bits(), b.Destination.Bits()
|
||||
if b1 != b2 {
|
||||
if b1 > b2 {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// No nexthop before non-empty nexthop
|
||||
v = bytes.Compare(a.NextHop, b.NextHop)
|
||||
v = a.NextHop.Compare(b.NextHop)
|
||||
if v != 0 {
|
||||
return v
|
||||
}
|
||||
@@ -685,17 +654,11 @@ func routeDataCompare(a, b *winipcfg.RouteData) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func sortRouteData(a []*winipcfg.RouteData) {
|
||||
sort.Slice(a, func(i, j int) bool {
|
||||
return routeDataCompare(a[i], a[j]) < 0
|
||||
})
|
||||
}
|
||||
|
||||
func deltaRouteData(a, b []*winipcfg.RouteData) (add, del []*winipcfg.RouteData) {
|
||||
add = make([]*winipcfg.RouteData, 0, len(b))
|
||||
del = make([]*winipcfg.RouteData, 0, len(a))
|
||||
sortRouteData(a)
|
||||
sortRouteData(b)
|
||||
slices.SortFunc(a, routeDataLess)
|
||||
slices.SortFunc(b, routeDataLess)
|
||||
|
||||
i := 0
|
||||
j := 0
|
||||
@@ -751,15 +714,15 @@ func getAllInterfaceRoutes(ifc *winipcfg.IPAdapterAddresses) ([]*winipcfg.RouteD
|
||||
rd := make([]*winipcfg.RouteData, 0, len(routes4)+len(routes6))
|
||||
for _, r := range routes4 {
|
||||
rd = append(rd, &winipcfg.RouteData{
|
||||
Destination: r.DestinationPrefix.IPNet(),
|
||||
NextHop: r.NextHop.IP(),
|
||||
Destination: r.DestinationPrefix.Prefix(),
|
||||
NextHop: r.NextHop.Addr(),
|
||||
Metric: r.Metric,
|
||||
})
|
||||
}
|
||||
for _, r := range routes6 {
|
||||
rd = append(rd, &winipcfg.RouteData{
|
||||
Destination: r.DestinationPrefix.IPNet(),
|
||||
NextHop: r.NextHop.IP(),
|
||||
Destination: r.DestinationPrefix.Prefix(),
|
||||
NextHop: r.NextHop.Addr(),
|
||||
Metric: r.Metric,
|
||||
})
|
||||
}
|
||||
@@ -777,8 +740,8 @@ func filterRoutes(routes []*winipcfg.RouteData, dontDelete []netip.Prefix) []*wi
|
||||
}
|
||||
for _, r := range routes {
|
||||
// We don't want to touch broadcast routes that Windows adds.
|
||||
nr, ok := netaddr.FromStdIPNet(&r.Destination)
|
||||
if !ok {
|
||||
nr := r.Destination
|
||||
if !nr.IsValid() {
|
||||
continue
|
||||
}
|
||||
if nr.IsSingleIP() {
|
||||
@@ -789,8 +752,8 @@ func filterRoutes(routes []*winipcfg.RouteData, dontDelete []netip.Prefix) []*wi
|
||||
}
|
||||
filtered := make([]*winipcfg.RouteData, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
rr, ok := netaddr.FromStdIPNet(&r.Destination)
|
||||
if ok && ddm[rr] {
|
||||
rr := r.Destination
|
||||
if rr.IsValid() && ddm[rr] {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, r)
|
||||
|
||||
@@ -7,41 +7,30 @@ package router
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go4.org/netipx"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func randIP() net.IP {
|
||||
func randIP() netip.Addr {
|
||||
b := byte(rand.Intn(3))
|
||||
return net.IP{b, b, b, b}
|
||||
return netip.AddrFrom4([4]byte{b, b, b, b})
|
||||
}
|
||||
|
||||
func randRouteData() *winipcfg.RouteData {
|
||||
return &winipcfg.RouteData{
|
||||
Destination: net.IPNet{
|
||||
IP: randIP(),
|
||||
Mask: net.CIDRMask(rand.Intn(3)+1, 32),
|
||||
},
|
||||
NextHop: randIP(),
|
||||
Metric: uint32(rand.Intn(3)),
|
||||
Destination: netip.PrefixFrom(randIP(), rand.Intn(30)+1),
|
||||
NextHop: randIP(),
|
||||
Metric: uint32(rand.Intn(3)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteLess(t *testing.T) {
|
||||
type D = winipcfg.RouteData
|
||||
ipnet := func(s string) net.IPNet {
|
||||
ipp, err := netip.ParsePrefix(s)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing test data %q: %v", s, err)
|
||||
}
|
||||
return *netipx.PrefixIPNet(ipp)
|
||||
}
|
||||
|
||||
ipnet := netip.MustParsePrefix
|
||||
tests := []struct {
|
||||
ri, rj *winipcfg.RouteData
|
||||
want bool
|
||||
@@ -72,76 +61,51 @@ func TestRouteLess(t *testing.T) {
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
ri: &D{Destination: ipnet("1.1.0.0/16"), Metric: 1, NextHop: net.ParseIP("3.3.3.3")},
|
||||
rj: &D{Destination: ipnet("1.1.0.0/16"), Metric: 1, NextHop: net.ParseIP("4.4.4.4")},
|
||||
ri: &D{Destination: ipnet("1.1.0.0/16"), Metric: 1, NextHop: netip.MustParseAddr("3.3.3.3")},
|
||||
rj: &D{Destination: ipnet("1.1.0.0/16"), Metric: 1, NextHop: netip.MustParseAddr("4.4.4.4")},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
got := routeLess(tt.ri, tt.rj)
|
||||
got := routeDataLess(tt.ri, tt.rj)
|
||||
if got != tt.want {
|
||||
t.Errorf("%v. less = %v; want %v", i, got, tt.want)
|
||||
}
|
||||
back := routeLess(tt.rj, tt.ri)
|
||||
back := routeDataLess(tt.rj, tt.ri)
|
||||
if back && got {
|
||||
t.Errorf("%v. less both ways", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteLessConsistent(t *testing.T) {
|
||||
func TestRouteDataLessConsistent(t *testing.T) {
|
||||
for i := 0; i < 10000; i++ {
|
||||
ri := randRouteData()
|
||||
rj := randRouteData()
|
||||
if routeLess(ri, rj) && routeLess(rj, ri) {
|
||||
if routeDataLess(ri, rj) && routeDataLess(rj, ri) {
|
||||
t.Fatalf("both compare less to each other:\n\t%#v\nand\n\t%#v", ri, rj)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func equalNetIPs(a, b []*net.IPNet) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if netCompare(*a[i], *b[i]) != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ipnet4(ip string, bits int) *net.IPNet {
|
||||
return &net.IPNet{
|
||||
IP: net.ParseIP(ip),
|
||||
Mask: net.CIDRMask(bits, 32),
|
||||
}
|
||||
}
|
||||
|
||||
// each cidr can end in "[4]" to mean To4 form.
|
||||
func nets(cidrs ...string) (ret []*net.IPNet) {
|
||||
func nets(cidrs ...string) (ret []netip.Prefix) {
|
||||
for _, s := range cidrs {
|
||||
to4 := strings.HasSuffix(s, "[4]")
|
||||
if to4 {
|
||||
s = strings.TrimSuffix(s, "[4]")
|
||||
}
|
||||
ip, ipNet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Bogus CIDR %q in test", s))
|
||||
}
|
||||
if to4 {
|
||||
ip = ip.To4()
|
||||
}
|
||||
ipNet.IP = ip
|
||||
ret = append(ret, ipNet)
|
||||
ret = append(ret, netip.MustParsePrefix(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func nilIfEmpty[E any](s []E) []E {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestDeltaNets(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b []*net.IPNet
|
||||
wantAdd, wantDel []*net.IPNet
|
||||
a, b []netip.Prefix
|
||||
wantAdd, wantDel []netip.Prefix
|
||||
}{
|
||||
{
|
||||
a: nets("1.2.3.4/24", "1.2.3.4/31", "1.2.3.3/32", "10.0.1.1/32", "100.0.1.1/32"),
|
||||
@@ -161,30 +125,16 @@ func TestDeltaNets(t *testing.T) {
|
||||
},
|
||||
{
|
||||
a: nets("100.84.36.11/32", "fe80::99d0:ec2d:b2e7:536b/64"),
|
||||
b: nets("100.84.36.11/32[4]"),
|
||||
b: nets("100.84.36.11/32"),
|
||||
wantDel: nets("fe80::99d0:ec2d:b2e7:536b/64"),
|
||||
},
|
||||
{
|
||||
a: []*net.IPNet{
|
||||
{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Mask: net.IPMask{0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
},
|
||||
b: []*net.IPNet{
|
||||
{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Mask: net.IPMask{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
add, del := deltaNets(tt.a, tt.b)
|
||||
if !equalNetIPs(add, tt.wantAdd) {
|
||||
if !reflect.DeepEqual(nilIfEmpty(add), nilIfEmpty(tt.wantAdd)) {
|
||||
t.Errorf("[%d] add:\n got: %v\n want: %v\n", i, add, tt.wantAdd)
|
||||
}
|
||||
if !equalNetIPs(del, tt.wantDel) {
|
||||
if !reflect.DeepEqual(nilIfEmpty(del), nilIfEmpty(tt.wantDel)) {
|
||||
t.Errorf("[%d] del:\n got: %v\n want: %v\n", i, del, tt.wantDel)
|
||||
}
|
||||
}
|
||||
@@ -210,35 +160,40 @@ func equalRouteDatas(a, b []*winipcfg.RouteData) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func ipnet4(ip string, bits int) netip.Prefix {
|
||||
return netip.PrefixFrom(netip.MustParseAddr(ip), bits)
|
||||
}
|
||||
|
||||
func TestFilterRoutes(t *testing.T) {
|
||||
var h0 net.IP
|
||||
var h0 netip.Addr
|
||||
|
||||
in := []*winipcfg.RouteData{
|
||||
// LinkLocal and Loopback routes.
|
||||
{*ipnet4("169.254.0.0", 16), h0, 1},
|
||||
{*ipnet4("169.254.255.255", 32), h0, 1},
|
||||
{*ipnet4("127.0.0.0", 8), h0, 1},
|
||||
{*ipnet4("127.255.255.255", 32), h0, 1},
|
||||
{ipnet4("169.254.0.0", 16), h0, 1},
|
||||
{ipnet4("169.254.255.255", 32), h0, 1},
|
||||
{ipnet4("127.0.0.0", 8), h0, 1},
|
||||
{ipnet4("127.255.255.255", 32), h0, 1},
|
||||
// Local LAN routes.
|
||||
{*ipnet4("192.168.0.0", 24), h0, 1},
|
||||
{*ipnet4("192.168.0.255", 32), h0, 1},
|
||||
{*ipnet4("192.168.1.0", 25), h0, 1},
|
||||
{*ipnet4("192.168.1.127", 32), h0, 1},
|
||||
{ipnet4("192.168.0.0", 24), h0, 1},
|
||||
{ipnet4("192.168.0.255", 32), h0, 1},
|
||||
{ipnet4("192.168.1.0", 25), h0, 1},
|
||||
{ipnet4("192.168.1.127", 32), h0, 1},
|
||||
// Some random other route.
|
||||
{*ipnet4("192.168.2.23", 32), h0, 1},
|
||||
{ipnet4("192.168.2.23", 32), h0, 1},
|
||||
// Our own tailscale address.
|
||||
{*ipnet4("100.100.100.100", 32), h0, 1},
|
||||
{ipnet4("100.100.100.100", 32), h0, 1},
|
||||
// Other tailscale addresses.
|
||||
{*ipnet4("100.100.100.101", 32), h0, 1},
|
||||
{*ipnet4("100.100.100.102", 32), h0, 1},
|
||||
{ipnet4("100.100.100.101", 32), h0, 1},
|
||||
{ipnet4("100.100.100.102", 32), h0, 1},
|
||||
}
|
||||
want := []*winipcfg.RouteData{
|
||||
{*ipnet4("169.254.0.0", 16), h0, 1},
|
||||
{*ipnet4("127.0.0.0", 8), h0, 1},
|
||||
{*ipnet4("192.168.0.0", 24), h0, 1},
|
||||
{*ipnet4("192.168.1.0", 25), h0, 1},
|
||||
{*ipnet4("192.168.2.23", 32), h0, 1},
|
||||
{*ipnet4("100.100.100.101", 32), h0, 1},
|
||||
{*ipnet4("100.100.100.102", 32), h0, 1},
|
||||
{ipnet4("169.254.0.0", 16), h0, 1},
|
||||
{ipnet4("127.0.0.0", 8), h0, 1},
|
||||
{ipnet4("192.168.0.0", 24), h0, 1},
|
||||
{ipnet4("192.168.1.0", 25), h0, 1},
|
||||
{ipnet4("192.168.2.23", 32), h0, 1},
|
||||
{ipnet4("100.100.100.101", 32), h0, 1},
|
||||
{ipnet4("100.100.100.102", 32), h0, 1},
|
||||
}
|
||||
|
||||
got := filterRoutes(in, mustCIDRs("100.100.100.100/32"))
|
||||
@@ -248,29 +203,29 @@ func TestFilterRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeltaRouteData(t *testing.T) {
|
||||
var h0 net.IP
|
||||
h1 := net.ParseIP("99.99.99.99")
|
||||
h2 := net.ParseIP("99.99.9.99")
|
||||
var h0 netip.Addr
|
||||
h1 := netip.MustParseAddr("99.99.99.99")
|
||||
h2 := netip.MustParseAddr("99.99.9.99")
|
||||
|
||||
a := []*winipcfg.RouteData{
|
||||
{*ipnet4("1.2.3.4", 32), h0, 1},
|
||||
{*ipnet4("1.2.3.4", 24), h1, 2},
|
||||
{*ipnet4("1.2.3.4", 24), h2, 1},
|
||||
{*ipnet4("1.2.3.5", 32), h0, 1},
|
||||
{ipnet4("1.2.3.4", 32), h0, 1},
|
||||
{ipnet4("1.2.3.4", 24), h1, 2},
|
||||
{ipnet4("1.2.3.4", 24), h2, 1},
|
||||
{ipnet4("1.2.3.5", 32), h0, 1},
|
||||
}
|
||||
b := []*winipcfg.RouteData{
|
||||
{*ipnet4("1.2.3.5", 32), h0, 1},
|
||||
{*ipnet4("1.2.3.4", 24), h1, 2},
|
||||
{*ipnet4("1.2.3.4", 24), h2, 2},
|
||||
{ipnet4("1.2.3.5", 32), h0, 1},
|
||||
{ipnet4("1.2.3.4", 24), h1, 2},
|
||||
{ipnet4("1.2.3.4", 24), h2, 2},
|
||||
}
|
||||
add, del := deltaRouteData(a, b)
|
||||
|
||||
wantAdd := []*winipcfg.RouteData{
|
||||
{*ipnet4("1.2.3.4", 24), h2, 2},
|
||||
{ipnet4("1.2.3.4", 24), h2, 2},
|
||||
}
|
||||
wantDel := []*winipcfg.RouteData{
|
||||
{*ipnet4("1.2.3.4", 32), h0, 1},
|
||||
{*ipnet4("1.2.3.4", 24), h2, 1},
|
||||
{ipnet4("1.2.3.4", 32), h0, 1},
|
||||
{ipnet4("1.2.3.4", 24), h2, 1},
|
||||
}
|
||||
|
||||
if !equalRouteDatas(add, wantAdd) {
|
||||
|
||||
@@ -6,7 +6,6 @@ package wgengine
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go4.org/mem"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"tailscale.com/control/controlclient"
|
||||
@@ -46,13 +44,13 @@ import (
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/deephash"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/singleflight"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/magicsock"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/wgcfg"
|
||||
"tailscale.com/wgengine/wgint"
|
||||
"tailscale.com/wgengine/wglog"
|
||||
)
|
||||
|
||||
@@ -147,10 +145,6 @@ type userspaceEngine struct {
|
||||
// value of the ICMP identifer and sequence number concatenated.
|
||||
icmpEchoResponseCallback map[uint32]func()
|
||||
|
||||
// this singleflight is used to deduplicate calls to getStatus when we
|
||||
// don't care if the data is perfectly fresh
|
||||
getStatusSf singleflight.Group[struct{}, *Status]
|
||||
|
||||
// Lock ordering: magicsock.Conn.mu, wgLock, then mu.
|
||||
}
|
||||
|
||||
@@ -983,138 +977,51 @@ var singleNewline = []byte{'\n'}
|
||||
|
||||
var ErrEngineClosing = errors.New("engine closing; no status")
|
||||
|
||||
func (e *userspaceEngine) getPeerStatusLite(pk key.NodePublic) (status ipnstate.PeerStatusLite, ok bool) {
|
||||
e.wgLock.Lock()
|
||||
if e.wgdev == nil {
|
||||
e.wgLock.Unlock()
|
||||
return status, false
|
||||
}
|
||||
peer := e.wgdev.LookupPeer(pk.Raw32())
|
||||
e.wgLock.Unlock()
|
||||
if peer == nil {
|
||||
return status, false
|
||||
}
|
||||
status.NodeKey = pk
|
||||
status.RxBytes = int64(wgint.PeerRxBytes(peer))
|
||||
status.TxBytes = int64(wgint.PeerTxBytes(peer))
|
||||
status.LastHandshake = time.Unix(0, wgint.PeerLastHandshakeNano(peer))
|
||||
return status, true
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) getStatus() (*Status, error) {
|
||||
// Grab derpConns before acquiring wgLock to not violate lock ordering;
|
||||
// the DERPs method acquires magicsock.Conn.mu.
|
||||
// (See comment in userspaceEngine's declaration.)
|
||||
derpConns := e.magicConn.DERPs()
|
||||
|
||||
e.wgLock.Lock()
|
||||
defer e.wgLock.Unlock()
|
||||
|
||||
e.mu.Lock()
|
||||
closing := e.closing
|
||||
peerKeys := make([]key.NodePublic, len(e.peerSequence))
|
||||
copy(peerKeys, e.peerSequence)
|
||||
localAddrs := append([]tailcfg.Endpoint(nil), e.endpoints...)
|
||||
e.mu.Unlock()
|
||||
|
||||
if closing {
|
||||
return nil, ErrEngineClosing
|
||||
}
|
||||
|
||||
if e.wgdev == nil {
|
||||
// RequestStatus was invoked before the wgengine has
|
||||
// finished initializing. This can happen when wgegine
|
||||
// provides a callback to magicsock for endpoint
|
||||
// updates that calls RequestStatus.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close() // to unblock writes on error path returns
|
||||
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
// TODO(apenwarr): get rid of silly uapi stuff for in-process comms
|
||||
// FIXME: get notified of status changes instead of polling.
|
||||
err := e.wgdev.IpcGetOperation(pw)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("IpcGetOperation: %w", err)
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
pp := make(map[key.NodePublic]ipnstate.PeerStatusLite)
|
||||
var p ipnstate.PeerStatusLite
|
||||
|
||||
var hst1, hst2, n int64
|
||||
|
||||
br := e.statusBufioReader
|
||||
if br != nil {
|
||||
br.Reset(pr)
|
||||
} else {
|
||||
br = bufio.NewReaderSize(pr, 1<<10)
|
||||
e.statusBufioReader = br
|
||||
}
|
||||
for {
|
||||
line, err := br.ReadSlice('\n')
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading from UAPI pipe: %w", err)
|
||||
}
|
||||
line = bytes.TrimSuffix(line, singleNewline)
|
||||
k := line
|
||||
var v mem.RO
|
||||
if i := bytes.IndexByte(line, '='); i != -1 {
|
||||
k = line[:i]
|
||||
v = mem.B(line[i+1:])
|
||||
}
|
||||
switch string(k) {
|
||||
case "public_key":
|
||||
pk, err := key.ParseNodePublicUntyped(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: invalid key in line %q", line)
|
||||
}
|
||||
if !p.NodeKey.IsZero() {
|
||||
pp[p.NodeKey] = p
|
||||
}
|
||||
p = ipnstate.PeerStatusLite{NodeKey: pk}
|
||||
case "rx_bytes":
|
||||
n, err = mem.ParseInt(v, 10, 64)
|
||||
p.RxBytes = n
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: rx_bytes invalid: %#v", line)
|
||||
}
|
||||
case "tx_bytes":
|
||||
n, err = mem.ParseInt(v, 10, 64)
|
||||
p.TxBytes = n
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: tx_bytes invalid: %#v", line)
|
||||
}
|
||||
case "last_handshake_time_sec":
|
||||
hst1, err = mem.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: hst1 invalid: %#v", line)
|
||||
}
|
||||
case "last_handshake_time_nsec":
|
||||
hst2, err = mem.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: hst2 invalid: %#v", line)
|
||||
}
|
||||
if hst1 != 0 || hst2 != 0 {
|
||||
p.LastHandshake = time.Unix(hst1, hst2)
|
||||
} // else leave at time.IsZero()
|
||||
}
|
||||
}
|
||||
if !p.NodeKey.IsZero() {
|
||||
pp[p.NodeKey] = p
|
||||
}
|
||||
if err := <-errc; err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: %v", err)
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
// Do two passes, one to calculate size and the other to populate.
|
||||
// This code is sensitive to allocations.
|
||||
npeers := 0
|
||||
for _, pk := range e.peerSequence {
|
||||
if _, ok := pp[pk]; ok { // ignore idle ones not in wireguard-go's config
|
||||
npeers++
|
||||
}
|
||||
}
|
||||
|
||||
peers := make([]ipnstate.PeerStatusLite, 0, npeers)
|
||||
for _, pk := range e.peerSequence {
|
||||
if p, ok := pp[pk]; ok { // ignore idle ones not in wireguard-go's config
|
||||
peers = append(peers, p)
|
||||
peers := make([]ipnstate.PeerStatusLite, 0, len(peerKeys))
|
||||
for _, key := range peerKeys {
|
||||
if status, found := e.getPeerStatusLite(key); found {
|
||||
peers = append(peers, status)
|
||||
}
|
||||
}
|
||||
|
||||
return &Status{
|
||||
AsOf: time.Now(),
|
||||
LocalAddrs: append([]tailcfg.Endpoint(nil), e.endpoints...),
|
||||
LocalAddrs: localAddrs,
|
||||
Peers: peers,
|
||||
DERPs: derpConns,
|
||||
}, nil
|
||||
|
||||
@@ -8,10 +8,12 @@
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/envknob"
|
||||
@@ -38,21 +40,49 @@ func NewWatchdog(e Engine) Engine {
|
||||
return e
|
||||
}
|
||||
return &watchdogEngine{
|
||||
wrap: e,
|
||||
logf: log.Printf,
|
||||
fatalf: log.Fatalf,
|
||||
maxWait: 45 * time.Second,
|
||||
wrap: e,
|
||||
logf: log.Printf,
|
||||
fatalf: log.Fatalf,
|
||||
maxWait: 45 * time.Second,
|
||||
inFlight: make(map[inFlightKey]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
type inFlightKey struct {
|
||||
op string
|
||||
ctr uint64
|
||||
}
|
||||
|
||||
type watchdogEngine struct {
|
||||
wrap Engine
|
||||
logf func(format string, args ...any)
|
||||
fatalf func(format string, args ...any)
|
||||
maxWait time.Duration
|
||||
|
||||
// Track the start time(s) of in-flight operations
|
||||
inFlightMu sync.Mutex
|
||||
inFlight map[inFlightKey]time.Time
|
||||
inFlightCtr uint64
|
||||
}
|
||||
|
||||
func (e *watchdogEngine) watchdogErr(name string, fn func() error) error {
|
||||
// Track all in-flight operations so we can print more useful error
|
||||
// messages on watchdog failure
|
||||
e.inFlightMu.Lock()
|
||||
key := inFlightKey{
|
||||
op: name,
|
||||
ctr: e.inFlightCtr,
|
||||
}
|
||||
e.inFlightCtr++
|
||||
e.inFlight[key] = time.Now()
|
||||
e.inFlightMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
e.inFlightMu.Lock()
|
||||
defer e.inFlightMu.Unlock()
|
||||
delete(e.inFlight, key)
|
||||
}()
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- fn()
|
||||
@@ -66,6 +96,22 @@ func (e *watchdogEngine) watchdogErr(name string, fn func() error) error {
|
||||
buf := new(strings.Builder)
|
||||
pprof.Lookup("goroutine").WriteTo(buf, 1)
|
||||
e.logf("wgengine watchdog stacks:\n%s", buf.String())
|
||||
|
||||
// Collect the list of in-flight operations for debugging.
|
||||
var (
|
||||
b []byte
|
||||
now = time.Now()
|
||||
)
|
||||
e.inFlightMu.Lock()
|
||||
for k, t := range e.inFlight {
|
||||
dur := now.Sub(t).Round(time.Millisecond)
|
||||
b = fmt.Appendf(b, "in-flight[%d]: name=%s duration=%v start=%s\n", k.ctr, k.op, dur, t.Format(time.RFC3339Nano))
|
||||
}
|
||||
e.inFlightMu.Unlock()
|
||||
|
||||
// Print everything as a single string to avoid log
|
||||
// rate limits.
|
||||
e.logf("wgengine watchdog in-flight:\n%s", b)
|
||||
e.fatalf("wgengine: watchdog timeout on %s", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,13 +5,9 @@
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tstest"
|
||||
)
|
||||
|
||||
func TestWatchdog(t *testing.T) {
|
||||
@@ -41,43 +37,4 @@ func TestWatchdog(t *testing.T) {
|
||||
e.RequestStatus()
|
||||
e.Close()
|
||||
})
|
||||
|
||||
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
e, err := NewFakeUserspaceEngine(t.Logf, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(e.Close)
|
||||
usEngine := e.(*userspaceEngine)
|
||||
e = NewWatchdog(e)
|
||||
wdEngine := e.(*watchdogEngine)
|
||||
wdEngine.maxWait = maxWaitMultiple * 100 * time.Millisecond
|
||||
|
||||
logBuf := new(tstest.MemLogger)
|
||||
fatalCalled := make(chan struct{})
|
||||
wdEngine.logf = logBuf.Logf
|
||||
wdEngine.fatalf = func(format string, args ...any) {
|
||||
t.Logf("FATAL: %s", fmt.Sprintf(format, args...))
|
||||
fatalCalled <- struct{}{}
|
||||
}
|
||||
|
||||
usEngine.wgLock.Lock() // blocks getStatus so the watchdog will fire
|
||||
|
||||
go e.RequestStatus()
|
||||
|
||||
select {
|
||||
case <-fatalCalled:
|
||||
if !strings.Contains(logBuf.String(), "goroutine profile: total ") {
|
||||
t.Errorf("fatal called without watchdog stacks, got: %s", logBuf.String())
|
||||
}
|
||||
// expected
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("watchdog failed to fire")
|
||||
}
|
||||
|
||||
usEngine.wgLock.Unlock()
|
||||
wdEngine.fatalf = t.Fatalf
|
||||
wdEngine.Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -16,8 +16,11 @@ import (
|
||||
|
||||
// ToUAPI writes cfg in UAPI format to w.
|
||||
// Prev is the previous device Config.
|
||||
// Prev is required so that we can remove now-defunct peers
|
||||
// without having to remove and re-add all peers.
|
||||
//
|
||||
// Prev is required so that we can remove now-defunct peers without having to
|
||||
// remove and re-add all peers, and so that we can avoid writing information
|
||||
// about peers that have not changed since the previous time we wrote our
|
||||
// Config.
|
||||
func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
var stickyErr error
|
||||
set := func(key, value string) {
|
||||
@@ -49,13 +52,33 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
// Add/configure all new peers.
|
||||
for _, p := range cfg.Peers {
|
||||
oldPeer, wasPresent := old[p.PublicKey]
|
||||
|
||||
// We only want to write the peer header/version if we're about
|
||||
// to change something about that peer, or if it's a new peer.
|
||||
// Figure out up-front whether we'll need to do anything for
|
||||
// this peer, and skip doing anything if not.
|
||||
//
|
||||
// If the peer was not present in the previous config, this
|
||||
// implies that this is a new peer; set all of these to 'true'
|
||||
// to ensure that we're writing the full peer configuration.
|
||||
willSetEndpoint := oldPeer.WGEndpoint != p.PublicKey || !wasPresent
|
||||
willChangeIPs := !cidrsEqual(oldPeer.AllowedIPs, p.AllowedIPs) || !wasPresent
|
||||
willChangeKeepalive := oldPeer.PersistentKeepalive != p.PersistentKeepalive || !wasPresent
|
||||
|
||||
if !willSetEndpoint && !willChangeIPs && !willChangeKeepalive {
|
||||
// It's safe to skip doing anything here; wireguard-go
|
||||
// will not remove a peer if it's unspecified unless we
|
||||
// tell it to (which we do below if necessary).
|
||||
continue
|
||||
}
|
||||
|
||||
setPeer(p)
|
||||
set("protocol_version", "1")
|
||||
|
||||
// Avoid setting endpoints if the correct one is already known
|
||||
// to WireGuard, because doing so generates a bit more work in
|
||||
// calling magicsock's ParseEndpoint for effectively a no-op.
|
||||
if oldPeer.WGEndpoint != p.PublicKey {
|
||||
if willSetEndpoint {
|
||||
if wasPresent {
|
||||
// We had an endpoint, and it was wrong.
|
||||
// By construction, this should not happen.
|
||||
@@ -72,7 +95,7 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
// If p.AllowedIPs is a strict superset of oldPeer.AllowedIPs,
|
||||
// then skip replace_allowed_ips and instead add only
|
||||
// the new ipps with allowed_ip.
|
||||
if !cidrsEqual(oldPeer.AllowedIPs, p.AllowedIPs) {
|
||||
if willChangeIPs {
|
||||
set("replace_allowed_ips", "true")
|
||||
for _, ipp := range p.AllowedIPs {
|
||||
set("allowed_ip", ipp.String())
|
||||
@@ -81,7 +104,7 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
|
||||
// Set PersistentKeepalive after the peer is otherwise configured,
|
||||
// because it can trigger handshake packets.
|
||||
if oldPeer.PersistentKeepalive != p.PersistentKeepalive {
|
||||
if willChangeKeepalive {
|
||||
setUint16("persistent_keepalive_interval", p.PersistentKeepalive)
|
||||
}
|
||||
}
|
||||
|
||||
59
wgengine/wgint/wgint.go
Normal file
59
wgengine/wgint/wgint.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package wgint provides somewhat shady access to wireguard-go
|
||||
// internals that don't (yet) have public APIs.
|
||||
package wgint
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
var (
|
||||
offHandshake = getPeerStatsOffset("lastHandshakeNano")
|
||||
offRxBytes = getPeerStatsOffset("rxBytes")
|
||||
offTxBytes = getPeerStatsOffset("txBytes")
|
||||
)
|
||||
|
||||
func getPeerStatsOffset(name string) uintptr {
|
||||
peerType := reflect.TypeOf(device.Peer{})
|
||||
sf, ok := peerType.FieldByName("stats")
|
||||
if !ok {
|
||||
panic("no stats field in device.Peer")
|
||||
}
|
||||
if sf.Type.Kind() != reflect.Struct {
|
||||
panic("stats field is not a struct")
|
||||
}
|
||||
base := sf.Offset
|
||||
|
||||
st := sf.Type
|
||||
field, ok := st.FieldByName(name)
|
||||
if !ok {
|
||||
panic("no " + name + " field in device.Peer.stats")
|
||||
}
|
||||
if field.Type.Kind() != reflect.Int64 && field.Type.Kind() != reflect.Uint64 {
|
||||
panic("unexpected kind of " + name + " field in device.Peer.stats")
|
||||
}
|
||||
return base + field.Offset
|
||||
}
|
||||
|
||||
// PeerLastHandshakeNano returns the last handshake time in nanoseconds since the
|
||||
// unix epoch.
|
||||
func PeerLastHandshakeNano(peer *device.Peer) int64 {
|
||||
return atomic.LoadInt64((*int64)(unsafe.Add(unsafe.Pointer(peer), offHandshake)))
|
||||
}
|
||||
|
||||
// PeerRxBytes returns the number of bytes received from this peer.
|
||||
func PeerRxBytes(peer *device.Peer) uint64 {
|
||||
return atomic.LoadUint64((*uint64)(unsafe.Add(unsafe.Pointer(peer), offRxBytes)))
|
||||
}
|
||||
|
||||
// PeerTxBytes returns the number of bytes sent to this peer.
|
||||
func PeerTxBytes(peer *device.Peer) uint64 {
|
||||
return atomic.LoadUint64((*uint64)(unsafe.Add(unsafe.Pointer(peer), offTxBytes)))
|
||||
}
|
||||
24
wgengine/wgint/wgint_test.go
Normal file
24
wgengine/wgint/wgint_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgint
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func TestPeerStats(t *testing.T) {
|
||||
peer := new(device.Peer)
|
||||
if got := PeerLastHandshakeNano(peer); got != 0 {
|
||||
t.Errorf("PeerLastHandshakeNano = %v, want 0", got)
|
||||
}
|
||||
if got := PeerRxBytes(peer); got != 0 {
|
||||
t.Errorf("PeerRxBytes = %v, want 0", got)
|
||||
}
|
||||
if got := PeerTxBytes(peer); got != 0 {
|
||||
t.Errorf("PeerTxBytes = %v, want 0", got)
|
||||
}
|
||||
}
|
||||
@@ -113,6 +113,11 @@ bee
|
||||
bearded
|
||||
beardie
|
||||
pogona
|
||||
chicken
|
||||
hen
|
||||
rooster
|
||||
quail
|
||||
grouse
|
||||
|
||||
# Musical scales
|
||||
acoustic
|
||||
|
||||
@@ -217,3 +217,8 @@ dragon
|
||||
bearded
|
||||
beardie
|
||||
pogona
|
||||
chicken
|
||||
hen
|
||||
rooster
|
||||
quail
|
||||
grouse
|
||||
|
||||
Reference in New Issue
Block a user