Compare commits
36 Commits
ip6tables
...
gitops-1.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b14e31831a | ||
|
|
0b00b7a135 | ||
|
|
70ed22ccf9 | ||
|
|
7ca17b6bdb | ||
|
|
e945d87d76 | ||
|
|
1ac4a26fee | ||
|
|
761163815c | ||
|
|
9f6c8517e0 | ||
|
|
27f36f77c3 | ||
|
|
122bd667dc | ||
|
|
21cd402204 | ||
|
|
0ae0439668 | ||
|
|
6dcc6313a6 | ||
|
|
78dbb59a00 | ||
|
|
7e40071571 | ||
|
|
90dc0e1702 | ||
|
|
2c18517121 | ||
|
|
d6c3588ed3 | ||
|
|
81dba3738e | ||
|
|
ad1cc6cff9 | ||
|
|
68d9d161f4 | ||
|
|
c66f99fcdc | ||
|
|
08b3f5f070 | ||
|
|
66d7d2549f | ||
|
|
d20392d413 | ||
|
|
58cc049a9f | ||
|
|
9b77ac128a | ||
|
|
e1738ea78e | ||
|
|
9bf13fc3d1 | ||
|
|
ab7e6f3f11 | ||
|
|
c5b1565337 | ||
|
|
d2e2d8438b | ||
|
|
23c3831ff9 | ||
|
|
296b008b9f | ||
|
|
31bf3874d6 | ||
|
|
e0c5ac1f02 |
7
.github/workflows/cross-darwin.yml
vendored
7
.github/workflows/cross-darwin.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: macOS build cmd
|
||||
env:
|
||||
GOOS: darwin
|
||||
|
||||
7
.github/workflows/cross-freebsd.yml
vendored
7
.github/workflows/cross-freebsd.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: FreeBSD build cmd
|
||||
env:
|
||||
GOOS: freebsd
|
||||
|
||||
7
.github/workflows/cross-openbsd.yml
vendored
7
.github/workflows/cross-openbsd.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: OpenBSD build cmd
|
||||
env:
|
||||
GOOS: openbsd
|
||||
|
||||
7
.github/workflows/cross-wasm.yml
vendored
7
.github/workflows/cross-wasm.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Wasm client build
|
||||
env:
|
||||
GOOS: js
|
||||
|
||||
7
.github/workflows/cross-windows.yml
vendored
7
.github/workflows/cross-windows.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Windows build cmd
|
||||
env:
|
||||
GOOS: windows
|
||||
|
||||
8
.github/workflows/depaware.yml
vendored
8
.github/workflows/depaware.yml
vendored
@@ -17,13 +17,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: depaware
|
||||
run: go run github.com/tailscale/depaware --check
|
||||
|
||||
10
.github/workflows/go_generate.yml
vendored
10
.github/workflows/go_generate.yml
vendored
@@ -18,16 +18,16 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: check 'go generate' is clean
|
||||
run: |
|
||||
if [[ "${{github.ref}}" == release-branch/* ]]
|
||||
|
||||
35
.github/workflows/go_mod_tidy.yml
vendored
Normal file
35
.github/workflows/go_mod_tidy.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: go mod tidy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- "*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: check 'go mod tidy' is clean
|
||||
run: |
|
||||
go mod tidy
|
||||
echo
|
||||
echo
|
||||
git diff --name-only --exit-code || (echo "Please run 'go mod tidy'."; exit 1)
|
||||
8
.github/workflows/license.yml
vendored
8
.github/workflows/license.yml
vendored
@@ -17,13 +17,13 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Run license checker
|
||||
run: ./scripts/check_license_headers.sh .
|
||||
|
||||
7
.github/workflows/linux-race.yml
vendored
7
.github/workflows/linux-race.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Basic build
|
||||
run: go build ./cmd/...
|
||||
|
||||
|
||||
7
.github/workflows/linux.yml
vendored
7
.github/workflows/linux.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Basic build
|
||||
run: go build ./cmd/...
|
||||
|
||||
|
||||
7
.github/workflows/linux32.yml
vendored
7
.github/workflows/linux32.yml
vendored
@@ -19,16 +19,15 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
go-version-file: go.mod
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Basic build
|
||||
run: GOARCH=386 go build ./cmd/...
|
||||
|
||||
|
||||
6
.github/workflows/static-analysis.yml
vendored
6
.github/workflows/static-analysis.yml
vendored
@@ -16,12 +16,12 @@ jobs:
|
||||
gofmt:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
- name: Run gofmt (goimports)
|
||||
run: go run golang.org/x/tools/cmd/goimports -d --format-only .
|
||||
- uses: k0kubun/action-slack@v2.0.0
|
||||
|
||||
30
.github/workflows/tsconnect-pkg-publish.yml
vendored
Normal file
30
.github/workflows/tsconnect-pkg-publish.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: "@tailscale/connect npm publish"
|
||||
|
||||
on: workflow_dispatch
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up node
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "16.x"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
|
||||
- name: Build package
|
||||
# Build with build_dist.sh to ensure that version information is embedded.
|
||||
# GOROOT is specified so that the Go/Wasm that is trigged by build-pk
|
||||
# also picks up our custom Go toolchain.
|
||||
run: |
|
||||
./build_dist.sh tailscale.com/cmd/tsconnect
|
||||
GOROOT="${HOME}/.cache/tailscale-go" ./tsconnect build-pkg
|
||||
|
||||
- name: Publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.TSCONNECT_NPM_PUBLISH_AUTH_TOKEN }}
|
||||
run: ./tool/yarn --cwd ./cmd/tsconnect/pkg publish --access public
|
||||
8
.github/workflows/vm.yml
vendored
8
.github/workflows/vm.yml
vendored
@@ -19,13 +19,13 @@ jobs:
|
||||
- name: Set GOPATH
|
||||
run: echo "GOPATH=$HOME/go" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Run VM tests
|
||||
run: go test ./tstest/integration/vms -v -no-s3 -run-vm-tests -run=TestRunUbuntu2004
|
||||
|
||||
7
.github/workflows/windows.yml
vendored
7
.github/workflows/windows.yml
vendored
@@ -19,14 +19,13 @@ jobs:
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19.x
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Restore Cache
|
||||
uses: actions/cache@v3
|
||||
|
||||
@@ -1 +1 @@
|
||||
1.29.0
|
||||
1.30.0
|
||||
|
||||
@@ -11,15 +11,31 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum amount of time we should wait when reading a response from BIRD.
|
||||
responseTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// New creates a BIRDClient.
|
||||
func New(socket string) (*BIRDClient, error) {
|
||||
return newWithTimeout(socket, responseTimeout)
|
||||
}
|
||||
|
||||
func newWithTimeout(socket string, timeout time.Duration) (*BIRDClient, error) {
|
||||
conn, err := net.Dial("unix", socket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to BIRD: %w", err)
|
||||
}
|
||||
b := &BIRDClient{socket: socket, conn: conn, scanner: bufio.NewScanner(conn)}
|
||||
b := &BIRDClient{
|
||||
socket: socket,
|
||||
conn: conn,
|
||||
scanner: bufio.NewScanner(conn),
|
||||
timeNow: time.Now,
|
||||
timeout: timeout,
|
||||
}
|
||||
// Read and discard the first line as that is the welcome message.
|
||||
if _, err := b.readResponse(); err != nil {
|
||||
return nil, err
|
||||
@@ -32,6 +48,8 @@ type BIRDClient struct {
|
||||
socket string
|
||||
conn net.Conn
|
||||
scanner *bufio.Scanner
|
||||
timeNow func() time.Time
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// Close closes the underlying connection to BIRD.
|
||||
@@ -81,10 +99,15 @@ func (b *BIRDClient) EnableProtocol(protocol string) error {
|
||||
// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’.
|
||||
|
||||
func (b *BIRDClient) exec(cmd string, args ...any) (string, error) {
|
||||
if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fmt.Fprintln(b.conn)
|
||||
if _, err := fmt.Fprintln(b.conn); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return b.readResponse()
|
||||
}
|
||||
|
||||
@@ -105,14 +128,20 @@ func hasResponseCode(s []byte) bool {
|
||||
}
|
||||
|
||||
func (b *BIRDClient) readResponse() (string, error) {
|
||||
// Set the read timeout before we start reading anything.
|
||||
if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var resp strings.Builder
|
||||
var done bool
|
||||
for !done {
|
||||
if !b.scanner.Scan() {
|
||||
return "", fmt.Errorf("reading response from bird failed: %q", resp.String())
|
||||
}
|
||||
if err := b.scanner.Err(); err != nil {
|
||||
return "", err
|
||||
if err := b.scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String())
|
||||
}
|
||||
out := b.scanner.Bytes()
|
||||
if _, err := resp.Write(out); err != nil {
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fakeBIRD struct {
|
||||
@@ -109,3 +112,82 @@ func TestChirp(t *testing.T) {
|
||||
t.Fatalf("disabling %q succeded", "rando")
|
||||
}
|
||||
}
|
||||
|
||||
type hangingListener struct {
|
||||
net.Listener
|
||||
t *testing.T
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
sock string
|
||||
}
|
||||
|
||||
func newHangingListener(t *testing.T) *hangingListener {
|
||||
sock := filepath.Join(t.TempDir(), "sock")
|
||||
l, err := net.Listen("unix", sock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &hangingListener{
|
||||
Listener: l,
|
||||
t: t,
|
||||
done: make(chan struct{}),
|
||||
sock: sock,
|
||||
}
|
||||
}
|
||||
|
||||
func (hl *hangingListener) Stop() {
|
||||
hl.Close()
|
||||
close(hl.done)
|
||||
hl.wg.Wait()
|
||||
}
|
||||
|
||||
func (hl *hangingListener) listen() error {
|
||||
for {
|
||||
c, err := hl.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
hl.wg.Add(1)
|
||||
go hl.handle(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (hl *hangingListener) handle(c net.Conn) {
|
||||
defer hl.wg.Done()
|
||||
|
||||
// Write our fake first line of response so that we get into the read loop
|
||||
fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.")
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hl.t.Logf("connection still hanging")
|
||||
case <-hl.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChirpTimeout(t *testing.T) {
|
||||
fb := newHangingListener(t)
|
||||
defer fb.Stop()
|
||||
go fb.listen()
|
||||
|
||||
c, err := newWithTimeout(fb.sock, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = c.EnableProtocol("tailscale")
|
||||
if err == nil {
|
||||
t.Fatal("got err=nil, want timeout")
|
||||
}
|
||||
if !os.IsTimeout(err) {
|
||||
t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
@@ -30,17 +31,14 @@ var (
|
||||
cacheFname = rootFlagSet.String("cache-file", "./version-cache.json", "filename for the previous known version hash")
|
||||
timeout = rootFlagSet.Duration("timeout", 5*time.Minute, "timeout for the entire CI run")
|
||||
githubSyntax = rootFlagSet.Bool("github-syntax", true, "use GitHub Action error syntax (https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#setting-an-error-message)")
|
||||
|
||||
modifiedExternallyFailure = make(chan struct{}, 1)
|
||||
)
|
||||
|
||||
func modifiedExternallyError() {
|
||||
if *githubSyntax {
|
||||
fmt.Printf("::error file=%s,line=1,col=1,title=Policy File Modified Externally::The policy file was modified externally in the admin console.\n", *policyFname)
|
||||
fmt.Printf("::warning file=%s,line=1,col=1,title=Policy File Modified Externally::The policy file was modified externally in the admin console.\n", *policyFname)
|
||||
} else {
|
||||
fmt.Printf("The policy file was modified externally in the admin console.\n")
|
||||
}
|
||||
modifiedExternallyFailure <- struct{}{}
|
||||
}
|
||||
|
||||
func apply(cache *Cache, tailnet, apiKey string) func(context.Context, []string) error {
|
||||
@@ -207,10 +205,6 @@ func main() {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(modifiedExternallyFailure) != 0 {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func sumFile(fname string) (string, error) {
|
||||
@@ -271,13 +265,16 @@ func applyNewACL(ctx context.Context, tailnet, apiKey, policyFname, oldEtag stri
|
||||
}
|
||||
|
||||
func testNewACLs(ctx context.Context, tailnet, apiKey, policyFname string) error {
|
||||
fin, err := os.Open(policyFname)
|
||||
data, err := os.ReadFile(policyFname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err = hujson.Standardize(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fin.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://api.tailscale.com/api/v2/tailnet/%s/acl/validate", tailnet), fin)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://api.tailscale.com/api/v2/tailnet/%s/acl/validate", tailnet), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -296,6 +296,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/wgengine/router from tailscale.com/ipn/ipnlocal+
|
||||
tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+
|
||||
tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal
|
||||
💣 tailscale.com/wgengine/wgint from tailscale.com/wgengine
|
||||
tailscale.com/wgengine/wglog from tailscale.com/wgengine
|
||||
W 💣 tailscale.com/wgengine/winnet from tailscale.com/wgengine/router
|
||||
golang.org/x/crypto/acme from tailscale.com/ipn/localapi
|
||||
@@ -404,6 +405,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
mime/quotedprintable from mime/multipart
|
||||
net from crypto/tls+
|
||||
net/http from expvar+
|
||||
net/http/httptest from tailscale.com/control/controlclient
|
||||
net/http/httptrace from github.com/tcnksm/go-httpstat+
|
||||
net/http/httputil from github.com/aws/smithy-go/transport/http+
|
||||
net/http/internal from net/http+
|
||||
|
||||
@@ -5,9 +5,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
esbuild "github.com/evanw/esbuild/pkg/api"
|
||||
"github.com/tailscale/hujson"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
func runBuildPkg() {
|
||||
@@ -41,4 +47,33 @@ func runBuildPkg() {
|
||||
log.Fatalf("Type generation failed: %v", err)
|
||||
}
|
||||
|
||||
if err := updateVersion(); err != nil {
|
||||
log.Fatalf("Cannot update version: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("Built package version %s", version.Long)
|
||||
}
|
||||
|
||||
func updateVersion() error {
|
||||
packageJSONBytes, err := os.ReadFile("package.json.tmpl")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not read package.json: %w", err)
|
||||
}
|
||||
|
||||
var packageJSON map[string]any
|
||||
packageJSONBytes, err = hujson.Standardize(packageJSONBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not standardize template package.json: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(packageJSONBytes, &packageJSON); err != nil {
|
||||
return fmt.Errorf("Could not unmarshal package.json: %w", err)
|
||||
}
|
||||
packageJSON["version"] = version.Long
|
||||
|
||||
packageJSONBytes, err = json.MarshalIndent(packageJSON, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not marshal package.json: %w", err)
|
||||
}
|
||||
|
||||
return os.WriteFile(path.Join(*pkgDir, "package.json"), packageJSONBytes, 0644)
|
||||
}
|
||||
|
||||
17
cmd/tsconnect/package.json.tmpl
Normal file
17
cmd/tsconnect/package.json.tmpl
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Template for the package.json that is generated by the build-pkg command.
|
||||
// The version number will be replaced by the current Tailscale client version
|
||||
// number.
|
||||
{
|
||||
"author": "Tailscale Inc.",
|
||||
"description": "Tailscale Connect SDK",
|
||||
"license": "BSD-3-Clause",
|
||||
"name": "tailscale-connect",
|
||||
"type": "module",
|
||||
"main": "./pkg.js",
|
||||
"types": "./pkg.d.ts",
|
||||
"version": "AUTO_GENERATED"
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
{
|
||||
"author": "Tailscale Inc.",
|
||||
"description": "Tailscale Connect SDK",
|
||||
"license": "BSD-3-Clause",
|
||||
"name": "@tailscale/connect",
|
||||
"type": "module",
|
||||
"main": "./pkg.js",
|
||||
"types": "./pkg.d.ts",
|
||||
"version": "0.0.5"
|
||||
}
|
||||
@@ -27,26 +27,39 @@ export function runSSHSession(
|
||||
|
||||
term.focus()
|
||||
|
||||
const sshSession = ipn.ssh(def.hostname, def.username, {
|
||||
writeFn: (input) => term.write(input),
|
||||
setReadFn: (hook) => (onDataHook = hook),
|
||||
let resizeObserver: ResizeObserver | undefined
|
||||
let handleBeforeUnload: ((e: BeforeUnloadEvent) => void) | undefined
|
||||
|
||||
const sshSession = ipn.ssh(def.hostname + "2", def.username, {
|
||||
writeFn(input) {
|
||||
term.write(input)
|
||||
},
|
||||
writeErrorFn(err) {
|
||||
console.error(err)
|
||||
term.write(err)
|
||||
},
|
||||
setReadFn(hook) {
|
||||
onDataHook = hook
|
||||
},
|
||||
rows: term.rows,
|
||||
cols: term.cols,
|
||||
onDone: () => {
|
||||
resizeObserver.disconnect()
|
||||
onDone() {
|
||||
resizeObserver?.disconnect()
|
||||
term.dispose()
|
||||
window.removeEventListener("beforeunload", handleBeforeUnload)
|
||||
if (handleBeforeUnload) {
|
||||
window.removeEventListener("beforeunload", handleBeforeUnload)
|
||||
}
|
||||
onDone()
|
||||
},
|
||||
})
|
||||
|
||||
// Make terminal and SSH session track the size of the containing DOM node.
|
||||
const resizeObserver = new ResizeObserver(() => fitAddon.fit())
|
||||
resizeObserver = new ResizeObserver(() => fitAddon.fit())
|
||||
resizeObserver.observe(termContainerNode)
|
||||
term.onResize(({ rows, cols }) => sshSession.resize(rows, cols))
|
||||
|
||||
// Close the session if the user closes the window without an explicit
|
||||
// exit.
|
||||
const handleBeforeUnload = () => sshSession.close()
|
||||
handleBeforeUnload = () => sshSession.close()
|
||||
window.addEventListener("beforeunload", handleBeforeUnload)
|
||||
}
|
||||
|
||||
1
cmd/tsconnect/src/types/wasm_js.d.ts
vendored
1
cmd/tsconnect/src/types/wasm_js.d.ts
vendored
@@ -19,6 +19,7 @@ declare global {
|
||||
username: string,
|
||||
termConfig: {
|
||||
writeFn: (data: string) => void
|
||||
writeErrorFn: (err: string) => void
|
||||
setReadFn: (readFn: (data: string) => void) => void
|
||||
rows: number
|
||||
cols: number
|
||||
|
||||
@@ -347,6 +347,7 @@ type jsSSHSession struct {
|
||||
|
||||
func (s *jsSSHSession) Run() {
|
||||
writeFn := s.termConfig.Get("writeFn")
|
||||
writeErrorFn := s.termConfig.Get("writeErrorFn")
|
||||
setReadFn := s.termConfig.Get("setReadFn")
|
||||
rows := s.termConfig.Get("rows").Int()
|
||||
cols := s.termConfig.Get("cols").Int()
|
||||
@@ -357,7 +358,7 @@ func (s *jsSSHSession) Run() {
|
||||
writeFn.Invoke(s)
|
||||
}
|
||||
writeError := func(label string, err error) {
|
||||
write(fmt.Sprintf("%s Error: %v\r\n", label, err))
|
||||
writeErrorFn.Invoke(fmt.Sprintf("%s Error: %v\r\n", label, err))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -73,6 +75,7 @@ type Direct struct {
|
||||
skipIPForwardingCheck bool
|
||||
pinger Pinger
|
||||
popBrowser func(url string) // or nil
|
||||
c2nHandler http.Handler // or nil
|
||||
|
||||
mu sync.Mutex // mutex guards the following fields
|
||||
serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key
|
||||
@@ -108,6 +111,7 @@ type Options struct {
|
||||
LinkMonitor *monitor.Mon // optional link monitor
|
||||
PopBrowserURL func(url string) // optional func to open browser
|
||||
Dialer *tsdial.Dialer // non-nil
|
||||
C2NHandler http.Handler // or nil
|
||||
|
||||
// GetNLPublicKey specifies an optional function to use
|
||||
// Network Lock. If nil, it's not used.
|
||||
@@ -210,6 +214,7 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
skipIPForwardingCheck: opts.SkipIPForwardingCheck,
|
||||
pinger: opts.Pinger,
|
||||
popBrowser: opts.PopBrowserURL,
|
||||
c2nHandler: opts.C2NHandler,
|
||||
dialer: opts.Dialer,
|
||||
}
|
||||
if opts.Hostinfo == nil {
|
||||
@@ -1205,7 +1210,8 @@ func (c *Direct) isUniquePingRequest(pr *tailcfg.PingRequest) bool {
|
||||
|
||||
func (c *Direct) answerPing(pr *tailcfg.PingRequest) {
|
||||
httpc := c.httpc
|
||||
if pr.URLIsNoise {
|
||||
useNoise := pr.URLIsNoise || pr.Types == "c2n" && c.noiseConfigured()
|
||||
if useNoise {
|
||||
nc, err := c.getNoiseClient()
|
||||
if err != nil {
|
||||
c.logf("failed to get noise client for ping request: %v", err)
|
||||
@@ -1217,9 +1223,17 @@ func (c *Direct) answerPing(pr *tailcfg.PingRequest) {
|
||||
c.logf("invalid PingRequest with no URL")
|
||||
return
|
||||
}
|
||||
if pr.Types == "" {
|
||||
switch pr.Types {
|
||||
case "":
|
||||
answerHeadPing(c.logf, httpc, pr)
|
||||
return
|
||||
case "c2n":
|
||||
if !useNoise && !envknob.Bool("TS_DEBUG_PERMIT_HTTP_C2N") {
|
||||
c.logf("refusing to answer c2n ping without noise")
|
||||
return
|
||||
}
|
||||
answerC2NPing(c.logf, c.c2nHandler, httpc, pr)
|
||||
return
|
||||
}
|
||||
for _, t := range strings.Split(pr.Types, ",") {
|
||||
switch pt := tailcfg.PingType(t); pt {
|
||||
@@ -1253,6 +1267,54 @@ func answerHeadPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr *tailcfg.PingRequest) {
|
||||
if c2nHandler == nil {
|
||||
logf("answerC2NPing: c2nHandler not defined")
|
||||
return
|
||||
}
|
||||
hreq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload)))
|
||||
if err != nil {
|
||||
logf("answerC2NPing: ReadRequest: %v", err)
|
||||
return
|
||||
}
|
||||
if pr.Log {
|
||||
logf("answerC2NPing: got c2n request for %v ...", hreq.RequestURI)
|
||||
}
|
||||
handlerTimeout := time.Minute
|
||||
if v := hreq.Header.Get("C2n-Handler-Timeout"); v != "" {
|
||||
handlerTimeout, _ = time.ParseDuration(v)
|
||||
}
|
||||
handlerCtx, cancel := context.WithTimeout(context.Background(), handlerTimeout)
|
||||
defer cancel()
|
||||
hreq = hreq.WithContext(handlerCtx)
|
||||
rec := httptest.NewRecorder()
|
||||
c2nHandler.ServeHTTP(rec, hreq)
|
||||
cancel()
|
||||
|
||||
c2nResBuf := new(bytes.Buffer)
|
||||
rec.Result().Write(c2nResBuf)
|
||||
|
||||
replyCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(replyCtx, "POST", pr.URL, c2nResBuf)
|
||||
if err != nil {
|
||||
logf("answerC2NPing: NewRequestWithContext: %v", err)
|
||||
return
|
||||
}
|
||||
if pr.Log {
|
||||
logf("answerC2NPing: sending POST ping to %v ...", pr.URL)
|
||||
}
|
||||
t0 := time.Now()
|
||||
_, err = c.Do(req)
|
||||
d := time.Since(t0).Round(time.Millisecond)
|
||||
if err != nil {
|
||||
logf("answerC2NPing error: %v to %v (after %v)", err, pr.URL, d)
|
||||
} else if pr.Log {
|
||||
logf("answerC2NPing complete to %v (after %v)", pr.URL, d)
|
||||
}
|
||||
}
|
||||
|
||||
func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- struct{}, d time.Duration) error {
|
||||
const maxSleep = 5 * time.Minute
|
||||
if d > maxSleep {
|
||||
|
||||
@@ -18,7 +18,7 @@ spec:
|
||||
command: ["/bin/sh"]
|
||||
args:
|
||||
- -c
|
||||
- sysctl -w net.ipv4.ip_forward=1
|
||||
- sysctl -w net.ipv4.ip_forward=1 -w net.ipv6.conf.all.forwarding=1
|
||||
resources:
|
||||
requests:
|
||||
cpu: 1m
|
||||
|
||||
21
ipn/ipnlocal/c2n.go
Normal file
21
ipn/ipnlocal/c2n.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ipnlocal
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (b *LocalBackend) handleC2N(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/echo":
|
||||
// Test handler.
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Write(body)
|
||||
default:
|
||||
http.Error(w, "unknown c2n path", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
@@ -955,6 +955,8 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
|
||||
hostinfo := hostinfo.New()
|
||||
hostinfo.BackendLogID = b.backendLogID
|
||||
hostinfo.FrontendLogID = opts.FrontendLogID
|
||||
hostinfo.Userspace.Set(wgengine.IsNetstack(b.e))
|
||||
hostinfo.UserspaceRouter.Set(wgengine.IsNetstackRouter(b.e))
|
||||
|
||||
if b.cc != nil {
|
||||
// TODO(apenwarr): avoid the need to reinit controlclient.
|
||||
@@ -1075,6 +1077,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
|
||||
PopBrowserURL: b.tellClientToBrowseToURL,
|
||||
Dialer: b.Dialer(),
|
||||
Status: b.setClientStatus,
|
||||
C2NHandler: http.HandlerFunc(b.handleC2N),
|
||||
|
||||
// Don't warn about broken Linux IP forwarding when
|
||||
// netstack is being used.
|
||||
|
||||
@@ -505,6 +505,8 @@ func osEmoji(os string) string {
|
||||
return "👿"
|
||||
case "openbsd":
|
||||
return "🐡"
|
||||
case "illumos":
|
||||
return "☀️"
|
||||
}
|
||||
return "👽"
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/version/distro"
|
||||
)
|
||||
|
||||
@@ -185,7 +186,10 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acmeKey: %w", err)
|
||||
}
|
||||
ac := &acme.Client{Key: key}
|
||||
ac := &acme.Client{
|
||||
Key: key,
|
||||
UserAgent: "tailscaled/" + version.Long,
|
||||
}
|
||||
|
||||
a, err := ac.GetReg(ctx, "" /* pre-RFC param */)
|
||||
switch {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -108,6 +109,10 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger {
|
||||
procID = 7
|
||||
}
|
||||
}
|
||||
|
||||
stdLogf := func(f string, a ...any) {
|
||||
fmt.Fprintf(cfg.Stderr, strings.TrimSuffix(f, "\n")+"\n", a...)
|
||||
}
|
||||
l := &Logger{
|
||||
privateID: cfg.PrivateID,
|
||||
stderr: cfg.Stderr,
|
||||
@@ -121,7 +126,7 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger {
|
||||
sentinel: make(chan int32, 16),
|
||||
drainLogs: cfg.DrainLogs,
|
||||
timeNow: cfg.TimeNow,
|
||||
bo: backoff.NewBackoff("logtail", logf, 30*time.Second),
|
||||
bo: backoff.NewBackoff("logtail", stdLogf, 30*time.Second),
|
||||
metricsDelta: cfg.MetricsDelta,
|
||||
|
||||
procID: procID,
|
||||
|
||||
@@ -91,6 +91,30 @@ func (c Config) hasDefaultIPResolversOnly() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// hasHostsWithoutSplitDNSRoutes reports whether c contains any Host entries
|
||||
// that aren't covered by a SplitDNS route suffix.
|
||||
func (c Config) hasHostsWithoutSplitDNSRoutes() bool {
|
||||
// TODO(bradfitz): this could be more efficient, but we imagine
|
||||
// the number of SplitDNS routes and/or hosts will be small.
|
||||
for host := range c.Hosts {
|
||||
if !c.hasSplitDNSRouteForHost(host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasSplitDNSRouteForHost reports whether c contains a SplitDNS route
|
||||
// that contains hosts.
|
||||
func (c Config) hasSplitDNSRouteForHost(host dnsname.FQDN) bool {
|
||||
for route := range c.Routes {
|
||||
if route.Contains(host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c Config) hasDefaultResolvers() bool {
|
||||
return len(c.DefaultResolvers) > 0
|
||||
}
|
||||
|
||||
@@ -207,9 +207,14 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig
|
||||
// case where cfg is entirely zero, in which case these
|
||||
// configs clear all Tailscale DNS settings.
|
||||
return rcfg, ocfg, nil
|
||||
case cfg.hasDefaultIPResolversOnly():
|
||||
// Trivial CorpDNS configuration, just override the OS
|
||||
// resolver.
|
||||
case cfg.hasDefaultIPResolversOnly() && !cfg.hasHostsWithoutSplitDNSRoutes():
|
||||
// Trivial CorpDNS configuration, just override the OS resolver.
|
||||
//
|
||||
// If there are hosts (ExtraRecords) that are not covered by an existing
|
||||
// SplitDNS route, then we don't go into this path so that we fall into
|
||||
// the next case and send the extra record hosts queries through
|
||||
// 100.100.100.100 instead where we can answer them.
|
||||
//
|
||||
// TODO: for OSes that support it, pass IP:port and DoH
|
||||
// addresses directly to OS.
|
||||
// https://github.com/tailscale/tailscale/issues/1666
|
||||
|
||||
@@ -199,6 +199,71 @@ func TestManager(t *testing.T) {
|
||||
"bradfitz.ts.com.", "2.3.4.5"),
|
||||
},
|
||||
},
|
||||
{
|
||||
// If Hosts are specified (i.e. ExtraRecords) that aren't a split
|
||||
// DNS route and a global resolver is specified, then make
|
||||
// everything go via 100.100.100.100.
|
||||
name: "hosts-with-global-dns-uses-quad100",
|
||||
split: true,
|
||||
in: Config{
|
||||
DefaultResolvers: mustRes("1.1.1.1", "9.9.9.9"),
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
os: OSConfig{
|
||||
Nameservers: mustIPs("100.100.100.100"),
|
||||
},
|
||||
rs: resolver.Config{
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
Routes: upstreams(".", "1.1.1.1", "9.9.9.9"),
|
||||
},
|
||||
},
|
||||
{
|
||||
// This is the above hosts-with-global-dns-uses-quad100 test but
|
||||
// verifying that if global DNS servers aren't set (the 1.1.1.1 and
|
||||
// 9.9.9.9 above), then we don't configure 100.100.100.100 as the
|
||||
// resolver.
|
||||
name: "hosts-without-global-dns-not-use-quad100",
|
||||
split: true,
|
||||
in: Config{
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
os: OSConfig{},
|
||||
rs: resolver.Config{
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
},
|
||||
{
|
||||
// This tests that ExtraRecords (foo.tld and bar.tld here) don't trigger forcing
|
||||
// traffic through 100.100.100.100 if there's Split DNS support and the extra
|
||||
// records are part of a split DNS route.
|
||||
name: "hosts-with-extrarecord-hosts-with-routes-no-quad100",
|
||||
split: true,
|
||||
in: Config{
|
||||
Routes: upstreams(
|
||||
"tld.", "4.4.4.4",
|
||||
),
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
os: OSConfig{
|
||||
Nameservers: mustIPs("4.4.4.4"),
|
||||
MatchDomains: fqdns("tld."),
|
||||
},
|
||||
rs: resolver.Config{
|
||||
Hosts: hosts(
|
||||
"foo.tld.", "1.2.3.4",
|
||||
"bar.tld.", "2.3.4.5"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "corp",
|
||||
in: Config{
|
||||
|
||||
@@ -388,10 +388,10 @@ func (m windowsManager) Close() error {
|
||||
// Windows DHCP client from sending dynamic DNS updates for our interface to
|
||||
// AD domain controllers.
|
||||
func (m windowsManager) disableDynamicUpdates() error {
|
||||
if err := m.setSingleDWORD(winutil.IPv4TCPIPInterfacePrefix, "EnableDNSUpdate", 0); err != nil {
|
||||
if err := m.setSingleDWORD(winutil.IPv4TCPIPInterfacePrefix, "DisableDynamicUpdate", 1); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.setSingleDWORD(winutil.IPv6TCPIPInterfacePrefix, "EnableDNSUpdate", 0); err != nil {
|
||||
if err := m.setSingleDWORD(winutil.IPv6TCPIPInterfacePrefix, "DisableDynamicUpdate", 1); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -77,7 +77,8 @@ type CapabilityVersion int
|
||||
// 38: 2022-08-11: added PingRequest.URLIsNoise
|
||||
// 39: 2022-08-15: clients can talk Noise over arbitrary HTTPS port
|
||||
// 40: 2022-08-22: added Node.KeySignature, PeersChangedPatch.KeySignature
|
||||
const CurrentCapabilityVersion CapabilityVersion = 40
|
||||
// 41: 2022-08-30: uses 100.100.100.100 for route-less ExtraRecords if global nameservers is set
|
||||
const CurrentCapabilityVersion CapabilityVersion = 41
|
||||
|
||||
type StableID string
|
||||
|
||||
@@ -464,25 +465,27 @@ type Service struct {
|
||||
// Because it contains pointers (slices), this type should not be used
|
||||
// as a value type.
|
||||
type Hostinfo struct {
|
||||
IPNVersion string `json:",omitempty"` // version of this code
|
||||
FrontendLogID string `json:",omitempty"` // logtail ID of frontend instance
|
||||
BackendLogID string `json:",omitempty"` // logtail ID of backend instance
|
||||
OS string `json:",omitempty"` // operating system the client runs on (a version.OS value)
|
||||
OSVersion string `json:",omitempty"` // operating system version, with optional distro prefix ("Debian 10.4", "Windows 10 Pro 10.0.19041")
|
||||
Desktop opt.Bool `json:",omitempty"` // if a desktop was detected on Linux
|
||||
Package string `json:",omitempty"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown)
|
||||
DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone12,3")
|
||||
Hostname string `json:",omitempty"` // name of the host the client runs on
|
||||
ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections
|
||||
ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user
|
||||
GoArch string `json:",omitempty"` // the host's GOARCH value (of the running binary)
|
||||
GoVersion string `json:",omitempty"` // Go version binary was built with
|
||||
RoutableIPs []netip.Prefix `json:",omitempty"` // set of IP ranges this client can route
|
||||
RequestTags []string `json:",omitempty"` // set of ACL tags this node wants to claim
|
||||
Services []Service `json:",omitempty"` // services advertised by this machine
|
||||
NetInfo *NetInfo `json:",omitempty"`
|
||||
SSH_HostKeys []string `json:"sshHostKeys,omitempty"` // if advertised
|
||||
Cloud string `json:",omitempty"`
|
||||
IPNVersion string `json:",omitempty"` // version of this code
|
||||
FrontendLogID string `json:",omitempty"` // logtail ID of frontend instance
|
||||
BackendLogID string `json:",omitempty"` // logtail ID of backend instance
|
||||
OS string `json:",omitempty"` // operating system the client runs on (a version.OS value)
|
||||
OSVersion string `json:",omitempty"` // operating system version, with optional distro prefix ("Debian 10.4", "Windows 10 Pro 10.0.19041")
|
||||
Desktop opt.Bool `json:",omitempty"` // if a desktop was detected on Linux
|
||||
Package string `json:",omitempty"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown)
|
||||
DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone12,3")
|
||||
Hostname string `json:",omitempty"` // name of the host the client runs on
|
||||
ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections
|
||||
ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user
|
||||
GoArch string `json:",omitempty"` // the host's GOARCH value (of the running binary)
|
||||
GoVersion string `json:",omitempty"` // Go version binary was built with
|
||||
RoutableIPs []netip.Prefix `json:",omitempty"` // set of IP ranges this client can route
|
||||
RequestTags []string `json:",omitempty"` // set of ACL tags this node wants to claim
|
||||
Services []Service `json:",omitempty"` // services advertised by this machine
|
||||
NetInfo *NetInfo `json:",omitempty"`
|
||||
SSH_HostKeys []string `json:"sshHostKeys,omitempty"` // if advertised
|
||||
Cloud string `json:",omitempty"`
|
||||
Userspace opt.Bool `json:",omitempty"` // if the client is running in userspace (netstack) mode
|
||||
UserspaceRouter opt.Bool `json:",omitempty"` // if the client's subnet router is running in userspace (netstack) mode
|
||||
|
||||
// NOTE: any new fields containing pointers in this type
|
||||
// require changes to Hostinfo.Equal.
|
||||
@@ -1155,12 +1158,15 @@ const (
|
||||
// PingRequest with Types and IP, will send a ping to the IP and send a POST
|
||||
// request containing a PingResponse to the URL containing results.
|
||||
type PingRequest struct {
|
||||
// URL is the URL to send a HEAD request to.
|
||||
// URL is the URL to reply to the PingRequest to.
|
||||
// It will be a unique URL each time. No auth headers are necessary.
|
||||
//
|
||||
// If the client sees multiple PingRequests with the same URL,
|
||||
// subsequent ones should be ignored.
|
||||
// If Types and IP are defined, then URL is the URL to send a POST request to.
|
||||
//
|
||||
// The HTTP method that the node should make back to URL depends on the other
|
||||
// fields of the PingRequest. If Types is defined, then URL is the URL to
|
||||
// send a POST request to. Otherwise, the node should just make a HEAD
|
||||
// request to URL.
|
||||
URL string
|
||||
|
||||
// URLIsNoise, if true, means that the client should hit URL over the Noise
|
||||
@@ -1173,11 +1179,22 @@ type PingRequest struct {
|
||||
|
||||
// Types is the types of ping that are initiated. Can be any PingType, comma
|
||||
// separated, e.g. "disco,TSMP"
|
||||
Types string
|
||||
//
|
||||
// As a special case, if Types is "c2n", then this PingRequest is a
|
||||
// client-to-node HTTP request. The HTTP request should be handled by this
|
||||
// node's c2n handler and the HTTP response sent in a POST to URL. For c2n,
|
||||
// the value of URLIsNoise is ignored and only the Noise transport (back to
|
||||
// the control plane) will be used, as if URLIsNoise were true.
|
||||
Types string `json:",omitempty"`
|
||||
|
||||
// IP is the ping target.
|
||||
// It is used in TSMP pings, if IP is invalid or empty then do a HEAD request to the URL.
|
||||
// IP is the ping target, when needed by the PingType(s) given in Types.
|
||||
IP netip.Addr
|
||||
|
||||
// Payload is the ping payload.
|
||||
//
|
||||
// It is only used for c2n requests, in which case it's an HTTP/1.0 or
|
||||
// HTTP/1.1-formatted HTTP request as parsable with http.ReadRequest.
|
||||
Payload []byte `json:",omitempty"`
|
||||
}
|
||||
|
||||
// PingResponse provides result information for a TSMP or Disco PingRequest.
|
||||
|
||||
@@ -115,25 +115,27 @@ func (src *Hostinfo) Clone() *Hostinfo {
|
||||
|
||||
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
|
||||
var _HostinfoCloneNeedsRegeneration = Hostinfo(struct {
|
||||
IPNVersion string
|
||||
FrontendLogID string
|
||||
BackendLogID string
|
||||
OS string
|
||||
OSVersion string
|
||||
Desktop opt.Bool
|
||||
Package string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
ShareeNode bool
|
||||
GoArch string
|
||||
GoVersion string
|
||||
RoutableIPs []netip.Prefix
|
||||
RequestTags []string
|
||||
Services []Service
|
||||
NetInfo *NetInfo
|
||||
SSH_HostKeys []string
|
||||
Cloud string
|
||||
IPNVersion string
|
||||
FrontendLogID string
|
||||
BackendLogID string
|
||||
OS string
|
||||
OSVersion string
|
||||
Desktop opt.Bool
|
||||
Package string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
ShareeNode bool
|
||||
GoArch string
|
||||
GoVersion string
|
||||
RoutableIPs []netip.Prefix
|
||||
RequestTags []string
|
||||
Services []Service
|
||||
NetInfo *NetInfo
|
||||
SSH_HostKeys []string
|
||||
Cloud string
|
||||
Userspace opt.Bool
|
||||
UserspaceRouter opt.Bool
|
||||
}{})
|
||||
|
||||
// Clone makes a deep copy of NetInfo.
|
||||
|
||||
@@ -37,6 +37,7 @@ func TestHostinfoEqual(t *testing.T) {
|
||||
"GoArch", "GoVersion",
|
||||
"RoutableIPs", "RequestTags",
|
||||
"Services", "NetInfo", "SSH_HostKeys", "Cloud",
|
||||
"Userspace", "UserspaceRouter",
|
||||
}
|
||||
if have := fieldsOf(reflect.TypeOf(Hostinfo{})); !reflect.DeepEqual(have, hiHandles) {
|
||||
t.Errorf("Hostinfo.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
||||
|
||||
@@ -271,29 +271,33 @@ func (v HostinfoView) Services() views.Slice[Service] { return views.SliceOf(
|
||||
func (v HostinfoView) NetInfo() NetInfoView { return v.ж.NetInfo.View() }
|
||||
func (v HostinfoView) SSH_HostKeys() views.Slice[string] { return views.SliceOf(v.ж.SSH_HostKeys) }
|
||||
func (v HostinfoView) Cloud() string { return v.ж.Cloud }
|
||||
func (v HostinfoView) Userspace() opt.Bool { return v.ж.Userspace }
|
||||
func (v HostinfoView) UserspaceRouter() opt.Bool { return v.ж.UserspaceRouter }
|
||||
func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.ж.Equal(v2.ж) }
|
||||
|
||||
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
|
||||
var _HostinfoViewNeedsRegeneration = Hostinfo(struct {
|
||||
IPNVersion string
|
||||
FrontendLogID string
|
||||
BackendLogID string
|
||||
OS string
|
||||
OSVersion string
|
||||
Desktop opt.Bool
|
||||
Package string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
ShareeNode bool
|
||||
GoArch string
|
||||
GoVersion string
|
||||
RoutableIPs []netip.Prefix
|
||||
RequestTags []string
|
||||
Services []Service
|
||||
NetInfo *NetInfo
|
||||
SSH_HostKeys []string
|
||||
Cloud string
|
||||
IPNVersion string
|
||||
FrontendLogID string
|
||||
BackendLogID string
|
||||
OS string
|
||||
OSVersion string
|
||||
Desktop opt.Bool
|
||||
Package string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
ShareeNode bool
|
||||
GoArch string
|
||||
GoVersion string
|
||||
RoutableIPs []netip.Prefix
|
||||
RequestTags []string
|
||||
Services []Service
|
||||
NetInfo *NetInfo
|
||||
SSH_HostKeys []string
|
||||
Cloud string
|
||||
Userspace opt.Bool
|
||||
UserspaceRouter opt.Bool
|
||||
}{})
|
||||
|
||||
// View returns a readonly view of NetInfo.
|
||||
|
||||
@@ -29,8 +29,6 @@ type State struct {
|
||||
|
||||
// DisablementSecrets are KDF-derived values which can be used
|
||||
// to turn off the TKA in the event of a consensus-breaking bug.
|
||||
// An AUM of type DisableNL should contain a secret when results
|
||||
// in one of these values when run through the disablement KDF.
|
||||
//
|
||||
// TODO(tom): This is an alpha feature, remove this mechanism once
|
||||
// we have confidence in our implementation.
|
||||
@@ -169,6 +167,9 @@ func (s State) applyVerifiedAUM(update AUM) (State, error) {
|
||||
if update.Meta != nil {
|
||||
k.Meta = update.Meta
|
||||
}
|
||||
if err := k.StaticValidate(); err != nil {
|
||||
return State{}, fmt.Errorf("updated key fails validation: %v", err)
|
||||
}
|
||||
out := s.cloneForUpdate(&update)
|
||||
for i := range out.Keys {
|
||||
if bytes.Equal(out.Keys[i].ID(), update.KeyID) {
|
||||
|
||||
@@ -181,6 +181,7 @@ func TestApplyUpdatesChain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApplyUpdateErrors(t *testing.T) {
|
||||
tooLargeVotes := uint(99999)
|
||||
tcs := []struct {
|
||||
Name string
|
||||
Updates []AUM
|
||||
@@ -205,6 +206,12 @@ func TestApplyUpdateErrors(t *testing.T) {
|
||||
State{},
|
||||
ErrNoSuchKey,
|
||||
},
|
||||
{
|
||||
"UpdateKey now fails validation",
|
||||
[]AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}},
|
||||
State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}},
|
||||
errors.New("updated key fails validation: excessive key weight: 99999 > 4096"),
|
||||
},
|
||||
{
|
||||
"Bad lastAUMHash",
|
||||
[]AUM{
|
||||
|
||||
@@ -374,6 +374,74 @@ func TestAddPingRequest(t *testing.T) {
|
||||
t.Error("all ping attempts failed")
|
||||
}
|
||||
|
||||
func TestC2NPingRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
env := newTestEnv(t)
|
||||
n1 := newTestNode(t, env)
|
||||
n1.StartDaemon()
|
||||
|
||||
n1.AwaitListening()
|
||||
n1.MustUp()
|
||||
n1.AwaitRunning()
|
||||
|
||||
gotPing := make(chan bool, 1)
|
||||
waitPing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("unexpected ping method %q", r.Method)
|
||||
}
|
||||
got, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("ping body read error: %v", err)
|
||||
}
|
||||
const want = "HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nabc"
|
||||
if string(got) != want {
|
||||
t.Errorf("body error\n got: %q\nwant: %q", got, want)
|
||||
}
|
||||
gotPing <- true
|
||||
}))
|
||||
defer waitPing.Close()
|
||||
|
||||
nodes := env.Control.AllNodes()
|
||||
if len(nodes) != 1 {
|
||||
t.Fatalf("expected 1 node, got %d nodes", len(nodes))
|
||||
}
|
||||
|
||||
nodeKey := nodes[0].Key
|
||||
|
||||
// Check that we get at least one ping reply after 10 tries.
|
||||
for try := 1; try <= 10; try++ {
|
||||
t.Logf("ping %v ...", try)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if err := env.Control.AwaitNodeInMapRequest(ctx, nodeKey); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cancel()
|
||||
|
||||
pr := &tailcfg.PingRequest{
|
||||
URL: fmt.Sprintf("%s/ping-%d", waitPing.URL, try),
|
||||
Log: true,
|
||||
Types: "c2n",
|
||||
Payload: []byte("POST /echo HTTP/1.0\r\nContent-Length: 3\r\n\r\nabc"),
|
||||
}
|
||||
if !env.Control.AddPingRequest(nodeKey, pr) {
|
||||
t.Logf("failed to AddPingRequest")
|
||||
continue
|
||||
}
|
||||
|
||||
// Wait for PingRequest to come back
|
||||
pingTimeout := time.NewTimer(2 * time.Second)
|
||||
defer pingTimeout.Stop()
|
||||
select {
|
||||
case <-gotPing:
|
||||
t.Logf("got ping; success")
|
||||
return
|
||||
case <-pingTimeout.C:
|
||||
// Try again.
|
||||
}
|
||||
}
|
||||
t.Error("all ping attempts failed")
|
||||
}
|
||||
|
||||
// Issue 2434: when "down" (WantRunning false), tailscaled shouldn't
|
||||
// be connected to control.
|
||||
func TestNoControlConnWhenDown(t *testing.T) {
|
||||
@@ -737,6 +805,7 @@ func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon {
|
||||
cmd.Args = append(cmd.Args, "-verbose=2")
|
||||
}
|
||||
cmd.Env = append(os.Environ(),
|
||||
"TS_DEBUG_PERMIT_HTTP_C2N=1",
|
||||
"TS_LOG_TARGET="+n.env.LogCatcherServer.URL,
|
||||
"HTTP_PROXY="+n.env.TrafficTrapServer.URL,
|
||||
"HTTPS_PROXY="+n.env.TrafficTrapServer.URL,
|
||||
|
||||
@@ -9,12 +9,13 @@ package logger
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func rusageMaxRSS() float64 {
|
||||
var ru syscall.Rusage
|
||||
err := syscall.Getrusage(syscall.RUSAGE_SELF, &ru)
|
||||
var ru unix.Rusage
|
||||
err := unix.Getrusage(unix.RUSAGE_SELF, &ru)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
179
util/cstruct/cstruct.go
Normal file
179
util/cstruct/cstruct.go
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package cstruct provides a helper for decoding binary data that is in the
|
||||
// form of a padded C structure.
|
||||
package cstruct
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"tailscale.com/util/endian"
|
||||
)
|
||||
|
||||
// Size of a pointer-typed value, in bits
|
||||
const pointerSize = 32 << (^uintptr(0) >> 63)
|
||||
|
||||
// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on
|
||||
// a 16- or 8-bit architecture any time soon.
|
||||
const is64Bit = pointerSize == 64
|
||||
|
||||
// Decoder reads and decodes padded fields from a slice of bytes. All fields
|
||||
// are decoded with native endianness.
|
||||
//
|
||||
// Methods of a Decoder do not return errors, but rather store any error within
|
||||
// the Decoder. The first error can be obtained via the Err method; after the
|
||||
// first error, methods will return the zero value for their type.
|
||||
type Decoder struct {
|
||||
b []byte
|
||||
off int
|
||||
err error
|
||||
dbuf [8]byte // for decoding
|
||||
}
|
||||
|
||||
// NewDecoder creates a Decoder from a byte slice.
|
||||
func NewDecoder(b []byte) *Decoder {
|
||||
return &Decoder{b: b}
|
||||
}
|
||||
|
||||
var errUnsupportedSize = errors.New("unsupported size")
|
||||
|
||||
func padBytes(offset, size int) int {
|
||||
if offset == 0 || size == 1 {
|
||||
return 0
|
||||
}
|
||||
remainder := offset % size
|
||||
return size - remainder
|
||||
}
|
||||
|
||||
func (d *Decoder) getField(b []byte) error {
|
||||
size := len(b)
|
||||
|
||||
// We only support fields that are multiples of 2 (or 1-sized)
|
||||
if size != 1 && size&1 == 1 {
|
||||
return errUnsupportedSize
|
||||
}
|
||||
|
||||
// Fields are aligned to their size
|
||||
padBytes := padBytes(d.off, size)
|
||||
if d.off+size+padBytes > len(d.b) {
|
||||
return io.EOF
|
||||
}
|
||||
d.off += padBytes
|
||||
|
||||
copy(b, d.b[d.off:d.off+size])
|
||||
d.off += size
|
||||
return nil
|
||||
}
|
||||
|
||||
// Err returns the first error that was encountered by this Decoder.
|
||||
func (d *Decoder) Err() error {
|
||||
return d.err
|
||||
}
|
||||
|
||||
// Offset returns the current read offset for data in the buffer.
|
||||
func (d *Decoder) Offset() int {
|
||||
return d.off
|
||||
}
|
||||
|
||||
// Byte returns a single byte from the buffer.
|
||||
func (d *Decoder) Byte() byte {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := d.getField(d.dbuf[0:1]); err != nil {
|
||||
d.err = err
|
||||
return 0
|
||||
}
|
||||
return d.dbuf[0]
|
||||
}
|
||||
|
||||
// Byte returns a number of bytes from the buffer based on the size of the
|
||||
// input slice. No padding is applied.
|
||||
//
|
||||
// If an error is encountered or this Decoder has previously encountered an
|
||||
// error, no changes are made to the provided buffer.
|
||||
func (d *Decoder) Bytes(b []byte) {
|
||||
if d.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// No padding for byte slices
|
||||
size := len(b)
|
||||
if d.off+size >= len(d.b) {
|
||||
d.err = io.EOF
|
||||
return
|
||||
}
|
||||
copy(b, d.b[d.off:d.off+size])
|
||||
d.off += size
|
||||
}
|
||||
|
||||
// Uint16 returns a uint16 decoded from the buffer.
|
||||
func (d *Decoder) Uint16() uint16 {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := d.getField(d.dbuf[0:2]); err != nil {
|
||||
d.err = err
|
||||
return 0
|
||||
}
|
||||
return endian.Native.Uint16(d.dbuf[0:2])
|
||||
}
|
||||
|
||||
// Uint32 returns a uint32 decoded from the buffer.
|
||||
func (d *Decoder) Uint32() uint32 {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := d.getField(d.dbuf[0:4]); err != nil {
|
||||
d.err = err
|
||||
return 0
|
||||
}
|
||||
return endian.Native.Uint32(d.dbuf[0:4])
|
||||
}
|
||||
|
||||
// Uint64 returns a uint64 decoded from the buffer.
|
||||
func (d *Decoder) Uint64() uint64 {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if err := d.getField(d.dbuf[0:8]); err != nil {
|
||||
d.err = err
|
||||
return 0
|
||||
}
|
||||
return endian.Native.Uint64(d.dbuf[0:8])
|
||||
}
|
||||
|
||||
// Uintptr returns a uintptr decoded from the buffer.
|
||||
func (d *Decoder) Uintptr() uintptr {
|
||||
if d.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if is64Bit {
|
||||
return uintptr(d.Uint64())
|
||||
} else {
|
||||
return uintptr(d.Uint32())
|
||||
}
|
||||
}
|
||||
|
||||
// Int16 returns a int16 decoded from the buffer.
|
||||
func (d *Decoder) Int16() int16 {
|
||||
return int16(d.Uint16())
|
||||
}
|
||||
|
||||
// Int32 returns a int32 decoded from the buffer.
|
||||
func (d *Decoder) Int32() int32 {
|
||||
return int32(d.Uint32())
|
||||
}
|
||||
|
||||
// Int64 returns a int64 decoded from the buffer.
|
||||
func (d *Decoder) Int64() int64 {
|
||||
return int64(d.Uint64())
|
||||
}
|
||||
75
util/cstruct/cstruct_example_test.go
Normal file
75
util/cstruct/cstruct_example_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Only built on 64-bit platforms to avoid complexity
|
||||
|
||||
//go:build amd64 || arm64 || mips64le || ppc64le || riscv64
|
||||
// +build amd64 arm64 mips64le ppc64le riscv64
|
||||
|
||||
package cstruct
|
||||
|
||||
import "fmt"
|
||||
|
||||
// This test provides a semi-realistic example of how you can
|
||||
// use this package to decode a C structure.
|
||||
func ExampleDecoder() {
|
||||
// Our example C structure:
|
||||
// struct mystruct {
|
||||
// char *p;
|
||||
// char c;
|
||||
// /* implicit: char _pad[3]; */
|
||||
// int x;
|
||||
// };
|
||||
//
|
||||
// The Go structure definition:
|
||||
type myStruct struct {
|
||||
Ptr uintptr
|
||||
Ch byte
|
||||
Intval uint32
|
||||
}
|
||||
|
||||
// Our "in-memory" version of the above structure
|
||||
buf := []byte{
|
||||
1, 2, 3, 4, 0, 0, 0, 0, // ptr
|
||||
5, // ch
|
||||
99, 99, 99, // padding
|
||||
78, 6, 0, 0, // x
|
||||
}
|
||||
d := NewDecoder(buf)
|
||||
|
||||
// Decode the structure; if one of these function returns an error,
|
||||
// then subsequent decoder functions will return the zero value.
|
||||
var x myStruct
|
||||
x.Ptr = d.Uintptr()
|
||||
x.Ch = d.Byte()
|
||||
x.Intval = d.Uint32()
|
||||
|
||||
// Note that per the Go language spec:
|
||||
// [...] when evaluating the operands of an expression, assignment,
|
||||
// or return statement, all function calls, method calls, and
|
||||
// (channel) communication operations are evaluated in lexical
|
||||
// left-to-right order
|
||||
//
|
||||
// Since each field is assigned via a function call, one could use the
|
||||
// following snippet to decode the struct.
|
||||
// x := myStruct{
|
||||
// Ptr: d.Uintptr(),
|
||||
// Ch: d.Byte(),
|
||||
// Intval: d.Uint32(),
|
||||
// }
|
||||
//
|
||||
// However, this means that reordering the fields in the initialization
|
||||
// statement–normally a semantically identical operation–would change
|
||||
// the way the structure is parsed. Thus we do it as above with
|
||||
// explicit ordering.
|
||||
|
||||
// After finishing with the decoder, check errors
|
||||
if err := d.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Print the decoder offset and structure
|
||||
fmt.Printf("off=%d struct=%#v\n", d.Offset(), x)
|
||||
// Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e}
|
||||
}
|
||||
151
util/cstruct/cstruct_test.go
Normal file
151
util/cstruct/cstruct_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cstruct
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPadBytes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
offset int
|
||||
size int
|
||||
want int
|
||||
}{
|
||||
// No padding at beginning of structure
|
||||
{0, 1, 0},
|
||||
{0, 2, 0},
|
||||
{0, 4, 0},
|
||||
{0, 8, 0},
|
||||
|
||||
// No padding for single bytes
|
||||
{1, 1, 0},
|
||||
|
||||
// Single byte padding
|
||||
{1, 2, 1},
|
||||
{3, 4, 1},
|
||||
|
||||
// Multi-byte padding
|
||||
{1, 4, 3},
|
||||
{2, 8, 6},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%d_%d_%d", tc.offset, tc.size, tc.want), func(t *testing.T) {
|
||||
got := padBytes(tc.offset, tc.size)
|
||||
if got != tc.want {
|
||||
t.Errorf("got=%d; want=%d", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
t.Run("UnsignedTypes", func(t *testing.T) {
|
||||
dec := func(n int) *Decoder {
|
||||
buf := make([]byte, n)
|
||||
buf[0] = 1
|
||||
|
||||
d := NewDecoder(buf)
|
||||
|
||||
// Use t.Cleanup to perform an assertion on this
|
||||
// decoder after the test code is finished with it.
|
||||
t.Cleanup(func() {
|
||||
if err := d.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
return d
|
||||
}
|
||||
if got := dec(2).Uint16(); got != 1 {
|
||||
t.Errorf("uint16: got=%d; want=1", got)
|
||||
}
|
||||
if got := dec(4).Uint32(); got != 1 {
|
||||
t.Errorf("uint32: got=%d; want=1", got)
|
||||
}
|
||||
if got := dec(8).Uint64(); got != 1 {
|
||||
t.Errorf("uint64: got=%d; want=1", got)
|
||||
}
|
||||
if got := dec(pointerSize / 8).Uintptr(); got != 1 {
|
||||
t.Errorf("uintptr: got=%d; want=1", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SignedTypes", func(t *testing.T) {
|
||||
dec := func(n int) *Decoder {
|
||||
// Make a buffer of the exact size that consists of 0xff bytes
|
||||
buf := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
buf[i] = 0xff
|
||||
}
|
||||
|
||||
d := NewDecoder(buf)
|
||||
|
||||
// Use t.Cleanup to perform an assertion on this
|
||||
// decoder after the test code is finished with it.
|
||||
t.Cleanup(func() {
|
||||
if err := d.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
return d
|
||||
}
|
||||
if got := dec(2).Int16(); got != -1 {
|
||||
t.Errorf("int16: got=%d; want=-1", got)
|
||||
}
|
||||
if got := dec(4).Int32(); got != -1 {
|
||||
t.Errorf("int32: got=%d; want=-1", got)
|
||||
}
|
||||
if got := dec(8).Int64(); got != -1 {
|
||||
t.Errorf("int64: got=%d; want=-1", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsufficientData", func(t *testing.T) {
|
||||
dec := func(n int) *Decoder {
|
||||
// Make a buffer that's too small and contains arbitrary bytes
|
||||
buf := make([]byte, n-1)
|
||||
for i := 0; i < n-1; i++ {
|
||||
buf[i] = 0xAD
|
||||
}
|
||||
|
||||
// Use t.Cleanup to perform an assertion on this
|
||||
// decoder after the test code is finished with it.
|
||||
d := NewDecoder(buf)
|
||||
t.Cleanup(func() {
|
||||
if err := d.Err(); err == nil || !errors.Is(err, io.EOF) {
|
||||
t.Errorf("(n=%d) expected io.EOF; got=%v", n, err)
|
||||
}
|
||||
})
|
||||
return d
|
||||
}
|
||||
|
||||
dec(2).Uint16()
|
||||
dec(4).Uint32()
|
||||
dec(8).Uint64()
|
||||
dec(pointerSize / 8).Uintptr()
|
||||
|
||||
dec(2).Int16()
|
||||
dec(4).Int32()
|
||||
dec(8).Int64()
|
||||
})
|
||||
|
||||
t.Run("Bytes", func(t *testing.T) {
|
||||
d := NewDecoder([]byte("hello worldasdf"))
|
||||
t.Cleanup(func() {
|
||||
if err := d.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
buf := make([]byte, 11)
|
||||
d.Bytes(buf)
|
||||
if got := string(buf); got != "hello world" {
|
||||
t.Errorf("bytes: got=%q; want=%q", got, "hello world")
|
||||
}
|
||||
})
|
||||
}
|
||||
38
util/deephash/debug.go
Normal file
38
util/deephash/debug.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build deephash_debug
|
||||
|
||||
package deephash
|
||||
|
||||
import "fmt"
|
||||
|
||||
func (h *hasher) HashBytes(b []byte) {
|
||||
fmt.Printf("B(%q)+", b)
|
||||
h.Block512.HashBytes(b)
|
||||
}
|
||||
func (h *hasher) HashString(s string) {
|
||||
fmt.Printf("S(%q)+", s)
|
||||
h.Block512.HashString(s)
|
||||
}
|
||||
func (h *hasher) HashUint8(n uint8) {
|
||||
fmt.Printf("U8(%d)+", n)
|
||||
h.Block512.HashUint8(n)
|
||||
}
|
||||
func (h *hasher) HashUint16(n uint16) {
|
||||
fmt.Printf("U16(%d)+", n)
|
||||
h.Block512.HashUint16(n)
|
||||
}
|
||||
func (h *hasher) HashUint32(n uint32) {
|
||||
fmt.Printf("U32(%d)+", n)
|
||||
h.Block512.HashUint32(n)
|
||||
}
|
||||
func (h *hasher) HashUint64(n uint64) {
|
||||
fmt.Printf("U64(%d)+", n)
|
||||
h.Block512.HashUint64(n)
|
||||
}
|
||||
func (h *hasher) Sum(b []byte) []byte {
|
||||
fmt.Println("FIN")
|
||||
return h.Block512.Sum(b)
|
||||
}
|
||||
@@ -6,7 +6,7 @@
|
||||
// without looping. The hash is only valid within the lifetime of a program.
|
||||
// Users should not store the hash on disk or send it over the network.
|
||||
// The hash is sufficiently strong and unique such that
|
||||
// Hash(x) == Hash(y) is an appropriate replacement for x == y.
|
||||
// Hash(&x) == Hash(&y) is an appropriate replacement for x == y.
|
||||
//
|
||||
// The definition of equality is identical to reflect.DeepEqual except:
|
||||
// - Floating-point values are compared based on the raw bits,
|
||||
@@ -24,11 +24,9 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"tailscale.com/util/hashx"
|
||||
)
|
||||
@@ -41,9 +39,10 @@ import (
|
||||
//
|
||||
// The logic below hashes a value by printing it to a hash.Hash.
|
||||
// To be parsable, it assumes that we know the Go type of each value:
|
||||
// * scalar types (e.g., bool or int32) are printed as fixed-width fields.
|
||||
// * list types (e.g., strings, slices, and AppendTo buffers) are prefixed
|
||||
// by a fixed-width length field, followed by the contents of the list.
|
||||
// * scalar types (e.g., bool or int32) are directly printed as their
|
||||
// underlying memory representation.
|
||||
// * list types (e.g., strings and slices) are prefixed by a
|
||||
// fixed-width length field, followed by the contents of the list.
|
||||
// * slices, arrays, and structs print each element/field consecutively.
|
||||
// * interfaces print with a 1-byte prefix indicating whether it is nil.
|
||||
// If non-nil, it is followed by a fixed-width field of the type index,
|
||||
@@ -55,34 +54,46 @@ import (
|
||||
// * maps print with a 1-byte prefix indicating whether the map pointer is
|
||||
// 1) nil, 2) previously seen, or 3) newly seen. Previously seen pointers
|
||||
// are followed by a fixed-width field of the index of the previous pointer.
|
||||
// Newly seen maps are printed as a fixed-width field with the XOR of the
|
||||
// hash of every map entry. With a sufficiently strong hash, this value is
|
||||
// theoretically "parsable" by looking up the hash in a magical map that
|
||||
// returns the set of entries for that given hash.
|
||||
|
||||
// addressableValue is a reflect.Value that is guaranteed to be addressable
|
||||
// such that calling the Addr and Set methods do not panic.
|
||||
//
|
||||
// There is no compile magic that enforces this property,
|
||||
// but rather the need to construct this type makes it easier to examine each
|
||||
// construction site to ensure that this property is upheld.
|
||||
type addressableValue struct{ reflect.Value }
|
||||
|
||||
// newAddressableValue constructs a new addressable value of type t.
|
||||
func newAddressableValue(t reflect.Type) addressableValue {
|
||||
return addressableValue{reflect.New(t).Elem()} // dereferenced pointer is always addressable
|
||||
}
|
||||
|
||||
const scratchSize = 128
|
||||
// Newly seen maps are printed with a fixed-width length field, followed by
|
||||
// a fixed-width field with the XOR of the hash of every map entry.
|
||||
// With a sufficiently strong hash, this value is theoretically "parsable"
|
||||
// by looking up the hash in a magical map that returns the set of entries
|
||||
// for that given hash.
|
||||
|
||||
// hasher is reusable state for hashing a value.
|
||||
// Get one via hasherPool.
|
||||
type hasher struct {
|
||||
hashx.Block512
|
||||
scratch [scratchSize]byte
|
||||
visitStack visitStack
|
||||
}
|
||||
|
||||
var hasherPool = &sync.Pool{
|
||||
New: func() any { return new(hasher) },
|
||||
}
|
||||
|
||||
func (h *hasher) reset() {
|
||||
if h.Block512.Hash == nil {
|
||||
h.Block512.Hash = sha256.New()
|
||||
}
|
||||
h.Block512.Reset()
|
||||
}
|
||||
|
||||
// hashType hashes a reflect.Type.
|
||||
// The hash is only consistent within the lifetime of a program.
|
||||
func (h *hasher) hashType(t reflect.Type) {
|
||||
// This approach relies on reflect.Type always being backed by a unique
|
||||
// *reflect.rtype pointer. A safer approach is to use a global sync.Map
|
||||
// that maps reflect.Type to some arbitrary and unique index.
|
||||
// While safer, it requires global state with memory that can never be GC'd.
|
||||
rtypeAddr := reflect.ValueOf(t).Pointer() // address of *reflect.rtype
|
||||
h.HashUint64(uint64(rtypeAddr))
|
||||
}
|
||||
|
||||
func (h *hasher) sum() (s Sum) {
|
||||
h.Sum(s.sum[:0])
|
||||
return s
|
||||
}
|
||||
|
||||
// Sum is an opaque checksum type that is comparable.
|
||||
type Sum struct {
|
||||
sum [sha256.Size]byte
|
||||
@@ -107,92 +118,57 @@ func initSeed() {
|
||||
seed = uint64(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (h *hasher) Reset() {
|
||||
if h.Block512.Hash == nil {
|
||||
h.Block512.Hash = sha256.New()
|
||||
}
|
||||
h.Block512.Reset()
|
||||
}
|
||||
|
||||
func (h *hasher) sum() (s Sum) {
|
||||
h.Sum(s.sum[:0])
|
||||
return s
|
||||
}
|
||||
|
||||
var hasherPool = &sync.Pool{
|
||||
New: func() any { return new(hasher) },
|
||||
}
|
||||
|
||||
// Hash returns the hash of v.
|
||||
// For performance, this should be a non-nil pointer.
|
||||
func Hash(v any) (s Sum) {
|
||||
func Hash[T any](v *T) Sum {
|
||||
h := hasherPool.Get().(*hasher)
|
||||
defer hasherPool.Put(h)
|
||||
h.Reset()
|
||||
h.reset()
|
||||
seedOnce.Do(initSeed)
|
||||
h.HashUint64(seed)
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.IsValid() {
|
||||
var va addressableValue
|
||||
if rv.Kind() == reflect.Pointer && !rv.IsNil() {
|
||||
va = addressableValue{rv.Elem()} // dereferenced pointer is always addressable
|
||||
} else {
|
||||
va = newAddressableValue(rv.Type())
|
||||
va.Set(rv)
|
||||
}
|
||||
|
||||
// Always treat the Hash input as an interface (it is), including hashing
|
||||
// its type, otherwise two Hash calls of different types could hash to the
|
||||
// same bytes off the different types and get equivalent Sum values. This is
|
||||
// the same thing that we do for reflect.Kind Interface in hashValue, but
|
||||
// the initial reflect.ValueOf from an interface value effectively strips
|
||||
// the interface box off so we have to do it at the top level by hand.
|
||||
h.hashType(va.Type())
|
||||
ti := getTypeInfo(va.Type())
|
||||
ti.hasher()(h, va)
|
||||
// Always treat the Hash input as if it were an interface by including
|
||||
// a hash of the type. This ensures that hashing of two different types
|
||||
// but with the same value structure produces different hashes.
|
||||
t := reflect.TypeOf(v).Elem()
|
||||
h.hashType(t)
|
||||
if v == nil {
|
||||
h.HashUint8(0) // indicates nil
|
||||
} else {
|
||||
h.HashUint8(1) // indicates visiting pointer element
|
||||
p := pointerOf(reflect.ValueOf(v))
|
||||
hash := lookupTypeHasher(t)
|
||||
hash(h, p)
|
||||
}
|
||||
return h.sum()
|
||||
}
|
||||
|
||||
// HasherForType is like Hash, but it returns a Hash func that's specialized for
|
||||
// the provided reflect type, avoiding a map lookup per value.
|
||||
func HasherForType[T any]() func(T) Sum {
|
||||
var zeroT T
|
||||
t := reflect.TypeOf(zeroT)
|
||||
ti := getTypeInfo(t)
|
||||
var tiElem *typeInfo
|
||||
if t.Kind() == reflect.Pointer {
|
||||
tiElem = getTypeInfo(t.Elem())
|
||||
}
|
||||
// HasherForType returns a hash that is specialized for the provided type.
|
||||
func HasherForType[T any]() func(*T) Sum {
|
||||
var v *T
|
||||
seedOnce.Do(initSeed)
|
||||
|
||||
return func(v T) (s Sum) {
|
||||
t := reflect.TypeOf(v).Elem()
|
||||
hash := lookupTypeHasher(t)
|
||||
return func(v *T) (s Sum) {
|
||||
// This logic is identical to Hash, but pull out a few statements.
|
||||
h := hasherPool.Get().(*hasher)
|
||||
defer hasherPool.Put(h)
|
||||
h.Reset()
|
||||
h.reset()
|
||||
h.HashUint64(seed)
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
if rv.IsValid() {
|
||||
if rv.Kind() == reflect.Pointer && !rv.IsNil() {
|
||||
va := addressableValue{rv.Elem()} // dereferenced pointer is always addressable
|
||||
h.hashType(va.Type())
|
||||
tiElem.hasher()(h, va)
|
||||
} else {
|
||||
va := newAddressableValue(rv.Type())
|
||||
va.Set(rv)
|
||||
h.hashType(va.Type())
|
||||
ti.hasher()(h, va)
|
||||
}
|
||||
h.hashType(t)
|
||||
if v == nil {
|
||||
h.HashUint8(0) // indicates nil
|
||||
} else {
|
||||
h.HashUint8(1) // indicates visiting pointer element
|
||||
p := pointerOf(reflect.ValueOf(v))
|
||||
hash(h, p)
|
||||
}
|
||||
return h.sum()
|
||||
}
|
||||
}
|
||||
|
||||
// Update sets last to the hash of v and reports whether its value changed.
|
||||
func Update(last *Sum, v any) (changed bool) {
|
||||
func Update[T any](last *Sum, v *T) (changed bool) {
|
||||
sum := Hash(v)
|
||||
changed = sum != *last
|
||||
if changed {
|
||||
@@ -201,129 +177,29 @@ func Update(last *Sum, v any) (changed bool) {
|
||||
return changed
|
||||
}
|
||||
|
||||
// typeInfo describes properties of a type.
|
||||
//
|
||||
// A non-nil typeInfo is populated into the typeHasher map
|
||||
// when its type is first requested, before its func is created.
|
||||
// Its func field fn is only populated once the type has been created.
|
||||
// This is used for recursive types.
|
||||
type typeInfo struct {
|
||||
rtype reflect.Type
|
||||
isRecursive bool
|
||||
// typeHasherFunc hashes the value pointed at by p for a given type.
|
||||
// For example, if t is a bool, then p is a *bool.
|
||||
// The provided pointer must always be non-nil.
|
||||
type typeHasherFunc func(h *hasher, p pointer)
|
||||
|
||||
// elemTypeInfo is the element type's typeInfo.
|
||||
// It's set when rtype is of Kind Ptr, Slice, Array, Map.
|
||||
elemTypeInfo *typeInfo
|
||||
var typeHasherCache sync.Map // map[reflect.Type]typeHasherFunc
|
||||
|
||||
// keyTypeInfo is the map key type's typeInfo.
|
||||
// It's set when rtype is of Kind Map.
|
||||
keyTypeInfo *typeInfo
|
||||
|
||||
hashFuncOnce sync.Once
|
||||
hashFuncLazy typeHasherFunc // nil until created
|
||||
}
|
||||
|
||||
type typeHasherFunc func(h *hasher, v addressableValue)
|
||||
|
||||
var typeInfoMap sync.Map // map[reflect.Type]*typeInfo
|
||||
var typeInfoMapPopulate sync.Mutex // just for adding to typeInfoMap
|
||||
|
||||
func (ti *typeInfo) hasher() typeHasherFunc {
|
||||
ti.hashFuncOnce.Do(ti.buildHashFuncOnce)
|
||||
return ti.hashFuncLazy
|
||||
}
|
||||
|
||||
func (ti *typeInfo) buildHashFuncOnce() {
|
||||
ti.hashFuncLazy = genTypeHasher(ti)
|
||||
}
|
||||
|
||||
// fieldInfo describes a struct field.
|
||||
type fieldInfo struct {
|
||||
index int // index of field for reflect.Value.Field(n); -1 if invalid
|
||||
typeInfo *typeInfo
|
||||
canMemHash bool
|
||||
offset uintptr // when we can memhash the field
|
||||
size uintptr // when we can memhash the field
|
||||
}
|
||||
|
||||
// mergeContiguousFieldsCopy returns a copy of f with contiguous memhashable fields
|
||||
// merged together. Such fields get a bogus index and fu value.
|
||||
func mergeContiguousFieldsCopy(in []fieldInfo) []fieldInfo {
|
||||
ret := make([]fieldInfo, 0, len(in))
|
||||
var last *fieldInfo
|
||||
for _, f := range in {
|
||||
// Combine two fields if they're both contiguous & memhash-able.
|
||||
if f.canMemHash && last != nil && last.canMemHash && last.offset+last.size == f.offset {
|
||||
last.size += f.size
|
||||
last.index = -1
|
||||
last.typeInfo = nil
|
||||
} else {
|
||||
ret = append(ret, f)
|
||||
last = &ret[len(ret)-1]
|
||||
}
|
||||
func lookupTypeHasher(t reflect.Type) typeHasherFunc {
|
||||
if v, ok := typeHasherCache.Load(t); ok {
|
||||
return v.(typeHasherFunc)
|
||||
}
|
||||
return ret
|
||||
hash := makeTypeHasher(t)
|
||||
v, _ := typeHasherCache.LoadOrStore(t, hash)
|
||||
return v.(typeHasherFunc)
|
||||
}
|
||||
|
||||
// genHashStructFields generates a typeHasherFunc for t, which must be of kind Struct.
|
||||
func genHashStructFields(t reflect.Type) typeHasherFunc {
|
||||
fields := make([]fieldInfo, 0, t.NumField())
|
||||
for i, n := 0, t.NumField(); i < n; i++ {
|
||||
sf := t.Field(i)
|
||||
if sf.Type.Size() == 0 {
|
||||
continue
|
||||
}
|
||||
fields = append(fields, fieldInfo{
|
||||
index: i,
|
||||
typeInfo: getTypeInfo(sf.Type),
|
||||
canMemHash: typeIsMemHashable(sf.Type),
|
||||
offset: sf.Offset,
|
||||
size: sf.Type.Size(),
|
||||
})
|
||||
}
|
||||
fields = mergeContiguousFieldsCopy(fields)
|
||||
return structHasher{fields}.hash
|
||||
}
|
||||
|
||||
type structHasher struct {
|
||||
fields []fieldInfo
|
||||
}
|
||||
|
||||
func (sh structHasher) hash(h *hasher, v addressableValue) {
|
||||
base := v.Addr().UnsafePointer()
|
||||
for _, f := range sh.fields {
|
||||
if f.canMemHash {
|
||||
h.HashBytes(unsafe.Slice((*byte)(unsafe.Pointer(uintptr(base)+f.offset)), f.size))
|
||||
continue
|
||||
}
|
||||
va := addressableValue{v.Field(f.index)} // field is addressable if parent struct is addressable
|
||||
f.typeInfo.hasher()(h, va)
|
||||
}
|
||||
}
|
||||
|
||||
// genHashPtrToMemoryRange returns a hasher where the reflect.Value is a Ptr to
|
||||
// the provided eleType.
|
||||
func genHashPtrToMemoryRange(eleType reflect.Type) typeHasherFunc {
|
||||
size := eleType.Size()
|
||||
return func(h *hasher, v addressableValue) {
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
} else {
|
||||
h.HashUint8(1) // indicates visiting a pointer
|
||||
h.HashBytes(unsafe.Slice((*byte)(v.UnsafePointer()), size))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func genTypeHasher(ti *typeInfo) typeHasherFunc {
|
||||
t := ti.rtype
|
||||
|
||||
func makeTypeHasher(t reflect.Type) typeHasherFunc {
|
||||
// Types with specific hashing.
|
||||
switch t {
|
||||
case timeTimeType:
|
||||
return (*hasher).hashTimev
|
||||
return hashTime
|
||||
case netipAddrType:
|
||||
return (*hasher).hashAddrv
|
||||
return hashAddr
|
||||
}
|
||||
|
||||
// Types that can have their memory representation directly hashed.
|
||||
@@ -333,107 +209,40 @@ func genTypeHasher(ti *typeInfo) typeHasherFunc {
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
return (*hasher).hashString
|
||||
case reflect.Slice:
|
||||
et := t.Elem()
|
||||
if typeIsMemHashable(et) {
|
||||
return (*hasher).hashSliceMem
|
||||
}
|
||||
eti := getTypeInfo(et)
|
||||
return genHashSliceElements(eti)
|
||||
return hashString
|
||||
case reflect.Array:
|
||||
et := t.Elem()
|
||||
eti := getTypeInfo(et)
|
||||
return genHashArray(t, eti)
|
||||
return makeArrayHasher(t)
|
||||
case reflect.Slice:
|
||||
return makeSliceHasher(t)
|
||||
case reflect.Struct:
|
||||
return genHashStructFields(t)
|
||||
return makeStructHasher(t)
|
||||
case reflect.Map:
|
||||
return func(h *hasher, v addressableValue) {
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
if ti.isRecursive {
|
||||
ptr := pointerOf(v)
|
||||
if idx, ok := h.visitStack.seen(ptr); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.visitStack.push(ptr)
|
||||
defer h.visitStack.pop(ptr)
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting a map
|
||||
h.hashMap(v, ti, ti.isRecursive)
|
||||
}
|
||||
return makeMapHasher(t)
|
||||
case reflect.Pointer:
|
||||
et := t.Elem()
|
||||
if typeIsMemHashable(et) {
|
||||
return genHashPtrToMemoryRange(et)
|
||||
}
|
||||
eti := getTypeInfo(et)
|
||||
return func(h *hasher, v addressableValue) {
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
if ti.isRecursive {
|
||||
ptr := pointerOf(v)
|
||||
if idx, ok := h.visitStack.seen(ptr); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.visitStack.push(ptr)
|
||||
defer h.visitStack.pop(ptr)
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting a pointer
|
||||
va := addressableValue{v.Elem()} // dereferenced pointer is always addressable
|
||||
eti.hasher()(h, va)
|
||||
}
|
||||
return makePointerHasher(t)
|
||||
case reflect.Interface:
|
||||
return func(h *hasher, v addressableValue) {
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
va := newAddressableValue(v.Elem().Type())
|
||||
va.Set(v.Elem())
|
||||
|
||||
h.HashUint8(1) // indicates visiting interface value
|
||||
h.hashType(va.Type())
|
||||
ti := getTypeInfo(va.Type())
|
||||
ti.hasher()(h, va)
|
||||
}
|
||||
return makeInterfaceHasher(t)
|
||||
default: // Func, Chan, UnsafePointer
|
||||
return noopHasherFunc
|
||||
return func(*hasher, pointer) {}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hasher) hashString(v addressableValue) {
|
||||
s := v.String()
|
||||
h.HashUint64(uint64(len(s)))
|
||||
h.HashString(s)
|
||||
}
|
||||
|
||||
// hashTimev hashes v, of kind time.Time.
|
||||
func (h *hasher) hashTimev(v addressableValue) {
|
||||
func hashTime(h *hasher, p pointer) {
|
||||
// Include the zone offset (but not the name) to keep
|
||||
// Hash(t1) == Hash(t2) being semantically equivalent to
|
||||
// t1.Format(time.RFC3339Nano) == t2.Format(time.RFC3339Nano).
|
||||
t := *(*time.Time)(v.Addr().UnsafePointer())
|
||||
t := *p.asTime()
|
||||
_, offset := t.Zone()
|
||||
h.HashUint64(uint64(t.Unix()))
|
||||
h.HashUint32(uint32(t.Nanosecond()))
|
||||
h.HashUint32(uint32(offset))
|
||||
}
|
||||
|
||||
// hashAddrv hashes v, of type netip.Addr.
|
||||
func (h *hasher) hashAddrv(v addressableValue) {
|
||||
func hashAddr(h *hasher, p pointer) {
|
||||
// The formatting of netip.Addr covers the
|
||||
// IP version, the address, and the optional zone name (for v6).
|
||||
// This is equivalent to a1.MarshalBinary() == a2.MarshalBinary().
|
||||
ip := *(*netip.Addr)(v.Addr().UnsafePointer())
|
||||
ip := *p.asAddr()
|
||||
switch {
|
||||
case !ip.IsValid():
|
||||
h.HashUint64(0)
|
||||
@@ -451,121 +260,254 @@ func (h *hasher) hashAddrv(v addressableValue) {
|
||||
}
|
||||
}
|
||||
|
||||
func hashString(h *hasher, p pointer) {
|
||||
s := *p.asString()
|
||||
h.HashUint64(uint64(len(s)))
|
||||
h.HashString(s)
|
||||
}
|
||||
|
||||
func makeMemHasher(n uintptr) typeHasherFunc {
|
||||
return func(h *hasher, v addressableValue) {
|
||||
h.HashBytes(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), n))
|
||||
return func(h *hasher, p pointer) {
|
||||
h.HashBytes(p.asMemory(n))
|
||||
}
|
||||
}
|
||||
|
||||
// hashSliceMem hashes v, of kind Slice, with a memhash-able element type.
|
||||
func (h *hasher) hashSliceMem(v addressableValue) {
|
||||
vLen := v.Len()
|
||||
h.HashUint64(uint64(vLen))
|
||||
if vLen == 0 {
|
||||
return
|
||||
func makeArrayHasher(t reflect.Type) typeHasherFunc {
|
||||
var once sync.Once
|
||||
var hashElem typeHasherFunc
|
||||
init := func() {
|
||||
hashElem = lookupTypeHasher(t.Elem())
|
||||
}
|
||||
h.HashBytes(unsafe.Slice((*byte)(v.UnsafePointer()), v.Type().Elem().Size()*uintptr(vLen)))
|
||||
}
|
||||
|
||||
func genHashArrayMem(n int, arraySize uintptr, efu *typeInfo) typeHasherFunc {
|
||||
return func(h *hasher, v addressableValue) {
|
||||
h.HashBytes(unsafe.Slice((*byte)(v.Addr().UnsafePointer()), arraySize))
|
||||
}
|
||||
}
|
||||
|
||||
func genHashArrayElements(n int, eti *typeInfo) typeHasherFunc {
|
||||
return func(h *hasher, v addressableValue) {
|
||||
n := t.Len() // number of array elements
|
||||
nb := t.Elem().Size() // byte size of each array element
|
||||
return func(h *hasher, p pointer) {
|
||||
once.Do(init)
|
||||
for i := 0; i < n; i++ {
|
||||
va := addressableValue{v.Index(i)} // element is addressable if parent array is addressable
|
||||
eti.hasher()(h, va)
|
||||
hashElem(h, p.arrayIndex(i, nb))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func noopHasherFunc(h *hasher, v addressableValue) {}
|
||||
|
||||
func genHashArray(t reflect.Type, eti *typeInfo) typeHasherFunc {
|
||||
if t.Size() == 0 {
|
||||
return noopHasherFunc
|
||||
func makeSliceHasher(t reflect.Type) typeHasherFunc {
|
||||
nb := t.Elem().Size() // byte size of each slice element
|
||||
if typeIsMemHashable(t.Elem()) {
|
||||
return func(h *hasher, p pointer) {
|
||||
pa := p.sliceArray()
|
||||
if pa.isNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting slice
|
||||
n := p.sliceLen()
|
||||
b := pa.asMemory(uintptr(n) * nb)
|
||||
h.HashUint64(uint64(n))
|
||||
h.HashBytes(b)
|
||||
}
|
||||
}
|
||||
et := t.Elem()
|
||||
if typeIsMemHashable(et) {
|
||||
return genHashArrayMem(t.Len(), t.Size(), eti)
|
||||
|
||||
var once sync.Once
|
||||
var hashElem typeHasherFunc
|
||||
init := func() {
|
||||
hashElem = lookupTypeHasher(t.Elem())
|
||||
if typeIsRecursive(t) {
|
||||
hashElemDefault := hashElem
|
||||
hashElem = func(h *hasher, p pointer) {
|
||||
if idx, ok := h.visitStack.seen(p.p); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting slice element
|
||||
h.visitStack.push(p.p)
|
||||
defer h.visitStack.pop(p.p)
|
||||
hashElemDefault(h, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
n := t.Len()
|
||||
return genHashArrayElements(n, eti)
|
||||
}
|
||||
|
||||
func genHashSliceElements(eti *typeInfo) typeHasherFunc {
|
||||
return sliceElementHasher{eti}.hash
|
||||
}
|
||||
|
||||
type sliceElementHasher struct {
|
||||
eti *typeInfo
|
||||
}
|
||||
|
||||
func (seh sliceElementHasher) hash(h *hasher, v addressableValue) {
|
||||
vLen := v.Len()
|
||||
h.HashUint64(uint64(vLen))
|
||||
for i := 0; i < vLen; i++ {
|
||||
va := addressableValue{v.Index(i)} // slice elements are always addressable
|
||||
seh.eti.hasher()(h, va)
|
||||
return func(h *hasher, p pointer) {
|
||||
pa := p.sliceArray()
|
||||
if pa.isNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
once.Do(init)
|
||||
h.HashUint8(1) // indicates visiting slice
|
||||
n := p.sliceLen()
|
||||
h.HashUint64(uint64(n))
|
||||
for i := 0; i < n; i++ {
|
||||
pe := pa.arrayIndex(i, nb)
|
||||
hashElem(h, pe)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTypeInfo(t reflect.Type) *typeInfo {
|
||||
if f, ok := typeInfoMap.Load(t); ok {
|
||||
return f.(*typeInfo)
|
||||
func makeStructHasher(t reflect.Type) typeHasherFunc {
|
||||
type fieldHasher struct {
|
||||
idx int // index of field for reflect.Type.Field(n); negative if memory is directly hashable
|
||||
hash typeHasherFunc // only valid if idx is not negative
|
||||
offset uintptr
|
||||
size uintptr
|
||||
}
|
||||
typeInfoMapPopulate.Lock()
|
||||
defer typeInfoMapPopulate.Unlock()
|
||||
newTypes := map[reflect.Type]*typeInfo{}
|
||||
ti := getTypeInfoLocked(t, newTypes)
|
||||
for t, ti := range newTypes {
|
||||
typeInfoMap.Store(t, ti)
|
||||
var once sync.Once
|
||||
var fields []fieldHasher
|
||||
init := func() {
|
||||
for i, numField := 0, t.NumField(); i < numField; i++ {
|
||||
sf := t.Field(i)
|
||||
f := fieldHasher{i, nil, sf.Offset, sf.Type.Size()}
|
||||
if typeIsMemHashable(sf.Type) {
|
||||
f.idx = -1
|
||||
}
|
||||
|
||||
// Combine with previous field if both contiguous and mem-hashable.
|
||||
if f.idx < 0 && len(fields) > 0 {
|
||||
if last := &fields[len(fields)-1]; last.idx < 0 && last.offset+last.size == f.offset {
|
||||
last.size += f.size
|
||||
continue
|
||||
}
|
||||
}
|
||||
fields = append(fields, f)
|
||||
}
|
||||
|
||||
for i, f := range fields {
|
||||
if f.idx >= 0 {
|
||||
fields[i].hash = lookupTypeHasher(t.Field(f.idx).Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return func(h *hasher, p pointer) {
|
||||
once.Do(init)
|
||||
for _, field := range fields {
|
||||
pf := p.structField(field.idx, field.offset, field.size)
|
||||
if field.idx < 0 {
|
||||
h.HashBytes(pf.asMemory(field.size))
|
||||
} else {
|
||||
field.hash(h, pf)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ti
|
||||
}
|
||||
|
||||
func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *typeInfo {
|
||||
if v, ok := typeInfoMap.Load(t); ok {
|
||||
return v.(*typeInfo)
|
||||
}
|
||||
if ti, ok := incomplete[t]; ok {
|
||||
return ti
|
||||
}
|
||||
ti := &typeInfo{
|
||||
rtype: t,
|
||||
isRecursive: typeIsRecursive(t),
|
||||
}
|
||||
incomplete[t] = ti
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Map:
|
||||
ti.keyTypeInfo = getTypeInfoLocked(t.Key(), incomplete)
|
||||
fallthrough
|
||||
case reflect.Ptr, reflect.Slice, reflect.Array:
|
||||
ti.elemTypeInfo = getTypeInfoLocked(t.Elem(), incomplete)
|
||||
func makeMapHasher(t reflect.Type) typeHasherFunc {
|
||||
var once sync.Once
|
||||
var hashKey, hashValue typeHasherFunc
|
||||
var isRecursive bool
|
||||
init := func() {
|
||||
hashKey = lookupTypeHasher(t.Key())
|
||||
hashValue = lookupTypeHasher(t.Elem())
|
||||
isRecursive = typeIsRecursive(t)
|
||||
}
|
||||
|
||||
return ti
|
||||
return func(h *hasher, p pointer) {
|
||||
v := p.asValue(t).Elem() // reflect.Map kind
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
once.Do(init)
|
||||
if isRecursive {
|
||||
pm := v.UnsafePointer() // underlying pointer of map
|
||||
if idx, ok := h.visitStack.seen(pm); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.visitStack.push(pm)
|
||||
defer h.visitStack.pop(pm)
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting map entries
|
||||
h.HashUint64(uint64(v.Len()))
|
||||
|
||||
mh := mapHasherPool.Get().(*mapHasher)
|
||||
defer mapHasherPool.Put(mh)
|
||||
|
||||
// Hash a map in a sort-free mannar.
|
||||
// It relies on a map being a an unordered set of KV entries.
|
||||
// So long as we hash each KV entry together, we can XOR all the
|
||||
// individual hashes to produce a unique hash for the entire map.
|
||||
k := mh.valKey.get(v.Type().Key())
|
||||
e := mh.valElem.get(v.Type().Elem())
|
||||
mh.sum = Sum{}
|
||||
mh.h.visitStack = h.visitStack // always use the parent's visit stack to avoid cycles
|
||||
for iter := v.MapRange(); iter.Next(); {
|
||||
k.SetIterKey(iter)
|
||||
e.SetIterValue(iter)
|
||||
mh.h.reset()
|
||||
hashKey(&mh.h, pointerOf(k.Addr()))
|
||||
hashValue(&mh.h, pointerOf(e.Addr()))
|
||||
mh.sum.xor(mh.h.sum())
|
||||
}
|
||||
h.HashBytes(mh.sum.sum[:])
|
||||
}
|
||||
}
|
||||
|
||||
func makePointerHasher(t reflect.Type) typeHasherFunc {
|
||||
var once sync.Once
|
||||
var hashElem typeHasherFunc
|
||||
var isRecursive bool
|
||||
init := func() {
|
||||
hashElem = lookupTypeHasher(t.Elem())
|
||||
isRecursive = typeIsRecursive(t)
|
||||
}
|
||||
return func(h *hasher, p pointer) {
|
||||
pe := p.pointerElem()
|
||||
if pe.isNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
once.Do(init)
|
||||
if isRecursive {
|
||||
if idx, ok := h.visitStack.seen(pe.p); ok {
|
||||
h.HashUint8(2) // indicates cycle
|
||||
h.HashUint64(uint64(idx))
|
||||
return
|
||||
}
|
||||
h.visitStack.push(pe.p)
|
||||
defer h.visitStack.pop(pe.p)
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting a pointer element
|
||||
hashElem(h, pe)
|
||||
}
|
||||
}
|
||||
|
||||
func makeInterfaceHasher(t reflect.Type) typeHasherFunc {
|
||||
return func(h *hasher, p pointer) {
|
||||
v := p.asValue(t).Elem() // reflect.Interface kind
|
||||
if v.IsNil() {
|
||||
h.HashUint8(0) // indicates nil
|
||||
return
|
||||
}
|
||||
h.HashUint8(1) // indicates visiting an interface value
|
||||
v = v.Elem()
|
||||
t := v.Type()
|
||||
h.hashType(t)
|
||||
va := reflect.New(t).Elem()
|
||||
va.Set(v)
|
||||
hashElem := lookupTypeHasher(t)
|
||||
hashElem(h, pointerOf(va.Addr()))
|
||||
}
|
||||
}
|
||||
|
||||
type mapHasher struct {
|
||||
h hasher
|
||||
valKey, valElem valueCache // re-usable values for map iteration
|
||||
h hasher
|
||||
valKey valueCache
|
||||
valElem valueCache
|
||||
sum Sum
|
||||
}
|
||||
|
||||
var mapHasherPool = &sync.Pool{
|
||||
New: func() any { return new(mapHasher) },
|
||||
}
|
||||
|
||||
type valueCache map[reflect.Type]addressableValue
|
||||
type valueCache map[reflect.Type]reflect.Value
|
||||
|
||||
func (c *valueCache) get(t reflect.Type) addressableValue {
|
||||
// get returns an addressable reflect.Value for the given type.
|
||||
func (c *valueCache) get(t reflect.Type) reflect.Value {
|
||||
v, ok := (*c)[t]
|
||||
if !ok {
|
||||
v = newAddressableValue(t)
|
||||
v = reflect.New(t).Elem()
|
||||
if *c == nil {
|
||||
*c = make(valueCache)
|
||||
}
|
||||
@@ -573,72 +515,3 @@ func (c *valueCache) get(t reflect.Type) addressableValue {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// hashMap hashes a map in a sort-free manner.
|
||||
// It relies on a map being a functionally an unordered set of KV entries.
|
||||
// So long as we hash each KV entry together, we can XOR all
|
||||
// of the individual hashes to produce a unique hash for the entire map.
|
||||
func (h *hasher) hashMap(v addressableValue, ti *typeInfo, checkCycles bool) {
|
||||
mh := mapHasherPool.Get().(*mapHasher)
|
||||
defer mapHasherPool.Put(mh)
|
||||
|
||||
var sum Sum
|
||||
if v.IsNil() {
|
||||
sum.sum[0] = 1 // something non-zero
|
||||
}
|
||||
|
||||
k := mh.valKey.get(v.Type().Key())
|
||||
e := mh.valElem.get(v.Type().Elem())
|
||||
mh.h.visitStack = h.visitStack // always use the parent's visit stack to avoid cycles
|
||||
for iter := v.MapRange(); iter.Next(); {
|
||||
k.SetIterKey(iter)
|
||||
e.SetIterValue(iter)
|
||||
mh.h.Reset()
|
||||
ti.keyTypeInfo.hasher()(&mh.h, k)
|
||||
ti.elemTypeInfo.hasher()(&mh.h, e)
|
||||
sum.xor(mh.h.sum())
|
||||
}
|
||||
h.HashBytes(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation
|
||||
}
|
||||
|
||||
// visitStack is a stack of pointers visited.
|
||||
// Pointers are pushed onto the stack when visited, and popped when leaving.
|
||||
// The integer value is the depth at which the pointer was visited.
|
||||
// The length of this stack should be zero after every hashing operation.
|
||||
type visitStack map[pointer]int
|
||||
|
||||
func (v visitStack) seen(p pointer) (int, bool) {
|
||||
idx, ok := v[p]
|
||||
return idx, ok
|
||||
}
|
||||
|
||||
func (v *visitStack) push(p pointer) {
|
||||
if *v == nil {
|
||||
*v = make(map[pointer]int)
|
||||
}
|
||||
(*v)[p] = len(*v)
|
||||
}
|
||||
|
||||
func (v visitStack) pop(p pointer) {
|
||||
delete(v, p)
|
||||
}
|
||||
|
||||
// pointer is a thin wrapper over unsafe.Pointer.
|
||||
// We only rely on comparability of pointers; we cannot rely on uintptr since
|
||||
// that would break if Go ever switched to a moving GC.
|
||||
type pointer struct{ p unsafe.Pointer }
|
||||
|
||||
func pointerOf(v addressableValue) pointer {
|
||||
return pointer{unsafe.Pointer(v.Value.Pointer())}
|
||||
}
|
||||
|
||||
// hashType hashes a reflect.Type.
|
||||
// The hash is only consistent within the lifetime of a program.
|
||||
func (h *hasher) hashType(t reflect.Type) {
|
||||
// This approach relies on reflect.Type always being backed by a unique
|
||||
// *reflect.rtype pointer. A safer approach is to use a global sync.Map
|
||||
// that maps reflect.Type to some arbitrary and unique index.
|
||||
// While safer, it requires global state with memory that can never be GC'd.
|
||||
rtypeAddr := reflect.ValueOf(t).Pointer() // address of *reflect.rtype
|
||||
h.HashUint64(uint64(rtypeAddr))
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
qt "github.com/frankban/quicktest"
|
||||
"go4.org/mem"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -94,6 +95,12 @@ func TestHash(t *testing.T) {
|
||||
{in: tuple{scalars{F64: float64(math.NaN())}, scalars{F64: float64(math.NaN())}}, wantEq: true},
|
||||
{in: tuple{scalars{C64: 32 + 32i}, scalars{C64: complex(math.Nextafter32(32, 0), 32)}}, wantEq: false},
|
||||
{in: tuple{scalars{C128: 64 + 64i}, scalars{C128: complex(math.Nextafter(64, 0), 64)}}, wantEq: false},
|
||||
{in: tuple{[]int(nil), []int(nil)}, wantEq: true},
|
||||
{in: tuple{[]int{}, []int(nil)}, wantEq: false},
|
||||
{in: tuple{[]int{}, []int{}}, wantEq: true},
|
||||
{in: tuple{[]string(nil), []string(nil)}, wantEq: true},
|
||||
{in: tuple{[]string{}, []string(nil)}, wantEq: false},
|
||||
{in: tuple{[]string{}, []string{}}, wantEq: true},
|
||||
{in: tuple{[]appendBytes{{}, {0, 0, 0, 0, 0, 0, 0, 1}}, []appendBytes{{}, {0, 0, 0, 0, 0, 0, 0, 1}}}, wantEq: true},
|
||||
{in: tuple{[]appendBytes{{}, {0, 0, 0, 0, 0, 0, 0, 1}}, []appendBytes{{0, 0, 0, 0, 0, 0, 0, 1}, {}}}, wantEq: false},
|
||||
{in: tuple{iface{MyBool(true)}, iface{MyBool(true)}}, wantEq: true},
|
||||
@@ -159,7 +166,7 @@ func TestHash(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
gotEq := Hash(tt.in[0]) == Hash(tt.in[1])
|
||||
gotEq := Hash(&tt.in[0]) == Hash(&tt.in[1])
|
||||
if gotEq != tt.wantEq {
|
||||
t.Errorf("(Hash(%T %v) == Hash(%T %v)) = %v, want %v", tt.in[0], tt.in[0], tt.in[1], tt.in[1], gotEq, tt.wantEq)
|
||||
}
|
||||
@@ -170,11 +177,11 @@ func TestDeepHash(t *testing.T) {
|
||||
// v contains the types of values we care about for our current callers.
|
||||
// Mostly we're just testing that we don't panic on handled types.
|
||||
v := getVal()
|
||||
|
||||
hash1 := Hash(v)
|
||||
t.Logf("hash: %v", hash1)
|
||||
for i := 0; i < 20; i++ {
|
||||
hash2 := Hash(getVal())
|
||||
v := getVal()
|
||||
hash2 := Hash(v)
|
||||
if hash1 != hash2 {
|
||||
t.Error("second hash didn't match")
|
||||
}
|
||||
@@ -185,7 +192,7 @@ func TestDeepHash(t *testing.T) {
|
||||
func TestIssue4868(t *testing.T) {
|
||||
m1 := map[int]string{1: "foo"}
|
||||
m2 := map[int]string{1: "bar"}
|
||||
if Hash(m1) == Hash(m2) {
|
||||
if Hash(&m1) == Hash(&m2) {
|
||||
t.Error("bogus")
|
||||
}
|
||||
}
|
||||
@@ -193,7 +200,7 @@ func TestIssue4868(t *testing.T) {
|
||||
func TestIssue4871(t *testing.T) {
|
||||
m1 := map[string]string{"": "", "x": "foo"}
|
||||
m2 := map[string]string{}
|
||||
if h1, h2 := Hash(m1), Hash(m2); h1 == h2 {
|
||||
if h1, h2 := Hash(&m1), Hash(&m2); h1 == h2 {
|
||||
t.Errorf("bogus: h1=%x, h2=%x", h1, h2)
|
||||
}
|
||||
}
|
||||
@@ -201,7 +208,7 @@ func TestIssue4871(t *testing.T) {
|
||||
func TestNilVsEmptymap(t *testing.T) {
|
||||
m1 := map[string]string(nil)
|
||||
m2 := map[string]string{}
|
||||
if h1, h2 := Hash(m1), Hash(m2); h1 == h2 {
|
||||
if h1, h2 := Hash(&m1), Hash(&m2); h1 == h2 {
|
||||
t.Errorf("bogus: h1=%x, h2=%x", h1, h2)
|
||||
}
|
||||
}
|
||||
@@ -209,7 +216,7 @@ func TestNilVsEmptymap(t *testing.T) {
|
||||
func TestMapFraming(t *testing.T) {
|
||||
m1 := map[string]string{"foo": "", "fo": "o"}
|
||||
m2 := map[string]string{}
|
||||
if h1, h2 := Hash(m1), Hash(m2); h1 == h2 {
|
||||
if h1, h2 := Hash(&m1), Hash(&m2); h1 == h2 {
|
||||
t.Errorf("bogus: h1=%x, h2=%x", h1, h2)
|
||||
}
|
||||
}
|
||||
@@ -217,23 +224,25 @@ func TestMapFraming(t *testing.T) {
|
||||
func TestQuick(t *testing.T) {
|
||||
initSeed()
|
||||
err := quick.Check(func(v, w map[string]string) bool {
|
||||
return (Hash(v) == Hash(w)) == reflect.DeepEqual(v, w)
|
||||
return (Hash(&v) == Hash(&w)) == reflect.DeepEqual(v, w)
|
||||
}, &quick.Config{MaxCount: 1000, Rand: rand.New(rand.NewSource(int64(seed)))})
|
||||
if err != nil {
|
||||
t.Fatalf("seed=%v, err=%v", seed, err)
|
||||
}
|
||||
}
|
||||
|
||||
func getVal() any {
|
||||
return &struct {
|
||||
WGConfig *wgcfg.Config
|
||||
RouterConfig *router.Config
|
||||
MapFQDNAddrs map[dnsname.FQDN][]netip.Addr
|
||||
MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort
|
||||
MapDiscoPublics map[key.DiscoPublic]bool
|
||||
MapResponse *tailcfg.MapResponse
|
||||
FilterMatch filter.Match
|
||||
}{
|
||||
type tailscaleTypes struct {
|
||||
WGConfig *wgcfg.Config
|
||||
RouterConfig *router.Config
|
||||
MapFQDNAddrs map[dnsname.FQDN][]netip.Addr
|
||||
MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort
|
||||
MapDiscoPublics map[key.DiscoPublic]bool
|
||||
MapResponse *tailcfg.MapResponse
|
||||
FilterMatch filter.Match
|
||||
}
|
||||
|
||||
func getVal() *tailscaleTypes {
|
||||
return &tailscaleTypes{
|
||||
&wgcfg.Config{
|
||||
Name: "foo",
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{3: 3}).Unmap(), 5)},
|
||||
@@ -410,13 +419,13 @@ func TestGetTypeHasher(t *testing.T) {
|
||||
{
|
||||
name: "string_slice",
|
||||
val: []string{"foo", "bar"},
|
||||
out: "\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x03\x00\x00\x00\x00\x00\x00\x00bar",
|
||||
out: "\x01\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x03\x00\x00\x00\x00\x00\x00\x00bar",
|
||||
},
|
||||
{
|
||||
name: "int_slice",
|
||||
val: []int{1, 0, -1},
|
||||
out: "\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff",
|
||||
out32: "\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff",
|
||||
out: "\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff",
|
||||
out32: "\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff",
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
@@ -451,8 +460,8 @@ func TestGetTypeHasher(t *testing.T) {
|
||||
{
|
||||
name: "packet_filter",
|
||||
val: filterRules,
|
||||
out: "\x04\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x00\x00\x00\x00\x01\x00\x02\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04 \x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
out32: "\x04\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x01\x00\x02\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04 \x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
out: "\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x00\x00\x00\x00\x01\x00\x02\x00\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04 \x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00",
|
||||
out32: "\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x01\x00\x02\x00\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04 \x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00",
|
||||
},
|
||||
{
|
||||
name: "netip.Addr",
|
||||
@@ -566,19 +575,19 @@ func TestGetTypeHasher(t *testing.T) {
|
||||
{
|
||||
name: "tailcfg.Node",
|
||||
val: &tailcfg.Node{},
|
||||
out: "\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + u64(uint64(time.Time{}.Unix())) + u64(0) + u32(0) + u32(0) + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + u64(uint64(time.Time{}.Unix())) + u32(0) + u32(0) + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
out: "\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\tn\x88\xf1\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\tn\x88\xf1\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rv := reflect.ValueOf(tt.val)
|
||||
va := newAddressableValue(rv.Type())
|
||||
va := reflect.New(rv.Type()).Elem()
|
||||
va.Set(rv)
|
||||
fn := getTypeInfo(va.Type()).hasher()
|
||||
fn := lookupTypeHasher(va.Type())
|
||||
hb := &hashBuffer{Hash: sha256.New()}
|
||||
h := new(hasher)
|
||||
h.Block512.Hash = hb
|
||||
fn(h, va)
|
||||
fn(h, pointerOf(va.Addr()))
|
||||
const ptrSize = 32 << uintptr(^uintptr(0)>>63)
|
||||
if tt.out32 != "" && ptrSize == 32 {
|
||||
tt.out = tt.out32
|
||||
@@ -591,6 +600,138 @@ func TestGetTypeHasher(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceCycle(t *testing.T) {
|
||||
type S []S
|
||||
c := qt.New(t)
|
||||
|
||||
a := make(S, 1) // cylic graph of 1 node
|
||||
a[0] = a
|
||||
b := make(S, 1) // cylic graph of 1 node
|
||||
b[0] = b
|
||||
ha := Hash(&a)
|
||||
hb := Hash(&b)
|
||||
c.Assert(ha, qt.Equals, hb)
|
||||
|
||||
c1 := make(S, 1) // cyclic graph of 2 nodes
|
||||
c2 := make(S, 1) // cyclic graph of 2 nodes
|
||||
c1[0] = c2
|
||||
c2[0] = c1
|
||||
hc1 := Hash(&c1)
|
||||
hc2 := Hash(&c2)
|
||||
c.Assert(hc1, qt.Equals, hc2)
|
||||
c.Assert(ha, qt.Not(qt.Equals), hc1)
|
||||
c.Assert(hb, qt.Not(qt.Equals), hc2)
|
||||
|
||||
c3 := make(S, 1) // graph of 1 node pointing to cyclic graph of 2 nodes
|
||||
c3[0] = c1
|
||||
hc3 := Hash(&c3)
|
||||
c.Assert(hc1, qt.Not(qt.Equals), hc3)
|
||||
|
||||
c4 := make(S, 2) // cyclic graph of 3 nodes
|
||||
c5 := make(S, 2) // cyclic graph of 3 nodes
|
||||
c4[0] = nil
|
||||
c4[1] = c4
|
||||
c5[0] = c5
|
||||
c5[1] = nil
|
||||
hc4 := Hash(&c4)
|
||||
hc5 := Hash(&c5)
|
||||
c.Assert(hc4, qt.Not(qt.Equals), hc5) // cycle occurs through different indexes
|
||||
}
|
||||
|
||||
func TestMapCycle(t *testing.T) {
|
||||
type M map[string]M
|
||||
c := qt.New(t)
|
||||
|
||||
a := make(M) // cylic graph of 1 node
|
||||
a["self"] = a
|
||||
b := make(M) // cylic graph of 1 node
|
||||
b["self"] = b
|
||||
ha := Hash(&a)
|
||||
hb := Hash(&b)
|
||||
c.Assert(ha, qt.Equals, hb)
|
||||
|
||||
c1 := make(M) // cyclic graph of 2 nodes
|
||||
c2 := make(M) // cyclic graph of 2 nodes
|
||||
c1["peer"] = c2
|
||||
c2["peer"] = c1
|
||||
hc1 := Hash(&c1)
|
||||
hc2 := Hash(&c2)
|
||||
c.Assert(hc1, qt.Equals, hc2)
|
||||
c.Assert(ha, qt.Not(qt.Equals), hc1)
|
||||
c.Assert(hb, qt.Not(qt.Equals), hc2)
|
||||
|
||||
c3 := make(M) // graph of 1 node pointing to cyclic graph of 2 nodes
|
||||
c3["child"] = c1
|
||||
hc3 := Hash(&c3)
|
||||
c.Assert(hc1, qt.Not(qt.Equals), hc3)
|
||||
|
||||
c4 := make(M) // cyclic graph of 3 nodes
|
||||
c5 := make(M) // cyclic graph of 3 nodes
|
||||
c4["0"] = nil
|
||||
c4["1"] = c4
|
||||
c5["0"] = c5
|
||||
c5["1"] = nil
|
||||
hc4 := Hash(&c4)
|
||||
hc5 := Hash(&c5)
|
||||
c.Assert(hc4, qt.Not(qt.Equals), hc5) // cycle occurs through different keys
|
||||
}
|
||||
|
||||
func TestPointerCycle(t *testing.T) {
|
||||
type P *P
|
||||
c := qt.New(t)
|
||||
|
||||
a := new(P) // cyclic graph of 1 node
|
||||
*a = a
|
||||
b := new(P) // cyclic graph of 1 node
|
||||
*b = b
|
||||
ha := Hash(&a)
|
||||
hb := Hash(&b)
|
||||
c.Assert(ha, qt.Equals, hb)
|
||||
|
||||
c1 := new(P) // cyclic graph of 2 nodes
|
||||
c2 := new(P) // cyclic graph of 2 nodes
|
||||
*c1 = c2
|
||||
*c2 = c1
|
||||
hc1 := Hash(&c1)
|
||||
hc2 := Hash(&c2)
|
||||
c.Assert(hc1, qt.Equals, hc2)
|
||||
c.Assert(ha, qt.Not(qt.Equals), hc1)
|
||||
c.Assert(hb, qt.Not(qt.Equals), hc2)
|
||||
|
||||
c3 := new(P) // graph of 1 node pointing to cyclic graph of 2 nodes
|
||||
*c3 = c1
|
||||
hc3 := Hash(&c3)
|
||||
c.Assert(hc1, qt.Not(qt.Equals), hc3)
|
||||
}
|
||||
|
||||
func TestInterfaceCycle(t *testing.T) {
|
||||
type I struct{ v any }
|
||||
c := qt.New(t)
|
||||
|
||||
a := new(I) // cyclic graph of 1 node
|
||||
a.v = a
|
||||
b := new(I) // cyclic graph of 1 node
|
||||
b.v = b
|
||||
ha := Hash(&a)
|
||||
hb := Hash(&b)
|
||||
c.Assert(ha, qt.Equals, hb)
|
||||
|
||||
c1 := new(I) // cyclic graph of 2 nodes
|
||||
c2 := new(I) // cyclic graph of 2 nodes
|
||||
c1.v = c2
|
||||
c2.v = c1
|
||||
hc1 := Hash(&c1)
|
||||
hc2 := Hash(&c2)
|
||||
c.Assert(hc1, qt.Equals, hc2)
|
||||
c.Assert(ha, qt.Not(qt.Equals), hc1)
|
||||
c.Assert(hb, qt.Not(qt.Equals), hc2)
|
||||
|
||||
c3 := new(I) // graph of 1 node pointing to cyclic graph of 2 nodes
|
||||
c3.v = c1
|
||||
hc3 := Hash(&c3)
|
||||
c.Assert(hc1, qt.Not(qt.Equals), hc3)
|
||||
}
|
||||
|
||||
var sink Sum
|
||||
|
||||
func BenchmarkHash(b *testing.B) {
|
||||
@@ -647,9 +788,8 @@ var filterRules = []tailcfg.FilterRule{
|
||||
func BenchmarkHashPacketFilter(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
hash := HasherForType[*[]tailcfg.FilterRule]()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sink = hash(&filterRules)
|
||||
sink = Hash(&filterRules)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -662,14 +802,13 @@ func TestHashMapAcyclic(t *testing.T) {
|
||||
|
||||
hb := &hashBuffer{Hash: sha256.New()}
|
||||
|
||||
ti := getTypeInfo(reflect.TypeOf(m))
|
||||
|
||||
hash := lookupTypeHasher(reflect.TypeOf(m))
|
||||
for i := 0; i < 20; i++ {
|
||||
v := addressableValue{reflect.ValueOf(&m).Elem()}
|
||||
va := reflect.ValueOf(&m).Elem()
|
||||
hb.Reset()
|
||||
h := new(hasher)
|
||||
h.Block512.Hash = hb
|
||||
h.hashMap(v, ti, false)
|
||||
hash(h, pointerOf(va.Addr()))
|
||||
h.sum()
|
||||
if got[string(hb.B)] {
|
||||
continue
|
||||
@@ -689,9 +828,9 @@ func TestPrintArray(t *testing.T) {
|
||||
hb := &hashBuffer{Hash: sha256.New()}
|
||||
h := new(hasher)
|
||||
h.Block512.Hash = hb
|
||||
v := addressableValue{reflect.ValueOf(&x).Elem()}
|
||||
ti := getTypeInfo(v.Type())
|
||||
ti.hasher()(h, v)
|
||||
va := reflect.ValueOf(&x).Elem()
|
||||
hash := lookupTypeHasher(va.Type())
|
||||
hash(h, pointerOf(va.Addr()))
|
||||
h.sum()
|
||||
const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f"
|
||||
if got := hb.B; string(got) != want {
|
||||
@@ -707,15 +846,15 @@ func BenchmarkHashMapAcyclic(b *testing.B) {
|
||||
}
|
||||
|
||||
hb := &hashBuffer{Hash: sha256.New()}
|
||||
v := addressableValue{reflect.ValueOf(&m).Elem()}
|
||||
ti := getTypeInfo(v.Type())
|
||||
va := reflect.ValueOf(&m).Elem()
|
||||
hash := lookupTypeHasher(va.Type())
|
||||
|
||||
h := new(hasher)
|
||||
h.Block512.Hash = hb
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h.Reset()
|
||||
h.hashMap(v, ti, false)
|
||||
hash(h, pointerOf(va.Addr()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -731,7 +870,7 @@ func BenchmarkTailcfgNode(b *testing.B) {
|
||||
func TestExhaustive(t *testing.T) {
|
||||
seen := make(map[Sum]bool)
|
||||
for i := 0; i < 100000; i++ {
|
||||
s := Hash(i)
|
||||
s := Hash(&i)
|
||||
if seen[s] {
|
||||
t.Fatalf("hash collision %v", i)
|
||||
}
|
||||
|
||||
115
util/deephash/pointer.go
Normal file
115
util/deephash/pointer.go
Normal file
@@ -0,0 +1,115 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package deephash
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// unsafePointer is an untyped pointer.
|
||||
// It is the caller's responsibility to call operations on the correct type.
|
||||
//
|
||||
// This pointer only ever points to a small set of kinds or types:
|
||||
// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface,
|
||||
// or a pointer to memory that is directly hashable.
|
||||
//
|
||||
// Arrays are represented as pointers to the first element.
|
||||
// Structs are represented as pointers to the first field.
|
||||
// Slices are represented as pointers to a slice header.
|
||||
// Pointers are represented as pointers to a pointer.
|
||||
//
|
||||
// We do not support direct operations on maps and interfaces, and instead
|
||||
// rely on pointer.asValue to convert the pointer back to a reflect.Value.
|
||||
// Conversion of an unsafe.Pointer to reflect.Value guarantees that the
|
||||
// read-only flag in the reflect.Value is unpopulated, avoiding panics that may
|
||||
// othewise have occurred since the value was obtained from an unexported field.
|
||||
type unsafePointer struct{ p unsafe.Pointer }
|
||||
|
||||
func unsafePointerOf(v reflect.Value) unsafePointer {
|
||||
return unsafePointer{v.UnsafePointer()}
|
||||
}
|
||||
func (p unsafePointer) isNil() bool {
|
||||
return p.p == nil
|
||||
}
|
||||
|
||||
// pointerElem dereferences a pointer.
|
||||
// p must point to a pointer.
|
||||
func (p unsafePointer) pointerElem() unsafePointer {
|
||||
return unsafePointer{*(*unsafe.Pointer)(p.p)}
|
||||
}
|
||||
|
||||
// sliceLen returns the slice length.
|
||||
// p must point to a slice.
|
||||
func (p unsafePointer) sliceLen() int {
|
||||
return (*reflect.SliceHeader)(p.p).Len
|
||||
}
|
||||
|
||||
// sliceArray returns a pointer to the underlying slice array.
|
||||
// p must point to a slice.
|
||||
func (p unsafePointer) sliceArray() unsafePointer {
|
||||
return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)}
|
||||
}
|
||||
|
||||
// arrayIndex returns a pointer to an element in the array.
|
||||
// p must point to an array.
|
||||
func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer {
|
||||
return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)}
|
||||
}
|
||||
|
||||
// structField returns a pointer to a field in a struct.
|
||||
// p must pointer to a struct.
|
||||
func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer {
|
||||
return unsafePointer{unsafe.Add(p.p, offset)}
|
||||
}
|
||||
|
||||
// asString casts p as a *string.
|
||||
func (p unsafePointer) asString() *string {
|
||||
return (*string)(p.p)
|
||||
}
|
||||
|
||||
// asTime casts p as a *time.Time.
|
||||
func (p unsafePointer) asTime() *time.Time {
|
||||
return (*time.Time)(p.p)
|
||||
}
|
||||
|
||||
// asAddr casts p as a *netip.Addr.
|
||||
func (p unsafePointer) asAddr() *netip.Addr {
|
||||
return (*netip.Addr)(p.p)
|
||||
}
|
||||
|
||||
// asValue casts p as a reflect.Value containing a pointer to value of t.
|
||||
func (p unsafePointer) asValue(typ reflect.Type) reflect.Value {
|
||||
return reflect.NewAt(typ, p.p)
|
||||
}
|
||||
|
||||
// asMemory returns the memory pointer at by p for a specified size.
|
||||
func (p unsafePointer) asMemory(size uintptr) []byte {
|
||||
return unsafe.Slice((*byte)(p.p), size)
|
||||
}
|
||||
|
||||
// visitStack is a stack of pointers visited.
|
||||
// Pointers are pushed onto the stack when visited, and popped when leaving.
|
||||
// The integer value is the depth at which the pointer was visited.
|
||||
// The length of this stack should be zero after every hashing operation.
|
||||
type visitStack map[unsafe.Pointer]int
|
||||
|
||||
func (v visitStack) seen(p unsafe.Pointer) (int, bool) {
|
||||
idx, ok := v[p]
|
||||
return idx, ok
|
||||
}
|
||||
|
||||
func (v *visitStack) push(p unsafe.Pointer) {
|
||||
if *v == nil {
|
||||
*v = make(map[unsafe.Pointer]int)
|
||||
}
|
||||
(*v)[p] = len(*v)
|
||||
}
|
||||
|
||||
func (v visitStack) pop(p unsafe.Pointer) {
|
||||
delete(v, p)
|
||||
}
|
||||
14
util/deephash/pointer_norace.go
Normal file
14
util/deephash/pointer_norace.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !race
|
||||
|
||||
package deephash
|
||||
|
||||
import "reflect"
|
||||
|
||||
type pointer = unsafePointer
|
||||
|
||||
// pointerOf returns a pointer from v, which must be a reflect.Pointer.
|
||||
func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) }
|
||||
100
util/deephash/pointer_race.go
Normal file
100
util/deephash/pointer_race.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build race
|
||||
|
||||
package deephash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// pointer is a typed pointer that performs safety checks for every operation.
|
||||
type pointer struct {
|
||||
unsafePointer
|
||||
t reflect.Type // type of pointed-at value; may be nil
|
||||
n uintptr // size of valid memory after p
|
||||
}
|
||||
|
||||
// pointerOf returns a pointer from v, which must be a reflect.Pointer.
|
||||
func pointerOf(v reflect.Value) pointer {
|
||||
assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind())
|
||||
te := v.Type().Elem()
|
||||
return pointer{unsafePointerOf(v), te, te.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) pointerElem() pointer {
|
||||
assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind())
|
||||
te := p.t.Elem()
|
||||
return pointer{p.unsafePointer.pointerElem(), te, te.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) sliceLen() int {
|
||||
assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind())
|
||||
return p.unsafePointer.sliceLen()
|
||||
}
|
||||
|
||||
func (p pointer) sliceArray() pointer {
|
||||
assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind())
|
||||
n := p.sliceLen()
|
||||
assert(n >= 0, "got negative slice length %d", n)
|
||||
ta := reflect.ArrayOf(n, p.t.Elem())
|
||||
return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) arrayIndex(index int, size uintptr) pointer {
|
||||
assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind())
|
||||
assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index)
|
||||
assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size)
|
||||
te := p.t.Elem()
|
||||
return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) structField(index int, offset, size uintptr) pointer {
|
||||
assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind())
|
||||
assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset)
|
||||
assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size)
|
||||
if index < 0 {
|
||||
return pointer{p.unsafePointer.structField(index, offset, size), nil, size}
|
||||
}
|
||||
sf := p.t.Field(index)
|
||||
t := sf.Type
|
||||
assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset)
|
||||
assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size)
|
||||
return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()}
|
||||
}
|
||||
|
||||
func (p pointer) asString() *string {
|
||||
assert(p.t.Kind() == reflect.String, "got %v, want string", p.t)
|
||||
return p.unsafePointer.asString()
|
||||
}
|
||||
|
||||
func (p pointer) asTime() *time.Time {
|
||||
assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType)
|
||||
return p.unsafePointer.asTime()
|
||||
}
|
||||
|
||||
func (p pointer) asAddr() *netip.Addr {
|
||||
assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType)
|
||||
return p.unsafePointer.asAddr()
|
||||
}
|
||||
|
||||
func (p pointer) asValue(typ reflect.Type) reflect.Value {
|
||||
assert(p.t == typ, "got %v, want %v", p.t, typ)
|
||||
return p.unsafePointer.asValue(typ)
|
||||
}
|
||||
|
||||
func (p pointer) asMemory(size uintptr) []byte {
|
||||
assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size)
|
||||
return p.unsafePointer.asMemory(size)
|
||||
}
|
||||
|
||||
func assert(b bool, f string, a ...any) {
|
||||
if !b {
|
||||
panic(fmt.Sprintf(f, a...))
|
||||
}
|
||||
}
|
||||
@@ -6,60 +6,58 @@
|
||||
// It is similar to the unix command uniq.
|
||||
package uniq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type badTypeError struct {
|
||||
typ reflect.Type
|
||||
}
|
||||
|
||||
func (e badTypeError) Error() string {
|
||||
return fmt.Sprintf("uniq.ModifySlice's first argument must have type *[]T, got %v", e.typ)
|
||||
}
|
||||
|
||||
// ModifySlice removes adjacent duplicate elements from the slice pointed to by sliceptr.
|
||||
// It adjusts the length of the slice appropriately and zeros the tail.
|
||||
// eq reports whether (*sliceptr)[i] and (*sliceptr)[j] are equal.
|
||||
// ModifySlice does O(len(*sliceptr)) operations.
|
||||
func ModifySlice(sliceptr any, eq func(i, j int) bool) {
|
||||
rvp := reflect.ValueOf(sliceptr)
|
||||
if rvp.Type().Kind() != reflect.Ptr {
|
||||
panic(badTypeError{rvp.Type()})
|
||||
}
|
||||
rv := rvp.Elem()
|
||||
if rv.Type().Kind() != reflect.Slice {
|
||||
panic(badTypeError{rvp.Type()})
|
||||
}
|
||||
|
||||
length := rv.Len()
|
||||
// ModifySlice removes adjacent duplicate elements from the given slice. It
|
||||
// adjusts the length of the slice appropriately and zeros the tail.
|
||||
//
|
||||
// ModifySlice does O(len(*slice)) operations.
|
||||
func ModifySlice[E comparable](slice *[]E) {
|
||||
// Remove duplicates
|
||||
dst := 0
|
||||
for i := 1; i < length; i++ {
|
||||
if eq(dst, i) {
|
||||
for i := 1; i < len(*slice); i++ {
|
||||
if (*slice)[i] == (*slice)[dst] {
|
||||
continue
|
||||
}
|
||||
dst++
|
||||
// slice[dst] = slice[i]
|
||||
rv.Index(dst).Set(rv.Index(i))
|
||||
(*slice)[dst] = (*slice)[i]
|
||||
}
|
||||
|
||||
// Zero out the elements we removed at the end of the slice
|
||||
end := dst + 1
|
||||
var zero reflect.Value
|
||||
if end < length {
|
||||
zero = reflect.Zero(rv.Type().Elem())
|
||||
var zero E
|
||||
for i := end; i < len(*slice); i++ {
|
||||
(*slice)[i] = zero
|
||||
}
|
||||
|
||||
// for i := range slice[end:] {
|
||||
// size[i] = 0/nil/{}
|
||||
// }
|
||||
for i := end; i < length; i++ {
|
||||
// slice[i] = 0/nil/{}
|
||||
rv.Index(i).Set(zero)
|
||||
}
|
||||
|
||||
// slice = slice[:end]
|
||||
if end < length {
|
||||
rv.SetLen(end)
|
||||
// Truncate the slice
|
||||
if end < len(*slice) {
|
||||
*slice = (*slice)[:end]
|
||||
}
|
||||
}
|
||||
|
||||
// ModifySliceFunc is the same as ModifySlice except that it allows using a
|
||||
// custom comparison function.
|
||||
//
|
||||
// eq should report whether the two provided elements are equal.
|
||||
func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) {
|
||||
// Remove duplicates
|
||||
dst := 0
|
||||
for i := 1; i < len(*slice); i++ {
|
||||
if eq((*slice)[dst], (*slice)[i]) {
|
||||
continue
|
||||
}
|
||||
dst++
|
||||
(*slice)[dst] = (*slice)[i]
|
||||
}
|
||||
|
||||
// Zero out the elements we removed at the end of the slice
|
||||
end := dst + 1
|
||||
var zero E
|
||||
for i := end; i < len(*slice); i++ {
|
||||
(*slice)[i] = zero
|
||||
}
|
||||
|
||||
// Truncate the slice
|
||||
if end < len(*slice) {
|
||||
*slice = (*slice)[:end]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,24 +12,25 @@ import (
|
||||
"tailscale.com/util/uniq"
|
||||
)
|
||||
|
||||
func TestModifySlice(t *testing.T) {
|
||||
func runTests(t *testing.T, cb func(*[]uint32)) {
|
||||
tests := []struct {
|
||||
in []int
|
||||
want []int
|
||||
// Use uint32 to be different from an int-typed slice index
|
||||
in []uint32
|
||||
want []uint32
|
||||
}{
|
||||
{in: []int{0, 1, 2}, want: []int{0, 1, 2}},
|
||||
{in: []int{0, 1, 2, 2}, want: []int{0, 1, 2}},
|
||||
{in: []int{0, 0, 1, 2}, want: []int{0, 1, 2}},
|
||||
{in: []int{0, 1, 0, 2}, want: []int{0, 1, 0, 2}},
|
||||
{in: []int{0}, want: []int{0}},
|
||||
{in: []int{0, 0}, want: []int{0}},
|
||||
{in: []int{}, want: []int{}},
|
||||
{in: []uint32{0, 1, 2}, want: []uint32{0, 1, 2}},
|
||||
{in: []uint32{0, 1, 2, 2}, want: []uint32{0, 1, 2}},
|
||||
{in: []uint32{0, 0, 1, 2}, want: []uint32{0, 1, 2}},
|
||||
{in: []uint32{0, 1, 0, 2}, want: []uint32{0, 1, 0, 2}},
|
||||
{in: []uint32{0}, want: []uint32{0}},
|
||||
{in: []uint32{0, 0}, want: []uint32{0}},
|
||||
{in: []uint32{}, want: []uint32{}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
in := make([]int, len(test.in))
|
||||
in := make([]uint32, len(test.in))
|
||||
copy(in, test.in)
|
||||
uniq.ModifySlice(&test.in, func(i, j int) bool { return test.in[i] == test.in[j] })
|
||||
cb(&test.in)
|
||||
if !reflect.DeepEqual(test.in, test.want) {
|
||||
t.Errorf("uniq.Slice(%v) = %v, want %v", in, test.in, test.want)
|
||||
}
|
||||
@@ -43,6 +44,20 @@ func TestModifySlice(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModifySlice(t *testing.T) {
|
||||
runTests(t, func(slice *[]uint32) {
|
||||
uniq.ModifySlice(slice)
|
||||
})
|
||||
}
|
||||
|
||||
func TestModifySliceFunc(t *testing.T) {
|
||||
runTests(t, func(slice *[]uint32) {
|
||||
uniq.ModifySliceFunc(slice, func(i, j uint32) bool {
|
||||
return i == j
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func Benchmark(b *testing.B) {
|
||||
benches := []struct {
|
||||
name string
|
||||
@@ -83,6 +98,6 @@ func benchmark(b *testing.B, size int64, reset func(s []byte)) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
s = s[:size]
|
||||
reset(s)
|
||||
uniq.ModifySlice(&s, func(i, j int) bool { return s[i] == s[j] })
|
||||
uniq.ModifySlice(&s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,8 @@ var Long = ""
|
||||
|
||||
// Short is a short version number for this build, of the form
|
||||
// "x.y.z" for builds stamped in the usual way (see
|
||||
// build_dist.sh in the root) or, for binaries built by hand with the // go tool, it's like Long's dev form, but ending at the date part,
|
||||
// build_dist.sh in the root) or, for binaries built by hand with the
|
||||
// go tool, it's like Long's dev form, but ending at the date part,
|
||||
// of the form "1.23.0-dev20220316".
|
||||
var Short = ""
|
||||
|
||||
|
||||
@@ -2843,7 +2843,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
|
||||
}
|
||||
ports = append(ports, 0)
|
||||
// Remove duplicates. (All duplicates are consecutive.)
|
||||
uniq.ModifySlice(&ports, func(i, j int) bool { return ports[i] == ports[j] })
|
||||
uniq.ModifySlice(&ports)
|
||||
|
||||
var pconn nettype.PacketConn
|
||||
for _, port := range ports {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/flowtrack"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@@ -157,28 +156,8 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) {
|
||||
return
|
||||
}
|
||||
|
||||
// We don't care if this information is perfectly up-to-date, since
|
||||
// we're just using it to print debug information.
|
||||
//
|
||||
// In tailscale/coral#72, we see a goroutine profile with thousands of
|
||||
// goroutines blocked on the mutex in getStatus here, so we wrap it in
|
||||
// a singleflight and accept stale information to reduce contention.
|
||||
st, err, _ := e.getStatusSf.Do(struct{}{}, e.getStatus)
|
||||
|
||||
var ps *ipnstate.PeerStatusLite
|
||||
if err == nil {
|
||||
for _, v := range st.Peers {
|
||||
if v.NodeKey == n.Key {
|
||||
v := v // copy
|
||||
ps = &v
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
e.logf("open-conn-track: timeout opening %v to node %v; failed to get engine status: %v", flow, n.Key.ShortString(), err)
|
||||
return
|
||||
}
|
||||
if ps == nil {
|
||||
ps, found := e.getPeerStatusLite(n.Key)
|
||||
if !found {
|
||||
onlyZeroRoute := true // whether peerForIP returned n only because its /0 route matched
|
||||
for _, r := range n.AllowedIPs {
|
||||
if r.Bits() != 0 && r.Contains(flow.Dst.Addr()) {
|
||||
|
||||
@@ -6,7 +6,6 @@ package wgengine
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go4.org/mem"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"tailscale.com/control/controlclient"
|
||||
@@ -46,13 +44,13 @@ import (
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/deephash"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/singleflight"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/magicsock"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/wgcfg"
|
||||
"tailscale.com/wgengine/wgint"
|
||||
"tailscale.com/wgengine/wglog"
|
||||
)
|
||||
|
||||
@@ -147,10 +145,6 @@ type userspaceEngine struct {
|
||||
// value of the ICMP identifer and sequence number concatenated.
|
||||
icmpEchoResponseCallback map[uint32]func()
|
||||
|
||||
// this singleflight is used to deduplicate calls to getStatus when we
|
||||
// don't care if the data is perfectly fresh
|
||||
getStatusSf singleflight.Group[struct{}, *Status]
|
||||
|
||||
// Lock ordering: magicsock.Conn.mu, wgLock, then mu.
|
||||
}
|
||||
|
||||
@@ -983,132 +977,44 @@ var singleNewline = []byte{'\n'}
|
||||
|
||||
var ErrEngineClosing = errors.New("engine closing; no status")
|
||||
|
||||
func (e *userspaceEngine) getPeerStatusLite(pk key.NodePublic) (status ipnstate.PeerStatusLite, ok bool) {
|
||||
e.wgLock.Lock()
|
||||
if e.wgdev == nil {
|
||||
e.wgLock.Unlock()
|
||||
return status, false
|
||||
}
|
||||
peer := e.wgdev.LookupPeer(pk.Raw32())
|
||||
e.wgLock.Unlock()
|
||||
if peer == nil {
|
||||
return status, false
|
||||
}
|
||||
status.NodeKey = pk
|
||||
status.RxBytes = int64(wgint.PeerRxBytes(peer))
|
||||
status.TxBytes = int64(wgint.PeerTxBytes(peer))
|
||||
status.LastHandshake = time.Unix(0, wgint.PeerLastHandshakeNano(peer))
|
||||
return status, true
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) getStatus() (*Status, error) {
|
||||
// Grab derpConns before acquiring wgLock to not violate lock ordering;
|
||||
// the DERPs method acquires magicsock.Conn.mu.
|
||||
// (See comment in userspaceEngine's declaration.)
|
||||
derpConns := e.magicConn.DERPs()
|
||||
|
||||
e.wgLock.Lock()
|
||||
defer e.wgLock.Unlock()
|
||||
|
||||
e.mu.Lock()
|
||||
closing := e.closing
|
||||
peerKeys := make([]key.NodePublic, len(e.peerSequence))
|
||||
copy(peerKeys, e.peerSequence)
|
||||
e.mu.Unlock()
|
||||
|
||||
if closing {
|
||||
return nil, ErrEngineClosing
|
||||
}
|
||||
|
||||
if e.wgdev == nil {
|
||||
// RequestStatus was invoked before the wgengine has
|
||||
// finished initializing. This can happen when wgegine
|
||||
// provides a callback to magicsock for endpoint
|
||||
// updates that calls RequestStatus.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close() // to unblock writes on error path returns
|
||||
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
// TODO(apenwarr): get rid of silly uapi stuff for in-process comms
|
||||
// FIXME: get notified of status changes instead of polling.
|
||||
err := e.wgdev.IpcGetOperation(pw)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("IpcGetOperation: %w", err)
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
pp := make(map[key.NodePublic]ipnstate.PeerStatusLite)
|
||||
var p ipnstate.PeerStatusLite
|
||||
|
||||
var hst1, hst2, n int64
|
||||
|
||||
br := e.statusBufioReader
|
||||
if br != nil {
|
||||
br.Reset(pr)
|
||||
} else {
|
||||
br = bufio.NewReaderSize(pr, 1<<10)
|
||||
e.statusBufioReader = br
|
||||
}
|
||||
for {
|
||||
line, err := br.ReadSlice('\n')
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading from UAPI pipe: %w", err)
|
||||
}
|
||||
line = bytes.TrimSuffix(line, singleNewline)
|
||||
k := line
|
||||
var v mem.RO
|
||||
if i := bytes.IndexByte(line, '='); i != -1 {
|
||||
k = line[:i]
|
||||
v = mem.B(line[i+1:])
|
||||
}
|
||||
switch string(k) {
|
||||
case "public_key":
|
||||
pk, err := key.ParseNodePublicUntyped(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: invalid key in line %q", line)
|
||||
}
|
||||
if !p.NodeKey.IsZero() {
|
||||
pp[p.NodeKey] = p
|
||||
}
|
||||
p = ipnstate.PeerStatusLite{NodeKey: pk}
|
||||
case "rx_bytes":
|
||||
n, err = mem.ParseInt(v, 10, 64)
|
||||
p.RxBytes = n
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: rx_bytes invalid: %#v", line)
|
||||
}
|
||||
case "tx_bytes":
|
||||
n, err = mem.ParseInt(v, 10, 64)
|
||||
p.TxBytes = n
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: tx_bytes invalid: %#v", line)
|
||||
}
|
||||
case "last_handshake_time_sec":
|
||||
hst1, err = mem.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: hst1 invalid: %#v", line)
|
||||
}
|
||||
case "last_handshake_time_nsec":
|
||||
hst2, err = mem.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: hst2 invalid: %#v", line)
|
||||
}
|
||||
if hst1 != 0 || hst2 != 0 {
|
||||
p.LastHandshake = time.Unix(hst1, hst2)
|
||||
} // else leave at time.IsZero()
|
||||
}
|
||||
}
|
||||
if !p.NodeKey.IsZero() {
|
||||
pp[p.NodeKey] = p
|
||||
}
|
||||
if err := <-errc; err != nil {
|
||||
return nil, fmt.Errorf("IpcGetOperation: %v", err)
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
// Do two passes, one to calculate size and the other to populate.
|
||||
// This code is sensitive to allocations.
|
||||
npeers := 0
|
||||
for _, pk := range e.peerSequence {
|
||||
if _, ok := pp[pk]; ok { // ignore idle ones not in wireguard-go's config
|
||||
npeers++
|
||||
}
|
||||
}
|
||||
|
||||
peers := make([]ipnstate.PeerStatusLite, 0, npeers)
|
||||
for _, pk := range e.peerSequence {
|
||||
if p, ok := pp[pk]; ok { // ignore idle ones not in wireguard-go's config
|
||||
peers = append(peers, p)
|
||||
peers := make([]ipnstate.PeerStatusLite, 0, len(peerKeys))
|
||||
for _, key := range peerKeys {
|
||||
if status, found := e.getPeerStatusLite(key); found {
|
||||
peers = append(peers, status)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,10 +8,12 @@
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/envknob"
|
||||
@@ -38,21 +40,49 @@ func NewWatchdog(e Engine) Engine {
|
||||
return e
|
||||
}
|
||||
return &watchdogEngine{
|
||||
wrap: e,
|
||||
logf: log.Printf,
|
||||
fatalf: log.Fatalf,
|
||||
maxWait: 45 * time.Second,
|
||||
wrap: e,
|
||||
logf: log.Printf,
|
||||
fatalf: log.Fatalf,
|
||||
maxWait: 45 * time.Second,
|
||||
inFlight: make(map[inFlightKey]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
type inFlightKey struct {
|
||||
op string
|
||||
ctr uint64
|
||||
}
|
||||
|
||||
type watchdogEngine struct {
|
||||
wrap Engine
|
||||
logf func(format string, args ...any)
|
||||
fatalf func(format string, args ...any)
|
||||
maxWait time.Duration
|
||||
|
||||
// Track the start time(s) of in-flight operations
|
||||
inFlightMu sync.Mutex
|
||||
inFlight map[inFlightKey]time.Time
|
||||
inFlightCtr uint64
|
||||
}
|
||||
|
||||
func (e *watchdogEngine) watchdogErr(name string, fn func() error) error {
|
||||
// Track all in-flight operations so we can print more useful error
|
||||
// messages on watchdog failure
|
||||
e.inFlightMu.Lock()
|
||||
key := inFlightKey{
|
||||
op: name,
|
||||
ctr: e.inFlightCtr,
|
||||
}
|
||||
e.inFlightCtr++
|
||||
e.inFlight[key] = time.Now()
|
||||
e.inFlightMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
e.inFlightMu.Lock()
|
||||
defer e.inFlightMu.Unlock()
|
||||
delete(e.inFlight, key)
|
||||
}()
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- fn()
|
||||
@@ -66,6 +96,22 @@ func (e *watchdogEngine) watchdogErr(name string, fn func() error) error {
|
||||
buf := new(strings.Builder)
|
||||
pprof.Lookup("goroutine").WriteTo(buf, 1)
|
||||
e.logf("wgengine watchdog stacks:\n%s", buf.String())
|
||||
|
||||
// Collect the list of in-flight operations for debugging.
|
||||
var (
|
||||
b []byte
|
||||
now = time.Now()
|
||||
)
|
||||
e.inFlightMu.Lock()
|
||||
for k, t := range e.inFlight {
|
||||
dur := now.Sub(t).Round(time.Millisecond)
|
||||
b = fmt.Appendf(b, "in-flight[%d]: name=%s duration=%v start=%s\n", k.ctr, k.op, dur, t.Format(time.RFC3339Nano))
|
||||
}
|
||||
e.inFlightMu.Unlock()
|
||||
|
||||
// Print everything as a single string to avoid log
|
||||
// rate limits.
|
||||
e.logf("wgengine watchdog in-flight:\n%s", b)
|
||||
e.fatalf("wgengine: watchdog timeout on %s", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,13 +5,9 @@
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tstest"
|
||||
)
|
||||
|
||||
func TestWatchdog(t *testing.T) {
|
||||
@@ -41,43 +37,4 @@ func TestWatchdog(t *testing.T) {
|
||||
e.RequestStatus()
|
||||
e.Close()
|
||||
})
|
||||
|
||||
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
e, err := NewFakeUserspaceEngine(t.Logf, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(e.Close)
|
||||
usEngine := e.(*userspaceEngine)
|
||||
e = NewWatchdog(e)
|
||||
wdEngine := e.(*watchdogEngine)
|
||||
wdEngine.maxWait = maxWaitMultiple * 100 * time.Millisecond
|
||||
|
||||
logBuf := new(tstest.MemLogger)
|
||||
fatalCalled := make(chan struct{})
|
||||
wdEngine.logf = logBuf.Logf
|
||||
wdEngine.fatalf = func(format string, args ...any) {
|
||||
t.Logf("FATAL: %s", fmt.Sprintf(format, args...))
|
||||
fatalCalled <- struct{}{}
|
||||
}
|
||||
|
||||
usEngine.wgLock.Lock() // blocks getStatus so the watchdog will fire
|
||||
|
||||
go e.RequestStatus()
|
||||
|
||||
select {
|
||||
case <-fatalCalled:
|
||||
if !strings.Contains(logBuf.String(), "goroutine profile: total ") {
|
||||
t.Errorf("fatal called without watchdog stacks, got: %s", logBuf.String())
|
||||
}
|
||||
// expected
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("watchdog failed to fire")
|
||||
}
|
||||
|
||||
usEngine.wgLock.Unlock()
|
||||
wdEngine.fatalf = t.Fatalf
|
||||
wdEngine.Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -16,8 +16,11 @@ import (
|
||||
|
||||
// ToUAPI writes cfg in UAPI format to w.
|
||||
// Prev is the previous device Config.
|
||||
// Prev is required so that we can remove now-defunct peers
|
||||
// without having to remove and re-add all peers.
|
||||
//
|
||||
// Prev is required so that we can remove now-defunct peers without having to
|
||||
// remove and re-add all peers, and so that we can avoid writing information
|
||||
// about peers that have not changed since the previous time we wrote our
|
||||
// Config.
|
||||
func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
var stickyErr error
|
||||
set := func(key, value string) {
|
||||
@@ -49,13 +52,33 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
// Add/configure all new peers.
|
||||
for _, p := range cfg.Peers {
|
||||
oldPeer, wasPresent := old[p.PublicKey]
|
||||
|
||||
// We only want to write the peer header/version if we're about
|
||||
// to change something about that peer, or if it's a new peer.
|
||||
// Figure out up-front whether we'll need to do anything for
|
||||
// this peer, and skip doing anything if not.
|
||||
//
|
||||
// If the peer was not present in the previous config, this
|
||||
// implies that this is a new peer; set all of these to 'true'
|
||||
// to ensure that we're writing the full peer configuration.
|
||||
willSetEndpoint := oldPeer.WGEndpoint != p.PublicKey || !wasPresent
|
||||
willChangeIPs := !cidrsEqual(oldPeer.AllowedIPs, p.AllowedIPs) || !wasPresent
|
||||
willChangeKeepalive := oldPeer.PersistentKeepalive != p.PersistentKeepalive || !wasPresent
|
||||
|
||||
if !willSetEndpoint && !willChangeIPs && !willChangeKeepalive {
|
||||
// It's safe to skip doing anything here; wireguard-go
|
||||
// will not remove a peer if it's unspecified unless we
|
||||
// tell it to (which we do below if necessary).
|
||||
continue
|
||||
}
|
||||
|
||||
setPeer(p)
|
||||
set("protocol_version", "1")
|
||||
|
||||
// Avoid setting endpoints if the correct one is already known
|
||||
// to WireGuard, because doing so generates a bit more work in
|
||||
// calling magicsock's ParseEndpoint for effectively a no-op.
|
||||
if oldPeer.WGEndpoint != p.PublicKey {
|
||||
if willSetEndpoint {
|
||||
if wasPresent {
|
||||
// We had an endpoint, and it was wrong.
|
||||
// By construction, this should not happen.
|
||||
@@ -72,7 +95,7 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
// If p.AllowedIPs is a strict superset of oldPeer.AllowedIPs,
|
||||
// then skip replace_allowed_ips and instead add only
|
||||
// the new ipps with allowed_ip.
|
||||
if !cidrsEqual(oldPeer.AllowedIPs, p.AllowedIPs) {
|
||||
if willChangeIPs {
|
||||
set("replace_allowed_ips", "true")
|
||||
for _, ipp := range p.AllowedIPs {
|
||||
set("allowed_ip", ipp.String())
|
||||
@@ -81,7 +104,7 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
|
||||
// Set PersistentKeepalive after the peer is otherwise configured,
|
||||
// because it can trigger handshake packets.
|
||||
if oldPeer.PersistentKeepalive != p.PersistentKeepalive {
|
||||
if willChangeKeepalive {
|
||||
setUint16("persistent_keepalive_interval", p.PersistentKeepalive)
|
||||
}
|
||||
}
|
||||
|
||||
59
wgengine/wgint/wgint.go
Normal file
59
wgengine/wgint/wgint.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package wgint provides somewhat shady access to wireguard-go
|
||||
// internals that don't (yet) have public APIs.
|
||||
package wgint
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
var (
|
||||
offHandshake = getPeerStatsOffset("lastHandshakeNano")
|
||||
offRxBytes = getPeerStatsOffset("rxBytes")
|
||||
offTxBytes = getPeerStatsOffset("txBytes")
|
||||
)
|
||||
|
||||
func getPeerStatsOffset(name string) uintptr {
|
||||
peerType := reflect.TypeOf(device.Peer{})
|
||||
sf, ok := peerType.FieldByName("stats")
|
||||
if !ok {
|
||||
panic("no stats field in device.Peer")
|
||||
}
|
||||
if sf.Type.Kind() != reflect.Struct {
|
||||
panic("stats field is not a struct")
|
||||
}
|
||||
base := sf.Offset
|
||||
|
||||
st := sf.Type
|
||||
field, ok := st.FieldByName(name)
|
||||
if !ok {
|
||||
panic("no " + name + " field in device.Peer.stats")
|
||||
}
|
||||
if field.Type.Kind() != reflect.Int64 && field.Type.Kind() != reflect.Uint64 {
|
||||
panic("unexpected kind of " + name + " field in device.Peer.stats")
|
||||
}
|
||||
return base + field.Offset
|
||||
}
|
||||
|
||||
// PeerLastHandshakeNano returns the last handshake time in nanoseconds since the
|
||||
// unix epoch.
|
||||
func PeerLastHandshakeNano(peer *device.Peer) int64 {
|
||||
return atomic.LoadInt64((*int64)(unsafe.Add(unsafe.Pointer(peer), offHandshake)))
|
||||
}
|
||||
|
||||
// PeerRxBytes returns the number of bytes received from this peer.
|
||||
func PeerRxBytes(peer *device.Peer) uint64 {
|
||||
return atomic.LoadUint64((*uint64)(unsafe.Add(unsafe.Pointer(peer), offRxBytes)))
|
||||
}
|
||||
|
||||
// PeerTxBytes returns the number of bytes sent to this peer.
|
||||
func PeerTxBytes(peer *device.Peer) uint64 {
|
||||
return atomic.LoadUint64((*uint64)(unsafe.Add(unsafe.Pointer(peer), offTxBytes)))
|
||||
}
|
||||
24
wgengine/wgint/wgint_test.go
Normal file
24
wgengine/wgint/wgint_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgint
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func TestPeerStats(t *testing.T) {
|
||||
peer := new(device.Peer)
|
||||
if got := PeerLastHandshakeNano(peer); got != 0 {
|
||||
t.Errorf("PeerLastHandshakeNano = %v, want 0", got)
|
||||
}
|
||||
if got := PeerRxBytes(peer); got != 0 {
|
||||
t.Errorf("PeerRxBytes = %v, want 0", got)
|
||||
}
|
||||
if got := PeerTxBytes(peer); got != 0 {
|
||||
t.Errorf("PeerTxBytes = %v, want 0", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user