Compare commits
252 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6dbb4425c | ||
|
|
38b0c3eea2 | ||
|
|
43e2efe441 | ||
|
|
fe68841dc7 | ||
|
|
69f3ceeb7c | ||
|
|
990e2f1ae9 | ||
|
|
961b9c8abf | ||
|
|
e298327ba8 | ||
|
|
be3ca5cbfd | ||
|
|
4970e771ab | ||
|
|
3669296cef | ||
|
|
0a42b0a726 | ||
|
|
16a9cfe2f4 | ||
|
|
5066b824a6 | ||
|
|
648268192b | ||
|
|
a89d610a3d | ||
|
|
318751c486 | ||
|
|
4957360ecd | ||
|
|
dd4e06f383 | ||
|
|
c53ab3111d | ||
|
|
05a79d79ae | ||
|
|
48fc9026e9 | ||
|
|
3b0514ef6d | ||
|
|
32ecdea157 | ||
|
|
2545575dd5 | ||
|
|
189d86cce5 | ||
|
|
218de6d530 | ||
|
|
de11f90d9d | ||
|
|
972a42cb33 | ||
|
|
d60917c0f1 | ||
|
|
f26b409bd5 | ||
|
|
6095a9b423 | ||
|
|
f745e1c058 | ||
|
|
ca2428ecaf | ||
|
|
d8e67ca2ab | ||
|
|
f562c35c0d | ||
|
|
f267a7396f | ||
|
|
c06d2a8513 | ||
|
|
bf195cd3d8 | ||
|
|
7cf50f6c84 | ||
|
|
3efc29d39d | ||
|
|
a3e7252ce6 | ||
|
|
5df6be9d38 | ||
|
|
52969bdfb0 | ||
|
|
a6559a8924 | ||
|
|
75e1cc1dd5 | ||
|
|
10ac066013 | ||
|
|
d74c9aa95b | ||
|
|
c976264bd1 | ||
|
|
f3e2b65637 | ||
|
|
380ee76d00 | ||
|
|
891898525c | ||
|
|
1f923124bf | ||
|
|
852136a03c | ||
|
|
65d2537c05 | ||
|
|
8163521c33 | ||
|
|
a2267aae99 | ||
|
|
cdfea347d0 | ||
|
|
44baa3463f | ||
|
|
45578b47f3 | ||
|
|
723b9eecb0 | ||
|
|
df674d4189 | ||
|
|
d361511512 | ||
|
|
19d77ce6a3 | ||
|
|
7ba148e54e | ||
|
|
19867b2b6d | ||
|
|
60f4982f9b | ||
|
|
bcbd41102c | ||
|
|
c3736250a4 | ||
|
|
d9ac2ada45 | ||
|
|
3b36400e35 | ||
|
|
c9e40abfb8 | ||
|
|
23123907c0 | ||
|
|
2f15894a10 | ||
|
|
fa45d606fa | ||
|
|
30bbbe9467 | ||
|
|
6e8f0860af | ||
|
|
969206fe88 | ||
|
|
e589c76e98 | ||
|
|
39ecb37fd6 | ||
|
|
c1d9e41bef | ||
|
|
f98706bdb3 | ||
|
|
61abab999e | ||
|
|
6255ce55df | ||
|
|
88e8456e9b | ||
|
|
1f7b1a4c6c | ||
|
|
b3d65ba943 | ||
|
|
5eedbcedd1 | ||
|
|
0ed9f62ed0 | ||
|
|
977381f9cc | ||
|
|
6c74065053 | ||
|
|
edcbb5394e | ||
|
|
21d1dbfce0 | ||
|
|
7815633821 | ||
|
|
98ffd78251 | ||
|
|
dba9b96908 | ||
|
|
96994ec431 | ||
|
|
0551bec95b | ||
|
|
96d806789f | ||
|
|
248d28671b | ||
|
|
bd59bba8e6 | ||
|
|
a8b95571fb | ||
|
|
de875a4d87 | ||
|
|
ecf5d69c7c | ||
|
|
3984f9be2f | ||
|
|
5280d039c4 | ||
|
|
0d481030f3 | ||
|
|
67ebba90e1 | ||
|
|
ce1b52bb71 | ||
|
|
4b75a27969 | ||
|
|
c1cabe75dc | ||
|
|
724ad13fe1 | ||
|
|
4db60a8436 | ||
|
|
742b8b44a8 | ||
|
|
5c6d8e3053 | ||
|
|
6196b7e658 | ||
|
|
32156330a8 | ||
|
|
c3c607e78a | ||
|
|
cf74e9039e | ||
|
|
0a5ab533c1 | ||
|
|
b9a95e6ce1 | ||
|
|
0fc15dcbd5 | ||
|
|
5132edacf7 | ||
|
|
9fbe8d7cf2 | ||
|
|
c9089c82e8 | ||
|
|
3f74859bb0 | ||
|
|
630379a1d0 | ||
|
|
0ea51872c9 | ||
|
|
9a8700b02a | ||
|
|
9f930ef2bf | ||
|
|
f5f3885b5b | ||
|
|
e9643ae724 | ||
|
|
16b2bbbbbb | ||
|
|
7883e5c5e7 | ||
|
|
6c70cf7222 | ||
|
|
0aea087766 | ||
|
|
73db7e99ab | ||
|
|
d94593e884 | ||
|
|
d7bc4ec029 | ||
|
|
80a14c49c6 | ||
|
|
c53b154171 | ||
|
|
622c0d0cb3 | ||
|
|
1d4f9852a7 | ||
|
|
771eb05bcb | ||
|
|
f2e5da916a | ||
|
|
9cd4e65191 | ||
|
|
97910ce712 | ||
|
|
14b4213c17 | ||
|
|
3f4f1cfe66 | ||
|
|
a477e70632 | ||
|
|
bb1a9e4700 | ||
|
|
23c93da942 | ||
|
|
c52905abaa | ||
|
|
847b6f039b | ||
|
|
57e8931160 | ||
|
|
0f0ed3dca0 | ||
|
|
056fbee4ef | ||
|
|
6233fd7ac3 | ||
|
|
e03cc2ef57 | ||
|
|
275a20f817 | ||
|
|
77e89c4a72 | ||
|
|
710ee88e94 | ||
|
|
77d3ef36f4 | ||
|
|
9b8ca219a1 | ||
|
|
7b3c0bb7f6 | ||
|
|
47b4a19786 | ||
|
|
f7124c7f06 | ||
|
|
92252b0988 | ||
|
|
2d6e84e19e | ||
|
|
9070aacdee | ||
|
|
e96f22e560 | ||
|
|
790ef2bc5f | ||
|
|
eb4eb34f37 | ||
|
|
7ca911a5c6 | ||
|
|
a83ca9e734 | ||
|
|
a975e86bb8 | ||
|
|
72bfea2ece | ||
|
|
6f73f2c15a | ||
|
|
103c06cc68 | ||
|
|
9258d64261 | ||
|
|
23e74a0f7a | ||
|
|
fe50cd0c48 | ||
|
|
b8edb7a5e9 | ||
|
|
0071888a17 | ||
|
|
4732722b87 | ||
|
|
dd43d9bc5f | ||
|
|
3553512a71 | ||
|
|
36e9cb948f | ||
|
|
894e3bfc96 | ||
|
|
19d95e095a | ||
|
|
5bc29e7388 | ||
|
|
2a8e064705 | ||
|
|
a8635784bc | ||
|
|
b87396b5d9 | ||
|
|
c2682553ff | ||
|
|
6fbd1abcd3 | ||
|
|
de5f6d70a8 | ||
|
|
666d404066 | ||
|
|
00ca17edf4 | ||
|
|
53fb25fc2f | ||
|
|
88c305c8af | ||
|
|
d9054da86a | ||
|
|
0ecaf7b5ed | ||
|
|
401e2ec307 | ||
|
|
58c9591a49 | ||
|
|
10368ef4c0 | ||
|
|
c12d87c54b | ||
|
|
c8cf3169ba | ||
|
|
7cbf6ab771 | ||
|
|
5d4415399b | ||
|
|
6757c990a8 | ||
|
|
08a6eeb55a | ||
|
|
d9fd5db1e1 | ||
|
|
abd79ea368 | ||
|
|
15a23ce65f | ||
|
|
a036c8c718 | ||
|
|
0371848097 | ||
|
|
4c23b5e4ea | ||
|
|
03aa319762 | ||
|
|
9dd3544e84 | ||
|
|
1f4ccae591 | ||
|
|
a447caebf8 | ||
|
|
50b2e5ffe6 | ||
|
|
8edcab04d5 | ||
|
|
51f421946f | ||
|
|
deb113838e | ||
|
|
280e8884dd | ||
|
|
d05b0500ac | ||
|
|
d1a30be275 | ||
|
|
51d176ecff | ||
|
|
07e02ec9d3 | ||
|
|
511840b1f6 | ||
|
|
5e1ee4be53 | ||
|
|
c3f7733f53 | ||
|
|
5c9ddf5e76 | ||
|
|
2ca2389c5f | ||
|
|
07ca0c1c29 | ||
|
|
39f2fe29f7 | ||
|
|
1cb7dab881 | ||
|
|
e441d3218e | ||
|
|
02231e968e | ||
|
|
6f590f5b52 | ||
|
|
1d2e497d47 | ||
|
|
059b1d10bb | ||
|
|
5e0ff494a5 | ||
|
|
4d599d194f | ||
|
|
b33c86b542 | ||
|
|
b663ab4685 | ||
|
|
5798826990 | ||
|
|
e01a4c50ba | ||
|
|
5a32f8e181 | ||
|
|
484b7fc9a3 |
2
.github/workflows/cross-darwin.yml
vendored
2
.github/workflows/cross-darwin.yml
vendored
@@ -3,7 +3,7 @@ name: Darwin-Cross
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
2
.github/workflows/cross-freebsd.yml
vendored
2
.github/workflows/cross-freebsd.yml
vendored
@@ -3,7 +3,7 @@ name: FreeBSD-Cross
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
2
.github/workflows/cross-openbsd.yml
vendored
2
.github/workflows/cross-openbsd.yml
vendored
@@ -3,7 +3,7 @@ name: OpenBSD-Cross
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
2
.github/workflows/cross-windows.yml
vendored
2
.github/workflows/cross-windows.yml
vendored
@@ -3,7 +3,7 @@ name: Windows-Cross
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
2
.github/workflows/license.yml
vendored
2
.github/workflows/license.yml
vendored
@@ -3,7 +3,7 @@ name: license
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
2
.github/workflows/linux.yml
vendored
2
.github/workflows/linux.yml
vendored
@@ -3,7 +3,7 @@ name: Linux
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
48
.github/workflows/linux32.yml
vendored
Normal file
48
.github/workflows/linux32.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
name: Linux 32-bit
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
if: "!contains(github.event.head_commit.message, '[ci skip]')"
|
||||
|
||||
steps:
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.14
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Basic build
|
||||
run: GOARCH=386 go build ./cmd/...
|
||||
|
||||
- name: Run tests on linux
|
||||
run: GOARCH=386 go test ./...
|
||||
|
||||
- uses: k0kubun/action-slack@v2.0.0
|
||||
with:
|
||||
payload: |
|
||||
{
|
||||
"attachments": [{
|
||||
"text": "${{ job.status }}: ${{ github.workflow }} <https://github.com/${{ github.repository }}/commit/${{ github.sha }}/checks|${{ env.COMMIT_DATE }} #${{ env.COMMIT_NUMBER_OF_DAY }}> " +
|
||||
"(<https://github.com/${{ github.repository }}/commit/${{ github.sha }}|" + "${{ github.sha }}".substring(0, 10) + ">) " +
|
||||
"of ${{ github.repository }}@" + "${{ github.ref }}".split('/').reverse()[0] + " by ${{ github.event.head_commit.committer.name }}",
|
||||
"color": "danger"
|
||||
}]
|
||||
}
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
|
||||
if: failure() && github.event_name == 'push'
|
||||
|
||||
5
.github/workflows/staticcheck.yml
vendored
5
.github/workflows/staticcheck.yml
vendored
@@ -3,7 +3,7 @@ name: staticcheck
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
@@ -21,6 +21,9 @@ jobs:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Print staticcheck version
|
||||
run: go run honnef.co/go/tools/cmd/staticcheck -version
|
||||
|
||||
|
||||
7
Makefile
Normal file
7
Makefile
Normal file
@@ -0,0 +1,7 @@
|
||||
usage:
|
||||
echo "See Makefile"
|
||||
|
||||
check: staticcheck
|
||||
|
||||
staticcheck:
|
||||
go run honnef.co/go/tools/cmd/staticcheck -- $$(go list ./... | grep -v tempfork)
|
||||
@@ -9,20 +9,39 @@
|
||||
package atomicfile // import "tailscale.com/atomicfile"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// WriteFile writes data to filename+some suffix, then renames it
|
||||
// into filename.
|
||||
func WriteFile(filename string, data []byte, perm os.FileMode) error {
|
||||
tmpname := filename + ".new.tmp"
|
||||
if err := ioutil.WriteFile(tmpname, data, perm); err != nil {
|
||||
return fmt.Errorf("%#v: %v", tmpname, err)
|
||||
func WriteFile(filename string, data []byte, perm os.FileMode) (err error) {
|
||||
f, err := ioutil.TempFile(filepath.Dir(filename), filepath.Base(filename)+".tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Rename(tmpname, filename); err != nil {
|
||||
return fmt.Errorf("%#v->%#v: %v", tmpname, filename, err)
|
||||
tmpName := f.Name()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
f.Close()
|
||||
os.Remove(tmpName)
|
||||
}
|
||||
}()
|
||||
if _, err := f.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
if runtime.GOOS != "windows" {
|
||||
if err := f.Chmod(perm); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := f.Sync(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmpName, filename)
|
||||
}
|
||||
|
||||
264
cmd/cloner/cloner.go
Normal file
264
cmd/cloner/cloner.go
Normal file
@@ -0,0 +1,264 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Cloner is a tool to automate the creation of a Clone method.
|
||||
//
|
||||
// The result of the Clone method aliases no memory that can be edited
|
||||
// with the original.
|
||||
//
|
||||
// This tool makes lots of implicit assumptions about the types you feed it.
|
||||
// In particular, it can only write relatively "shallow" Clone methods.
|
||||
// That is, if a type contains another named struct type, cloner assumes that
|
||||
// named type will also have a Clone method.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/format"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
var (
|
||||
flagTypes = flag.String("type", "", "comma-separated list of types; required")
|
||||
flagOutput = flag.String("output", "", "output file; required")
|
||||
flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.SetFlags(0)
|
||||
log.SetPrefix("cloner: ")
|
||||
flag.Parse()
|
||||
if len(*flagTypes) == 0 {
|
||||
flag.Usage()
|
||||
os.Exit(2)
|
||||
}
|
||||
typeNames := strings.Split(*flagTypes, ",")
|
||||
|
||||
cfg := &packages.Config{
|
||||
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
|
||||
Tests: false,
|
||||
}
|
||||
if *flagBuildTags != "" {
|
||||
cfg.BuildFlags = []string{"-tags=" + *flagBuildTags}
|
||||
}
|
||||
pkgs, err := packages.Load(cfg, ".")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if len(pkgs) != 1 {
|
||||
log.Fatalf("wrong number of packages: %d", len(pkgs))
|
||||
}
|
||||
pkg := pkgs[0]
|
||||
buf := new(bytes.Buffer)
|
||||
imports := make(map[string]struct{})
|
||||
for _, typeName := range typeNames {
|
||||
found := false
|
||||
for _, file := range pkg.Syntax {
|
||||
//var fbuf bytes.Buffer
|
||||
//ast.Fprint(&fbuf, pkg.Fset, file, nil)
|
||||
//fmt.Println(fbuf.String())
|
||||
|
||||
for _, d := range file.Decls {
|
||||
decl, ok := d.(*ast.GenDecl)
|
||||
if !ok || decl.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
for _, s := range decl.Specs {
|
||||
spec, ok := s.(*ast.TypeSpec)
|
||||
if !ok || spec.Name.Name != typeName {
|
||||
continue
|
||||
}
|
||||
typeNameObj := pkg.TypesInfo.Defs[spec.Name]
|
||||
typ, ok := typeNameObj.Type().(*types.Named)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
pkg := typeNameObj.Pkg()
|
||||
gen(buf, imports, typeName, typ, pkg)
|
||||
}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Fatalf("could not find type %s", typeName)
|
||||
}
|
||||
}
|
||||
|
||||
contents := new(bytes.Buffer)
|
||||
fmt.Fprintf(contents, header, *flagTypes, pkg.Name)
|
||||
fmt.Fprintf(contents, "import (\n")
|
||||
for s := range imports {
|
||||
fmt.Fprintf(contents, "\t%q\n", s)
|
||||
}
|
||||
fmt.Fprintf(contents, ")\n\n")
|
||||
contents.Write(buf.Bytes())
|
||||
|
||||
out, err := format.Source(contents.Bytes())
|
||||
if err != nil {
|
||||
log.Fatalf("%s, in source:\n%s", err, contents.Bytes())
|
||||
}
|
||||
|
||||
output := *flagOutput
|
||||
if output == "" {
|
||||
flag.Usage()
|
||||
os.Exit(2)
|
||||
}
|
||||
if err := ioutil.WriteFile(output, out, 0666); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
const header = `// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Code generated by tailscale.com/cmd/cloner -type %s; DO NOT EDIT.
|
||||
|
||||
package %s
|
||||
|
||||
`
|
||||
|
||||
func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types.Named, thisPkg *types.Package) {
|
||||
pkgQual := func(pkg *types.Package) string {
|
||||
if thisPkg == pkg {
|
||||
return ""
|
||||
}
|
||||
imports[pkg.Path()] = struct{}{}
|
||||
return pkg.Name()
|
||||
}
|
||||
importedName := func(t types.Type) string {
|
||||
return types.TypeString(t, pkgQual)
|
||||
}
|
||||
|
||||
switch t := typ.Underlying().(type) {
|
||||
case *types.Struct:
|
||||
_ = t
|
||||
name := typ.Obj().Name()
|
||||
fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name)
|
||||
fmt.Fprintf(buf, "// The result aliases no memory with the original.\n")
|
||||
fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", name, name)
|
||||
writef := func(format string, args ...interface{}) {
|
||||
fmt.Fprintf(buf, "\t"+format+"\n", args...)
|
||||
}
|
||||
writef("if src == nil {")
|
||||
writef("\treturn nil")
|
||||
writef("}")
|
||||
writef("dst := new(%s)", name)
|
||||
writef("*dst = *src")
|
||||
for i := 0; i < t.NumFields(); i++ {
|
||||
fname := t.Field(i).Name()
|
||||
ft := t.Field(i).Type()
|
||||
if !containsPointers(ft) {
|
||||
continue
|
||||
}
|
||||
if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) {
|
||||
writef("dst.%s = *src.%s.Clone()", fname, fname)
|
||||
continue
|
||||
}
|
||||
switch ft := ft.Underlying().(type) {
|
||||
case *types.Slice:
|
||||
if containsPointers(ft.Elem()) {
|
||||
n := importedName(ft.Elem())
|
||||
writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname)
|
||||
writef("for i := range dst.%s {", fname)
|
||||
if _, isPtr := ft.Elem().(*types.Pointer); isPtr {
|
||||
writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname)
|
||||
} else {
|
||||
writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname)
|
||||
}
|
||||
writef("}")
|
||||
} else {
|
||||
writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname)
|
||||
}
|
||||
case *types.Pointer:
|
||||
if named, _ := ft.Elem().(*types.Named); named != nil && containsPointers(ft.Elem()) {
|
||||
writef("dst.%s = src.%s.Clone()", fname, fname)
|
||||
continue
|
||||
}
|
||||
n := importedName(ft.Elem())
|
||||
writef("if dst.%s != nil {", fname)
|
||||
writef("\tdst.%s = new(%s)", fname, n)
|
||||
writef("\t*dst.%s = *src.%s", fname, fname)
|
||||
if containsPointers(ft.Elem()) {
|
||||
writef("\t" + `panic("TODO pointers in pointers")`)
|
||||
}
|
||||
writef("}")
|
||||
case *types.Map:
|
||||
writef("if dst.%s != nil {", fname)
|
||||
writef("\tdst.%s = map[%s]%s{}", fname, importedName(ft.Key()), importedName(ft.Elem()))
|
||||
if sliceType, isSlice := ft.Elem().(*types.Slice); isSlice {
|
||||
n := importedName(sliceType.Elem())
|
||||
writef("\tfor k := range src.%s {", fname)
|
||||
// use zero-length slice instead of nil to ensure
|
||||
// the key is always copied.
|
||||
writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname)
|
||||
writef("\t}")
|
||||
} else if containsPointers(ft.Elem()) {
|
||||
writef("\t\t" + `panic("TODO map value pointers")`)
|
||||
} else {
|
||||
writef("\tfor k, v := range src.%s {", fname)
|
||||
writef("\t\tdst.%s[k] = v", fname)
|
||||
writef("\t}")
|
||||
}
|
||||
writef("}")
|
||||
case *types.Struct:
|
||||
writef(`panic("TODO struct %s")`, fname)
|
||||
default:
|
||||
writef(`panic(fmt.Sprintf("TODO: %T", ft))`)
|
||||
}
|
||||
}
|
||||
writef("return dst")
|
||||
fmt.Fprintf(buf, "}\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
func hasBasicUnderlying(typ types.Type) bool {
|
||||
switch typ.Underlying().(type) {
|
||||
case *types.Slice, *types.Map:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func containsPointers(typ types.Type) bool {
|
||||
switch typ.String() {
|
||||
case "time.Time":
|
||||
// time.Time contains a pointer that does not need copying
|
||||
return false
|
||||
case "inet.af/netaddr.IP":
|
||||
return false
|
||||
}
|
||||
switch ft := typ.Underlying().(type) {
|
||||
case *types.Array:
|
||||
return containsPointers(ft.Elem())
|
||||
case *types.Chan:
|
||||
return true
|
||||
case *types.Interface:
|
||||
return true // a little too broad
|
||||
case *types.Map:
|
||||
return true
|
||||
case *types.Pointer:
|
||||
return true
|
||||
case *types.Slice:
|
||||
return true
|
||||
case *types.Struct:
|
||||
for i := 0; i < ft.NumFields(); i++ {
|
||||
if containsPointers(ft.Field(i).Type()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
@@ -42,6 +43,8 @@ var (
|
||||
hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443")
|
||||
logCollection = flag.String("logcollection", "", "If non-empty, logtail collection to log to")
|
||||
runSTUN = flag.Bool("stun", false, "also run a STUN server")
|
||||
meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.")
|
||||
meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list")
|
||||
)
|
||||
|
||||
type config struct {
|
||||
@@ -118,6 +121,22 @@ func main() {
|
||||
letsEncrypt := tsweb.IsProd443(*addr)
|
||||
|
||||
s := derp.NewServer(key.Private(cfg.PrivateKey), log.Printf)
|
||||
|
||||
if *meshPSKFile != "" {
|
||||
b, err := ioutil.ReadFile(*meshPSKFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
key := strings.TrimSpace(string(b))
|
||||
if matched, _ := regexp.MatchString(`(?i)^[0-9a-f]{64,}$`, key); !matched {
|
||||
log.Fatalf("key in %s must contain 64+ hex digits", *meshPSKFile)
|
||||
}
|
||||
s.SetMeshKey(key)
|
||||
log.Printf("DERP mesh key configured")
|
||||
}
|
||||
if err := startMesh(s); err != nil {
|
||||
log.Fatalf("startMesh: %v", err)
|
||||
}
|
||||
expvar.Publish("derp", s.ExpVar())
|
||||
|
||||
// Create our own mux so we don't expose /debug/ stuff to the world.
|
||||
@@ -166,7 +185,7 @@ func main() {
|
||||
}
|
||||
httpsrv.TLSConfig = certManager.TLSConfig()
|
||||
go func() {
|
||||
err := http.ListenAndServe(":80", certManager.HTTPHandler(tsweb.Port80Handler{mux}))
|
||||
err := http.ListenAndServe(":80", certManager.HTTPHandler(tsweb.Port80Handler{Main: mux}))
|
||||
if err != nil {
|
||||
if err != http.ErrServerClosed {
|
||||
log.Fatal(err)
|
||||
@@ -185,6 +204,15 @@ func main() {
|
||||
|
||||
func debugHandler(s *derp.Server) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/debug/check" {
|
||||
err := s.ConsistencyCheck()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
} else {
|
||||
io.WriteString(w, "derp.Server ConsistencyCheck okay")
|
||||
}
|
||||
return
|
||||
}
|
||||
f := func(format string, args ...interface{}) { fmt.Fprintf(w, format, args...) }
|
||||
f(`<html><body>
|
||||
<h1>DERP debug</h1>
|
||||
@@ -192,12 +220,14 @@ func debugHandler(s *derp.Server) http.Handler {
|
||||
`)
|
||||
f("<li><b>Hostname:</b> %v</li>\n", *hostname)
|
||||
f("<li><b>Uptime:</b> %v</li>\n", tsweb.Uptime())
|
||||
f("<li><b>Mesh Key:</b> %v</li>\n", s.HasMeshKey())
|
||||
|
||||
f(`<li><a href="/debug/vars">/debug/vars</a> (Go)</li>
|
||||
<li><a href="/debug/varz">/debug/varz</a> (Prometheus)</li>
|
||||
<li><a href="/debug/pprof/">/debug/pprof/</a></li>
|
||||
<li><a href="/debug/pprof/goroutine?debug=1">/debug/pprof/goroutine</a> (collapsed)</li>
|
||||
<li><a href="/debug/pprof/goroutine?debug=2">/debug/pprof/goroutine</a> (full)</li>
|
||||
<li><a href="/debug/check">/debug/check</a> internal consistency check</li>
|
||||
<ul>
|
||||
</html>
|
||||
`)
|
||||
@@ -268,7 +298,7 @@ func serveSTUN() {
|
||||
}
|
||||
}
|
||||
|
||||
var validProdHostname = regexp.MustCompile(`^derp(\d+|\-\w+)?\.tailscale\.com\.?$`)
|
||||
var validProdHostname = regexp.MustCompile(`^derp([^.]*)\.tailscale\.com\.?$`)
|
||||
|
||||
func prodAutocertHostPolicy(_ context.Context, host string) error {
|
||||
if validProdHostname.MatchString(host) {
|
||||
@@ -276,3 +306,16 @@ func prodAutocertHostPolicy(_ context.Context, host string) error {
|
||||
}
|
||||
return errors.New("invalid hostname")
|
||||
}
|
||||
|
||||
func defaultMeshPSKFile() string {
|
||||
try := []string{
|
||||
"/home/derp/keys/derp-mesh.key",
|
||||
filepath.Join(os.Getenv("HOME"), "keys", "derp-mesh.key"),
|
||||
}
|
||||
for _, p := range try {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -17,10 +17,11 @@ func TestProdAutocertHostPolicy(t *testing.T) {
|
||||
{"derp.tailscale.com", true},
|
||||
{"derp.tailscale.com.", true},
|
||||
{"derp1.tailscale.com", true},
|
||||
{"derp1b.tailscale.com", true},
|
||||
{"derp2.tailscale.com", true},
|
||||
{"derp02.tailscale.com", true},
|
||||
{"derp-nyc.tailscale.com", true},
|
||||
{"derpfoo.tailscale.com", false},
|
||||
{"derpfoo.tailscale.com", true},
|
||||
{"derp02.bar.tailscale.com", false},
|
||||
{"example.net", false},
|
||||
}
|
||||
|
||||
45
cmd/derper/mesh.go
Normal file
45
cmd/derper/mesh.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
func startMesh(s *derp.Server) error {
|
||||
if *meshWith == "" {
|
||||
return nil
|
||||
}
|
||||
if !s.HasMeshKey() {
|
||||
return errors.New("--mesh-with requires --mesh-psk-file")
|
||||
}
|
||||
for _, host := range strings.Split(*meshWith, ",") {
|
||||
if err := startMeshWithHost(s, host); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func startMeshWithHost(s *derp.Server, host string) error {
|
||||
logf := logger.WithPrefix(log.Printf, fmt.Sprintf("mesh(%q): ", host))
|
||||
c, err := derphttp.NewClient(s.PrivateKey(), "https://"+host+"/derp", logf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.MeshKey = s.MeshKey()
|
||||
add := func(k key.Public) { s.AddPacketForwarder(k, c) }
|
||||
remove := func(k key.Public) { s.RemovePacketForwarder(k, c) }
|
||||
go c.RunWatchConnectionLoop(s.PublicKey(), add, remove)
|
||||
return nil
|
||||
}
|
||||
123
cmd/tailscale/cli/cli.go
Normal file
123
cmd/tailscale/cli/cli.go
Normal file
@@ -0,0 +1,123 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package cli contains the cmd/tailscale CLI code in a package that can be included
|
||||
// in other wrapper binaries such as the Mac and Windows clients.
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/peterbourgon/ff/v2/ffcli"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/paths"
|
||||
"tailscale.com/safesocket"
|
||||
)
|
||||
|
||||
// ActLikeCLI reports whether a GUI application should act like the
|
||||
// CLI based on os.Args, GOOS, the context the process is running in
|
||||
// (pty, parent PID), etc.
|
||||
func ActLikeCLI() bool {
|
||||
if len(os.Args) < 2 {
|
||||
return false
|
||||
}
|
||||
switch os.Args[1] {
|
||||
case "up", "status", "netcheck", "version",
|
||||
"-V", "--version", "-h", "--help":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Run runs the CLI. The args do not include the binary name.
|
||||
func Run(args []string) error {
|
||||
if len(args) == 1 && (args[0] == "-V" || args[0] == "--version") {
|
||||
args = []string{"version"}
|
||||
}
|
||||
|
||||
rootfs := flag.NewFlagSet("tailscale", flag.ExitOnError)
|
||||
rootfs.StringVar(&rootArgs.socket, "socket", paths.DefaultTailscaledSocket(), "path to tailscaled's unix socket")
|
||||
|
||||
rootCmd := &ffcli.Command{
|
||||
Name: "tailscale",
|
||||
ShortUsage: "tailscale subcommand [flags]",
|
||||
ShortHelp: "The easiest, most secure way to use WireGuard.",
|
||||
LongHelp: strings.TrimSpace(`
|
||||
This CLI is still under active development. Commands and flags will
|
||||
change in the future.
|
||||
`),
|
||||
Subcommands: []*ffcli.Command{
|
||||
upCmd,
|
||||
netcheckCmd,
|
||||
statusCmd,
|
||||
versionCmd,
|
||||
},
|
||||
FlagSet: rootfs,
|
||||
Exec: func(context.Context, []string) error { return flag.ErrHelp },
|
||||
}
|
||||
|
||||
if err := rootCmd.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := rootCmd.Run(context.Background())
|
||||
if err == flag.ErrHelp {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var rootArgs struct {
|
||||
socket string
|
||||
}
|
||||
|
||||
func connect(ctx context.Context) (net.Conn, *ipn.BackendClient, context.Context, context.CancelFunc) {
|
||||
c, err := safesocket.Connect(rootArgs.socket, 41112)
|
||||
if err != nil {
|
||||
if runtime.GOOS != "windows" && rootArgs.socket == "" {
|
||||
log.Fatalf("--socket cannot be empty")
|
||||
}
|
||||
log.Fatalf("Failed to connect to connect to tailscaled. (safesocket.Connect: %v)\n", err)
|
||||
}
|
||||
clientToServer := func(b []byte) {
|
||||
ipn.WriteMsg(c, b)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
interrupt := make(chan os.Signal, 1)
|
||||
signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-interrupt
|
||||
c.Close()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
bc := ipn.NewBackendClient(log.Printf, clientToServer)
|
||||
return c, bc, ctx, cancel
|
||||
}
|
||||
|
||||
// pump receives backend messages on conn and pushes them into bc.
|
||||
func pump(ctx context.Context, bc *ipn.BackendClient, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
for ctx.Err() == nil {
|
||||
msg, err := ipn.ReadMsg(conn)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Printf("ReadMsg: %v\n", err)
|
||||
break
|
||||
}
|
||||
bc.GotNotifyMsg(msg)
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -117,6 +117,7 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error {
|
||||
}
|
||||
fmt.Printf("\t* MappingVariesByDestIP: %v\n", report.MappingVariesByDestIP)
|
||||
fmt.Printf("\t* HairPinning: %v\n", report.HairPinning)
|
||||
fmt.Printf("\t* PortMapping: %v\n", portMapping(report))
|
||||
|
||||
// When DERP latency checking failed,
|
||||
// magicsock will try to pick the DERP server that
|
||||
@@ -142,3 +143,20 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func portMapping(r *netcheck.Report) string {
|
||||
if !r.AnyPortMappingChecked() {
|
||||
return "not checked"
|
||||
}
|
||||
var got []string
|
||||
if r.UPnP.EqualBool(true) {
|
||||
got = append(got, "UPnP")
|
||||
}
|
||||
if r.PMP.EqualBool(true) {
|
||||
got = append(got, "NAT-PMP")
|
||||
}
|
||||
if r.PCP.EqualBool(true) {
|
||||
got = append(got, "PCP")
|
||||
}
|
||||
return strings.Join(got, ", ")
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/peterbourgon/ff/v2/ffcli"
|
||||
"github.com/toqueteos/webbrowser"
|
||||
@@ -24,13 +25,14 @@ import (
|
||||
|
||||
var statusCmd = &ffcli.Command{
|
||||
Name: "status",
|
||||
ShortUsage: "status [-web] [-json]",
|
||||
ShortUsage: "status [-active] [-web] [-json]",
|
||||
ShortHelp: "Show state of tailscaled and its connections",
|
||||
Exec: runStatus,
|
||||
FlagSet: (func() *flag.FlagSet {
|
||||
fs := flag.NewFlagSet("status", flag.ExitOnError)
|
||||
fs.BoolVar(&statusArgs.json, "json", false, "output in JSON format (WARNING: format subject to change)")
|
||||
fs.BoolVar(&statusArgs.web, "web", false, "run webserver with HTML showing status")
|
||||
fs.BoolVar(&statusArgs.active, "active", false, "filter output to only peers with active sessions (not applicable to web mode)")
|
||||
fs.StringVar(&statusArgs.listen, "listen", "127.0.0.1:8384", "listen address; use port 0 for automatic")
|
||||
fs.BoolVar(&statusArgs.browser, "browser", true, "Open a browser in web mode")
|
||||
return fs
|
||||
@@ -42,6 +44,7 @@ var statusArgs struct {
|
||||
web bool // run webserver
|
||||
listen string // in web mode, webserver address to listen on, empty means auto
|
||||
browser bool // in web mode, whether to open browser
|
||||
active bool // in CLI mode, filter output to only peers with active sessions
|
||||
}
|
||||
|
||||
func runStatus(ctx context.Context, args []string) error {
|
||||
@@ -75,6 +78,13 @@ func runStatus(ctx context.Context, args []string) error {
|
||||
return err
|
||||
}
|
||||
if statusArgs.json {
|
||||
if statusArgs.active {
|
||||
for peer, ps := range st.Peer {
|
||||
if !peerActive(ps) {
|
||||
delete(st.Peer, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
j, err := json.MarshalIndent(st, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -119,6 +129,10 @@ func runStatus(ctx context.Context, args []string) error {
|
||||
f := func(format string, a ...interface{}) { fmt.Fprintf(&buf, format, a...) }
|
||||
for _, peer := range st.Peers() {
|
||||
ps := st.Peer[peer]
|
||||
active := peerActive(ps)
|
||||
if statusArgs.active && !active {
|
||||
continue
|
||||
}
|
||||
f("%s %-7s %-15s %-18s tx=%8d rx=%8d ",
|
||||
peer.ShortString(),
|
||||
ps.OS,
|
||||
@@ -127,6 +141,13 @@ func runStatus(ctx context.Context, args []string) error {
|
||||
ps.TxBytes,
|
||||
ps.RxBytes,
|
||||
)
|
||||
relay := ps.Relay
|
||||
if active && relay != "" && ps.CurAddr == "" {
|
||||
relay = "*" + relay + "*"
|
||||
} else {
|
||||
relay = " " + relay
|
||||
}
|
||||
f("%-6s", relay)
|
||||
for i, addr := range ps.Addrs {
|
||||
if i != 0 {
|
||||
f(", ")
|
||||
@@ -142,3 +163,10 @@ func runStatus(ctx context.Context, args []string) error {
|
||||
os.Stdout.Write(buf.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
// peerActive reports whether ps has recent activity.
|
||||
//
|
||||
// TODO: have the server report this bool instead.
|
||||
func peerActive(ps *ipnstate.PeerStatus) bool {
|
||||
return !ps.LastWrite.IsZero() && time.Since(ps.LastWrite) < 2*time.Minute
|
||||
}
|
||||
252
cmd/tailscale/cli/up.go
Normal file
252
cmd/tailscale/cli/up.go
Normal file
@@ -0,0 +1,252 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/peterbourgon/ff/v2/ffcli"
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/wgengine/router"
|
||||
)
|
||||
|
||||
// globalStateKey is the ipn.StateKey that tailscaled loads on
|
||||
// startup.
|
||||
//
|
||||
// We have to support multiple state keys for other OSes (Windows in
|
||||
// particular), but right now Unix daemons run with a single
|
||||
// node-global state. To keep open the option of having per-user state
|
||||
// later, the global state key doesn't look like a username.
|
||||
const globalStateKey = "_daemon"
|
||||
|
||||
var upCmd = &ffcli.Command{
|
||||
Name: "up",
|
||||
ShortUsage: "up [flags]",
|
||||
ShortHelp: "Connect to your Tailscale network",
|
||||
|
||||
LongHelp: strings.TrimSpace(`
|
||||
"tailscale up" connects this machine to your Tailscale network,
|
||||
triggering authentication if necessary.
|
||||
|
||||
The flags passed to this command are specific to this machine. If you don't
|
||||
specify any flags, options are reset to their default.
|
||||
`),
|
||||
FlagSet: (func() *flag.FlagSet {
|
||||
upf := flag.NewFlagSet("up", flag.ExitOnError)
|
||||
upf.StringVar(&upArgs.server, "login-server", "https://login.tailscale.com", "base URL of control server")
|
||||
upf.BoolVar(&upArgs.acceptRoutes, "accept-routes", false, "accept routes advertised by other Tailscale nodes")
|
||||
upf.BoolVar(&upArgs.acceptDNS, "accept-dns", true, "accept DNS configuration from the admin panel")
|
||||
upf.BoolVar(&upArgs.singleRoutes, "host-routes", true, "install host routes to other Tailscale nodes")
|
||||
upf.BoolVar(&upArgs.shieldsUp, "shields-up", false, "don't allow incoming connections")
|
||||
upf.StringVar(&upArgs.advertiseTags, "advertise-tags", "", "ACL tags to request (comma-separated, e.g. eng,montreal,ssh)")
|
||||
upf.StringVar(&upArgs.authKey, "authkey", "", "node authorization key")
|
||||
upf.StringVar(&upArgs.hostname, "hostname", "", "hostname to use instead of the one provided by the OS")
|
||||
upf.BoolVar(&upArgs.enableDERP, "enable-derp", true, "enable the use of DERP servers")
|
||||
if runtime.GOOS == "linux" || isBSD(runtime.GOOS) {
|
||||
upf.StringVar(&upArgs.advertiseRoutes, "advertise-routes", "", "routes to advertise to other nodes (comma-separated, e.g. 10.0.0.0/8,192.168.0.0/24)")
|
||||
}
|
||||
if runtime.GOOS == "linux" {
|
||||
upf.BoolVar(&upArgs.snat, "snat-subnet-routes", true, "source NAT traffic to local routes advertised with -advertise-routes")
|
||||
upf.StringVar(&upArgs.netfilterMode, "netfilter-mode", "on", "netfilter mode (one of on, nodivert, off)")
|
||||
}
|
||||
return upf
|
||||
})(),
|
||||
Exec: runUp,
|
||||
}
|
||||
|
||||
var upArgs struct {
|
||||
server string
|
||||
acceptRoutes bool
|
||||
acceptDNS bool
|
||||
singleRoutes bool
|
||||
shieldsUp bool
|
||||
advertiseRoutes string
|
||||
advertiseTags string
|
||||
enableDERP bool
|
||||
snat bool
|
||||
netfilterMode string
|
||||
authKey string
|
||||
hostname string
|
||||
}
|
||||
|
||||
// parseIPOrCIDR parses an IP address or a CIDR prefix. If the input
|
||||
// is an IP address, it is returned in CIDR form with a /32 mask for
|
||||
// IPv4 or a /128 mask for IPv6.
|
||||
func parseIPOrCIDR(s string) (wgcfg.CIDR, bool) {
|
||||
if strings.Contains(s, "/") {
|
||||
ret, err := wgcfg.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return wgcfg.CIDR{}, false
|
||||
}
|
||||
return ret, true
|
||||
}
|
||||
|
||||
ip, ok := wgcfg.ParseIP(s)
|
||||
if !ok {
|
||||
return wgcfg.CIDR{}, false
|
||||
}
|
||||
if ip.Is4() {
|
||||
return wgcfg.CIDR{IP: ip, Mask: 32}, true
|
||||
} else {
|
||||
return wgcfg.CIDR{IP: ip, Mask: 128}, true
|
||||
}
|
||||
}
|
||||
|
||||
func isBSD(s string) bool {
|
||||
return s == "dragonfly" || s == "freebsd" || s == "netbsd" || s == "openbsd"
|
||||
}
|
||||
|
||||
func warning(format string, args ...interface{}) {
|
||||
fmt.Printf("Warning: "+format+"\n", args...)
|
||||
}
|
||||
|
||||
// checkIPForwarding prints warnings on linux if IP forwarding is not
|
||||
// enabled, or if we were unable to verify the state of IP forwarding.
|
||||
func checkIPForwarding() {
|
||||
var key string
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
key = "net.ipv4.ip_forward"
|
||||
} else if isBSD(runtime.GOOS) {
|
||||
key = "net.inet.ip.forwarding"
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
bs, err := exec.Command("sysctl", "-n", key).Output()
|
||||
if err != nil {
|
||||
warning("couldn't check %s (%v).\nSubnet routes won't work without IP forwarding.", key, err)
|
||||
return
|
||||
}
|
||||
on, err := strconv.ParseBool(string(bytes.TrimSpace(bs)))
|
||||
if err != nil {
|
||||
warning("couldn't parse %s (%v).\nSubnet routes won't work without IP forwarding.", key, err)
|
||||
return
|
||||
}
|
||||
if !on {
|
||||
warning("%s is disabled. Subnet routes won't work.", key)
|
||||
}
|
||||
}
|
||||
|
||||
func runUp(ctx context.Context, args []string) error {
|
||||
if len(args) > 0 {
|
||||
log.Fatalf("too many non-flag arguments: %q", args)
|
||||
}
|
||||
|
||||
var routes []wgcfg.CIDR
|
||||
if upArgs.advertiseRoutes != "" {
|
||||
checkIPForwarding()
|
||||
advroutes := strings.Split(upArgs.advertiseRoutes, ",")
|
||||
for _, s := range advroutes {
|
||||
cidr, ok := parseIPOrCIDR(s)
|
||||
if !ok {
|
||||
log.Fatalf("%q is not a valid IP address or CIDR prefix", s)
|
||||
}
|
||||
routes = append(routes, cidr)
|
||||
}
|
||||
}
|
||||
|
||||
var tags []string
|
||||
if upArgs.advertiseTags != "" {
|
||||
tags = strings.Split(upArgs.advertiseTags, ",")
|
||||
for _, tag := range tags {
|
||||
err := tailcfg.CheckTag(tag)
|
||||
if err != nil {
|
||||
log.Fatalf("tag: %q: %s", tag, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(upArgs.hostname) > 256 {
|
||||
log.Fatalf("hostname too long: %d bytes (max 256)", len(upArgs.hostname))
|
||||
}
|
||||
|
||||
// TODO(apenwarr): fix different semantics between prefs and uflags
|
||||
// TODO(apenwarr): allow setting/using CorpDNS
|
||||
prefs := ipn.NewPrefs()
|
||||
prefs.ControlURL = upArgs.server
|
||||
prefs.WantRunning = true
|
||||
prefs.RouteAll = upArgs.acceptRoutes
|
||||
prefs.CorpDNS = upArgs.acceptDNS
|
||||
prefs.AllowSingleHosts = upArgs.singleRoutes
|
||||
prefs.ShieldsUp = upArgs.shieldsUp
|
||||
prefs.AdvertiseRoutes = routes
|
||||
prefs.AdvertiseTags = tags
|
||||
prefs.NoSNAT = !upArgs.snat
|
||||
prefs.DisableDERP = !upArgs.enableDERP
|
||||
prefs.Hostname = upArgs.hostname
|
||||
if runtime.GOOS == "linux" {
|
||||
switch upArgs.netfilterMode {
|
||||
case "on":
|
||||
prefs.NetfilterMode = router.NetfilterOn
|
||||
case "nodivert":
|
||||
prefs.NetfilterMode = router.NetfilterNoDivert
|
||||
warning("netfilter=nodivert; add iptables calls to ts-* chains manually.")
|
||||
case "off":
|
||||
prefs.NetfilterMode = router.NetfilterOff
|
||||
warning("netfilter=off; configure iptables yourself.")
|
||||
default:
|
||||
log.Fatalf("invalid value --netfilter-mode: %q", upArgs.netfilterMode)
|
||||
}
|
||||
}
|
||||
|
||||
c, bc, ctx, cancel := connect(ctx)
|
||||
defer cancel()
|
||||
|
||||
var printed bool
|
||||
|
||||
bc.SetPrefs(prefs)
|
||||
opts := ipn.Options{
|
||||
StateKey: globalStateKey,
|
||||
AuthKey: upArgs.authKey,
|
||||
Notify: func(n ipn.Notify) {
|
||||
if n.ErrMessage != nil {
|
||||
log.Fatalf("backend error: %v\n", *n.ErrMessage)
|
||||
}
|
||||
if s := n.State; s != nil {
|
||||
switch *s {
|
||||
case ipn.NeedsLogin:
|
||||
printed = true
|
||||
bc.StartLoginInteractive()
|
||||
case ipn.NeedsMachineAuth:
|
||||
printed = true
|
||||
fmt.Fprintf(os.Stderr, "\nTo authorize your machine, visit (as admin):\n\n\t%s/admin/machines\n\n", upArgs.server)
|
||||
case ipn.Starting, ipn.Running:
|
||||
// Done full authentication process
|
||||
if printed {
|
||||
// Only need to print an update if we printed the "please click" message earlier.
|
||||
fmt.Fprintf(os.Stderr, "Success.\n")
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
if url := n.BrowseToURL; url != nil {
|
||||
fmt.Fprintf(os.Stderr, "\nTo authenticate, visit:\n\n\t%s\n\n", *url)
|
||||
}
|
||||
},
|
||||
}
|
||||
// We still have to Start right now because it's the only way to
|
||||
// set up notifications and whatnot. This causes a bunch of churn
|
||||
// every time the CLI touches anything.
|
||||
//
|
||||
// TODO(danderson): redo the frontend/backend API to assume
|
||||
// ephemeral frontends that read/modify/write state, once
|
||||
// Windows/Mac state is moved into backend.
|
||||
bc.Start(opts)
|
||||
pump(ctx, bc, c)
|
||||
|
||||
return nil
|
||||
}
|
||||
69
cmd/tailscale/cli/version.go
Normal file
69
cmd/tailscale/cli/version.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/peterbourgon/ff/v2/ffcli"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
var versionCmd = &ffcli.Command{
|
||||
Name: "version",
|
||||
ShortUsage: "version [flags]",
|
||||
ShortHelp: "Print Tailscale version",
|
||||
FlagSet: (func() *flag.FlagSet {
|
||||
fs := flag.NewFlagSet("version", flag.ExitOnError)
|
||||
fs.BoolVar(&versionArgs.daemon, "daemon", false, "also print local node's daemon version")
|
||||
return fs
|
||||
})(),
|
||||
Exec: runVersion,
|
||||
}
|
||||
|
||||
var versionArgs struct {
|
||||
daemon bool // also check local node's daemon version
|
||||
}
|
||||
|
||||
func runVersion(ctx context.Context, args []string) error {
|
||||
if len(args) > 0 {
|
||||
log.Fatalf("too many non-flag arguments: %q", args)
|
||||
}
|
||||
if !versionArgs.daemon {
|
||||
fmt.Println(version.LONG)
|
||||
return nil
|
||||
}
|
||||
fmt.Printf("Client: %s\n", version.LONG)
|
||||
|
||||
c, bc, ctx, cancel := connect(ctx)
|
||||
defer cancel()
|
||||
|
||||
bc.AllowVersionSkew = true
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
bc.SetNotifyCallback(func(n ipn.Notify) {
|
||||
if n.ErrMessage != nil {
|
||||
log.Fatal(*n.ErrMessage)
|
||||
}
|
||||
if n.Status != nil {
|
||||
fmt.Printf("Daemon: %s\n", n.Version)
|
||||
close(done)
|
||||
}
|
||||
})
|
||||
go pump(ctx, bc, c)
|
||||
|
||||
bc.RequestStatus()
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
@@ -7,306 +7,22 @@
|
||||
package main // import "tailscale.com/cmd/tailscale"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/apenwarr/fixconsole"
|
||||
"github.com/peterbourgon/ff/v2/ffcli"
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/paths"
|
||||
"tailscale.com/safesocket"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/cmd/tailscale/cli"
|
||||
)
|
||||
|
||||
// globalStateKey is the ipn.StateKey that tailscaled loads on
|
||||
// startup.
|
||||
//
|
||||
// We have to support multiple state keys for other OSes (Windows in
|
||||
// particular), but right now Unix daemons run with a single
|
||||
// node-global state. To keep open the option of having per-user state
|
||||
// later, the global state key doesn't look like a username.
|
||||
const globalStateKey = "_daemon"
|
||||
|
||||
var rootArgs struct {
|
||||
socket string
|
||||
}
|
||||
|
||||
func main() {
|
||||
err := fixconsole.FixConsoleIfNeeded()
|
||||
if err != nil {
|
||||
log.Printf("fixConsoleOutput: %v\n", err)
|
||||
}
|
||||
|
||||
upf := flag.NewFlagSet("up", flag.ExitOnError)
|
||||
upf.StringVar(&upArgs.server, "login-server", "https://login.tailscale.com", "base URL of control server")
|
||||
upf.BoolVar(&upArgs.acceptRoutes, "accept-routes", false, "accept routes advertised by other Tailscale nodes")
|
||||
upf.BoolVar(&upArgs.singleRoutes, "host-routes", true, "install host routes to other Tailscale nodes")
|
||||
upf.BoolVar(&upArgs.shieldsUp, "shields-up", false, "don't allow incoming connections")
|
||||
upf.StringVar(&upArgs.advertiseTags, "advertise-tags", "", "ACL tags to request (comma-separated, e.g. eng,montreal,ssh)")
|
||||
upf.StringVar(&upArgs.authKey, "authkey", "", "node authorization key")
|
||||
upf.BoolVar(&upArgs.enableDERP, "enable-derp", true, "enable the use of DERP servers")
|
||||
if runtime.GOOS == "linux" {
|
||||
upf.StringVar(&upArgs.advertiseRoutes, "advertise-routes", "", "routes to advertise to other nodes (comma-separated, e.g. 10.0.0.0/8,192.168.0.0/24)")
|
||||
upf.BoolVar(&upArgs.snat, "snat-subnet-routes", true, "source NAT traffic to local routes advertised with -advertise-routes")
|
||||
upf.StringVar(&upArgs.netfilterMode, "netfilter-mode", "on", "netfilter mode (one of on, nodivert, off)")
|
||||
}
|
||||
upCmd := &ffcli.Command{
|
||||
Name: "up",
|
||||
ShortUsage: "up [flags]",
|
||||
ShortHelp: "Connect to your Tailscale network",
|
||||
|
||||
LongHelp: strings.TrimSpace(`
|
||||
"tailscale up" connects this machine to your Tailscale network,
|
||||
triggering authentication if necessary.
|
||||
|
||||
The flags passed to this command are specific to this machine. If you don't
|
||||
specify any flags, options are reset to their default.
|
||||
`),
|
||||
FlagSet: upf,
|
||||
Exec: runUp,
|
||||
}
|
||||
|
||||
rootfs := flag.NewFlagSet("tailscale", flag.ExitOnError)
|
||||
rootfs.StringVar(&rootArgs.socket, "socket", paths.DefaultTailscaledSocket(), "path to tailscaled's unix socket")
|
||||
|
||||
rootCmd := &ffcli.Command{
|
||||
Name: "tailscale",
|
||||
ShortUsage: "tailscale subcommand [flags]",
|
||||
ShortHelp: "The easiest, most secure way to use WireGuard.",
|
||||
LongHelp: strings.TrimSpace(`
|
||||
This CLI is still under active development. Commands and flags will
|
||||
change in the future.
|
||||
`),
|
||||
Subcommands: []*ffcli.Command{
|
||||
upCmd,
|
||||
netcheckCmd,
|
||||
statusCmd,
|
||||
},
|
||||
FlagSet: rootfs,
|
||||
Exec: func(context.Context, []string) error { return flag.ErrHelp },
|
||||
}
|
||||
|
||||
if err := rootCmd.ParseAndRun(context.Background(), os.Args[1:]); err != nil && err != flag.ErrHelp {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var upArgs struct {
|
||||
server string
|
||||
acceptRoutes bool
|
||||
singleRoutes bool
|
||||
shieldsUp bool
|
||||
advertiseRoutes string
|
||||
advertiseTags string
|
||||
enableDERP bool
|
||||
snat bool
|
||||
netfilterMode string
|
||||
authKey string
|
||||
}
|
||||
|
||||
// parseIPOrCIDR parses an IP address or a CIDR prefix. If the input
|
||||
// is an IP address, it is returned in CIDR form with a /32 mask for
|
||||
// IPv4 or a /128 mask for IPv6.
|
||||
func parseIPOrCIDR(s string) (wgcfg.CIDR, bool) {
|
||||
if strings.Contains(s, "/") {
|
||||
ret, err := wgcfg.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return wgcfg.CIDR{}, false
|
||||
}
|
||||
return ret, true
|
||||
}
|
||||
|
||||
ip, ok := wgcfg.ParseIP(s)
|
||||
if !ok {
|
||||
return wgcfg.CIDR{}, false
|
||||
}
|
||||
if ip.Is4() {
|
||||
return wgcfg.CIDR{ip, 32}, true
|
||||
} else {
|
||||
return wgcfg.CIDR{ip, 128}, true
|
||||
}
|
||||
}
|
||||
|
||||
func warning(format string, args ...interface{}) {
|
||||
fmt.Printf("Warning: "+format+"\n", args...)
|
||||
}
|
||||
|
||||
// checkIPForwarding prints warnings on linux if IP forwarding is not
|
||||
// enabled, or if we were unable to verify the state of IP forwarding.
|
||||
func checkIPForwarding() {
|
||||
if runtime.GOOS != "linux" {
|
||||
return
|
||||
}
|
||||
bs, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward")
|
||||
if err != nil {
|
||||
warning("couldn't check /proc/sys/net/ipv4/ip_forward (%v).\nSubnet routes won't work without IP forwarding.", err)
|
||||
return
|
||||
}
|
||||
on, err := strconv.ParseBool(string(bytes.TrimSpace(bs)))
|
||||
if err != nil {
|
||||
warning("couldn't parse /proc/sys/net/ipv4/ip_forward (%v).\nSubnet routes won't work without IP forwarding.", err)
|
||||
return
|
||||
}
|
||||
if !on {
|
||||
warning("/proc/sys/net/ipv4/ip_forward is disabled. Subnet routes won't work.")
|
||||
}
|
||||
}
|
||||
|
||||
func runUp(ctx context.Context, args []string) error {
|
||||
if len(args) > 0 {
|
||||
log.Fatalf("too many non-flag arguments: %q", args)
|
||||
}
|
||||
|
||||
var routes []wgcfg.CIDR
|
||||
if upArgs.advertiseRoutes != "" {
|
||||
checkIPForwarding()
|
||||
advroutes := strings.Split(upArgs.advertiseRoutes, ",")
|
||||
for _, s := range advroutes {
|
||||
cidr, ok := parseIPOrCIDR(s)
|
||||
if !ok {
|
||||
log.Fatalf("%q is not a valid IP address or CIDR prefix", s)
|
||||
}
|
||||
routes = append(routes, cidr)
|
||||
}
|
||||
}
|
||||
|
||||
var tags []string
|
||||
if upArgs.advertiseTags != "" {
|
||||
tags = strings.Split(upArgs.advertiseTags, ",")
|
||||
for _, tag := range tags {
|
||||
err := tailcfg.CheckTag(tag)
|
||||
if err != nil {
|
||||
log.Fatalf("tag: %q: %s", tag, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(apenwarr): fix different semantics between prefs and uflags
|
||||
// TODO(apenwarr): allow setting/using CorpDNS
|
||||
prefs := ipn.NewPrefs()
|
||||
prefs.ControlURL = upArgs.server
|
||||
prefs.WantRunning = true
|
||||
prefs.RouteAll = upArgs.acceptRoutes
|
||||
prefs.AllowSingleHosts = upArgs.singleRoutes
|
||||
prefs.ShieldsUp = upArgs.shieldsUp
|
||||
prefs.AdvertiseRoutes = routes
|
||||
prefs.AdvertiseTags = tags
|
||||
prefs.NoSNAT = !upArgs.snat
|
||||
prefs.DisableDERP = !upArgs.enableDERP
|
||||
if runtime.GOOS == "linux" {
|
||||
switch upArgs.netfilterMode {
|
||||
case "on":
|
||||
prefs.NetfilterMode = router.NetfilterOn
|
||||
case "nodivert":
|
||||
prefs.NetfilterMode = router.NetfilterNoDivert
|
||||
warning("netfilter=nodivert; add iptables calls to ts-* chains manually.")
|
||||
case "off":
|
||||
prefs.NetfilterMode = router.NetfilterOff
|
||||
warning("netfilter=off; configure iptables yourself.")
|
||||
default:
|
||||
log.Fatalf("invalid value --netfilter-mode: %q", upArgs.netfilterMode)
|
||||
}
|
||||
}
|
||||
|
||||
c, bc, ctx, cancel := connect(ctx)
|
||||
defer cancel()
|
||||
|
||||
var printed bool
|
||||
|
||||
bc.SetPrefs(prefs)
|
||||
opts := ipn.Options{
|
||||
StateKey: globalStateKey,
|
||||
AuthKey: upArgs.authKey,
|
||||
Notify: func(n ipn.Notify) {
|
||||
if n.ErrMessage != nil {
|
||||
log.Fatalf("backend error: %v\n", *n.ErrMessage)
|
||||
}
|
||||
if s := n.State; s != nil {
|
||||
switch *s {
|
||||
case ipn.NeedsLogin:
|
||||
printed = true
|
||||
bc.StartLoginInteractive()
|
||||
case ipn.NeedsMachineAuth:
|
||||
printed = true
|
||||
fmt.Fprintf(os.Stderr, "\nTo authorize your machine, visit (as admin):\n\n\t%s/admin/machines\n\n", upArgs.server)
|
||||
case ipn.Starting, ipn.Running:
|
||||
// Done full authentication process
|
||||
if printed {
|
||||
// Only need to print an update if we printed the "please click" message earlier.
|
||||
fmt.Fprintf(os.Stderr, "Success.\n")
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
if url := n.BrowseToURL; url != nil {
|
||||
fmt.Fprintf(os.Stderr, "\nTo authenticate, visit:\n\n\t%s\n\n", *url)
|
||||
}
|
||||
},
|
||||
}
|
||||
// We still have to Start right now because it's the only way to
|
||||
// set up notifications and whatnot. This causes a bunch of churn
|
||||
// every time the CLI touches anything.
|
||||
//
|
||||
// TODO(danderson): redo the frontend/backend API to assume
|
||||
// ephemeral frontends that read/modify/write state, once
|
||||
// Windows/Mac state is moved into backend.
|
||||
bc.Start(opts)
|
||||
pump(ctx, bc, c)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func connect(ctx context.Context) (net.Conn, *ipn.BackendClient, context.Context, context.CancelFunc) {
|
||||
c, err := safesocket.Connect(rootArgs.socket, 41112)
|
||||
if err != nil {
|
||||
if runtime.GOOS != "windows" && rootArgs.socket == "" {
|
||||
log.Fatalf("--socket cannot be empty")
|
||||
}
|
||||
log.Fatalf("Failed to connect to connect to tailscaled. (safesocket.Connect: %v)\n", err)
|
||||
}
|
||||
clientToServer := func(b []byte) {
|
||||
ipn.WriteMsg(c, b)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
interrupt := make(chan os.Signal, 1)
|
||||
signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-interrupt
|
||||
c.Close()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
bc := ipn.NewBackendClient(log.Printf, clientToServer)
|
||||
return c, bc, ctx, cancel
|
||||
}
|
||||
|
||||
// pump receives backend messages on conn and pushes them into bc.
|
||||
func pump(ctx context.Context, bc *ipn.BackendClient, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
for ctx.Err() == nil {
|
||||
msg, err := ipn.ReadMsg(conn)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Printf("ReadMsg: %v\n", err)
|
||||
break
|
||||
}
|
||||
bc.GotNotifyMsg(msg)
|
||||
if err := cli.Run(os.Args[1:]); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,8 +15,10 @@ import (
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/apenwarr/fixconsole"
|
||||
@@ -27,6 +29,7 @@ import (
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/magicsock"
|
||||
"tailscale.com/wgengine/router"
|
||||
)
|
||||
|
||||
// globalStateKey is the ipn.StateKey that tailscaled loads on
|
||||
@@ -38,6 +41,27 @@ import (
|
||||
// later, the global state key doesn't look like a username.
|
||||
const globalStateKey = "_daemon"
|
||||
|
||||
// defaultTunName returns the default tun device name for the platform.
|
||||
func defaultTunName() string {
|
||||
switch runtime.GOOS {
|
||||
case "openbsd":
|
||||
return "tun"
|
||||
case "windows":
|
||||
return "Tailscale"
|
||||
}
|
||||
return "tailscale0"
|
||||
}
|
||||
|
||||
var args struct {
|
||||
cleanup bool
|
||||
fake bool
|
||||
debug string
|
||||
tunname string
|
||||
port uint16
|
||||
statepath string
|
||||
socketpath string
|
||||
}
|
||||
|
||||
func main() {
|
||||
// We aren't very performance sensitive, and the parts that are
|
||||
// performance sensitive (wireguard) try hard not to do any memory
|
||||
@@ -47,77 +71,112 @@ func main() {
|
||||
debug.SetGCPercent(10)
|
||||
}
|
||||
|
||||
defaultTunName := "tailscale0"
|
||||
if runtime.GOOS == "openbsd" {
|
||||
defaultTunName = "tun"
|
||||
}
|
||||
// Set default values for getopt.
|
||||
args.tunname = defaultTunName()
|
||||
args.port = magicsock.DefaultPort
|
||||
args.statepath = paths.DefaultTailscaledStateFile()
|
||||
args.socketpath = paths.DefaultTailscaledSocket()
|
||||
|
||||
fake := getopt.BoolLong("fake", 0, "fake tunnel+routing instead of tuntap")
|
||||
debug := getopt.StringLong("debug", 0, "", "Address of debug server")
|
||||
tunname := getopt.StringLong("tun", 0, defaultTunName, "tunnel interface name")
|
||||
listenport := getopt.Uint16Long("port", 'p', magicsock.DefaultPort, "WireGuard port (0=autoselect)")
|
||||
statepath := getopt.StringLong("state", 0, paths.DefaultTailscaledStateFile(), "Path of state file")
|
||||
socketpath := getopt.StringLong("socket", 's', paths.DefaultTailscaledSocket(), "Path of the service unix socket")
|
||||
|
||||
logf := wgengine.RusagePrefixLog(log.Printf)
|
||||
logf = logger.RateLimitedFn(logf, 5*time.Second, 5, 100)
|
||||
getopt.FlagLong(&args.cleanup, "cleanup", 0, "clean up system state and exit")
|
||||
getopt.FlagLong(&args.fake, "fake", 0, "fake tunnel+routing instead of tuntap")
|
||||
getopt.FlagLong(&args.debug, "debug", 0, "address of debug server")
|
||||
getopt.FlagLong(&args.tunname, "tun", 0, "tunnel interface name")
|
||||
getopt.FlagLong(&args.port, "port", 'p', "WireGuard port (0=autoselect)")
|
||||
getopt.FlagLong(&args.statepath, "state", 0, "path of state file")
|
||||
getopt.FlagLong(&args.socketpath, "socket", 's', "path of the service unix socket")
|
||||
|
||||
err := fixconsole.FixConsoleIfNeeded()
|
||||
if err != nil {
|
||||
logf("fixConsoleOutput: %v", err)
|
||||
log.Fatalf("fixConsoleOutput: %v", err)
|
||||
}
|
||||
pol := logpolicy.New("tailnode.log.tailscale.io")
|
||||
|
||||
getopt.Parse()
|
||||
if len(getopt.Args()) > 0 {
|
||||
log.Fatalf("too many non-flag arguments: %#v", getopt.Args()[0])
|
||||
}
|
||||
|
||||
if *statepath == "" {
|
||||
if args.statepath == "" {
|
||||
log.Fatalf("--state is required")
|
||||
}
|
||||
|
||||
if *socketpath == "" {
|
||||
if args.socketpath == "" && runtime.GOOS != "windows" {
|
||||
log.Fatalf("--socket is required")
|
||||
}
|
||||
|
||||
if err := run(); err != nil {
|
||||
// No need to log; the func already did
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
var err error
|
||||
|
||||
pol := logpolicy.New("tailnode.log.tailscale.io")
|
||||
defer func() {
|
||||
// Finish uploading logs after closing everything else.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
pol.Shutdown(ctx)
|
||||
}()
|
||||
|
||||
logf := wgengine.RusagePrefixLog(log.Printf)
|
||||
logf = logger.RateLimitedFn(logf, 5*time.Second, 5, 100)
|
||||
|
||||
if args.cleanup {
|
||||
router.Cleanup(logf, args.tunname)
|
||||
return nil
|
||||
}
|
||||
|
||||
var debugMux *http.ServeMux
|
||||
if *debug != "" {
|
||||
if args.debug != "" {
|
||||
debugMux = newDebugMux()
|
||||
go runDebugServer(debugMux, *debug)
|
||||
go runDebugServer(debugMux, args.debug)
|
||||
}
|
||||
|
||||
var e wgengine.Engine
|
||||
if *fake {
|
||||
if args.fake {
|
||||
e, err = wgengine.NewFakeUserspaceEngine(logf, 0)
|
||||
} else {
|
||||
e, err = wgengine.NewUserspaceEngine(logf, *tunname, *listenport)
|
||||
e, err = wgengine.NewUserspaceEngine(logf, args.tunname, args.port)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("wgengine.New: %v", err)
|
||||
logf("wgengine.New: %v", err)
|
||||
return err
|
||||
}
|
||||
e = wgengine.NewWatchdog(e)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Exit gracefully by cancelling the ipnserver context in most common cases:
|
||||
// interrupted from the TTY or killed by a service manager.
|
||||
go func() {
|
||||
interrupt := make(chan os.Signal, 1)
|
||||
signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
select {
|
||||
case <-interrupt:
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
// continue
|
||||
}
|
||||
}()
|
||||
|
||||
opts := ipnserver.Options{
|
||||
SocketPath: *socketpath,
|
||||
SocketPath: args.socketpath,
|
||||
Port: 41112,
|
||||
StatePath: *statepath,
|
||||
StatePath: args.statepath,
|
||||
AutostartStateKey: globalStateKey,
|
||||
LegacyConfigPath: paths.LegacyConfigPath,
|
||||
LegacyConfigPath: paths.LegacyConfigPath(),
|
||||
SurviveDisconnects: true,
|
||||
DebugMux: debugMux,
|
||||
}
|
||||
err = ipnserver.Run(context.Background(), logf, pol.PublicID.String(), opts, e)
|
||||
if err != nil {
|
||||
log.Fatalf("tailscaled: %v", err)
|
||||
err = ipnserver.Run(ctx, logf, pol.PublicID.String(), opts, e)
|
||||
// Cancelation is not an error: it is the only way to stop ipnserver.
|
||||
if err != nil && err != context.Canceled {
|
||||
logf("ipnserver.Run: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(crawshaw): It would be nice to start a timeout context the moment a signal
|
||||
// is received and use that timeout to give us a moment to finish uploading logs
|
||||
// here. But the signal is handled inside ipnserver.Run, so some plumbing is needed.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
pol.Shutdown(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDebugMux() *http.ServeMux {
|
||||
|
||||
@@ -9,6 +9,7 @@ StartLimitBurst=0
|
||||
[Service]
|
||||
EnvironmentFile=/etc/default/tailscaled
|
||||
ExecStart=/usr/sbin/tailscaled --state=/var/lib/tailscale/tailscaled.state --socket=/run/tailscale/tailscaled.sock --port $PORT $FLAGS
|
||||
ExecStopPost=/usr/sbin/tailscaled --cleanup
|
||||
|
||||
Restart=on-failure
|
||||
|
||||
|
||||
@@ -355,12 +355,13 @@ func (c *Client) authRoutine() {
|
||||
err = fmt.Errorf("weird: server required a new url?")
|
||||
report(err, "WaitLoginURL")
|
||||
}
|
||||
goal.url = url
|
||||
goal.token = nil
|
||||
goal.flags = LoginDefault
|
||||
|
||||
c.mu.Lock()
|
||||
c.loginGoal = goal
|
||||
c.loginGoal = &LoginGoal{
|
||||
wantLoggedIn: true,
|
||||
flags: LoginDefault,
|
||||
url: url,
|
||||
}
|
||||
c.state = StateURLVisitRequired
|
||||
c.synced = false
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -85,6 +85,7 @@ type Direct struct {
|
||||
newDecompressor func() (Decompressor, error)
|
||||
keepAlive bool
|
||||
logf logger.Logf
|
||||
discoPubKey tailcfg.DiscoKey
|
||||
|
||||
mu sync.Mutex // mutex guards the following fields
|
||||
serverKey wgcfg.Key
|
||||
@@ -92,9 +93,10 @@ type Direct struct {
|
||||
authKey string
|
||||
tryingNewKey wgcfg.PrivateKey
|
||||
expiry *time.Time
|
||||
hostinfo *tailcfg.Hostinfo // always non-nil
|
||||
endpoints []string
|
||||
localPort uint16 // or zero to mean auto
|
||||
// hostinfo is mutated in-place while mu is held.
|
||||
hostinfo *tailcfg.Hostinfo // always non-nil
|
||||
endpoints []string
|
||||
localPort uint16 // or zero to mean auto
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
@@ -103,6 +105,7 @@ type Options struct {
|
||||
AuthKey string // optional node auth key for auto registration
|
||||
TimeNow func() time.Time // time.Now implementation used by Client
|
||||
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
|
||||
DiscoPublicKey tailcfg.DiscoKey
|
||||
NewDecompressor func() (Decompressor, error)
|
||||
KeepAlive bool
|
||||
Logf logger.Logf
|
||||
@@ -152,6 +155,7 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
keepAlive: opts.KeepAlive,
|
||||
persist: opts.Persist,
|
||||
authKey: opts.AuthKey,
|
||||
discoPubKey: opts.DiscoPublicKey,
|
||||
}
|
||||
if opts.Hostinfo == nil {
|
||||
c.SetHostinfo(NewHostinfo())
|
||||
@@ -262,6 +266,8 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
|
||||
tryingNewKey := c.tryingNewKey
|
||||
serverKey := c.serverKey
|
||||
authKey := c.authKey
|
||||
hostinfo := c.hostinfo.Clone()
|
||||
backendLogID := hostinfo.BackendLogID
|
||||
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -318,7 +324,7 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
|
||||
if tryingNewKey == (wgcfg.PrivateKey{}) {
|
||||
log.Fatalf("tryingNewKey is empty, give up")
|
||||
}
|
||||
if c.hostinfo.BackendLogID == "" {
|
||||
if backendLogID == "" {
|
||||
err = errors.New("hostinfo: BackendLogID missing")
|
||||
return regen, url, err
|
||||
}
|
||||
@@ -326,7 +332,7 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
|
||||
Version: 1,
|
||||
OldNodeKey: tailcfg.NodeKey(oldNodeKey),
|
||||
NodeKey: tailcfg.NodeKey(tryingNewKey.Public()),
|
||||
Hostinfo: c.hostinfo,
|
||||
Hostinfo: hostinfo,
|
||||
Followup: url,
|
||||
}
|
||||
c.logf("RegisterReq: onode=%v node=%v fup=%v",
|
||||
@@ -381,7 +387,7 @@ func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags,
|
||||
// - user is disabled
|
||||
|
||||
if resp.AuthURL != "" {
|
||||
c.logf("AuthURL is %.20v...", resp.AuthURL)
|
||||
c.logf("AuthURL is %v", resp.AuthURL)
|
||||
} else {
|
||||
c.logf("No AuthURL")
|
||||
}
|
||||
@@ -445,19 +451,18 @@ func (c *Direct) SetEndpoints(localPort uint16, endpoints []string) (changed boo
|
||||
return c.newEndpoints(localPort, endpoints)
|
||||
}
|
||||
|
||||
var debugNetmap, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_NETMAP"))
|
||||
|
||||
func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkMap)) error {
|
||||
c.mu.Lock()
|
||||
persist := c.persist
|
||||
serverURL := c.serverURL
|
||||
serverKey := c.serverKey
|
||||
hostinfo := c.hostinfo
|
||||
hostinfo := c.hostinfo.Clone()
|
||||
backendLogID := hostinfo.BackendLogID
|
||||
localPort := c.localPort
|
||||
ep := append([]string(nil), c.endpoints...)
|
||||
c.mu.Unlock()
|
||||
|
||||
if hostinfo.BackendLogID == "" {
|
||||
if backendLogID == "" {
|
||||
return errors.New("hostinfo: BackendLogID missing")
|
||||
}
|
||||
|
||||
@@ -465,18 +470,20 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
|
||||
c.logf("PollNetMap: stream=%v :%v %v", maxPolls, localPort, ep)
|
||||
|
||||
vlogf := logger.Discard
|
||||
if debugNetmap {
|
||||
if Debug.NetMap {
|
||||
vlogf = c.logf
|
||||
}
|
||||
|
||||
request := tailcfg.MapRequest{
|
||||
Version: 4,
|
||||
IncludeIPv6: includeIPv6(),
|
||||
KeepAlive: c.keepAlive,
|
||||
NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()),
|
||||
Endpoints: ep,
|
||||
Stream: allowStream,
|
||||
Hostinfo: hostinfo,
|
||||
Version: 4,
|
||||
IncludeIPv6: true,
|
||||
KeepAlive: c.keepAlive,
|
||||
NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()),
|
||||
DiscoKey: c.discoPubKey,
|
||||
Endpoints: ep,
|
||||
Stream: allowStream,
|
||||
Hostinfo: hostinfo,
|
||||
DebugForceDisco: Debug.ForceDisco,
|
||||
}
|
||||
if c.newDecompressor != nil {
|
||||
request.Compress = "zstd"
|
||||
@@ -593,7 +600,19 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
|
||||
lastDERPMap = resp.DERPMap
|
||||
}
|
||||
if resp.Debug != nil && resp.Debug.LogHeapPprof {
|
||||
logheap.LogHeap()
|
||||
go logheap.LogHeap(resp.Debug.LogHeapURL)
|
||||
}
|
||||
// Temporarily (2020-06-29) support removing all but
|
||||
// discovery-supporting nodes during development, for
|
||||
// less noise.
|
||||
if Debug.OnlyDisco {
|
||||
filtered := resp.Peers[:0]
|
||||
for _, p := range resp.Peers {
|
||||
if !p.DiscoKey.IsZero() {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
resp.Peers = filtered
|
||||
}
|
||||
|
||||
nm := &NetworkMap{
|
||||
@@ -612,6 +631,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
|
||||
Hostinfo: resp.Node.Hostinfo,
|
||||
PacketFilter: c.parsePacketFilter(resp.PacketFilter),
|
||||
DERPMap: lastDERPMap,
|
||||
Debug: resp.Debug,
|
||||
}
|
||||
for _, profile := range resp.UserProfiles {
|
||||
nm.UserProfiles[profile.ID] = profile
|
||||
@@ -658,8 +678,10 @@ func decode(res *http.Response, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg
|
||||
}
|
||||
|
||||
func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
|
||||
c.mu.Lock()
|
||||
mkey := c.persist.PrivateMachineKey
|
||||
serverKey := c.serverKey
|
||||
c.mu.Unlock()
|
||||
|
||||
decrypted, err := decryptMsg(msg, &serverKey, &mkey)
|
||||
if err != nil {
|
||||
@@ -758,13 +780,38 @@ func loadServerKey(ctx context.Context, httpc *http.Client, serverURL string) (w
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// includeIPv6 reports whether we should enable IPv6 for magicsock
|
||||
// connections. This is only here temporarily (2020-03-26) as a
|
||||
// opt-out in case there are problems.
|
||||
func includeIPv6() bool {
|
||||
if e := os.Getenv("DEBUG_INCLUDE_IPV6"); e != "" {
|
||||
v, _ := strconv.ParseBool(e)
|
||||
return v
|
||||
}
|
||||
return true
|
||||
// Debug contains temporary internal-only debug knobs.
|
||||
// They're unexported to not draw attention to them.
|
||||
var Debug = initDebug()
|
||||
|
||||
type debug struct {
|
||||
NetMap bool
|
||||
OnlyDisco bool
|
||||
Disco bool
|
||||
ForceDisco bool // ask control server to not filter out our disco key
|
||||
}
|
||||
|
||||
func initDebug() debug {
|
||||
d := debug{
|
||||
NetMap: envBool("TS_DEBUG_NETMAP"),
|
||||
OnlyDisco: os.Getenv("TS_DEBUG_USE_DISCO") == "only",
|
||||
ForceDisco: os.Getenv("TS_DEBUG_USE_DISCO") == "only" || envBool("TS_DEBUG_USE_DISCO"),
|
||||
}
|
||||
if d.ForceDisco || os.Getenv("TS_DEBUG_USE_DISCO") == "" {
|
||||
// This is now defaults to on.
|
||||
d.Disco = true
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func envBool(k string) bool {
|
||||
e := os.Getenv(k)
|
||||
if e == "" {
|
||||
return false
|
||||
}
|
||||
v, err := strconv.ParseBool(e)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid non-bool %q for env var %q", e, k))
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -5,16 +5,18 @@
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
@@ -27,7 +29,7 @@ type NetworkMap struct {
|
||||
Addresses []wgcfg.CIDR
|
||||
LocalPort uint16 // used for debugging
|
||||
MachineStatus tailcfg.MachineStatus
|
||||
Peers []*tailcfg.Node
|
||||
Peers []*tailcfg.Node // sorted by Node.ID
|
||||
DNS []wgcfg.IP
|
||||
DNSDomains []string
|
||||
Hostinfo tailcfg.Hostinfo
|
||||
@@ -37,6 +39,9 @@ type NetworkMap struct {
|
||||
// between updates and should not be modified.
|
||||
DERPMap *tailcfg.DERPMap
|
||||
|
||||
// Debug knobs from control server for debug or feature gating.
|
||||
Debug *tailcfg.Debug
|
||||
|
||||
// ACLs
|
||||
|
||||
User tailcfg.UserID
|
||||
@@ -49,92 +54,146 @@ type NetworkMap struct {
|
||||
// TODO(crawshaw): Capabilities []tailcfg.Capability
|
||||
}
|
||||
|
||||
func (n *NetworkMap) Equal(n2 *NetworkMap) bool {
|
||||
// TODO(crawshaw): this is crude, but is an easy way to avoid bugs.
|
||||
b, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b2, err := json.Marshal(n2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bytes.Equal(b, b2)
|
||||
}
|
||||
|
||||
func (nm NetworkMap) String() string {
|
||||
return nm.Concise()
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) Concise() string {
|
||||
buf := new(strings.Builder)
|
||||
fmt.Fprintf(buf, "netmap: self: %v auth=%v :%v %v\n",
|
||||
nm.NodeKey.ShortString(), nm.MachineStatus,
|
||||
nm.LocalPort, nm.Addresses)
|
||||
|
||||
nm.printConciseHeader(buf)
|
||||
for _, p := range nm.Peers {
|
||||
aip := make([]string, len(p.AllowedIPs))
|
||||
for i, a := range p.AllowedIPs {
|
||||
s := fmt.Sprint(a)
|
||||
if strings.HasSuffix(s, "/32") {
|
||||
s = s[0 : len(s)-3]
|
||||
}
|
||||
aip[i] = s
|
||||
}
|
||||
|
||||
ep := make([]string, len(p.Endpoints))
|
||||
for i, e := range p.Endpoints {
|
||||
// Align vertically on the ':' between IP and port
|
||||
colon := strings.IndexByte(e, ':')
|
||||
for colon > 0 && len(e)-colon < 6 {
|
||||
e += " "
|
||||
colon--
|
||||
}
|
||||
ep[i] = fmt.Sprintf("%21v", e)
|
||||
}
|
||||
|
||||
derp := p.DERP
|
||||
const derpPrefix = "127.3.3.40:"
|
||||
if strings.HasPrefix(derp, derpPrefix) {
|
||||
derp = "D" + derp[len(derpPrefix):]
|
||||
}
|
||||
|
||||
// Most of the time, aip is just one element, so format the
|
||||
// table to look good in that case. This will also make multi-
|
||||
// subnet nodes stand out visually.
|
||||
fmt.Fprintf(buf, " %v %-2v %-15v : %v\n",
|
||||
p.Key.ShortString(), derp,
|
||||
strings.Join(aip, " "),
|
||||
strings.Join(ep, " "))
|
||||
printPeerConcise(buf, p)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// printConciseHeader prints a concise header line representing nm to buf.
|
||||
//
|
||||
// If this function is changed to access different fields of nm, keep
|
||||
// in equalConciseHeader in sync.
|
||||
func (nm *NetworkMap) printConciseHeader(buf *strings.Builder) {
|
||||
fmt.Fprintf(buf, "netmap: self: %v auth=%v",
|
||||
nm.NodeKey.ShortString(), nm.MachineStatus)
|
||||
if nm.LocalPort != 0 {
|
||||
fmt.Fprintf(buf, " port=%v", nm.LocalPort)
|
||||
}
|
||||
if nm.Debug != nil {
|
||||
j, _ := json.Marshal(nm.Debug)
|
||||
fmt.Fprintf(buf, " debug=%s", j)
|
||||
}
|
||||
fmt.Fprintf(buf, " %v", nm.Addresses)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
// equalConciseHeader reports whether a and b are equal for the fields
|
||||
// used by printConciseHeader.
|
||||
func (a *NetworkMap) equalConciseHeader(b *NetworkMap) bool {
|
||||
if a.NodeKey != b.NodeKey ||
|
||||
a.MachineStatus != b.MachineStatus ||
|
||||
a.LocalPort != b.LocalPort ||
|
||||
len(a.Addresses) != len(b.Addresses) {
|
||||
return false
|
||||
}
|
||||
for i, a := range a.Addresses {
|
||||
if b.Addresses[i] != a {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return (a.Debug == nil && b.Debug == nil) || reflect.DeepEqual(a.Debug, b.Debug)
|
||||
}
|
||||
|
||||
// printPeerConcise appends to buf a line repsenting the peer p.
|
||||
//
|
||||
// If this function is changed to access different fields of p, keep
|
||||
// in nodeConciseEqual in sync.
|
||||
func printPeerConcise(buf *strings.Builder, p *tailcfg.Node) {
|
||||
aip := make([]string, len(p.AllowedIPs))
|
||||
for i, a := range p.AllowedIPs {
|
||||
s := strings.TrimSuffix(fmt.Sprint(a), "/32")
|
||||
aip[i] = s
|
||||
}
|
||||
|
||||
ep := make([]string, len(p.Endpoints))
|
||||
for i, e := range p.Endpoints {
|
||||
// Align vertically on the ':' between IP and port
|
||||
colon := strings.IndexByte(e, ':')
|
||||
spaces := 0
|
||||
for colon > 0 && len(e)+spaces-colon < 6 {
|
||||
spaces++
|
||||
colon--
|
||||
}
|
||||
ep[i] = fmt.Sprintf("%21v", e+strings.Repeat(" ", spaces))
|
||||
}
|
||||
|
||||
derp := p.DERP
|
||||
const derpPrefix = "127.3.3.40:"
|
||||
if strings.HasPrefix(derp, derpPrefix) {
|
||||
derp = "D" + derp[len(derpPrefix):]
|
||||
}
|
||||
|
||||
// Most of the time, aip is just one element, so format the
|
||||
// table to look good in that case. This will also make multi-
|
||||
// subnet nodes stand out visually.
|
||||
fmt.Fprintf(buf, " %v %-2v %-15v : %v\n",
|
||||
p.Key.ShortString(), derp,
|
||||
strings.Join(aip, " "),
|
||||
strings.Join(ep, " "))
|
||||
}
|
||||
|
||||
// nodeConciseEqual reports whether a and b are equal for the fields accessed by printPeerConcise.
|
||||
func nodeConciseEqual(a, b *tailcfg.Node) bool {
|
||||
return a.Key == b.Key &&
|
||||
a.DERP == b.DERP &&
|
||||
eqCIDRsIgnoreNil(a.AllowedIPs, b.AllowedIPs) &&
|
||||
eqStringsIgnoreNil(a.Endpoints, b.Endpoints)
|
||||
}
|
||||
|
||||
func (b *NetworkMap) ConciseDiffFrom(a *NetworkMap) string {
|
||||
out := []string{}
|
||||
ra := strings.Split(a.Concise(), "\n")
|
||||
rb := strings.Split(b.Concise(), "\n")
|
||||
var diff strings.Builder
|
||||
|
||||
ma := map[string]struct{}{}
|
||||
for _, s := range ra {
|
||||
ma[s] = struct{}{}
|
||||
// See if header (non-peers, "bare") part of the network map changed.
|
||||
// If so, print its diff lines first.
|
||||
if !a.equalConciseHeader(b) {
|
||||
diff.WriteByte('-')
|
||||
a.printConciseHeader(&diff)
|
||||
diff.WriteByte('+')
|
||||
b.printConciseHeader(&diff)
|
||||
}
|
||||
|
||||
mb := map[string]struct{}{}
|
||||
for _, s := range rb {
|
||||
mb[s] = struct{}{}
|
||||
}
|
||||
|
||||
for _, s := range ra {
|
||||
if _, ok := mb[s]; !ok {
|
||||
out = append(out, "-"+s)
|
||||
aps, bps := a.Peers, b.Peers
|
||||
for len(aps) > 0 && len(bps) > 0 {
|
||||
pa, pb := aps[0], bps[0]
|
||||
switch {
|
||||
case pa.ID == pb.ID:
|
||||
if !nodeConciseEqual(pa, pb) {
|
||||
diff.WriteByte('-')
|
||||
printPeerConcise(&diff, pa)
|
||||
diff.WriteByte('+')
|
||||
printPeerConcise(&diff, pb)
|
||||
}
|
||||
aps, bps = aps[1:], bps[1:]
|
||||
case pa.ID > pb.ID:
|
||||
// New peer in b.
|
||||
diff.WriteByte('+')
|
||||
printPeerConcise(&diff, pb)
|
||||
bps = bps[1:]
|
||||
case pb.ID > pa.ID:
|
||||
// Deleted peer in b.
|
||||
diff.WriteByte('-')
|
||||
printPeerConcise(&diff, pa)
|
||||
aps = aps[1:]
|
||||
}
|
||||
}
|
||||
for _, s := range rb {
|
||||
if _, ok := ma[s]; !ok {
|
||||
out = append(out, "+"+s)
|
||||
}
|
||||
for _, pa := range aps {
|
||||
diff.WriteByte('-')
|
||||
printPeerConcise(&diff, pa)
|
||||
}
|
||||
return strings.Join(out, "\n")
|
||||
for _, pb := range bps {
|
||||
diff.WriteByte('+')
|
||||
printPeerConcise(&diff, pb)
|
||||
}
|
||||
return diff.String()
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) JSON() string {
|
||||
@@ -145,138 +204,141 @@ func (nm *NetworkMap) JSON() string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// WGConfigFlags is a bitmask of flags to control the behavior of the
|
||||
// wireguard configuration generation done by NetMap.WGCfg.
|
||||
type WGConfigFlags int
|
||||
|
||||
const (
|
||||
UAllowSingleHosts = 1 << iota
|
||||
UAllowSubnetRoutes
|
||||
UAllowDefaultRoute
|
||||
UHackDefaultRoute
|
||||
|
||||
UDefault = 0
|
||||
AllowSingleHosts WGConfigFlags = 1 << iota
|
||||
AllowSubnetRoutes
|
||||
AllowDefaultRoute
|
||||
HackDefaultRoute
|
||||
)
|
||||
|
||||
// Several programs need to parse these arguments into uflags, so let's
|
||||
// centralize it here.
|
||||
func UFlagsHelper(uroutes, rroutes, droutes bool) int {
|
||||
uflags := 0
|
||||
if uroutes {
|
||||
uflags |= UAllowSingleHosts
|
||||
}
|
||||
if rroutes {
|
||||
uflags |= UAllowSubnetRoutes
|
||||
}
|
||||
if droutes {
|
||||
uflags |= UAllowDefaultRoute
|
||||
}
|
||||
return uflags
|
||||
}
|
||||
|
||||
// TODO(bradfitz): UAPI seems to only be used by the old confnode and
|
||||
// pingnode; delete this when those are deleted/rewritten?
|
||||
func (nm *NetworkMap) UAPI(uflags int, dnsOverride []wgcfg.IP) string {
|
||||
wgcfg, err := nm.WGCfg(uflags, dnsOverride)
|
||||
func (nm *NetworkMap) UAPI(flags WGConfigFlags, dnsOverride []wgcfg.IP) string {
|
||||
wgcfg, err := nm.WGCfg(log.Printf, flags, dnsOverride)
|
||||
if err != nil {
|
||||
log.Fatalf("WGCfg() failed unexpectedly: %v\n", err)
|
||||
log.Fatalf("WGCfg() failed unexpectedly: %v", err)
|
||||
}
|
||||
s, err := wgcfg.ToUAPI()
|
||||
if err != nil {
|
||||
log.Fatalf("ToUAPI() failed unexpectedly: %v\n", err)
|
||||
log.Fatalf("ToUAPI() failed unexpectedly: %v", err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) WGCfg(uflags int, dnsOverride []wgcfg.IP) (*wgcfg.Config, error) {
|
||||
s := nm._WireGuardConfig(uflags, dnsOverride, true)
|
||||
return wgcfg.FromWgQuick(s, "tailscale")
|
||||
}
|
||||
// EndpointDiscoSuffix is appended to the hex representation of a peer's discovery key
|
||||
// and is then the sole wireguard endpoint for peers with a non-zero discovery key.
|
||||
// This form is then recognize by magicsock's CreateEndpoint.
|
||||
const EndpointDiscoSuffix = ".disco.tailscale:12345"
|
||||
|
||||
// TODO(apenwarr): This mode is dangerous.
|
||||
// Discarding the extra endpoints is almost universally the wrong choice.
|
||||
// Except that plain wireguard can't handle a peer with multiple endpoints.
|
||||
// (Yet?)
|
||||
func (nm *NetworkMap) WireGuardConfigOneEndpoint(uflags int, dnsOverride []wgcfg.IP) string {
|
||||
return nm._WireGuardConfig(uflags, dnsOverride, false)
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) _WireGuardConfig(uflags int, dnsOverride []wgcfg.IP, allEndpoints bool) string {
|
||||
buf := new(strings.Builder)
|
||||
fmt.Fprintf(buf, "[Interface]\n")
|
||||
fmt.Fprintf(buf, "PrivateKey = %s\n", base64.StdEncoding.EncodeToString(nm.PrivateKey[:]))
|
||||
if len(nm.Addresses) > 0 {
|
||||
fmt.Fprintf(buf, "Address = ")
|
||||
for i, cidr := range nm.Addresses {
|
||||
if i > 0 {
|
||||
fmt.Fprintf(buf, ", ")
|
||||
}
|
||||
fmt.Fprintf(buf, "%s", cidr)
|
||||
}
|
||||
fmt.Fprintf(buf, "\n")
|
||||
// WGCfg returns the NetworkMaps's Wireguard configuration.
|
||||
func (nm *NetworkMap) WGCfg(logf logger.Logf, flags WGConfigFlags, dnsOverride []wgcfg.IP) (*wgcfg.Config, error) {
|
||||
cfg := &wgcfg.Config{
|
||||
Name: "tailscale",
|
||||
PrivateKey: nm.PrivateKey,
|
||||
Addresses: nm.Addresses,
|
||||
ListenPort: nm.LocalPort,
|
||||
DNS: append([]wgcfg.IP(nil), dnsOverride...),
|
||||
Peers: make([]wgcfg.Peer, 0, len(nm.Peers)),
|
||||
}
|
||||
fmt.Fprintf(buf, "ListenPort = %d\n", nm.LocalPort)
|
||||
if len(dnsOverride) > 0 {
|
||||
dnss := []string{}
|
||||
for _, ip := range dnsOverride {
|
||||
dnss = append(dnss, ip.String())
|
||||
}
|
||||
fmt.Fprintf(buf, "DNS = %s\n", strings.Join(dnss, ","))
|
||||
}
|
||||
fmt.Fprintf(buf, "\n")
|
||||
|
||||
for i, peer := range nm.Peers {
|
||||
if (uflags&UAllowSingleHosts) == 0 && len(peer.AllowedIPs) < 2 {
|
||||
log.Printf("wgcfg: %v skipping a single-host peer.\n", peer.Key.ShortString())
|
||||
for _, peer := range nm.Peers {
|
||||
if Debug.OnlyDisco && peer.DiscoKey.IsZero() {
|
||||
continue
|
||||
}
|
||||
if i > 0 {
|
||||
fmt.Fprintf(buf, "\n")
|
||||
if (flags&AllowSingleHosts) == 0 && len(peer.AllowedIPs) < 2 {
|
||||
logf("wgcfg: %v skipping a single-host peer.", peer.Key.ShortString())
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(buf, "[Peer]\n")
|
||||
fmt.Fprintf(buf, "PublicKey = %s\n", base64.StdEncoding.EncodeToString(peer.Key[:]))
|
||||
var endpoints []string
|
||||
if peer.DERP != "" {
|
||||
endpoints = append(endpoints, peer.DERP)
|
||||
cfg.Peers = append(cfg.Peers, wgcfg.Peer{
|
||||
PublicKey: wgcfg.Key(peer.Key),
|
||||
})
|
||||
cpeer := &cfg.Peers[len(cfg.Peers)-1]
|
||||
if peer.KeepAlive {
|
||||
cpeer.PersistentKeepalive = 25 // seconds
|
||||
}
|
||||
endpoints = append(endpoints, peer.Endpoints...)
|
||||
if len(endpoints) > 0 {
|
||||
if len(endpoints) == 1 {
|
||||
fmt.Fprintf(buf, "Endpoint = %s", endpoints[0])
|
||||
} else if allEndpoints {
|
||||
// TODO(apenwarr): This mode is incompatible.
|
||||
// Normal wireguard clients don't know how to
|
||||
// parse it (yet?)
|
||||
fmt.Fprintf(buf, "Endpoint = %s",
|
||||
strings.Join(endpoints, ","))
|
||||
} else {
|
||||
fmt.Fprintf(buf, "Endpoint = %s # other endpoints: %s",
|
||||
endpoints[0],
|
||||
strings.Join(endpoints[1:], ", "))
|
||||
|
||||
if !peer.DiscoKey.IsZero() {
|
||||
if err := appendEndpoint(cpeer, fmt.Sprintf("%x%s", peer.DiscoKey[:], EndpointDiscoSuffix)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cpeer.Endpoints = []wgcfg.Endpoint{{Host: fmt.Sprintf("%x.disco.tailscale", peer.DiscoKey[:]), Port: 12345}}
|
||||
} else {
|
||||
if err := appendEndpoint(cpeer, peer.DERP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, ep := range peer.Endpoints {
|
||||
if err := appendEndpoint(cpeer, ep); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
var aips []string
|
||||
for _, allowedIP := range peer.AllowedIPs {
|
||||
aip := allowedIP.String()
|
||||
if allowedIP.Mask == 0 {
|
||||
if (uflags & UAllowDefaultRoute) == 0 {
|
||||
log.Printf("wgcfg: %v skipping default route\n", peer.Key.ShortString())
|
||||
if (flags & AllowDefaultRoute) == 0 {
|
||||
logf("wgcfg: %v skipping default route", peer.Key.ShortString())
|
||||
continue
|
||||
}
|
||||
if (uflags & UHackDefaultRoute) != 0 {
|
||||
aip = "10.0.0.0/8"
|
||||
log.Printf("wgcfg: %v converting default route => %v\n", peer.Key.ShortString(), aip)
|
||||
if (flags & HackDefaultRoute) != 0 {
|
||||
allowedIP = wgcfg.CIDR{IP: wgcfg.IPv4(10, 0, 0, 0), Mask: 8}
|
||||
logf("wgcfg: %v converting default route => %v", peer.Key.ShortString(), allowedIP.String())
|
||||
}
|
||||
} else if allowedIP.Mask < 32 {
|
||||
if (uflags & UAllowSubnetRoutes) == 0 {
|
||||
log.Printf("wgcfg: %v skipping subnet route\n", peer.Key.ShortString())
|
||||
if (flags & AllowSubnetRoutes) == 0 {
|
||||
logf("wgcfg: %v skipping subnet route", peer.Key.ShortString())
|
||||
continue
|
||||
}
|
||||
}
|
||||
aips = append(aips, aip)
|
||||
}
|
||||
fmt.Fprintf(buf, "AllowedIPs = %s\n", strings.Join(aips, ", "))
|
||||
if peer.KeepAlive {
|
||||
fmt.Fprintf(buf, "PersistentKeepalive = 25\n")
|
||||
cpeer.AllowedIPs = append(cpeer.AllowedIPs, allowedIP)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func appendEndpoint(peer *wgcfg.Peer, epStr string) error {
|
||||
if epStr == "" {
|
||||
return nil
|
||||
}
|
||||
host, port, err := net.SplitHostPort(epStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("malformed endpoint %q for peer %v", epStr, peer.PublicKey.ShortString())
|
||||
}
|
||||
port16, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port in endpoint %q for peer %v", epStr, peer.PublicKey.ShortString())
|
||||
}
|
||||
peer.Endpoints = append(peer.Endpoints, wgcfg.Endpoint{Host: host, Port: uint16(port16)})
|
||||
return nil
|
||||
}
|
||||
|
||||
// eqStringsIgnoreNil reports whether a and b have the same length and
|
||||
// contents, but ignore whether a or b are nil.
|
||||
func eqStringsIgnoreNil(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// eqCIDRsIgnoreNil reports whether a and b have the same length and
|
||||
// contents, but ignore whether a or b are nil.
|
||||
func eqCIDRsIgnoreNil(a, b []wgcfg.CIDR) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
217
control/controlclient/netmap_test.go
Normal file
217
control/controlclient/netmap_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func testNodeKey(b byte) (ret tailcfg.NodeKey) {
|
||||
for i := range ret {
|
||||
ret[i] = b
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestNetworkMapConcise(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
nm *NetworkMap
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
nm: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
{
|
||||
Key: testNodeKey(3),
|
||||
DERP: "127.3.3.40:4",
|
||||
Endpoints: []string{"10.2.0.100:12", "10.1.0.100:12345"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "netmap: self: [AQEBA] auth=machine-unknown []\n [AgICA] D2 : 192.168.0.100:12 192.168.0.100:12354\n [AwMDA] D4 : 10.2.0.100:12 10.1.0.100:12345\n",
|
||||
},
|
||||
{
|
||||
name: "debug_non_nil",
|
||||
nm: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Debug: &tailcfg.Debug{},
|
||||
},
|
||||
want: "netmap: self: [AQEBA] auth=machine-unknown debug={} []\n",
|
||||
},
|
||||
{
|
||||
name: "debug_values",
|
||||
nm: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Debug: &tailcfg.Debug{LogHeapPprof: true},
|
||||
},
|
||||
want: "netmap: self: [AQEBA] auth=machine-unknown debug={\"LogHeapPprof\":true} []\n",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got string
|
||||
n := int(testing.AllocsPerRun(1000, func() {
|
||||
got = tt.nm.Concise()
|
||||
}))
|
||||
t.Logf("Allocs = %d", n)
|
||||
if got != tt.want {
|
||||
t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConciseDiffFrom(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
a, b *NetworkMap
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no_change",
|
||||
a: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
},
|
||||
},
|
||||
b: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "header_change",
|
||||
a: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
},
|
||||
},
|
||||
b: &NetworkMap{
|
||||
NodeKey: testNodeKey(2),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "-netmap: self: [AQEBA] auth=machine-unknown []\n+netmap: self: [AgICA] auth=machine-unknown []\n",
|
||||
},
|
||||
{
|
||||
name: "peer_add",
|
||||
a: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
ID: 2,
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
},
|
||||
},
|
||||
b: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
ID: 1,
|
||||
Key: testNodeKey(1),
|
||||
DERP: "127.3.3.40:1",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Key: testNodeKey(3),
|
||||
DERP: "127.3.3.40:3",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "+ [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n+ [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n",
|
||||
},
|
||||
{
|
||||
name: "peer_remove",
|
||||
a: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
ID: 1,
|
||||
Key: testNodeKey(1),
|
||||
DERP: "127.3.3.40:1",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Key: testNodeKey(3),
|
||||
DERP: "127.3.3.40:3",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
},
|
||||
},
|
||||
b: &NetworkMap{
|
||||
NodeKey: testNodeKey(1),
|
||||
Peers: []*tailcfg.Node{
|
||||
{
|
||||
ID: 2,
|
||||
Key: testNodeKey(2),
|
||||
DERP: "127.3.3.40:2",
|
||||
Endpoints: []string{"192.168.0.100:12", "192.168.0.100:12354"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "- [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n- [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got string
|
||||
n := int(testing.AllocsPerRun(50, func() {
|
||||
got = tt.b.ConciseDiffFrom(tt.a)
|
||||
}))
|
||||
t.Logf("Allocs = %d", n)
|
||||
if got != tt.want {
|
||||
t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
28
derp/derp.go
28
derp/derp.go
@@ -32,10 +32,11 @@ const MaxPacketSize = 64 << 10
|
||||
const magic = "DERP🔑" // 8 bytes: 0x44 45 52 50 f0 9f 94 91
|
||||
|
||||
const (
|
||||
nonceLen = 24
|
||||
keyLen = 32
|
||||
maxInfoLen = 1 << 20
|
||||
keepAlive = 60 * time.Second
|
||||
nonceLen = 24
|
||||
frameHeaderLen = 1 + 4 // frameType byte + 4 byte length
|
||||
keyLen = 32
|
||||
maxInfoLen = 1 << 20
|
||||
keepAlive = 60 * time.Second
|
||||
)
|
||||
|
||||
// protocolVersion is bumped whenever there's a wire-incompatible change.
|
||||
@@ -71,6 +72,7 @@ const (
|
||||
frameClientInfo = frameType(0x02) // 32B pub key + 24B nonce + naclbox(json)
|
||||
frameServerInfo = frameType(0x03) // 24B nonce + naclbox(json)
|
||||
frameSendPacket = frameType(0x04) // 32B dest pub key + packet bytes
|
||||
frameForwardPacket = frameType(0x0a) // 32B src pub key + 32B dst pub key + packet bytes
|
||||
frameRecvPacket = frameType(0x05) // v0/1: packet bytes, v2: 32B src pub key + packet bytes
|
||||
frameKeepAlive = frameType(0x06) // no payload, no-op (to be replaced with ping/pong)
|
||||
frameNotePreferred = frameType(0x07) // 1 byte payload: 0x01 or 0x00 for whether this is client's home node
|
||||
@@ -81,6 +83,24 @@ const (
|
||||
// framePeerGone to B so B can forget that a reverse path
|
||||
// exists on that connection to get back to A.
|
||||
framePeerGone = frameType(0x08) // 32B pub key of peer that's gone
|
||||
|
||||
// framePeerPresent is like framePeerGone, but for other
|
||||
// members of the DERP region when they're meshed up together.
|
||||
framePeerPresent = frameType(0x09) // 32B pub key of peer that's connected
|
||||
|
||||
// frameWatchConns is how one DERP node in a regional mesh
|
||||
// subscribes to the others in the region.
|
||||
// There's no payload. If the sender doesn't have permission, the connection
|
||||
// is closed. Otherwise, the client is initially flooded with
|
||||
// framePeerPresent for all connected nodes, and then a stream of
|
||||
// framePeerPresent & framePeerGone has peers connect and disconnect.
|
||||
frameWatchConns = frameType(0x10)
|
||||
|
||||
// frameClosePeer is a privileged frame type (requires the
|
||||
// mesh key for now) that closes the provided peer's
|
||||
// connection. (To be used for cluster load balancing
|
||||
// purposes, when clients end up on a non-ideal node)
|
||||
frameClosePeer = frameType(0x11) // 32B pub key of peer to close.
|
||||
)
|
||||
|
||||
var bin = binary.BigEndian
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// Client is a DERP client.
|
||||
type Client struct {
|
||||
serverKey key.Public // of the DERP server; not a machine or node key
|
||||
privateKey key.Private
|
||||
@@ -27,13 +28,48 @@ type Client struct {
|
||||
logf logger.Logf
|
||||
nc Conn
|
||||
br *bufio.Reader
|
||||
meshKey string
|
||||
|
||||
wmu sync.Mutex // hold while writing to bw
|
||||
bw *bufio.Writer
|
||||
wmu sync.Mutex // hold while writing to bw
|
||||
bw *bufio.Writer
|
||||
|
||||
// Owned by Recv:
|
||||
peeked int // bytes to discard on next Recv
|
||||
readErr error // sticky read error
|
||||
}
|
||||
|
||||
func NewClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logger.Logf) (*Client, error) {
|
||||
// ClientOpt is an option passed to NewClient.
|
||||
type ClientOpt interface {
|
||||
update(*clientOpt)
|
||||
}
|
||||
|
||||
type clientOptFunc func(*clientOpt)
|
||||
|
||||
func (f clientOptFunc) update(o *clientOpt) { f(o) }
|
||||
|
||||
// clientOpt are the options passed to newClient.
|
||||
type clientOpt struct {
|
||||
MeshKey string
|
||||
}
|
||||
|
||||
// MeshKey returns a ClientOpt to pass to the DERP server during connect to get
|
||||
// access to join the mesh.
|
||||
//
|
||||
// An empty key means to not use a mesh key.
|
||||
func MeshKey(key string) ClientOpt { return clientOptFunc(func(o *clientOpt) { o.MeshKey = key }) }
|
||||
|
||||
func NewClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opts ...ClientOpt) (*Client, error) {
|
||||
var opt clientOpt
|
||||
for _, o := range opts {
|
||||
if o == nil {
|
||||
return nil, errors.New("nil ClientOpt")
|
||||
}
|
||||
o.update(&opt)
|
||||
}
|
||||
return newClient(privateKey, nc, brw, logf, opt)
|
||||
}
|
||||
|
||||
func newClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opt clientOpt) (*Client, error) {
|
||||
c := &Client{
|
||||
privateKey: privateKey,
|
||||
publicKey: privateKey.Public(),
|
||||
@@ -41,8 +77,8 @@ func NewClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logg
|
||||
nc: nc,
|
||||
br: brw.Reader,
|
||||
bw: brw.Writer,
|
||||
meshKey: opt.MeshKey,
|
||||
}
|
||||
|
||||
if err := c.recvServerKey(); err != nil {
|
||||
return nil, fmt.Errorf("derp.Client: failed to receive server key: %v", err)
|
||||
}
|
||||
@@ -109,6 +145,12 @@ func (c *Client) recvServerInfo() (*serverInfo, error) {
|
||||
|
||||
type clientInfo struct {
|
||||
Version int // `json:"version,omitempty"`
|
||||
|
||||
// MeshKey optionally specifies a pre-shared key used by
|
||||
// trusted clients. It's required to subscribe to the
|
||||
// connection list & forward packets. It's empty for regular
|
||||
// users.
|
||||
MeshKey string // `json:"meshKey,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Client) sendClientKey() error {
|
||||
@@ -116,7 +158,10 @@ func (c *Client) sendClientKey() error {
|
||||
if _, err := crand.Read(nonce[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
msg, err := json.Marshal(clientInfo{Version: protocolVersion})
|
||||
msg, err := json.Marshal(clientInfo{
|
||||
Version: protocolVersion,
|
||||
MeshKey: c.meshKey,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -129,6 +174,9 @@ func (c *Client) sendClientKey() error {
|
||||
return writeFrame(c.bw, frameClientInfo, buf)
|
||||
}
|
||||
|
||||
// ServerPublicKey returns the server's public key.
|
||||
func (c *Client) ServerPublicKey() key.Public { return c.serverKey }
|
||||
|
||||
// Send sends a packet to the Tailscale node identified by dstKey.
|
||||
//
|
||||
// It is an error if the packet is larger than 64KB.
|
||||
@@ -160,6 +208,40 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
|
||||
return c.bw.Flush()
|
||||
}
|
||||
|
||||
func (c *Client) ForwardPacket(srcKey, dstKey key.Public, pkt []byte) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("derp.ForwardPacket: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(pkt) > MaxPacketSize {
|
||||
return fmt.Errorf("packet too big: %d", len(pkt))
|
||||
}
|
||||
|
||||
c.wmu.Lock()
|
||||
defer c.wmu.Unlock()
|
||||
|
||||
timer := time.AfterFunc(5*time.Second, c.writeTimeoutFired)
|
||||
defer timer.Stop()
|
||||
|
||||
if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.bw.Write(srcKey[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.bw.Write(dstKey[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.bw.Write(pkt); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.bw.Flush()
|
||||
}
|
||||
|
||||
func (c *Client) writeTimeoutFired() { c.nc.Close() }
|
||||
|
||||
// NotePreferred sends a packet that tells the server whether this
|
||||
// client is the user's preferred server. This is only used in the
|
||||
// server for stats.
|
||||
@@ -186,6 +268,25 @@ func (c *Client) NotePreferred(preferred bool) (err error) {
|
||||
return c.bw.Flush()
|
||||
}
|
||||
|
||||
// WatchConnectionChanges sends a request to subscribe to the peer's connection list.
|
||||
// It's a fatal error if the client wasn't created using MeshKey.
|
||||
func (c *Client) WatchConnectionChanges() error {
|
||||
c.wmu.Lock()
|
||||
defer c.wmu.Unlock()
|
||||
if err := writeFrameHeader(c.bw, frameWatchConns, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.bw.Flush()
|
||||
}
|
||||
|
||||
// ClosePeer asks the server to close target's TCP connection.
|
||||
// It's a fatal error if the client wasn't created using MeshKey.
|
||||
func (c *Client) ClosePeer(target key.Public) error {
|
||||
c.wmu.Lock()
|
||||
defer c.wmu.Unlock()
|
||||
return writeFrame(c.bw, frameClosePeer, target[:])
|
||||
}
|
||||
|
||||
// ReceivedMessage represents a type returned by Client.Recv. Unless
|
||||
// otherwise documented, the returned message aliases the byte slice
|
||||
// provided to Recv and thus the message is only as good as that
|
||||
@@ -211,11 +312,23 @@ type PeerGoneMessage key.Public
|
||||
|
||||
func (PeerGoneMessage) msg() {}
|
||||
|
||||
// PeerPresentMessage is a ReceivedMessage that indicates that the client
|
||||
// is connected to the server. (Only used by trusted mesh clients)
|
||||
type PeerPresentMessage key.Public
|
||||
|
||||
func (PeerPresentMessage) msg() {}
|
||||
|
||||
// Recv reads a message from the DERP server.
|
||||
// The provided buffer must be large enough to receive a complete packet,
|
||||
// which in practice are are 1.5-4 KB, but can be up to 64 KB.
|
||||
//
|
||||
// The returned message may alias memory owned by the Client; it
|
||||
// should only be accessed until the next call to Client.
|
||||
//
|
||||
// Once Recv returns an error, the Client is dead forever.
|
||||
func (c *Client) Recv(b []byte) (m ReceivedMessage, err error) {
|
||||
func (c *Client) Recv() (m ReceivedMessage, err error) {
|
||||
return c.recvTimeout(120 * time.Second)
|
||||
}
|
||||
|
||||
func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err error) {
|
||||
if c.readErr != nil {
|
||||
return nil, c.readErr
|
||||
}
|
||||
@@ -227,11 +340,45 @@ func (c *Client) Recv(b []byte) (m ReceivedMessage, err error) {
|
||||
}()
|
||||
|
||||
for {
|
||||
c.nc.SetReadDeadline(time.Now().Add(120 * time.Second))
|
||||
t, n, err := readFrame(c.br, 1<<20, b)
|
||||
c.nc.SetReadDeadline(time.Now().Add(timeout))
|
||||
|
||||
// Discard any peeked bytes from a previous Recv call.
|
||||
if c.peeked != 0 {
|
||||
if n, err := c.br.Discard(c.peeked); err != nil || n != c.peeked {
|
||||
// Documented to never fail, but might as well check.
|
||||
return nil, fmt.Errorf("bufio.Reader.Discard(%d bytes): got %v, %v", c.peeked, n, err)
|
||||
}
|
||||
c.peeked = 0
|
||||
}
|
||||
|
||||
t, n, err := readFrameHeader(c.br)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n > 1<<20 {
|
||||
return nil, fmt.Errorf("unexpectedly large frame of %d bytes returned", n)
|
||||
}
|
||||
|
||||
var b []byte // frame payload (past the 5 byte header)
|
||||
|
||||
// If the frame fits in our bufio.Reader buffer, just use it.
|
||||
// In practice it's 4KB (from derphttp.Client's bufio.NewReader(httpConn)) and
|
||||
// in practive, WireGuard packets (and thus DERP frames) are under 1.5KB.
|
||||
// So This is the common path.
|
||||
if int(n) <= c.br.Size() {
|
||||
b, err = c.br.Peek(int(n))
|
||||
c.peeked = int(n)
|
||||
} else {
|
||||
// But if for some reason we read a large DERP message (which isn't necessarily
|
||||
// a Wireguard packet), then just allocate memory for it.
|
||||
// TODO(bradfitz): use a pool if large frames ever happen in practice.
|
||||
b = make([]byte, n)
|
||||
_, err = io.ReadFull(c.br, b)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
default:
|
||||
continue
|
||||
@@ -248,6 +395,15 @@ func (c *Client) Recv(b []byte) (m ReceivedMessage, err error) {
|
||||
copy(pg[:], b[:keyLen])
|
||||
return pg, nil
|
||||
|
||||
case framePeerPresent:
|
||||
if n < keyLen {
|
||||
c.logf("[unexpected] dropping short peerPresent frame from DERP server")
|
||||
continue
|
||||
}
|
||||
var pg PeerPresentMessage
|
||||
copy(pg[:], b[:keyLen])
|
||||
return pg, nil
|
||||
|
||||
case frameRecvPacket:
|
||||
var rp ReceivedPacket
|
||||
if c.protoVersion < protocolSrcAddrs {
|
||||
|
||||
@@ -20,11 +20,13 @@ import (
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/metrics"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
@@ -37,41 +39,86 @@ const (
|
||||
writeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
const host64bit = (^uint(0) >> 32) & 1 // 1 on 64-bit, 0 on 32-bit
|
||||
|
||||
// pad32bit is 4 on 32-bit machines and 0 on 64-bit.
|
||||
// It exists so the Server struct's atomic fields can be aligned to 8
|
||||
// byte boundaries. (As tested by GOARCH=386 go test, etc)
|
||||
const pad32bit = 4 - host64bit*4 // 0 on 64-bit, 4 on 32-bit
|
||||
|
||||
// Server is a DERP server.
|
||||
type Server struct {
|
||||
// WriteTimeout, if non-zero, specifies how long to wait
|
||||
// before failing when writing to a client.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// OnlyDisco controls whether, for tests, non-discovery packets
|
||||
// are dropped. This is used by magicsock tests to verify that
|
||||
// NAT traversal works (using DERP for out-of-band messaging)
|
||||
// but the packets themselves aren't going via DERP.
|
||||
OnlyDisco bool
|
||||
_ [pad32bit]byte
|
||||
|
||||
privateKey key.Private
|
||||
publicKey key.Public
|
||||
logf logger.Logf
|
||||
memSys0 uint64 // runtime.MemStats.Sys at start (or early-ish)
|
||||
meshKey string
|
||||
|
||||
// Counters:
|
||||
packetsSent, bytesSent expvar.Int
|
||||
packetsRecv, bytesRecv expvar.Int
|
||||
packetsDropped expvar.Int
|
||||
packetsDroppedReason metrics.LabelMap
|
||||
packetsDroppedUnknown *expvar.Int // unknown dst pubkey
|
||||
packetsDroppedGone *expvar.Int // dst conn shutting down
|
||||
packetsDroppedQueueHead *expvar.Int // queue full, drop head packet
|
||||
packetsDroppedQueueTail *expvar.Int // queue full, drop tail packet
|
||||
packetsDroppedWrite *expvar.Int // error writing to dst conn
|
||||
peerGoneFrames expvar.Int // number of peer gone frames sent
|
||||
accepts expvar.Int
|
||||
curClients expvar.Int
|
||||
curHomeClients expvar.Int // ones with preferred
|
||||
clientsReplaced expvar.Int
|
||||
unknownFrames expvar.Int
|
||||
homeMovesIn expvar.Int // established clients announce home server moves in
|
||||
homeMovesOut expvar.Int // established clients announce home server moves out
|
||||
_ [pad32bit]byte
|
||||
packetsSent, bytesSent expvar.Int
|
||||
packetsRecv, bytesRecv expvar.Int
|
||||
packetsDropped expvar.Int
|
||||
packetsDroppedReason metrics.LabelMap
|
||||
packetsDroppedUnknown *expvar.Int // unknown dst pubkey
|
||||
packetsDroppedFwdUnknown *expvar.Int // unknown dst pubkey on forward
|
||||
packetsDroppedGone *expvar.Int // dst conn shutting down
|
||||
packetsDroppedQueueHead *expvar.Int // queue full, drop head packet
|
||||
packetsDroppedQueueTail *expvar.Int // queue full, drop tail packet
|
||||
packetsDroppedWrite *expvar.Int // error writing to dst conn
|
||||
_ [pad32bit]byte
|
||||
packetsForwardedOut expvar.Int
|
||||
packetsForwardedIn expvar.Int
|
||||
peerGoneFrames expvar.Int // number of peer gone frames sent
|
||||
accepts expvar.Int
|
||||
curClients expvar.Int
|
||||
curHomeClients expvar.Int // ones with preferred
|
||||
clientsReplaced expvar.Int
|
||||
unknownFrames expvar.Int
|
||||
homeMovesIn expvar.Int // established clients announce home server moves in
|
||||
homeMovesOut expvar.Int // established clients announce home server moves out
|
||||
multiForwarderCreated expvar.Int
|
||||
multiForwarderDeleted expvar.Int
|
||||
removePktForwardOther expvar.Int
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
netConns map[Conn]chan struct{} // chan is closed when conn closes
|
||||
clients map[key.Public]*sclient
|
||||
clientsEver map[key.Public]bool // never deleted from, for stats; fine for now
|
||||
watchers map[*sclient]bool // mesh peer -> true
|
||||
// clientsMesh tracks all clients in the cluster, both locally
|
||||
// and to mesh peers. If the value is nil, that means the
|
||||
// peer is only local (and thus in the clients Map, but not
|
||||
// remote). If the value is non-nil, it's remote (+ maybe also
|
||||
// local).
|
||||
clientsMesh map[key.Public]PacketForwarder
|
||||
// sentTo tracks which peers have sent to which other peers,
|
||||
// and at which connection number. This isn't on sclient
|
||||
// because it includes intra-region forwarded packets as the
|
||||
// src.
|
||||
sentTo map[key.Public]map[key.Public]int64 // src => dst => dst's latest sclient.connNum
|
||||
}
|
||||
|
||||
// PacketForwarder is something that can forward packets.
|
||||
//
|
||||
// It's mostly an inteface for circular dependency reasons; the
|
||||
// typical implementation is derphttp.Client. The other implementation
|
||||
// is a multiForwarder, which this package creates as needed if a
|
||||
// public key gets more than one PacketForwarder registered for it.
|
||||
type PacketForwarder interface {
|
||||
ForwardPacket(src, dst key.Public, payload []byte) error
|
||||
}
|
||||
|
||||
// Conn is the subset of the underlying net.Conn the DERP Server needs.
|
||||
@@ -97,12 +144,16 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server {
|
||||
publicKey: privateKey.Public(),
|
||||
logf: logf,
|
||||
packetsDroppedReason: metrics.LabelMap{Label: "reason"},
|
||||
clients: make(map[key.Public]*sclient),
|
||||
clientsEver: make(map[key.Public]bool),
|
||||
netConns: make(map[Conn]chan struct{}),
|
||||
clients: map[key.Public]*sclient{},
|
||||
clientsEver: map[key.Public]bool{},
|
||||
clientsMesh: map[key.Public]PacketForwarder{},
|
||||
netConns: map[Conn]chan struct{}{},
|
||||
memSys0: ms.Sys,
|
||||
watchers: map[*sclient]bool{},
|
||||
sentTo: map[key.Public]map[key.Public]int64{},
|
||||
}
|
||||
s.packetsDroppedUnknown = s.packetsDroppedReason.Get("unknown_dest")
|
||||
s.packetsDroppedFwdUnknown = s.packetsDroppedReason.Get("unknown_dest_on_fwd")
|
||||
s.packetsDroppedGone = s.packetsDroppedReason.Get("gone")
|
||||
s.packetsDroppedQueueHead = s.packetsDroppedReason.Get("queue_head")
|
||||
s.packetsDroppedQueueTail = s.packetsDroppedReason.Get("queue_tail")
|
||||
@@ -110,6 +161,26 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server {
|
||||
return s
|
||||
}
|
||||
|
||||
// SetMesh sets the pre-shared key that regional DERP servers used to mesh
|
||||
// amongst themselves.
|
||||
//
|
||||
// It must be called before serving begins.
|
||||
func (s *Server) SetMeshKey(v string) {
|
||||
s.meshKey = v
|
||||
}
|
||||
|
||||
// HasMeshKey reports whether the server is configured with a mesh key.
|
||||
func (s *Server) HasMeshKey() bool { return s.meshKey != "" }
|
||||
|
||||
// MeshKey returns the configured mesh key, if any.
|
||||
func (s *Server) MeshKey() string { return s.meshKey }
|
||||
|
||||
// PrivateKey returns the server's private key.
|
||||
func (s *Server) PrivateKey() key.Private { return s.privateKey }
|
||||
|
||||
// PublicKey returns the server's public key.
|
||||
func (s *Server) PublicKey() key.Public { return s.publicKey }
|
||||
|
||||
// Close closes the server and waits for the connections to disconnect.
|
||||
func (s *Server) Close() error {
|
||||
s.mu.Lock()
|
||||
@@ -187,7 +258,23 @@ func (s *Server) registerClient(c *sclient) {
|
||||
}
|
||||
s.clients[c.key] = c
|
||||
s.clientsEver[c.key] = true
|
||||
if _, ok := s.clientsMesh[c.key]; !ok {
|
||||
s.clientsMesh[c.key] = nil // just for varz of total users in cluster
|
||||
}
|
||||
s.curClients.Add(1)
|
||||
s.broadcastPeerStateChangeLocked(c.key, true)
|
||||
}
|
||||
|
||||
// broadcastPeerStateChangeLocked enqueues a message to all watchers
|
||||
// (other DERP nodes in the region, or trusted clients) that peer's
|
||||
// presence changed.
|
||||
//
|
||||
// s.mu must be held.
|
||||
func (s *Server) broadcastPeerStateChangeLocked(peer key.Public, present bool) {
|
||||
for w := range s.watchers {
|
||||
w.peerStateChange = append(w.peerStateChange, peerConnState{peer: peer, present: present})
|
||||
go w.requestMeshUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
// unregisterClient removes a client from the server.
|
||||
@@ -198,30 +285,66 @@ func (s *Server) unregisterClient(c *sclient) {
|
||||
if cur == c {
|
||||
c.logf("removing connection")
|
||||
delete(s.clients, c.key)
|
||||
if v, ok := s.clientsMesh[c.key]; ok && v == nil {
|
||||
delete(s.clientsMesh, c.key)
|
||||
s.notePeerGoneFromRegionLocked(c.key)
|
||||
}
|
||||
s.broadcastPeerStateChangeLocked(c.key, false)
|
||||
}
|
||||
if c.canMesh {
|
||||
delete(s.watchers, c)
|
||||
}
|
||||
|
||||
s.curClients.Add(-1)
|
||||
if c.preferred {
|
||||
s.curHomeClients.Add(-1)
|
||||
}
|
||||
}
|
||||
|
||||
// notePeerGoneFromRegionLocked sends peerGone frames to parties that
|
||||
// key has sent to previously (whether those sends were from a local
|
||||
// client or forwarded). It must only be called after the key has
|
||||
// been removed from clientsMesh.
|
||||
func (s *Server) notePeerGoneFromRegionLocked(key key.Public) {
|
||||
if _, ok := s.clientsMesh[key]; ok {
|
||||
panic("usage")
|
||||
}
|
||||
|
||||
// Find still-connected peers and either notify that we've gone away
|
||||
// so they can drop their route entries to us (issue 150)
|
||||
// or move them over to the active client (in case a replaced client
|
||||
// connection is being unregistered).
|
||||
for pubKey, connNum := range c.sentTo {
|
||||
for pubKey, connNum := range s.sentTo[key] {
|
||||
if peer, ok := s.clients[pubKey]; ok && peer.connNum == connNum {
|
||||
if cur == c {
|
||||
go peer.requestPeerGoneWrite(c.key)
|
||||
} else {
|
||||
// Only if the current client has not already accepted a newer
|
||||
// connection from the peer.
|
||||
if _, ok := cur.sentTo[pubKey]; !ok {
|
||||
cur.sentTo[pubKey] = connNum
|
||||
}
|
||||
}
|
||||
go peer.requestPeerGoneWrite(key)
|
||||
}
|
||||
}
|
||||
delete(s.sentTo, key)
|
||||
}
|
||||
|
||||
func (s *Server) addWatcher(c *sclient) {
|
||||
if !c.canMesh {
|
||||
panic("invariant: addWatcher called without permissions")
|
||||
}
|
||||
|
||||
if c.key == s.publicKey {
|
||||
// We're connecting to ourself. Do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Queue messages for each already-connected client.
|
||||
for peer := range s.clients {
|
||||
c.peerStateChange = append(c.peerStateChange, peerConnState{peer: peer, present: true})
|
||||
}
|
||||
|
||||
// And enroll the watcher in future updates (of both
|
||||
// connections & disconnections).
|
||||
s.watchers[c] = true
|
||||
|
||||
go c.requestMeshUpdate()
|
||||
}
|
||||
|
||||
func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string, connNum int64) error {
|
||||
@@ -258,7 +381,10 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string, connN
|
||||
connectedAt: time.Now(),
|
||||
sendQueue: make(chan pkt, perClientSendQueueDepth),
|
||||
peerGone: make(chan key.Public),
|
||||
sentTo: make(map[key.Public]int64),
|
||||
canMesh: clientInfo.MeshKey != "" && clientInfo.MeshKey == s.meshKey,
|
||||
}
|
||||
if c.canMesh {
|
||||
c.meshUpdate = make(chan struct{})
|
||||
}
|
||||
if clientInfo != nil {
|
||||
c.info = *clientInfo
|
||||
@@ -307,6 +433,12 @@ func (c *sclient) run(ctx context.Context) error {
|
||||
err = c.handleFrameNotePreferred(ft, fl)
|
||||
case frameSendPacket:
|
||||
err = c.handleFrameSendPacket(ft, fl)
|
||||
case frameForwardPacket:
|
||||
err = c.handleFrameForwardPacket(ft, fl)
|
||||
case frameWatchConns:
|
||||
err = c.handleFrameWatchConns(ft, fl)
|
||||
case frameClosePeer:
|
||||
err = c.handleFrameClosePeer(ft, fl)
|
||||
default:
|
||||
err = c.handleUnknownFrame(ft, fl)
|
||||
}
|
||||
@@ -333,6 +465,92 @@ func (c *sclient) handleFrameNotePreferred(ft frameType, fl uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sclient) handleFrameWatchConns(ft frameType, fl uint32) error {
|
||||
if fl != 0 {
|
||||
return fmt.Errorf("handleFrameWatchConns wrong size")
|
||||
}
|
||||
if !c.canMesh {
|
||||
return fmt.Errorf("insufficient permissions")
|
||||
}
|
||||
c.s.addWatcher(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error {
|
||||
if fl != keyLen {
|
||||
return fmt.Errorf("handleFrameClosePeer wrong size")
|
||||
}
|
||||
if !c.canMesh {
|
||||
return fmt.Errorf("insufficient permissions")
|
||||
}
|
||||
var targetKey key.Public
|
||||
if _, err := io.ReadFull(c.br, targetKey[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
s := c.s
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if target, ok := s.clients[targetKey]; ok {
|
||||
c.logf("frameClosePeer closing peer %x", targetKey)
|
||||
go target.nc.Close()
|
||||
} else {
|
||||
c.logf("frameClosePeer failed to find peer %x", targetKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFrameForwardPacket reads a "forward packet" frame from the client
|
||||
// (which must be a trusted client, a peer in our mesh).
|
||||
func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
|
||||
if !c.canMesh {
|
||||
return fmt.Errorf("insufficient permissions")
|
||||
}
|
||||
s := c.s
|
||||
|
||||
srcKey, dstKey, contents, err := s.recvForwardPacket(c.br, fl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("client %x: recvForwardPacket: %v", c.key, err)
|
||||
}
|
||||
s.packetsForwardedIn.Add(1)
|
||||
|
||||
s.mu.Lock()
|
||||
dst := s.clients[dstKey]
|
||||
if dst != nil {
|
||||
s.notePeerSendLocked(srcKey, dst)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if dst == nil {
|
||||
s.packetsDropped.Add(1)
|
||||
s.packetsDroppedFwdUnknown.Add(1)
|
||||
if debug {
|
||||
c.logf("dropping forwarded packet for unknown %x", dstKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.sendPkt(dst, pkt{
|
||||
bs: contents,
|
||||
src: srcKey,
|
||||
})
|
||||
}
|
||||
|
||||
// notePeerSendLocked records that src sent to dst. We keep track of
|
||||
// that so when src disconnects, we can tell dst (if it's still
|
||||
// around) that src is gone (a peerGone frame).
|
||||
func (s *Server) notePeerSendLocked(src key.Public, dst *sclient) {
|
||||
m, ok := s.sentTo[src]
|
||||
if !ok {
|
||||
m = map[key.Public]int64{}
|
||||
s.sentTo[src] = m
|
||||
}
|
||||
m[dst.key] = dst.connNum
|
||||
}
|
||||
|
||||
// handleFrameSendPacket reads a "send packet" frame from the client.
|
||||
func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
|
||||
s := c.s
|
||||
|
||||
@@ -341,17 +559,30 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
|
||||
return fmt.Errorf("client %x: recvPacket: %v", c.key, err)
|
||||
}
|
||||
|
||||
if s.OnlyDisco && !disco.LooksLikeDiscoWrapper(contents) {
|
||||
s.packetsDropped.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
var fwd PacketForwarder
|
||||
s.mu.Lock()
|
||||
dst := s.clients[dstKey]
|
||||
if dst != nil {
|
||||
// Track that we've sent to this peer, so if/when we
|
||||
// disconnect first, the server can inform all our old
|
||||
// recipients that we're gone. (Issue 150 optimization)
|
||||
c.sentTo[dstKey] = dst.connNum
|
||||
if dst == nil {
|
||||
fwd = s.clientsMesh[dstKey]
|
||||
} else {
|
||||
s.notePeerSendLocked(c.key, dst)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if dst == nil {
|
||||
if fwd != nil {
|
||||
s.packetsForwardedOut.Add(1)
|
||||
if err := fwd.ForwardPacket(c.key, dstKey, contents); err != nil {
|
||||
// TODO:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
s.packetsDropped.Add(1)
|
||||
s.packetsDroppedUnknown.Add(1)
|
||||
if debug {
|
||||
@@ -366,6 +597,13 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
|
||||
if dst.info.Version >= protocolSrcAddrs {
|
||||
p.src = c.key
|
||||
}
|
||||
return c.sendPkt(dst, p)
|
||||
}
|
||||
|
||||
func (c *sclient) sendPkt(dst *sclient, p pkt) error {
|
||||
s := c.s
|
||||
dstKey := dst.key
|
||||
|
||||
// Attempt to queue for sending up to 3 times. On each attempt, if
|
||||
// the queue is full, try to drop from queue head to prioritize
|
||||
// fresher packets.
|
||||
@@ -418,6 +656,16 @@ func (c *sclient) requestPeerGoneWrite(peer key.Public) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *sclient) requestMeshUpdate() {
|
||||
if !c.canMesh {
|
||||
panic("unexpected requestMeshUpdate")
|
||||
}
|
||||
select {
|
||||
case c.meshUpdate <- struct{}{}:
|
||||
case <-c.done:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) verifyClient(clientKey key.Public, info *clientInfo) error {
|
||||
// TODO(crawshaw): implement policy constraints on who can use the DERP server
|
||||
// TODO(bradfitz): ... and at what rate.
|
||||
@@ -464,60 +712,86 @@ func (s *Server) sendServerInfo(bw *bufio.Writer, clientKey key.Public) error {
|
||||
func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *clientInfo, err error) {
|
||||
fl, err := readFrameTypeHeader(br, frameClientInfo)
|
||||
if err != nil {
|
||||
return key.Public{}, nil, err
|
||||
return zpub, nil, err
|
||||
}
|
||||
const minLen = keyLen + nonceLen
|
||||
if fl < minLen {
|
||||
return key.Public{}, nil, errors.New("short client info")
|
||||
return zpub, nil, errors.New("short client info")
|
||||
}
|
||||
// We don't trust the client at all yet, so limit its input size to limit
|
||||
// things like JSON resource exhausting (http://github.com/golang/go/issues/31789).
|
||||
if fl > 256<<10 {
|
||||
return key.Public{}, nil, errors.New("long client info")
|
||||
return zpub, nil, errors.New("long client info")
|
||||
}
|
||||
if _, err := io.ReadFull(br, clientKey[:]); err != nil {
|
||||
return key.Public{}, nil, err
|
||||
return zpub, nil, err
|
||||
}
|
||||
var nonce [24]byte
|
||||
if _, err := io.ReadFull(br, nonce[:]); err != nil {
|
||||
return key.Public{}, nil, fmt.Errorf("nonce: %v", err)
|
||||
return zpub, nil, fmt.Errorf("nonce: %v", err)
|
||||
}
|
||||
msgLen := int(fl - minLen)
|
||||
msgbox := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(br, msgbox); err != nil {
|
||||
return key.Public{}, nil, fmt.Errorf("msgbox: %v", err)
|
||||
return zpub, nil, fmt.Errorf("msgbox: %v", err)
|
||||
}
|
||||
msg, ok := box.Open(nil, msgbox, &nonce, (*[32]byte)(&clientKey), s.privateKey.B32())
|
||||
if !ok {
|
||||
return key.Public{}, nil, fmt.Errorf("msgbox: cannot open len=%d with client key %x", msgLen, clientKey[:])
|
||||
return zpub, nil, fmt.Errorf("msgbox: cannot open len=%d with client key %x", msgLen, clientKey[:])
|
||||
}
|
||||
info = new(clientInfo)
|
||||
if err := json.Unmarshal(msg, info); err != nil {
|
||||
return key.Public{}, nil, fmt.Errorf("msg: %v", err)
|
||||
return zpub, nil, fmt.Errorf("msg: %v", err)
|
||||
}
|
||||
return clientKey, info, nil
|
||||
}
|
||||
|
||||
func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) {
|
||||
if frameLen < keyLen {
|
||||
return key.Public{}, nil, errors.New("short send packet frame")
|
||||
return zpub, nil, errors.New("short send packet frame")
|
||||
}
|
||||
if _, err := io.ReadFull(br, dstKey[:]); err != nil {
|
||||
return key.Public{}, nil, err
|
||||
return zpub, nil, err
|
||||
}
|
||||
packetLen := frameLen - keyLen
|
||||
if packetLen > MaxPacketSize {
|
||||
return key.Public{}, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
|
||||
return zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
|
||||
}
|
||||
contents = make([]byte, packetLen)
|
||||
if _, err := io.ReadFull(br, contents); err != nil {
|
||||
return key.Public{}, nil, err
|
||||
return zpub, nil, err
|
||||
}
|
||||
s.packetsRecv.Add(1)
|
||||
s.bytesRecv.Add(int64(len(contents)))
|
||||
return dstKey, contents, nil
|
||||
}
|
||||
|
||||
// zpub is the key.Public zero value.
|
||||
var zpub key.Public
|
||||
|
||||
func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.Public, contents []byte, err error) {
|
||||
if frameLen < keyLen*2 {
|
||||
return zpub, zpub, nil, errors.New("short send packet frame")
|
||||
}
|
||||
if _, err := io.ReadFull(br, srcKey[:]); err != nil {
|
||||
return zpub, zpub, nil, err
|
||||
}
|
||||
if _, err := io.ReadFull(br, dstKey[:]); err != nil {
|
||||
return zpub, zpub, nil, err
|
||||
}
|
||||
packetLen := frameLen - keyLen*2
|
||||
if packetLen > MaxPacketSize {
|
||||
return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
|
||||
}
|
||||
contents = make([]byte, packetLen)
|
||||
if _, err := io.ReadFull(br, contents); err != nil {
|
||||
return zpub, zpub, nil, err
|
||||
}
|
||||
// TODO: was s.packetsRecv.Add(1)
|
||||
// TODO: was s.bytesRecv.Add(int64(len(contents)))
|
||||
return srcKey, dstKey, contents, nil
|
||||
}
|
||||
|
||||
// sclient is a client connection to the server.
|
||||
//
|
||||
// (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go)
|
||||
@@ -532,7 +806,9 @@ type sclient struct {
|
||||
done <-chan struct{} // closed when connection closes
|
||||
remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String()
|
||||
sendQueue chan pkt // packets queued to this client; never closed
|
||||
peerGone chan key.Public // write request that a previous sender has disconnected
|
||||
peerGone chan key.Public // write request that a previous sender has disconnected (not used by mesh peers)
|
||||
meshUpdate chan struct{} // write request to write peerStateChange
|
||||
canMesh bool // clientInfo had correct mesh token for inter-region routing
|
||||
|
||||
// Owned by run, not thread-safe.
|
||||
br *bufio.Reader
|
||||
@@ -542,11 +818,20 @@ type sclient struct {
|
||||
// Owned by sender, not thread-safe.
|
||||
bw *bufio.Writer
|
||||
|
||||
// Guarded by s.mu.
|
||||
// Guarded by s.mu
|
||||
//
|
||||
// sentTo tracks all the peers this client has ever sent a packet to, and at which
|
||||
// connection number.
|
||||
sentTo map[key.Public]int64 // recipient => rcpt's latest sclient.connNum
|
||||
// peerStateChange is used by mesh peers (a set of regional
|
||||
// DERP servers) and contains records that need to be sent to
|
||||
// the client for them to update their map of who's connected
|
||||
// to this node.
|
||||
peerStateChange []peerConnState
|
||||
}
|
||||
|
||||
// peerConnState represents whether a peer is connected to the server
|
||||
// or not.
|
||||
type peerConnState struct {
|
||||
peer key.Public
|
||||
present bool
|
||||
}
|
||||
|
||||
// pkt is a request to write a data frame to an sclient.
|
||||
@@ -628,6 +913,9 @@ func (c *sclient) sendLoop(ctx context.Context) error {
|
||||
case peer := <-c.peerGone:
|
||||
werr = c.sendPeerGone(peer)
|
||||
continue
|
||||
case <-c.meshUpdate:
|
||||
werr = c.sendMeshUpdates()
|
||||
continue
|
||||
case msg := <-c.sendQueue:
|
||||
werr = c.sendPacket(msg.src, msg.bs)
|
||||
continue
|
||||
@@ -648,6 +936,9 @@ func (c *sclient) sendLoop(ctx context.Context) error {
|
||||
return nil
|
||||
case peer := <-c.peerGone:
|
||||
werr = c.sendPeerGone(peer)
|
||||
case <-c.meshUpdate:
|
||||
werr = c.sendMeshUpdates()
|
||||
continue
|
||||
case msg := <-c.sendQueue:
|
||||
werr = c.sendPacket(msg.src, msg.bs)
|
||||
case <-keepAliveTick.C:
|
||||
@@ -677,6 +968,59 @@ func (c *sclient) sendPeerGone(peer key.Public) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// sendPeerPresent sends a peerPresent frame, without flushing.
|
||||
func (c *sclient) sendPeerPresent(peer key.Public) error {
|
||||
c.setWriteDeadline()
|
||||
if err := writeFrameHeader(c.bw, framePeerPresent, keyLen); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := c.bw.Write(peer[:])
|
||||
return err
|
||||
}
|
||||
|
||||
// sendMeshUpdates drains as many mesh peerStateChange entries as
|
||||
// possible into the write buffer WITHOUT flushing or otherwise
|
||||
// blocking (as it holds c.s.mu while working). If it can't drain them
|
||||
// all, it schedules itself to be called again in the future.
|
||||
func (c *sclient) sendMeshUpdates() error {
|
||||
c.s.mu.Lock()
|
||||
defer c.s.mu.Unlock()
|
||||
|
||||
writes := 0
|
||||
for _, pcs := range c.peerStateChange {
|
||||
if c.bw.Available() <= frameHeaderLen+keyLen {
|
||||
break
|
||||
}
|
||||
var err error
|
||||
if pcs.present {
|
||||
err = c.sendPeerPresent(pcs.peer)
|
||||
} else {
|
||||
err = c.sendPeerGone(pcs.peer)
|
||||
}
|
||||
if err != nil {
|
||||
// Shouldn't happen, though, as we're writing
|
||||
// into available buffer space, not the
|
||||
// network.
|
||||
return err
|
||||
}
|
||||
writes++
|
||||
}
|
||||
|
||||
remain := copy(c.peerStateChange, c.peerStateChange[writes:])
|
||||
c.peerStateChange = c.peerStateChange[:remain]
|
||||
|
||||
// Did we manage to write them all into the bufio buffer without flushing?
|
||||
if len(c.peerStateChange) == 0 {
|
||||
if cap(c.peerStateChange) > 16 {
|
||||
c.peerStateChange = nil
|
||||
}
|
||||
} else {
|
||||
// Didn't finish in the buffer space provided; schedule a future run.
|
||||
go c.requestMeshUpdate()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendPacket writes contents to the client in a RecvPacket frame. If
|
||||
// srcKey.IsZero, uses the old DERPv1 framing format, otherwise uses
|
||||
// DERPv2. The bytes of contents are only valid until this function
|
||||
@@ -716,6 +1060,114 @@ func (c *sclient) sendPacket(srcKey key.Public, contents []byte) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// AddPacketForwarder registers fwd as a packet forwarder for dst.
|
||||
// fwd must be comparable.
|
||||
func (s *Server) AddPacketForwarder(dst key.Public, fwd PacketForwarder) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if prev, ok := s.clientsMesh[dst]; ok {
|
||||
if prev == fwd {
|
||||
// Duplicate registration of same forwarder. Ignore.
|
||||
return
|
||||
}
|
||||
if m, ok := prev.(multiForwarder); ok {
|
||||
if _, ok := m[fwd]; !ok {
|
||||
// Duplicate registration of same forwarder in set; ignore.
|
||||
return
|
||||
}
|
||||
m[fwd] = m.maxVal() + 1
|
||||
return
|
||||
}
|
||||
if prev != nil {
|
||||
// Otherwise, the existing value is not a set,
|
||||
// not a dup, and not local-only (nil) so make
|
||||
// it a set.
|
||||
fwd = multiForwarder{
|
||||
prev: 1, // existed 1st, higher priority
|
||||
fwd: 2, // the passed in fwd is in 2nd place
|
||||
}
|
||||
s.multiForwarderCreated.Add(1)
|
||||
}
|
||||
}
|
||||
s.clientsMesh[dst] = fwd
|
||||
}
|
||||
|
||||
// RemovePacketForwarder removes fwd as a packet forwarder for dst.
|
||||
// fwd must be comparable.
|
||||
func (s *Server) RemovePacketForwarder(dst key.Public, fwd PacketForwarder) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, ok := s.clientsMesh[dst]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if m, ok := v.(multiForwarder); ok {
|
||||
if len(m) < 2 {
|
||||
panic("unexpected")
|
||||
}
|
||||
delete(m, fwd)
|
||||
// If fwd was in m and we no longer need to be a
|
||||
// multiForwarder, replace the entry with the
|
||||
// remaining PacketForwarder.
|
||||
if len(m) == 1 {
|
||||
var remain PacketForwarder
|
||||
for k := range m {
|
||||
remain = k
|
||||
}
|
||||
s.clientsMesh[dst] = remain
|
||||
s.multiForwarderDeleted.Add(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
if v != fwd {
|
||||
s.removePktForwardOther.Add(1)
|
||||
// Delete of an entry that wasn't in the
|
||||
// map. Harmless, so ignore.
|
||||
// (This might happen if a user is moving around
|
||||
// between nodes and/or the server sent duplicate
|
||||
// connection change broadcasts.)
|
||||
return
|
||||
}
|
||||
|
||||
if _, isLocal := s.clients[dst]; isLocal {
|
||||
s.clientsMesh[dst] = nil
|
||||
} else {
|
||||
delete(s.clientsMesh, dst)
|
||||
s.notePeerGoneFromRegionLocked(dst)
|
||||
}
|
||||
}
|
||||
|
||||
// multiForwarder is a PacketForwarder that represents a set of
|
||||
// forwarding options. It's used in the rare cases that a client is
|
||||
// connected to multiple DERP nodes in a region. That shouldn't really
|
||||
// happen except for perhaps during brief moments while the client is
|
||||
// reconfiguring, in which case we don't want to forget where the
|
||||
// client is. The map value is unique connection number; the lowest
|
||||
// one has been seen the longest. It's used to make sure we forward
|
||||
// packets consistently to the same node and don't pick randomly.
|
||||
type multiForwarder map[PacketForwarder]uint8
|
||||
|
||||
func (m multiForwarder) maxVal() (max uint8) {
|
||||
for _, v := range m {
|
||||
if v > max {
|
||||
max = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m multiForwarder) ForwardPacket(src, dst key.Public, payload []byte) error {
|
||||
var fwd PacketForwarder
|
||||
var lowest uint8
|
||||
for k, v := range m {
|
||||
if fwd == nil || v < lowest {
|
||||
fwd = k
|
||||
lowest = v
|
||||
}
|
||||
}
|
||||
return fwd.ForwardPacket(src, dst, payload)
|
||||
}
|
||||
|
||||
func (s *Server) expVarFunc(f func() interface{}) expvar.Func {
|
||||
return expvar.Func(func() interface{} {
|
||||
s.mu.Lock()
|
||||
@@ -729,8 +1181,12 @@ func (s *Server) ExpVar() expvar.Var {
|
||||
m := new(metrics.Set)
|
||||
m.Set("counter_unique_clients_ever", s.expVarFunc(func() interface{} { return len(s.clientsEver) }))
|
||||
m.Set("gauge_memstats_sys0", expvar.Func(func() interface{} { return int64(s.memSys0) }))
|
||||
m.Set("gauge_current_connnections", &s.curClients)
|
||||
m.Set("gauge_current_home_connnections", &s.curHomeClients)
|
||||
m.Set("gauge_watchers", s.expVarFunc(func() interface{} { return len(s.watchers) }))
|
||||
m.Set("gauge_current_connections", &s.curClients)
|
||||
m.Set("gauge_current_home_connections", &s.curHomeClients)
|
||||
m.Set("gauge_clients_total", expvar.Func(func() interface{} { return len(s.clientsMesh) }))
|
||||
m.Set("gauge_clients_local", expvar.Func(func() interface{} { return len(s.clients) }))
|
||||
m.Set("gauge_clients_remote", expvar.Func(func() interface{} { return len(s.clientsMesh) - len(s.clients) }))
|
||||
m.Set("accepts", &s.accepts)
|
||||
m.Set("clients_replaced", &s.clientsReplaced)
|
||||
m.Set("bytes_received", &s.bytesRecv)
|
||||
@@ -743,5 +1199,49 @@ func (s *Server) ExpVar() expvar.Var {
|
||||
m.Set("home_moves_in", &s.homeMovesIn)
|
||||
m.Set("home_moves_out", &s.homeMovesOut)
|
||||
m.Set("peer_gone_frames", &s.peerGoneFrames)
|
||||
m.Set("packets_forwarded_out", &s.packetsForwardedOut)
|
||||
m.Set("packets_forwarded_in", &s.packetsForwardedIn)
|
||||
m.Set("multiforwarder_created", &s.multiForwarderCreated)
|
||||
m.Set("multiforwarder_deleted", &s.multiForwarderDeleted)
|
||||
m.Set("packet_forwarder_delete_other_value", &s.removePktForwardOther)
|
||||
return m
|
||||
}
|
||||
|
||||
func (s *Server) ConsistencyCheck() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var errs []string
|
||||
|
||||
var nilMeshNotInClient int
|
||||
for k, f := range s.clientsMesh {
|
||||
if f == nil {
|
||||
if _, ok := s.clients[k]; !ok {
|
||||
nilMeshNotInClient++
|
||||
}
|
||||
}
|
||||
}
|
||||
if nilMeshNotInClient != 0 {
|
||||
errs = append(errs, fmt.Sprintf("%d s.clientsMesh keys not in s.clients", nilMeshNotInClient))
|
||||
}
|
||||
|
||||
var clientNotInMesh int
|
||||
for k := range s.clients {
|
||||
if _, ok := s.clientsMesh[k]; !ok {
|
||||
clientNotInMesh++
|
||||
}
|
||||
}
|
||||
if clientNotInMesh != 0 {
|
||||
errs = append(errs, fmt.Sprintf("%d s.clients keys not in s.clientsMesh", clientNotInMesh))
|
||||
}
|
||||
|
||||
if s.curClients.Value() != int64(len(s.clients)) {
|
||||
errs = append(errs, fmt.Sprintf("expvar connections = %d != clients map says of %d",
|
||||
s.curClients.Value(),
|
||||
len(s.clients)))
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return errors.New(strings.Join(errs, ", "))
|
||||
}
|
||||
|
||||
@@ -13,11 +13,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/nettest"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
func newPrivateKey(t *testing.T) (k key.Private) {
|
||||
@@ -87,8 +90,7 @@ func TestSendRecv(t *testing.T) {
|
||||
for i := 0; i < numClients; i++ {
|
||||
go func(i int) {
|
||||
for {
|
||||
b := make([]byte, 1<<16)
|
||||
m, err := clients[i].Recv(b)
|
||||
m, err := clients[i].Recv()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
@@ -103,7 +105,7 @@ func TestSendRecv(t *testing.T) {
|
||||
if m.Source.IsZero() {
|
||||
t.Errorf("zero Source address in ReceivedPacket")
|
||||
}
|
||||
recvChs[i] <- m.Data
|
||||
recvChs[i] <- append([]byte(nil), m.Data...)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
@@ -256,8 +258,7 @@ func TestSendFreeze(t *testing.T) {
|
||||
recv := func(name string, client *Client) {
|
||||
ch := chs(name)
|
||||
for {
|
||||
b := make([]byte, 1<<9)
|
||||
m, err := client.Recv(b)
|
||||
m, err := client.Recv()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s: %w", name, err)
|
||||
return
|
||||
@@ -391,3 +392,353 @@ func TestSendFreeze(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
s *Server
|
||||
ln net.Listener
|
||||
logf logger.Logf
|
||||
|
||||
mu sync.Mutex
|
||||
pubName map[key.Public]string
|
||||
clients map[*testClient]bool
|
||||
}
|
||||
|
||||
func (ts *testServer) addTestClient(c *testClient) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.clients[c] = true
|
||||
}
|
||||
|
||||
func (ts *testServer) addKeyName(k key.Public, name string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.pubName[k] = name
|
||||
ts.logf("test adding named key %q for %x", name, k)
|
||||
}
|
||||
|
||||
func (ts *testServer) keyName(k key.Public) string {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if name, ok := ts.pubName[k]; ok {
|
||||
return name
|
||||
}
|
||||
return k.ShortString()
|
||||
}
|
||||
|
||||
func (ts *testServer) close(t *testing.T) error {
|
||||
ts.ln.Close()
|
||||
ts.s.Close()
|
||||
for c := range ts.clients {
|
||||
c.close(t)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *testServer {
|
||||
t.Helper()
|
||||
logf := logger.WithPrefix(t.Logf, "derp-server: ")
|
||||
s := NewServer(newPrivateKey(t), logf)
|
||||
s.SetMeshKey("mesh-key")
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
i := 0
|
||||
for {
|
||||
i++
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// TODO: register c in ts so Close also closes it?
|
||||
go func(i int) {
|
||||
brwServer := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
|
||||
go s.Accept(c, brwServer, fmt.Sprintf("test-client-%d", i))
|
||||
}(i)
|
||||
}
|
||||
}()
|
||||
return &testServer{
|
||||
s: s,
|
||||
ln: ln,
|
||||
logf: logf,
|
||||
clients: map[*testClient]bool{},
|
||||
pubName: map[key.Public]string{},
|
||||
}
|
||||
}
|
||||
|
||||
type testClient struct {
|
||||
name string
|
||||
c *Client
|
||||
nc net.Conn
|
||||
pub key.Public
|
||||
ts *testServer
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net.Conn, key.Private, logger.Logf) (*Client, error)) *testClient {
|
||||
t.Helper()
|
||||
nc, err := net.Dial("tcp", ts.ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key := newPrivateKey(t)
|
||||
ts.addKeyName(key.Public(), name)
|
||||
c, err := newClient(nc, key, logger.WithPrefix(t.Logf, "client-"+name+": "))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tc := &testClient{
|
||||
name: name,
|
||||
nc: nc,
|
||||
c: c,
|
||||
ts: ts,
|
||||
pub: key.Public(),
|
||||
}
|
||||
ts.addTestClient(tc)
|
||||
return tc
|
||||
}
|
||||
|
||||
func newRegularClient(t *testing.T, ts *testServer, name string) *testClient {
|
||||
return newTestClient(t, ts, name, func(nc net.Conn, priv key.Private, logf logger.Logf) (*Client, error) {
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc))
|
||||
return NewClient(priv, nc, brw, logf)
|
||||
})
|
||||
}
|
||||
|
||||
func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient {
|
||||
return newTestClient(t, ts, name, func(nc net.Conn, priv key.Private, logf logger.Logf) (*Client, error) {
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc))
|
||||
c, err := NewClient(priv, nc, brw, logf, MeshKey("mesh-key"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.WatchConnectionChanges(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) {
|
||||
t.Helper()
|
||||
want := map[key.Public]bool{}
|
||||
for _, k := range peers {
|
||||
want[k] = true
|
||||
}
|
||||
|
||||
for {
|
||||
m, err := tc.c.recvTimeout(time.Second)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
switch m := m.(type) {
|
||||
case PeerPresentMessage:
|
||||
got := key.Public(m)
|
||||
if !want[got] {
|
||||
t.Fatalf("got peer present for %v; want present for %v", tc.ts.keyName(got), logger.ArgWriter(func(bw *bufio.Writer) {
|
||||
for _, pub := range peers {
|
||||
fmt.Fprintf(bw, "%s ", tc.ts.keyName(pub))
|
||||
}
|
||||
}))
|
||||
}
|
||||
delete(want, got)
|
||||
if len(want) == 0 {
|
||||
return
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unexpected message type %T", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tc *testClient) wantGone(t *testing.T, peer key.Public) {
|
||||
t.Helper()
|
||||
m, err := tc.c.recvTimeout(time.Second)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
switch m := m.(type) {
|
||||
case PeerGoneMessage:
|
||||
got := key.Public(m)
|
||||
if peer != got {
|
||||
t.Errorf("got gone message for %v; want gone for %v", tc.ts.keyName(got), tc.ts.keyName(peer))
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unexpected message type %T", m)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *testClient) close(t *testing.T) {
|
||||
t.Helper()
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
c.closed = true
|
||||
t.Logf("closing client %q (%x)", c.name, c.pub)
|
||||
c.nc.Close()
|
||||
}
|
||||
|
||||
// TestWatch tests the connection watcher mechanism used by regional
|
||||
// DERP nodes to mesh up with each other.
|
||||
func TestWatch(t *testing.T) {
|
||||
ts := newTestServer(t)
|
||||
defer ts.close(t)
|
||||
|
||||
w1 := newTestWatcher(t, ts, "w1")
|
||||
w1.wantPresent(t, w1.pub)
|
||||
|
||||
c1 := newRegularClient(t, ts, "c1")
|
||||
w1.wantPresent(t, c1.pub)
|
||||
|
||||
c2 := newRegularClient(t, ts, "c2")
|
||||
w1.wantPresent(t, c2.pub)
|
||||
|
||||
w2 := newTestWatcher(t, ts, "w2")
|
||||
w1.wantPresent(t, w2.pub)
|
||||
w2.wantPresent(t, w1.pub, w2.pub, c1.pub, c2.pub)
|
||||
|
||||
c3 := newRegularClient(t, ts, "c3")
|
||||
w1.wantPresent(t, c3.pub)
|
||||
w2.wantPresent(t, c3.pub)
|
||||
|
||||
c2.close(t)
|
||||
w1.wantGone(t, c2.pub)
|
||||
w2.wantGone(t, c2.pub)
|
||||
|
||||
w3 := newTestWatcher(t, ts, "w3")
|
||||
w1.wantPresent(t, w3.pub)
|
||||
w2.wantPresent(t, w3.pub)
|
||||
w3.wantPresent(t, c1.pub, c3.pub, w1.pub, w2.pub, w3.pub)
|
||||
|
||||
c1.close(t)
|
||||
w1.wantGone(t, c1.pub)
|
||||
w2.wantGone(t, c1.pub)
|
||||
w3.wantGone(t, c1.pub)
|
||||
}
|
||||
|
||||
type testFwd int
|
||||
|
||||
func (testFwd) ForwardPacket(key.Public, key.Public, []byte) error { panic("not called in tests") }
|
||||
|
||||
func pubAll(b byte) (ret key.Public) {
|
||||
for i := range ret {
|
||||
ret[i] = b
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestForwarderRegistration(t *testing.T) {
|
||||
s := &Server{
|
||||
clients: make(map[key.Public]*sclient),
|
||||
clientsMesh: map[key.Public]PacketForwarder{},
|
||||
}
|
||||
want := func(want map[key.Public]PacketForwarder) {
|
||||
t.Helper()
|
||||
if got := s.clientsMesh; !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want)
|
||||
}
|
||||
}
|
||||
wantCounter := func(c *expvar.Int, want int) {
|
||||
t.Helper()
|
||||
if got := c.Value(); got != int64(want) {
|
||||
t.Errorf("counter = %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
u1 := pubAll(1)
|
||||
u2 := pubAll(2)
|
||||
u3 := pubAll(3)
|
||||
|
||||
s.AddPacketForwarder(u1, testFwd(1))
|
||||
s.AddPacketForwarder(u2, testFwd(2))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: testFwd(1),
|
||||
u2: testFwd(2),
|
||||
})
|
||||
|
||||
// Verify a remove of non-registered forwarder is no-op.
|
||||
s.RemovePacketForwarder(u2, testFwd(999))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: testFwd(1),
|
||||
u2: testFwd(2),
|
||||
})
|
||||
|
||||
// Verify a remove of non-registered user is no-op.
|
||||
s.RemovePacketForwarder(u3, testFwd(1))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: testFwd(1),
|
||||
u2: testFwd(2),
|
||||
})
|
||||
|
||||
// Actual removal.
|
||||
s.RemovePacketForwarder(u2, testFwd(2))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: testFwd(1),
|
||||
})
|
||||
|
||||
// Adding a dup for a user.
|
||||
wantCounter(&s.multiForwarderCreated, 0)
|
||||
s.AddPacketForwarder(u1, testFwd(100))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: multiForwarder{
|
||||
testFwd(1): 1,
|
||||
testFwd(100): 2,
|
||||
},
|
||||
})
|
||||
wantCounter(&s.multiForwarderCreated, 1)
|
||||
|
||||
// Removing a forwarder in a multi set that doesn't exist; does nothing.
|
||||
s.RemovePacketForwarder(u1, testFwd(55))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: multiForwarder{
|
||||
testFwd(1): 1,
|
||||
testFwd(100): 2,
|
||||
},
|
||||
})
|
||||
|
||||
// Removing a forwarder in a multi set that does exist should collapse it away
|
||||
// from being a multiForwarder.
|
||||
wantCounter(&s.multiForwarderDeleted, 0)
|
||||
s.RemovePacketForwarder(u1, testFwd(1))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: testFwd(100),
|
||||
})
|
||||
wantCounter(&s.multiForwarderDeleted, 1)
|
||||
|
||||
// Removing an entry for a client that's still connected locally should result
|
||||
// in a nil forwarder.
|
||||
u1c := &sclient{
|
||||
key: u1,
|
||||
logf: logger.Discard,
|
||||
}
|
||||
s.clients[u1] = u1c
|
||||
s.RemovePacketForwarder(u1, testFwd(100))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: nil,
|
||||
})
|
||||
|
||||
// But once that client disconnects, it should go away.
|
||||
s.unregisterClient(u1c)
|
||||
want(map[key.Public]PacketForwarder{})
|
||||
|
||||
// But if it already has a forwarder, it's not removed.
|
||||
s.AddPacketForwarder(u1, testFwd(2))
|
||||
s.unregisterClient(u1c)
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: testFwd(2),
|
||||
})
|
||||
|
||||
// Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard
|
||||
// that they're also connected to a peer of ours. That sholdn't transition the forwarder
|
||||
// from nil to the new one, not a multiForwarder.
|
||||
s.clients[u1] = u1c
|
||||
s.clientsMesh[u1] = nil
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: nil,
|
||||
})
|
||||
s.AddPacketForwarder(u1, testFwd(3))
|
||||
want(map[key.Public]PacketForwarder{
|
||||
u1: testFwd(3),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ import (
|
||||
type Client struct {
|
||||
TLSConfig *tls.Config // optional; nil means default
|
||||
DNSCache *dnscache.Resolver // optional; nil means no caching
|
||||
MeshKey string // optional; for trusted clients
|
||||
|
||||
privateKey key.Private
|
||||
logf logger.Logf
|
||||
@@ -54,11 +55,13 @@ type Client struct {
|
||||
ctx context.Context // closed via cancelCtx in Client.Close
|
||||
cancelCtx context.CancelFunc
|
||||
|
||||
mu sync.Mutex
|
||||
preferred bool
|
||||
closed bool
|
||||
netConn io.Closer
|
||||
client *derp.Client
|
||||
mu sync.Mutex
|
||||
preferred bool
|
||||
closed bool
|
||||
netConn io.Closer
|
||||
client *derp.Client
|
||||
connGen int // incremented once per new connection; valid values are >0
|
||||
serverPubKey key.Public
|
||||
}
|
||||
|
||||
// NewRegionClient returns a new DERP-over-HTTP client. It connects lazily.
|
||||
@@ -106,10 +109,20 @@ func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Cli
|
||||
// Connect connects or reconnects to the server, unless already connected.
|
||||
// It returns nil if there was already a good connection, or if one was made.
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
_, err := c.connect(ctx, "derphttp.Client.Connect")
|
||||
_, _, err := c.connect(ctx, "derphttp.Client.Connect")
|
||||
return err
|
||||
}
|
||||
|
||||
// ServerPublicKey returns the server's public key.
|
||||
//
|
||||
// It only returns a non-zero value once a connection has succeeded
|
||||
// from an earlier call.
|
||||
func (c *Client) ServerPublicKey() key.Public {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.serverPubKey
|
||||
}
|
||||
|
||||
func urlPort(u *url.URL) string {
|
||||
if p := u.Port(); p != "" {
|
||||
return p
|
||||
@@ -152,14 +165,14 @@ func (c *Client) urlString(node *tailcfg.DERPNode) string {
|
||||
return fmt.Sprintf("https://%s/derp", node.HostName)
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) {
|
||||
func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, connGen int, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return nil, ErrClientClosed
|
||||
return nil, 0, ErrClientClosed
|
||||
}
|
||||
if c.client != nil {
|
||||
return c.client, nil
|
||||
return c.client, c.connGen, nil
|
||||
}
|
||||
|
||||
// timeout is the fallback maximum time (if ctx doesn't limit
|
||||
@@ -185,7 +198,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
|
||||
if c.getRegion != nil {
|
||||
reg = c.getRegion()
|
||||
if reg == nil {
|
||||
return nil, errors.New("DERP region not available")
|
||||
return nil, 0, errors.New("DERP region not available")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,7 +225,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
|
||||
tcpConn, node, err = c.dialRegion(ctx, reg)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Now that we have a TCP connection, force close it if the
|
||||
@@ -250,42 +263,44 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
|
||||
|
||||
req, err := http.NewRequest("GET", c.urlString(node), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Upgrade", "DERP")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
|
||||
if err := req.Write(brw); err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(brw.Reader, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
b, _ := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("GET failed: %v: %s", err, b)
|
||||
return nil, 0, fmt.Errorf("GET failed: %v: %s", err, b)
|
||||
}
|
||||
|
||||
derpClient, err := derp.NewClient(c.privateKey, httpConn, brw, c.logf)
|
||||
derpClient, err := derp.NewClient(c.privateKey, httpConn, brw, c.logf, derp.MeshKey(c.MeshKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
if c.preferred {
|
||||
if err := derpClient.NotePreferred(true); err != nil {
|
||||
go httpConn.Close()
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
c.serverPubKey = derpClient.ServerPublicKey()
|
||||
c.client = derpClient
|
||||
c.netConn = tcpConn
|
||||
return c.client, nil
|
||||
c.connGen++
|
||||
return c.client, c.connGen, nil
|
||||
}
|
||||
|
||||
func (c *Client) dialURL(ctx context.Context) (net.Conn, error) {
|
||||
@@ -323,6 +338,9 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C
|
||||
var firstErr error
|
||||
for _, n := range reg.Nodes {
|
||||
if n.STUNOnly {
|
||||
if firstErr == nil {
|
||||
firstErr = fmt.Errorf("no non-STUNOnly nodes for %s", c.targetString(reg))
|
||||
}
|
||||
continue
|
||||
}
|
||||
c, err := c.dialNode(ctx, n)
|
||||
@@ -463,7 +481,7 @@ func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, e
|
||||
}
|
||||
|
||||
func (c *Client) Send(dstKey key.Public, b []byte) error {
|
||||
client, err := c.connect(context.TODO(), "derphttp.Client.Send")
|
||||
client, _, err := c.connect(context.TODO(), "derphttp.Client.Send")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -473,6 +491,17 @@ func (c *Client) Send(dstKey key.Public, b []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) ForwardPacket(from, to key.Public, b []byte) error {
|
||||
client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.ForwardPacket(from, to, b); err != nil {
|
||||
c.closeForReconnect(client)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// NotePreferred notes whether this Client is the caller's preferred
|
||||
// (home) DERP node. It's only used for stats.
|
||||
func (c *Client) NotePreferred(v bool) {
|
||||
@@ -492,18 +521,58 @@ func (c *Client) NotePreferred(v bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Recv(b []byte) (derp.ReceivedMessage, error) {
|
||||
client, err := c.connect(context.TODO(), "derphttp.Client.Recv")
|
||||
// WatchConnectionChanges sends a request to subscribe to
|
||||
// notifications about clients connecting & disconnecting.
|
||||
//
|
||||
// Only trusted connections (using MeshKey) are allowed to use this.
|
||||
func (c *Client) WatchConnectionChanges() error {
|
||||
client, _, err := c.connect(context.TODO(), "derphttp.Client.WatchConnectionChanges")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
m, err := client.Recv(b)
|
||||
err = client.WatchConnectionChanges()
|
||||
if err != nil {
|
||||
c.closeForReconnect(client)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ClosePeer asks the server to close target's TCP connection.
|
||||
//
|
||||
// Only trusted connections (using MeshKey) are allowed to use this.
|
||||
func (c *Client) ClosePeer(target key.Public) error {
|
||||
client, _, err := c.connect(context.TODO(), "derphttp.Client.ClosePeer")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = client.ClosePeer(target)
|
||||
if err != nil {
|
||||
c.closeForReconnect(client)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Recv reads a message from c. The returned message may alias memory from Client.
|
||||
// The message should only be used until the next Client call.
|
||||
func (c *Client) Recv() (derp.ReceivedMessage, error) {
|
||||
m, _, err := c.RecvDetail()
|
||||
return m, err
|
||||
}
|
||||
|
||||
// RecvDetail is like Recv, but additional returns the connection generation on each message.
|
||||
// The connGen value is incremented every time the derphttp.Client reconnects to the server.
|
||||
func (c *Client) RecvDetail() (m derp.ReceivedMessage, connGen int, err error) {
|
||||
client, connGen, err := c.connect(context.TODO(), "derphttp.Client.Recv")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
m, err = client.Recv()
|
||||
if err != nil {
|
||||
c.closeForReconnect(client)
|
||||
}
|
||||
return m, connGen, err
|
||||
}
|
||||
|
||||
// Close closes the client. It will not automatically reconnect after
|
||||
// being closed.
|
||||
func (c *Client) Close() error {
|
||||
|
||||
@@ -93,8 +93,7 @@ func TestSendRecv(t *testing.T) {
|
||||
return
|
||||
default:
|
||||
}
|
||||
b := make([]byte, 1<<16)
|
||||
m, err := c.Recv(b)
|
||||
m, err := c.Recv()
|
||||
if err != nil {
|
||||
t.Logf("client%d: %v", i, err)
|
||||
break
|
||||
@@ -106,7 +105,7 @@ func TestSendRecv(t *testing.T) {
|
||||
case derp.PeerGoneMessage:
|
||||
// Ignore.
|
||||
case derp.ReceivedPacket:
|
||||
recvChs[i] <- m.Data
|
||||
recvChs[i] <- append([]byte(nil), m.Data...)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
|
||||
122
derp/derphttp/mesh_client.go
Normal file
122
derp/derphttp/mesh_client.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package derphttp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// RunWatchConnectionLoop loops forever, sending WatchConnectionChanges and subscribing to
|
||||
// connection changes.
|
||||
//
|
||||
// If the server's public key is ignoreServerKey, RunWatchConnectionLoop returns.
|
||||
//
|
||||
// Otherwise, the add and remove funcs are called as clients come & go.
|
||||
func (c *Client) RunWatchConnectionLoop(ignoreServerKey key.Public, add, remove func(key.Public)) {
|
||||
logf := c.logf
|
||||
const retryInterval = 5 * time.Second
|
||||
const statusInterval = 10 * time.Second
|
||||
var (
|
||||
mu sync.Mutex
|
||||
present = map[key.Public]bool{}
|
||||
loggedConnected = false
|
||||
)
|
||||
clear := func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(present) == 0 {
|
||||
return
|
||||
}
|
||||
logf("reconnected; clearing %d forwarding mappings", len(present))
|
||||
for k := range present {
|
||||
remove(k)
|
||||
}
|
||||
present = map[key.Public]bool{}
|
||||
}
|
||||
lastConnGen := 0
|
||||
lastStatus := time.Now()
|
||||
logConnectedLocked := func() {
|
||||
if loggedConnected {
|
||||
return
|
||||
}
|
||||
logf("connected; %d peers", len(present))
|
||||
loggedConnected = true
|
||||
}
|
||||
|
||||
const logConnectedDelay = 200 * time.Millisecond
|
||||
timer := time.AfterFunc(2*time.Second, func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
logConnectedLocked()
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
||||
updatePeer := func(k key.Public, isPresent bool) {
|
||||
if isPresent {
|
||||
add(k)
|
||||
} else {
|
||||
remove(k)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if isPresent {
|
||||
present[k] = true
|
||||
if !loggedConnected {
|
||||
timer.Reset(logConnectedDelay)
|
||||
}
|
||||
} else {
|
||||
// If we got a peerGone message, that means the initial connection's
|
||||
// flood of peerPresent messages is done, so we can log already:
|
||||
logConnectedLocked()
|
||||
delete(present, k)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
err := c.WatchConnectionChanges()
|
||||
if err != nil {
|
||||
clear()
|
||||
logf("WatchConnectionChanges: %v", err)
|
||||
time.Sleep(retryInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
if c.ServerPublicKey() == ignoreServerKey {
|
||||
logf("detected self-connect; ignoring host")
|
||||
return
|
||||
}
|
||||
for {
|
||||
m, connGen, err := c.RecvDetail()
|
||||
if err != nil {
|
||||
clear()
|
||||
logf("Recv: %v", err)
|
||||
time.Sleep(retryInterval)
|
||||
break
|
||||
}
|
||||
if connGen != lastConnGen {
|
||||
lastConnGen = connGen
|
||||
clear()
|
||||
}
|
||||
switch m := m.(type) {
|
||||
case derp.PeerPresentMessage:
|
||||
updatePeer(key.Public(m), true)
|
||||
case derp.PeerGoneMessage:
|
||||
updatePeer(key.Public(m), false)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if now := time.Now(); now.Sub(lastStatus) > statusInterval {
|
||||
lastStatus = now
|
||||
logf("%d peers", len(present))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
179
disco/disco.go
Normal file
179
disco/disco.go
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package disco contains the discovery message types.
|
||||
//
|
||||
// A discovery message is:
|
||||
//
|
||||
// Header:
|
||||
// magic [6]byte // “TS💬” (0x54 53 f0 9f 92 ac)
|
||||
// senderDiscoPub [32]byte // nacl public key
|
||||
// nonce [24]byte
|
||||
//
|
||||
// The recipient then decrypts the bytes following (the nacl secretbox)
|
||||
// and then the inner payload structure is:
|
||||
//
|
||||
// messageType byte (the MessageType constants below)
|
||||
// messageVersion byte (0 for now; but always ignore bytes at the end)
|
||||
// message-paylod [...]byte
|
||||
package disco
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// Magic is the 6 byte header of all discovery messages.
|
||||
const Magic = "TS💬" // 6 bytes: 0x54 53 f0 9f 92 ac
|
||||
|
||||
const keyLen = 32
|
||||
|
||||
// NonceLen is the length of the nonces used by nacl secretboxes.
|
||||
const NonceLen = 24
|
||||
|
||||
type MessageType byte
|
||||
|
||||
const (
|
||||
TypePing = MessageType(0x01)
|
||||
TypePong = MessageType(0x02)
|
||||
TypeCallMeMaybe = MessageType(0x03)
|
||||
)
|
||||
|
||||
const v0 = byte(0)
|
||||
|
||||
var errShort = errors.New("short message")
|
||||
|
||||
// LooksLikeDiscoWrapper reports whether p looks like it's a packet
|
||||
// containing an encrypted disco message.
|
||||
func LooksLikeDiscoWrapper(p []byte) bool {
|
||||
if len(p) < len(Magic)+keyLen+NonceLen {
|
||||
return false
|
||||
}
|
||||
return string(p[:len(Magic)]) == Magic
|
||||
}
|
||||
|
||||
// Parse parses the encrypted part of the message from inside the
|
||||
// nacl secretbox.
|
||||
func Parse(p []byte) (Message, error) {
|
||||
if len(p) < 2 {
|
||||
return nil, errShort
|
||||
}
|
||||
t, ver, p := MessageType(p[0]), p[1], p[2:]
|
||||
switch t {
|
||||
case TypePing:
|
||||
return parsePing(ver, p)
|
||||
case TypePong:
|
||||
return parsePong(ver, p)
|
||||
case TypeCallMeMaybe:
|
||||
return CallMeMaybe{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type 0x%02x", byte(t))
|
||||
}
|
||||
}
|
||||
|
||||
// Message a discovery message.
|
||||
type Message interface {
|
||||
// AppendMarshal appends the message's marshaled representation.
|
||||
AppendMarshal([]byte) []byte
|
||||
}
|
||||
|
||||
// appendMsgHeader appends two bytes (for t and ver) and then also
|
||||
// dataLen bytes to b, returning the appended slice in all. The
|
||||
// returned data slice is a subslice of all with just dataLen bytes of
|
||||
// where the caller will fill in the data.
|
||||
func appendMsgHeader(b []byte, t MessageType, ver uint8, dataLen int) (all, data []byte) {
|
||||
// TODO: optimize this?
|
||||
all = append(b, make([]byte, dataLen+2)...)
|
||||
all[len(b)] = byte(t)
|
||||
all[len(b)+1] = ver
|
||||
data = all[len(b)+2:]
|
||||
return
|
||||
}
|
||||
|
||||
type Ping struct {
|
||||
TxID [12]byte
|
||||
}
|
||||
|
||||
func (m *Ping) AppendMarshal(b []byte) []byte {
|
||||
ret, d := appendMsgHeader(b, TypePing, v0, 12)
|
||||
copy(d, m.TxID[:])
|
||||
return ret
|
||||
}
|
||||
|
||||
func parsePing(ver uint8, p []byte) (m *Ping, err error) {
|
||||
if len(p) < 12 {
|
||||
return nil, errShort
|
||||
}
|
||||
m = new(Ping)
|
||||
copy(m.TxID[:], p)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// CallMeMaybe is a message sent only over DERP to request that the recipient try
|
||||
// to open up a magicsock path back to the sender.
|
||||
//
|
||||
// The sender should've already sent UDP packets to the peer to open
|
||||
// up the stateful firewall mappings inbound.
|
||||
//
|
||||
// The recipient may choose to not open a path back, if it's already
|
||||
// happy with its path. But usually it will.
|
||||
type CallMeMaybe struct{}
|
||||
|
||||
func (CallMeMaybe) AppendMarshal(b []byte) []byte {
|
||||
ret, _ := appendMsgHeader(b, TypeCallMeMaybe, v0, 0)
|
||||
return ret
|
||||
}
|
||||
|
||||
// Pong is a response a Ping.
|
||||
//
|
||||
// It includes the sender's source IP + port, so it's effectively a
|
||||
// STUN response.
|
||||
type Pong struct {
|
||||
TxID [12]byte
|
||||
Src netaddr.IPPort // 18 bytes (16+2) on the wire; v4-mapped ipv6 for IPv4
|
||||
}
|
||||
|
||||
const pongLen = 12 + 16 + 2
|
||||
|
||||
func (m *Pong) AppendMarshal(b []byte) []byte {
|
||||
ret, d := appendMsgHeader(b, TypePong, v0, pongLen)
|
||||
d = d[copy(d, m.TxID[:]):]
|
||||
ip16 := m.Src.IP.As16()
|
||||
d = d[copy(d, ip16[:]):]
|
||||
binary.BigEndian.PutUint16(d, m.Src.Port)
|
||||
return ret
|
||||
}
|
||||
|
||||
func parsePong(ver uint8, p []byte) (m *Pong, err error) {
|
||||
if len(p) < pongLen {
|
||||
return nil, errShort
|
||||
}
|
||||
m = new(Pong)
|
||||
copy(m.TxID[:], p)
|
||||
p = p[12:]
|
||||
|
||||
m.Src.IP, _ = netaddr.FromStdIP(net.IP(p[:16]))
|
||||
p = p[16:]
|
||||
|
||||
m.Src.Port = binary.BigEndian.Uint16(p)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// MessageSummary returns a short summary of m for logging purposes.
|
||||
func MessageSummary(m Message) string {
|
||||
switch m := m.(type) {
|
||||
case *Ping:
|
||||
return fmt.Sprintf("ping tx=%x", m.TxID[:6])
|
||||
case *Pong:
|
||||
return fmt.Sprintf("pong tx=%x", m.TxID[:6])
|
||||
case CallMeMaybe:
|
||||
return "call-me-maybe"
|
||||
default:
|
||||
return fmt.Sprintf("%#v", m)
|
||||
}
|
||||
}
|
||||
82
disco/disco_test.go
Normal file
82
disco/disco_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package disco
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
func TestMarshalAndParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want string
|
||||
m Message
|
||||
}{
|
||||
{
|
||||
name: "ping",
|
||||
m: &Ping{
|
||||
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
||||
},
|
||||
want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c",
|
||||
},
|
||||
{
|
||||
name: "pong",
|
||||
m: &Pong{
|
||||
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
||||
Src: mustIPPort("2.3.4.5:1234"),
|
||||
},
|
||||
want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2",
|
||||
},
|
||||
{
|
||||
name: "pongv6",
|
||||
m: &Pong{
|
||||
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
||||
Src: mustIPPort("[fed0::12]:6666"),
|
||||
},
|
||||
want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a",
|
||||
},
|
||||
{
|
||||
name: "call_me_maybe",
|
||||
m: CallMeMaybe{},
|
||||
want: "03 00",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
foo := []byte("foo")
|
||||
got := string(tt.m.AppendMarshal(foo))
|
||||
if !strings.HasPrefix(got, "foo") {
|
||||
t.Fatalf("didn't start with foo: got %q", got)
|
||||
}
|
||||
got = strings.TrimPrefix(got, "foo")
|
||||
|
||||
gotHex := fmt.Sprintf("% x", got)
|
||||
if gotHex != tt.want {
|
||||
t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want)
|
||||
}
|
||||
|
||||
back, err := Parse([]byte(got))
|
||||
if err != nil {
|
||||
t.Fatalf("parse back: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(back, tt.m) {
|
||||
t.Errorf("message in %+v doesn't match Parse back result %+v", tt.m, back)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustIPPort(s string) netaddr.IPPort {
|
||||
ipp, err := netaddr.ParseIPPort(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ipp
|
||||
}
|
||||
15
go.mod
15
go.mod
@@ -1,6 +1,6 @@
|
||||
module tailscale.com
|
||||
|
||||
go 1.13
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239 // indirect
|
||||
@@ -9,25 +9,30 @@ require (
|
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
|
||||
github.com/gliderlabs/ssh v0.2.2
|
||||
github.com/go-ole/go-ole v1.2.4
|
||||
github.com/godbus/dbus/v5 v5.0.3
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e
|
||||
github.com/google/go-cmp v0.4.0
|
||||
github.com/goreleaser/nfpm v1.1.10
|
||||
github.com/klauspost/compress v1.9.8
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4
|
||||
github.com/klauspost/compress v1.10.10
|
||||
github.com/kr/pty v1.1.1
|
||||
github.com/mdlayher/netlink v1.1.0
|
||||
github.com/miekg/dns v1.1.30
|
||||
github.com/pborman/getopt v0.0.0-20190409184431-ee0cd42419d3
|
||||
github.com/peterbourgon/ff/v2 v2.0.0
|
||||
github.com/tailscale/winipcfg-go v0.0.0-20200413171540-609dcf2df55f
|
||||
github.com/tailscale/wireguard-go v0.0.0-20200515231107-62868271d710
|
||||
github.com/tailscale/wireguard-go v0.0.0-20200724155040-d554a2a5e7e1
|
||||
github.com/tcnksm/go-httpstat v0.2.0
|
||||
github.com/toqueteos/webbrowser v1.2.0
|
||||
go4.org/mem v0.0.0-20200601023850-d8ee1dfa5518
|
||||
go4.org/mem v0.0.0-20200706164138-185c595c3ecc
|
||||
golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6
|
||||
golang.org/x/net v0.0.0-20200301022130-244492dfa37a
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e
|
||||
golang.org/x/sys v0.0.0-20200501052902-10377860bb8e
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
|
||||
inet.af/netaddr v0.0.0-20200513162223-787f13e36cbe
|
||||
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425
|
||||
honnef.co/go/tools v0.0.1-2020.1.4
|
||||
inet.af/netaddr v0.0.0-20200718043157-99321d6ad24c
|
||||
rsc.io/goversion v1.2.0
|
||||
)
|
||||
|
||||
39
go.sum
39
go.sum
@@ -30,6 +30,8 @@ github.com/gliderlabs/ssh v0.2.2 h1:6zsha5zo/TWhRhwqCD3+EarCAgZ2yN28ipRnGPnwkI0=
|
||||
github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
|
||||
github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI=
|
||||
github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM=
|
||||
github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME=
|
||||
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
|
||||
@@ -38,6 +40,7 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/rpmpack v0.0.0-20191226140753-aa36bfddb3a0 h1:BW6OvS3kpT5UEPbCZ+KyX/OB4Ks9/MNMhWjqPPkZxsE=
|
||||
github.com/google/rpmpack v0.0.0-20191226140753-aa36bfddb3a0/go.mod h1:RaTPr0KUf2K7fnZYLNDrr8rxAamWs3iNywJLtQ2AzBg=
|
||||
github.com/goreleaser/nfpm v1.1.10 h1:0nwzKUJTcygNxTzVKq2Dh9wpVP1W2biUH6SNKmoxR3w=
|
||||
@@ -47,8 +50,9 @@ github.com/imdario/mergo v0.3.8/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJ
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4 h1:nwOc1YaOrYJ37sEBrtWZrdqzK22hiJs3GpDmP3sR2Yw=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
|
||||
github.com/klauspost/compress v1.9.8 h1:VMAMUUOh+gaxKTMk+zqbjsSjsIcUcL/LF4o63i82QyA=
|
||||
github.com/klauspost/compress v1.9.8/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.10.10 h1:a/y8CglcM7gLGYmlbP/stPE5sR3hbhFRUjCBfd/0B3I=
|
||||
github.com/klauspost/compress v1.10.10/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=
|
||||
@@ -61,6 +65,8 @@ github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE
|
||||
github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M=
|
||||
github.com/mdlayher/netlink v1.1.0 h1:mpdLgm+brq10nI9zM1BpX1kpDbh3NLl3RSnVq6ZSkfg=
|
||||
github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY=
|
||||
github.com/miekg/dns v1.1.30 h1:Qww6FseFn8PRfw07jueqIXqodm0JKiiKuK0DeXSqfyo=
|
||||
github.com/miekg/dns v1.1.30/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
||||
github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 h1:lDH9UUVJtmYCjyT0CI4q8xvlXPxeZ0gYCVvWbmPlp88=
|
||||
github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk=
|
||||
github.com/pborman/getopt v0.0.0-20190409184431-ee0cd42419d3 h1:YtFkrqsMEj7YqpIhRteVxJxCeC3jJBieuLr0d4C4rSA=
|
||||
@@ -72,6 +78,7 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/sassoftware/go-rpmutils v0.0.0-20190420191620-a8f1baeba37b h1:+gCnWOZV8Z/8jehJ2CdqB47Z3S+SREmQcuXkRFLNsiI=
|
||||
github.com/sassoftware/go-rpmutils v0.0.0-20190420191620-a8f1baeba37b/go.mod h1:am+Fp8Bt506lA3Rk3QCmSqmYmLMnPDhdDUcosQCAx+I=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -79,8 +86,6 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/tailscale/winipcfg-go v0.0.0-20200413171540-609dcf2df55f h1:uFj5bslHsMzxIM8UTjAhq4VXeo6GfNW91rpoh/WMJaY=
|
||||
github.com/tailscale/winipcfg-go v0.0.0-20200413171540-609dcf2df55f/go.mod h1:x880GWw5fvrl2DVTQ04ttXQD4DuppTt1Yz6wLibbjNE=
|
||||
github.com/tailscale/wireguard-go v0.0.0-20200515231107-62868271d710 h1:I6aq3tOYbZob9uwhGpr7R266qTeU9PFqS6NnpfCqEzo=
|
||||
github.com/tailscale/wireguard-go v0.0.0-20200515231107-62868271d710/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4=
|
||||
github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0=
|
||||
github.com/tcnksm/go-httpstat v0.2.0/go.mod h1:s3JVJFtQxtBEBC9dwcdTTXS9xFnM3SXAZwPG41aurT8=
|
||||
github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9rrstGQ=
|
||||
@@ -89,17 +94,23 @@ github.com/ulikunitz/xz v0.5.6 h1:jGHAfXawEGZQ3blwU5wnWKQJvAraT7Ftq9EXjnXYgt8=
|
||||
github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8=
|
||||
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo=
|
||||
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos=
|
||||
go4.org/mem v0.0.0-20200601023850-d8ee1dfa5518 h1:AA3bSGklCgkrqIGnvL4894oa/2K9ltE0RejXh8CgyvA=
|
||||
go4.org/mem v0.0.0-20200601023850-d8ee1dfa5518/go.mod h1:NEYvpHWemiG/E5UWfaN5QAIGZeT1sa0Z2UNk6oeMb/k=
|
||||
go4.org/mem v0.0.0-20200706164138-185c595c3ecc h1:paujszgN6SpsO/UsXC7xax3gQAKz/XQKCYZLQdU34Tw=
|
||||
go4.org/mem v0.0.0-20200706164138-185c595c3ecc/go.mod h1:NEYvpHWemiG/E5UWfaN5QAIGZeT1sa0Z2UNk6oeMb/k=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6 h1:TjszyFsQsyZNHwdVdZ5m7bjmreu0znc2kRYsEml9/Ww=
|
||||
golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191007182048-72f939374954/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI=
|
||||
@@ -110,6 +121,7 @@ golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BG
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -119,6 +131,7 @@ golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 h1:LfCXLvNmTYH9kEmVgqbnsWfruoXZIrh4YBgqVHtDvw0=
|
||||
@@ -131,18 +144,28 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d h1:/iIZNFGxc/a7C3yWjGcnboV+Tkc7mxr+p6fDztwoxuM=
|
||||
golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA=
|
||||
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
||||
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
inet.af/netaddr v0.0.0-20200513162223-787f13e36cbe h1:WjJ6wZhXEWQA3FFSwOjG8tO2q1NDFSqrUwNcTvxwMEQ=
|
||||
inet.af/netaddr v0.0.0-20200513162223-787f13e36cbe/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww=
|
||||
honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8=
|
||||
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||
inet.af/netaddr v0.0.0-20200718043157-99321d6ad24c h1:si3Owrfem175Ry6gKqnh59eOXxDojyBTIHxUKuvK/Eo=
|
||||
inet.af/netaddr v0.0.0-20200718043157-99321d6ad24c/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww=
|
||||
rsc.io/goversion v1.2.0 h1:SPn+NLTiAG7w30IRK/DKp1BjvpWabYgxlLp/+kx5J8w=
|
||||
rsc.io/goversion v1.2.0/go.mod h1:Eih9y/uIBS3ulggl7KNJ09xGSLcuNaLgmvvqa07sgfo=
|
||||
|
||||
93
internal/deepprint/deepprint.go
Normal file
93
internal/deepprint/deepprint.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package deepprint walks a Go value recursively, in a predictable
|
||||
// order, without looping, and prints each value out to a given
|
||||
// Writer, which is assumed to be a hash.Hash, as this package doesn't
|
||||
// format things nicely.
|
||||
//
|
||||
// This is intended as a lighter version of go-spew, etc. We don't need its
|
||||
// features when our writer is just a hash.
|
||||
package deepprint
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func Hash(v interface{}) string {
|
||||
h := sha256.New()
|
||||
Print(h, v)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
func Print(w io.Writer, v interface{}) {
|
||||
print(w, reflect.ValueOf(v), make(map[uintptr]bool))
|
||||
}
|
||||
|
||||
func print(w io.Writer, v reflect.Value, visited map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
}
|
||||
switch v.Kind() {
|
||||
default:
|
||||
panic(fmt.Sprintf("unhandled kind %v for type %v", v.Kind(), v.Type()))
|
||||
case reflect.Ptr:
|
||||
ptr := v.Pointer()
|
||||
if visited[ptr] {
|
||||
return
|
||||
}
|
||||
visited[ptr] = true
|
||||
print(w, v.Elem(), visited)
|
||||
return
|
||||
case reflect.Struct:
|
||||
fmt.Fprintf(w, "struct{\n")
|
||||
t := v.Type()
|
||||
for i, n := 0, v.NumField(); i < n; i++ {
|
||||
sf := t.Field(i)
|
||||
fmt.Fprintf(w, "%s: ", sf.Name)
|
||||
print(w, v.Field(i), visited)
|
||||
fmt.Fprintf(w, "\n")
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
if v.Type().Elem().Kind() == reflect.Uint8 && v.CanInterface() {
|
||||
fmt.Fprintf(w, "%q", v.Interface())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "[%d]{\n", v.Len())
|
||||
for i, ln := 0, v.Len(); i < ln; i++ {
|
||||
fmt.Fprintf(w, " [%d]: ", i)
|
||||
print(w, v.Index(i), visited)
|
||||
fmt.Fprintf(w, "\n")
|
||||
}
|
||||
fmt.Fprintf(w, "}\n")
|
||||
case reflect.Interface:
|
||||
print(w, v.Elem(), visited)
|
||||
case reflect.Map:
|
||||
sm := newSortedMap(v)
|
||||
fmt.Fprintf(w, "map[%d]{\n", len(sm.Key))
|
||||
for i, k := range sm.Key {
|
||||
print(w, k, visited)
|
||||
fmt.Fprintf(w, ": ")
|
||||
print(w, sm.Value[i], visited)
|
||||
fmt.Fprintf(w, "\n")
|
||||
}
|
||||
fmt.Fprintf(w, "}\n")
|
||||
|
||||
case reflect.String:
|
||||
fmt.Fprintf(w, "%s", v.String())
|
||||
case reflect.Bool:
|
||||
fmt.Fprintf(w, "%v", v.Bool())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
fmt.Fprintf(w, "%v", v.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
fmt.Fprintf(w, "%v", v.Uint())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
fmt.Fprintf(w, "%v", v.Float())
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
fmt.Fprintf(w, "%v", v.Complex())
|
||||
}
|
||||
}
|
||||
70
internal/deepprint/deepprint_test.go
Normal file
70
internal/deepprint/deepprint_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package deepprint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/wgengine/router"
|
||||
)
|
||||
|
||||
func TestDeepPrint(t *testing.T) {
|
||||
// v contains the types of values we care about for our current callers.
|
||||
// Mostly we're just testing that we don't panic on handled types.
|
||||
v := getVal()
|
||||
|
||||
var buf bytes.Buffer
|
||||
Print(&buf, v)
|
||||
t.Logf("Got: %s", buf.Bytes())
|
||||
|
||||
hash1 := Hash(v)
|
||||
t.Logf("hash: %v", hash1)
|
||||
for i := 0; i < 20; i++ {
|
||||
hash2 := Hash(getVal())
|
||||
if hash1 != hash2 {
|
||||
t.Error("second hash didn't match")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getVal() []interface{} {
|
||||
return []interface{}{
|
||||
&wgcfg.Config{
|
||||
Name: "foo",
|
||||
Addresses: []wgcfg.CIDR{{Mask: 5, IP: wgcfg.IP{Addr: [16]byte{3: 3}}}},
|
||||
ListenPort: 5,
|
||||
Peers: []wgcfg.Peer{
|
||||
{
|
||||
Endpoints: []wgcfg.Endpoint{
|
||||
{
|
||||
Host: "foo",
|
||||
Port: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
&router.Config{
|
||||
DNSConfig: router.DNSConfig{
|
||||
Nameservers: []netaddr.IP{netaddr.IPv4(8, 8, 8, 8)},
|
||||
Domains: []string{"tailscale.net"},
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"key1": "val1",
|
||||
"key2": "val2",
|
||||
"key3": "val3",
|
||||
"key4": "val4",
|
||||
"key5": "val5",
|
||||
"key6": "val6",
|
||||
"key7": "val7",
|
||||
"key8": "val8",
|
||||
"key9": "val9",
|
||||
},
|
||||
}
|
||||
}
|
||||
224
internal/deepprint/fmtsort.go
Normal file
224
internal/deepprint/fmtsort.go
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// and
|
||||
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This is a slightly modified fork of Go's src/internal/fmtsort/sort.go
|
||||
|
||||
package deepprint
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Note: Throughout this package we avoid calling reflect.Value.Interface as
|
||||
// it is not always legal to do so and it's easier to avoid the issue than to face it.
|
||||
|
||||
// sortedMap represents a map's keys and values. The keys and values are
|
||||
// aligned in index order: Value[i] is the value in the map corresponding to Key[i].
|
||||
type sortedMap struct {
|
||||
Key []reflect.Value
|
||||
Value []reflect.Value
|
||||
}
|
||||
|
||||
func (o *sortedMap) Len() int { return len(o.Key) }
|
||||
func (o *sortedMap) Less(i, j int) bool { return compare(o.Key[i], o.Key[j]) < 0 }
|
||||
func (o *sortedMap) Swap(i, j int) {
|
||||
o.Key[i], o.Key[j] = o.Key[j], o.Key[i]
|
||||
o.Value[i], o.Value[j] = o.Value[j], o.Value[i]
|
||||
}
|
||||
|
||||
// Sort accepts a map and returns a sortedMap that has the same keys and
|
||||
// values but in a stable sorted order according to the keys, modulo issues
|
||||
// raised by unorderable key values such as NaNs.
|
||||
//
|
||||
// The ordering rules are more general than with Go's < operator:
|
||||
//
|
||||
// - when applicable, nil compares low
|
||||
// - ints, floats, and strings order by <
|
||||
// - NaN compares less than non-NaN floats
|
||||
// - bool compares false before true
|
||||
// - complex compares real, then imag
|
||||
// - pointers compare by machine address
|
||||
// - channel values compare by machine address
|
||||
// - structs compare each field in turn
|
||||
// - arrays compare each element in turn.
|
||||
// Otherwise identical arrays compare by length.
|
||||
// - interface values compare first by reflect.Type describing the concrete type
|
||||
// and then by concrete value as described in the previous rules.
|
||||
//
|
||||
func newSortedMap(mapValue reflect.Value) *sortedMap {
|
||||
if mapValue.Type().Kind() != reflect.Map {
|
||||
return nil
|
||||
}
|
||||
// Note: this code is arranged to not panic even in the presence
|
||||
// of a concurrent map update. The runtime is responsible for
|
||||
// yelling loudly if that happens. See issue 33275.
|
||||
n := mapValue.Len()
|
||||
key := make([]reflect.Value, 0, n)
|
||||
value := make([]reflect.Value, 0, n)
|
||||
iter := mapValue.MapRange()
|
||||
for iter.Next() {
|
||||
key = append(key, iter.Key())
|
||||
value = append(value, iter.Value())
|
||||
}
|
||||
sorted := &sortedMap{
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
sort.Stable(sorted)
|
||||
return sorted
|
||||
}
|
||||
|
||||
// compare compares two values of the same type. It returns -1, 0, 1
|
||||
// according to whether a > b (1), a == b (0), or a < b (-1).
|
||||
// If the types differ, it returns -1.
|
||||
// See the comment on Sort for the comparison rules.
|
||||
func compare(aVal, bVal reflect.Value) int {
|
||||
aType, bType := aVal.Type(), bVal.Type()
|
||||
if aType != bType {
|
||||
return -1 // No good answer possible, but don't return 0: they're not equal.
|
||||
}
|
||||
switch aVal.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
a, b := aVal.Int(), bVal.Int()
|
||||
switch {
|
||||
case a < b:
|
||||
return -1
|
||||
case a > b:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
a, b := aVal.Uint(), bVal.Uint()
|
||||
switch {
|
||||
case a < b:
|
||||
return -1
|
||||
case a > b:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
case reflect.String:
|
||||
a, b := aVal.String(), bVal.String()
|
||||
switch {
|
||||
case a < b:
|
||||
return -1
|
||||
case a > b:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return floatCompare(aVal.Float(), bVal.Float())
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
a, b := aVal.Complex(), bVal.Complex()
|
||||
if c := floatCompare(real(a), real(b)); c != 0 {
|
||||
return c
|
||||
}
|
||||
return floatCompare(imag(a), imag(b))
|
||||
case reflect.Bool:
|
||||
a, b := aVal.Bool(), bVal.Bool()
|
||||
switch {
|
||||
case a == b:
|
||||
return 0
|
||||
case a:
|
||||
return 1
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
case reflect.Ptr:
|
||||
a, b := aVal.Pointer(), bVal.Pointer()
|
||||
switch {
|
||||
case a < b:
|
||||
return -1
|
||||
case a > b:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
case reflect.Chan:
|
||||
if c, ok := nilCompare(aVal, bVal); ok {
|
||||
return c
|
||||
}
|
||||
ap, bp := aVal.Pointer(), bVal.Pointer()
|
||||
switch {
|
||||
case ap < bp:
|
||||
return -1
|
||||
case ap > bp:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
case reflect.Struct:
|
||||
for i := 0; i < aVal.NumField(); i++ {
|
||||
if c := compare(aVal.Field(i), bVal.Field(i)); c != 0 {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return 0
|
||||
case reflect.Array:
|
||||
for i := 0; i < aVal.Len(); i++ {
|
||||
if c := compare(aVal.Index(i), bVal.Index(i)); c != 0 {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return 0
|
||||
case reflect.Interface:
|
||||
if c, ok := nilCompare(aVal, bVal); ok {
|
||||
return c
|
||||
}
|
||||
c := compare(reflect.ValueOf(aVal.Elem().Type()), reflect.ValueOf(bVal.Elem().Type()))
|
||||
if c != 0 {
|
||||
return c
|
||||
}
|
||||
return compare(aVal.Elem(), bVal.Elem())
|
||||
default:
|
||||
// Certain types cannot appear as keys (maps, funcs, slices), but be explicit.
|
||||
panic("bad type in compare: " + aType.String())
|
||||
}
|
||||
}
|
||||
|
||||
// nilCompare checks whether either value is nil. If not, the boolean is false.
|
||||
// If either value is nil, the boolean is true and the integer is the comparison
|
||||
// value. The comparison is defined to be 0 if both are nil, otherwise the one
|
||||
// nil value compares low. Both arguments must represent a chan, func,
|
||||
// interface, map, pointer, or slice.
|
||||
func nilCompare(aVal, bVal reflect.Value) (int, bool) {
|
||||
if aVal.IsNil() {
|
||||
if bVal.IsNil() {
|
||||
return 0, true
|
||||
}
|
||||
return -1, true
|
||||
}
|
||||
if bVal.IsNil() {
|
||||
return 1, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// floatCompare compares two floating-point values. NaNs compare low.
|
||||
func floatCompare(a, b float64) int {
|
||||
switch {
|
||||
case isNaN(a):
|
||||
return -1 // No good answer if b is a NaN so don't bother checking.
|
||||
case isNaN(b):
|
||||
return 1
|
||||
case a < b:
|
||||
return -1
|
||||
case a > b:
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func isNaN(a float64) bool {
|
||||
return a != a
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -27,6 +28,10 @@ const (
|
||||
Running
|
||||
)
|
||||
|
||||
// GoogleIDToken Type is the oauth2.Token.TokenType for the Google
|
||||
// ID tokens used by the Android client.
|
||||
const GoogleIDTokenType = "ts_android_google_login"
|
||||
|
||||
func (s State) String() string {
|
||||
return [...]string{"NoState", "NeedsLogin", "NeedsMachineAuth",
|
||||
"Stopped", "Starting", "Running"}[s]
|
||||
@@ -58,6 +63,12 @@ type Notify struct {
|
||||
BrowseToURL *string // UI should open a browser right now
|
||||
BackendLogID *string // public logtail id used by backend
|
||||
|
||||
// LocalTCPPort, if non-nil, informs the UI frontend which
|
||||
// (non-zero) localhost TCP port it's listening on.
|
||||
// This is currently only used by Tailscale when run in the
|
||||
// macOS Network Extension.
|
||||
LocalTCPPort *uint16 `json:",omitempty"`
|
||||
|
||||
// type is mirrored in xcode/Shared/IPN.swift
|
||||
}
|
||||
|
||||
@@ -123,6 +134,8 @@ type Backend interface {
|
||||
// flow. This should trigger a new BrowseToURL notification
|
||||
// eventually.
|
||||
StartLoginInteractive()
|
||||
// Login logs in with an OAuth2 token.
|
||||
Login(token *oauth2.Token)
|
||||
// Logout terminates the current login session and stops the
|
||||
// wireguard engine.
|
||||
Logout()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
)
|
||||
@@ -42,6 +43,14 @@ func (b *FakeBackend) newState(s State) {
|
||||
func (b *FakeBackend) StartLoginInteractive() {
|
||||
u := b.serverURL + "/this/is/fake"
|
||||
b.notify(Notify{BrowseToURL: &u})
|
||||
b.login()
|
||||
}
|
||||
|
||||
func (b *FakeBackend) Login(token *oauth2.Token) {
|
||||
b.login()
|
||||
}
|
||||
|
||||
func (b *FakeBackend) login() {
|
||||
b.newState(NeedsMachineAuth)
|
||||
b.newState(Stopped)
|
||||
// TODO(apenwarr): Fill in a more interesting netmap here.
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
@@ -154,6 +155,10 @@ func (h *Handle) StartLoginInteractive() {
|
||||
h.b.StartLoginInteractive()
|
||||
}
|
||||
|
||||
func (h *Handle) Login(token *oauth2.Token) {
|
||||
h.b.Login(token)
|
||||
}
|
||||
|
||||
func (h *Handle) Logout() {
|
||||
h.b.Logout()
|
||||
}
|
||||
|
||||
@@ -18,11 +18,11 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/safesocket"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/wgengine"
|
||||
@@ -33,15 +33,19 @@ type Options struct {
|
||||
// SocketPath, on unix systems, is the unix socket path to listen
|
||||
// on for frontend connections.
|
||||
SocketPath string
|
||||
|
||||
// Port, on windows, is the localhost TCP port to listen on for
|
||||
// frontend connections.
|
||||
Port int
|
||||
|
||||
// StatePath is the path to the stored agent state.
|
||||
StatePath string
|
||||
|
||||
// AutostartStateKey, if non-empty, immediately starts the agent
|
||||
// using the given StateKey. If empty, the agent stays idle and
|
||||
// waits for a frontend to start it.
|
||||
AutostartStateKey ipn.StateKey
|
||||
|
||||
// LegacyConfigPath optionally specifies the old-style relaynode
|
||||
// relay.conf location. If both LegacyConfigPath and
|
||||
// AutostartStateKey are specified and the requested state doesn't
|
||||
@@ -51,53 +55,151 @@ type Options struct {
|
||||
// TODO(danderson): remove some time after the transition to
|
||||
// tailscaled is done.
|
||||
LegacyConfigPath string
|
||||
|
||||
// SurviveDisconnects specifies how the server reacts to its
|
||||
// frontend disconnecting. If true, the server keeps running on
|
||||
// its existing state, and accepts new frontend connections. If
|
||||
// false, the server dumps its state and becomes idle.
|
||||
//
|
||||
// To support CLI connections (notably, "tailscale status"),
|
||||
// the actual definition of "disconnect" is when the
|
||||
// connection count transitions from 1 to 0.
|
||||
SurviveDisconnects bool
|
||||
|
||||
// DebugMux, if non-nil, specifies an HTTP ServeMux in which
|
||||
// to register a debug handler.
|
||||
DebugMux *http.ServeMux
|
||||
|
||||
// ErrorMessage, if not empty, signals that the server will exist
|
||||
// only to relay the provided critical error message to the user.
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
func pump(logf logger.Logf, ctx context.Context, bs *ipn.BackendServer, s net.Conn) {
|
||||
defer logf("Control connection done.")
|
||||
// server is an IPN backend and its set of 0 or more active connections
|
||||
// talking to an IPN backend.
|
||||
type server struct {
|
||||
resetOnZero bool // call bs.Reset on transition from 1->0 connections
|
||||
|
||||
for ctx.Err() == nil && !bs.GotQuit {
|
||||
msg, err := ipn.ReadMsg(s)
|
||||
bsMu sync.Mutex // lock order: bsMu, then mu
|
||||
bs *ipn.BackendServer
|
||||
|
||||
mu sync.Mutex
|
||||
clients map[net.Conn]bool
|
||||
}
|
||||
|
||||
func (s *server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) {
|
||||
s.addConn(c)
|
||||
logf("incoming control connection")
|
||||
defer s.removeAndCloseConn(c)
|
||||
for ctx.Err() == nil {
|
||||
msg, err := ipn.ReadMsg(c)
|
||||
if err != nil {
|
||||
logf("ReadMsg: %v", err)
|
||||
break
|
||||
if ctx.Err() == nil {
|
||||
logf("ReadMsg: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
err = bs.GotCommandMsg(msg)
|
||||
if err != nil {
|
||||
s.bsMu.Lock()
|
||||
if err := s.bs.GotCommandMsg(msg); err != nil {
|
||||
logf("GotCommandMsg: %v", err)
|
||||
break
|
||||
}
|
||||
gotQuit := s.bs.GotQuit
|
||||
s.bsMu.Unlock()
|
||||
if gotQuit {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e wgengine.Engine) (err error) {
|
||||
runDone := make(chan error, 1)
|
||||
defer func() { runDone <- err }()
|
||||
func (s *server) addConn(c net.Conn) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.clients == nil {
|
||||
s.clients = map[net.Conn]bool{}
|
||||
}
|
||||
s.clients[c] = true
|
||||
}
|
||||
|
||||
func (s *server) removeAndCloseConn(c net.Conn) {
|
||||
s.mu.Lock()
|
||||
delete(s.clients, c)
|
||||
remain := len(s.clients)
|
||||
s.mu.Unlock()
|
||||
|
||||
if remain == 0 && s.resetOnZero {
|
||||
s.bsMu.Lock()
|
||||
s.bs.Reset()
|
||||
s.bsMu.Unlock()
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func (s *server) stopAll() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for c := range s.clients {
|
||||
safesocket.ConnCloseRead(c)
|
||||
safesocket.ConnCloseWrite(c)
|
||||
}
|
||||
s.clients = nil
|
||||
}
|
||||
|
||||
func (s *server) writeToClients(b []byte) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for c := range s.clients {
|
||||
ipn.WriteMsg(c, b)
|
||||
}
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, logf logger.Logf, logid string, opts Options, e wgengine.Engine) error {
|
||||
runDone := make(chan struct{})
|
||||
defer close(runDone)
|
||||
|
||||
listen, _, err := safesocket.Listen(opts.SocketPath, uint16(opts.Port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("safesocket.Listen: %v", err)
|
||||
}
|
||||
|
||||
// Go listeners can't take a context, close it instead.
|
||||
server := &server{
|
||||
resetOnZero: !opts.SurviveDisconnects,
|
||||
}
|
||||
|
||||
// When the context is closed or when we return, whichever is first, close our listner
|
||||
// and all open connections.
|
||||
go func() {
|
||||
select {
|
||||
case <-rctx.Done():
|
||||
case <-ctx.Done():
|
||||
case <-runDone:
|
||||
}
|
||||
server.stopAll()
|
||||
listen.Close()
|
||||
}()
|
||||
logf("Listening on %v", listen.Addr())
|
||||
|
||||
bo := backoff.NewBackoff("ipnserver", logf)
|
||||
|
||||
if opts.ErrorMessage != "" {
|
||||
for i := 1; ctx.Err() == nil; i++ {
|
||||
s, err := listen.Accept()
|
||||
if err != nil {
|
||||
logf("%d: Accept: %v", i, err)
|
||||
bo.BackOff(ctx, err)
|
||||
continue
|
||||
}
|
||||
serverToClient := func(b []byte) {
|
||||
ipn.WriteMsg(s, b)
|
||||
}
|
||||
go func() {
|
||||
defer s.Close()
|
||||
bs := ipn.NewBackendServer(logf, nil, serverToClient)
|
||||
bs.SendErrorMessage(opts.ErrorMessage)
|
||||
s.Read(make([]byte, 1))
|
||||
}()
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var store ipn.StateStore
|
||||
if opts.StatePath != "" {
|
||||
store, err = ipn.NewFileStore(opts.StatePath)
|
||||
@@ -112,11 +214,9 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w
|
||||
if err != nil {
|
||||
return fmt.Errorf("NewLocalBackend: %v", err)
|
||||
}
|
||||
defer b.Shutdown()
|
||||
b.SetDecompressor(func() (controlclient.Decompressor, error) {
|
||||
return zstd.NewReader(nil,
|
||||
zstd.WithDecoderLowmem(true),
|
||||
zstd.WithDecoderConcurrency(1),
|
||||
)
|
||||
return smallzstd.NewDecoder(nil)
|
||||
})
|
||||
|
||||
if opts.DebugMux != nil {
|
||||
@@ -128,17 +228,10 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w
|
||||
})
|
||||
}
|
||||
|
||||
var s net.Conn
|
||||
serverToClient := func(b []byte) {
|
||||
if s != nil { // TODO: racy access to s?
|
||||
ipn.WriteMsg(s, b)
|
||||
}
|
||||
}
|
||||
|
||||
bs := ipn.NewBackendServer(logf, b, serverToClient)
|
||||
server.bs = ipn.NewBackendServer(logf, b, server.writeToClients)
|
||||
|
||||
if opts.AutostartStateKey != "" {
|
||||
bs.GotCommand(&ipn.Command{
|
||||
server.bs.GotCommand(&ipn.Command{
|
||||
Version: version.LONG,
|
||||
Start: &ipn.StartArgs{
|
||||
Opts: ipn.Options{
|
||||
@@ -149,55 +242,18 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
oldS net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
stopAll := func() {
|
||||
// Currently we only support one client connection at a time.
|
||||
// Theoretically we could allow multiple clients, by passing
|
||||
// notifications to all of them and accepting commands from
|
||||
// any of them, but there doesn't seem to be much need for
|
||||
// that right now.
|
||||
if oldS != nil {
|
||||
cancel()
|
||||
safesocket.ConnCloseRead(oldS)
|
||||
safesocket.ConnCloseWrite(oldS)
|
||||
}
|
||||
}
|
||||
|
||||
bo := backoff.NewBackoff("ipnserver", logf)
|
||||
|
||||
for i := 1; rctx.Err() == nil; i++ {
|
||||
s, err = listen.Accept()
|
||||
for i := 1; ctx.Err() == nil; i++ {
|
||||
c, err := listen.Accept()
|
||||
if err != nil {
|
||||
logf("%d: Accept: %v", i, err)
|
||||
bo.BackOff(rctx, err)
|
||||
if ctx.Err() == nil {
|
||||
logf("ipnserver: Accept: %v", err)
|
||||
bo.BackOff(ctx, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
logf("%d: Incoming control connection.", i)
|
||||
stopAll()
|
||||
|
||||
ctx, cancel = context.WithCancel(rctx)
|
||||
oldS = s
|
||||
|
||||
go func(ctx context.Context, s net.Conn, i int) {
|
||||
logf := logger.WithPrefix(logf, fmt.Sprintf("%d: ", i))
|
||||
pump(logf, ctx, bs, s)
|
||||
if !opts.SurviveDisconnects || bs.GotQuit {
|
||||
bs.Reset()
|
||||
s.Close()
|
||||
}
|
||||
// Quitting not allowed, just keep going.
|
||||
bs.GotQuit = false
|
||||
}(ctx, s, i)
|
||||
|
||||
bo.BackOff(ctx, nil)
|
||||
go server.serveConn(ctx, c, logger.WithPrefix(logf, fmt.Sprintf("ipnserver: conn%d: ", i)))
|
||||
}
|
||||
stopAll()
|
||||
|
||||
return rctx.Err()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func BabysitProc(ctx context.Context, args []string, logf logger.Logf) {
|
||||
|
||||
@@ -49,10 +49,12 @@ type PeerStatus struct {
|
||||
// Endpoints:
|
||||
Addrs []string
|
||||
CurAddr string // one of Addrs, or unique if roaming
|
||||
Relay string // DERP region
|
||||
|
||||
RxBytes int64
|
||||
TxBytes int64
|
||||
Created time.Time // time registered with tailcontrol
|
||||
LastWrite time.Time // time last packet sent
|
||||
LastSeen time.Time // last seen to tailcontrol
|
||||
LastHandshake time.Time // with local wireguard
|
||||
KeepAlive bool
|
||||
@@ -135,6 +137,9 @@ func (sb *StatusBuilder) AddPeer(peer key.Public, st *PeerStatus) {
|
||||
if v := st.HostName; v != "" {
|
||||
e.HostName = v
|
||||
}
|
||||
if v := st.Relay; v != "" {
|
||||
e.Relay = v
|
||||
}
|
||||
if v := st.UserID; v != 0 {
|
||||
e.UserID = v
|
||||
}
|
||||
@@ -165,6 +170,9 @@ func (sb *StatusBuilder) AddPeer(peer key.Public, st *PeerStatus) {
|
||||
if v := st.LastSeen; !v.IsZero() {
|
||||
e.LastSeen = v
|
||||
}
|
||||
if v := st.LastWrite; !v.IsZero() {
|
||||
e.LastWrite = v
|
||||
}
|
||||
if st.InNetworkMap {
|
||||
e.InNetworkMap = true
|
||||
}
|
||||
@@ -211,28 +219,19 @@ table tbody tr:nth-child(even) td { background-color: #f5f5f5; }
|
||||
//f("<p><b>opts:</b> <code>%s</code></p>\n", html.EscapeString(fmt.Sprintf("%+v", opts)))
|
||||
|
||||
f("<table>\n<thead>\n")
|
||||
f("<tr><th>Peer</th><th>Node</th><th>Owner</th><th>Rx</th><th>Tx</th><th>Handshake</th><th>Endpoints</th></tr>\n")
|
||||
f("<tr><th>Peer</th><th>Node</th><th>Owner</th><th>Rx</th><th>Tx</th><th>Activity</th><th>Endpoints</th></tr>\n")
|
||||
f("</thead>\n<tbody>\n")
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// The tailcontrol server rounds LastSeen to 10 minutes. So we
|
||||
// declare that a longAgo seen time of 15 minutes means
|
||||
// they're not connected.
|
||||
longAgo := now.Add(-15 * time.Minute)
|
||||
|
||||
for _, peer := range st.Peers() {
|
||||
ps := st.Peer[peer]
|
||||
var hsAgo string
|
||||
if !ps.LastHandshake.IsZero() {
|
||||
hsAgo = now.Sub(ps.LastHandshake).Round(time.Second).String() + " ago"
|
||||
} else {
|
||||
if ps.LastSeen.Before(longAgo) {
|
||||
hsAgo = "<i>offline</i>"
|
||||
} else if !ps.KeepAlive {
|
||||
hsAgo = "on demand"
|
||||
} else {
|
||||
hsAgo = "<b>pending</b>"
|
||||
var actAgo string
|
||||
if !ps.LastWrite.IsZero() {
|
||||
ago := now.Sub(ps.LastWrite)
|
||||
actAgo = ago.Round(time.Second).String() + " ago"
|
||||
if ago < 5*time.Minute {
|
||||
actAgo = "<b>" + actAgo + "</b>"
|
||||
}
|
||||
}
|
||||
var owner string
|
||||
@@ -250,9 +249,20 @@ table tbody tr:nth-child(even) td { background-color: #f5f5f5; }
|
||||
html.EscapeString(owner),
|
||||
ps.RxBytes,
|
||||
ps.TxBytes,
|
||||
hsAgo,
|
||||
actAgo,
|
||||
)
|
||||
f("<td class=\"aright\">")
|
||||
// TODO: let server report this active bool instead
|
||||
active := !ps.LastWrite.IsZero() && time.Since(ps.LastWrite) < 2*time.Minute
|
||||
relay := ps.Relay
|
||||
if relay != "" {
|
||||
if active && ps.CurAddr == "" {
|
||||
f("🔗 <b>derp-%v</b><br>", html.EscapeString(relay))
|
||||
} else {
|
||||
f("derp-%v<br>", html.EscapeString(relay))
|
||||
}
|
||||
}
|
||||
|
||||
match := false
|
||||
for _, addr := range ps.Addrs {
|
||||
if addr == ps.CurAddr {
|
||||
|
||||
451
ipn/local.go
451
ipn/local.go
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"golang.org/x/oauth2"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/tsdns"
|
||||
)
|
||||
|
||||
// LocalBackend is the glue between the major pieces of the Tailscale
|
||||
@@ -48,6 +50,7 @@ type LocalBackend struct {
|
||||
store StateStore
|
||||
backendLogID string
|
||||
portpoll *portlist.Poller // may be nil
|
||||
portpollOnce sync.Once
|
||||
newDecompressor func() (controlclient.Decompressor, error)
|
||||
|
||||
// TODO: these fields are accessed unsafely by concurrent
|
||||
@@ -56,14 +59,16 @@ type LocalBackend struct {
|
||||
lastFilterPrint time.Time
|
||||
|
||||
// The mutex protects the following elements.
|
||||
mu sync.Mutex
|
||||
notify func(Notify)
|
||||
c *controlclient.Client
|
||||
stateKey StateKey
|
||||
prefs *Prefs
|
||||
state State
|
||||
hiCache *tailcfg.Hostinfo
|
||||
netMapCache *controlclient.NetworkMap
|
||||
mu sync.Mutex
|
||||
notify func(Notify)
|
||||
c *controlclient.Client
|
||||
stateKey StateKey
|
||||
prefs *Prefs
|
||||
state State
|
||||
// hostinfo is mutated in-place while mu is held.
|
||||
hostinfo *tailcfg.Hostinfo
|
||||
// netMap is not mutated in-place once set.
|
||||
netMap *controlclient.NetworkMap
|
||||
engineStatus EngineStatus
|
||||
endpoints []string
|
||||
blocked bool
|
||||
@@ -105,11 +110,6 @@ func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengin
|
||||
}
|
||||
b.statusChanged = sync.NewCond(&b.statusLock)
|
||||
|
||||
if b.portpoll != nil {
|
||||
go b.portpoll.Run(ctx)
|
||||
go b.readPoller()
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
@@ -145,11 +145,11 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) {
|
||||
|
||||
// TODO: hostinfo, and its networkinfo
|
||||
// TODO: EngineStatus copy (and deprecate it?)
|
||||
if b.netMapCache != nil {
|
||||
for id, up := range b.netMapCache.UserProfiles {
|
||||
if b.netMap != nil {
|
||||
for id, up := range b.netMap.UserProfiles {
|
||||
sb.AddUser(id, up)
|
||||
}
|
||||
for _, p := range b.netMapCache.Peers {
|
||||
for _, p := range b.netMap.Peers {
|
||||
var lastSeen time.Time
|
||||
if p.LastSeen != nil {
|
||||
lastSeen = *p.LastSeen
|
||||
@@ -183,6 +183,123 @@ func (b *LocalBackend) SetDecompressor(fn func() (controlclient.Decompressor, er
|
||||
b.newDecompressor = fn
|
||||
}
|
||||
|
||||
// setClientStatus is the callback invoked by the control client whenever it posts a new status.
|
||||
// Among other things, this is where we update the netmap, packet filters, DNS and DERP maps.
|
||||
func (b *LocalBackend) setClientStatus(st controlclient.Status) {
|
||||
if st.LoginFinished != nil {
|
||||
// Auth completed, unblock the engine
|
||||
b.blockEngineUpdates(false)
|
||||
b.authReconfig()
|
||||
b.send(Notify{LoginFinished: &empty.Message{}})
|
||||
}
|
||||
if st.Persist != nil {
|
||||
persist := *st.Persist // copy
|
||||
|
||||
b.mu.Lock()
|
||||
b.prefs.Persist = &persist
|
||||
prefs := b.prefs.Clone()
|
||||
stateKey := b.stateKey
|
||||
b.mu.Unlock()
|
||||
|
||||
if stateKey != "" {
|
||||
if err := b.store.WriteState(stateKey, prefs.ToBytes()); err != nil {
|
||||
b.logf("Failed to save new controlclient state: %v", err)
|
||||
}
|
||||
}
|
||||
b.send(Notify{Prefs: prefs})
|
||||
}
|
||||
if st.NetMap != nil {
|
||||
// Netmap is unchanged only when the diff is empty.
|
||||
changed := true
|
||||
b.mu.Lock()
|
||||
if b.netMap != nil {
|
||||
diff := st.NetMap.ConciseDiffFrom(b.netMap)
|
||||
if strings.TrimSpace(diff) == "" {
|
||||
changed = false
|
||||
b.logf("netmap diff: (none)")
|
||||
} else {
|
||||
b.logf("netmap diff:\n%v", diff)
|
||||
}
|
||||
}
|
||||
disableDERP := b.prefs != nil && b.prefs.DisableDERP
|
||||
b.netMap = st.NetMap
|
||||
b.mu.Unlock()
|
||||
|
||||
b.send(Notify{NetMap: st.NetMap})
|
||||
// There is nothing to update if the map hasn't changed.
|
||||
if changed {
|
||||
b.updateFilter(st.NetMap)
|
||||
b.updateDNSMap(st.NetMap)
|
||||
b.e.SetNetworkMap(st.NetMap)
|
||||
}
|
||||
if disableDERP {
|
||||
b.e.SetDERPMap(nil)
|
||||
} else {
|
||||
b.e.SetDERPMap(st.NetMap.DERPMap)
|
||||
}
|
||||
}
|
||||
if st.URL != "" {
|
||||
b.logf("Received auth URL: %.20v...", st.URL)
|
||||
|
||||
b.mu.Lock()
|
||||
interact := b.interact
|
||||
b.authURL = st.URL
|
||||
b.mu.Unlock()
|
||||
|
||||
if interact > 0 {
|
||||
b.popBrowserAuthNow()
|
||||
}
|
||||
}
|
||||
if st.Err != "" {
|
||||
// TODO(crawshaw): display in the UI.
|
||||
b.logf("Received error: %v", st.Err)
|
||||
return
|
||||
}
|
||||
if st.NetMap != nil {
|
||||
b.mu.Lock()
|
||||
if b.state == NeedsLogin {
|
||||
b.prefs.WantRunning = true
|
||||
}
|
||||
prefs := b.prefs
|
||||
b.mu.Unlock()
|
||||
|
||||
b.SetPrefs(prefs)
|
||||
}
|
||||
b.stateMachine()
|
||||
}
|
||||
|
||||
// setWgengineStatus is the callback by the wireguard engine whenever it posts a new status.
|
||||
// This updates the endpoints both in the backend and in the control client.
|
||||
func (b *LocalBackend) setWgengineStatus(s *wgengine.Status, err error) {
|
||||
if err != nil {
|
||||
b.logf("wgengine status error: %#v", err)
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
b.logf("[unexpected] non-error wgengine update with status=nil: %v", s)
|
||||
return
|
||||
}
|
||||
|
||||
es := b.parseWgStatus(s)
|
||||
|
||||
b.mu.Lock()
|
||||
c := b.c
|
||||
b.engineStatus = es
|
||||
b.endpoints = append([]string{}, s.LocalAddrs...)
|
||||
b.mu.Unlock()
|
||||
|
||||
if c != nil {
|
||||
c.UpdateEndpoints(0, s.LocalAddrs)
|
||||
}
|
||||
b.stateMachine()
|
||||
|
||||
b.statusLock.Lock()
|
||||
b.statusChanged.Broadcast()
|
||||
b.statusLock.Unlock()
|
||||
|
||||
b.send(Notify{Engine: &es})
|
||||
}
|
||||
|
||||
// Start applies the configuration specified in opts, and starts the
|
||||
// state machine.
|
||||
//
|
||||
@@ -204,9 +321,9 @@ func (b *LocalBackend) Start(opts Options) error {
|
||||
b.logf("Start")
|
||||
}
|
||||
|
||||
hi := controlclient.NewHostinfo()
|
||||
hi.BackendLogID = b.backendLogID
|
||||
hi.FrontendLogID = opts.FrontendLogID
|
||||
hostinfo := controlclient.NewHostinfo()
|
||||
hostinfo.BackendLogID = b.backendLogID
|
||||
hostinfo.FrontendLogID = opts.FrontendLogID
|
||||
|
||||
b.mu.Lock()
|
||||
|
||||
@@ -221,11 +338,11 @@ func (b *LocalBackend) Start(opts Options) error {
|
||||
b.c.Shutdown()
|
||||
}
|
||||
|
||||
if b.hiCache != nil {
|
||||
hi.Services = b.hiCache.Services // keep any previous session and netinfo
|
||||
hi.NetInfo = b.hiCache.NetInfo
|
||||
if b.hostinfo != nil {
|
||||
hostinfo.Services = b.hostinfo.Services // keep any previous session and netinfo
|
||||
hostinfo.NetInfo = b.hostinfo.NetInfo
|
||||
}
|
||||
b.hiCache = hi
|
||||
b.hostinfo = hostinfo
|
||||
b.state = NoState
|
||||
|
||||
if err := b.loadStateLocked(opts.StateKey, opts.Prefs, opts.LegacyConfigPath); err != nil {
|
||||
@@ -234,16 +351,22 @@ func (b *LocalBackend) Start(opts Options) error {
|
||||
}
|
||||
|
||||
b.serverURL = b.prefs.ControlURL
|
||||
hi.RoutableIPs = append(hi.RoutableIPs, b.prefs.AdvertiseRoutes...)
|
||||
hi.RequestTags = append(hi.RequestTags, b.prefs.AdvertiseTags...)
|
||||
hostinfo.RoutableIPs = append(hostinfo.RoutableIPs, b.prefs.AdvertiseRoutes...)
|
||||
hostinfo.RequestTags = append(hostinfo.RequestTags, b.prefs.AdvertiseTags...)
|
||||
applyPrefsToHostinfo(hostinfo, b.prefs)
|
||||
|
||||
b.notify = opts.Notify
|
||||
b.netMapCache = nil
|
||||
b.netMap = nil
|
||||
persist := b.prefs.Persist
|
||||
b.mu.Unlock()
|
||||
|
||||
b.updateFilter(nil)
|
||||
|
||||
var discoPublic tailcfg.DiscoKey
|
||||
if controlclient.Debug.Disco {
|
||||
discoPublic = b.e.DiscoPublicKey()
|
||||
}
|
||||
|
||||
var err error
|
||||
if persist == nil {
|
||||
// let controlclient initialize it
|
||||
@@ -254,15 +377,25 @@ func (b *LocalBackend) Start(opts Options) error {
|
||||
Persist: *persist,
|
||||
ServerURL: b.serverURL,
|
||||
AuthKey: opts.AuthKey,
|
||||
Hostinfo: hi,
|
||||
Hostinfo: hostinfo,
|
||||
KeepAlive: true,
|
||||
NewDecompressor: b.newDecompressor,
|
||||
HTTPTestClient: opts.HTTPTestClient,
|
||||
DiscoPublicKey: discoPublic,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// At this point, we have finished using hostinfo without synchronization,
|
||||
// so it is safe to start readPoller which concurrently writes to it.
|
||||
if b.portpoll != nil {
|
||||
b.portpollOnce.Do(func() {
|
||||
go b.portpoll.Run(b.ctx)
|
||||
go b.readPoller()
|
||||
})
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
b.c = cli
|
||||
endpoints := b.endpoints
|
||||
@@ -272,111 +405,8 @@ func (b *LocalBackend) Start(opts Options) error {
|
||||
cli.UpdateEndpoints(0, endpoints)
|
||||
}
|
||||
|
||||
cli.SetStatusFunc(func(newSt controlclient.Status) {
|
||||
if newSt.LoginFinished != nil {
|
||||
// Auth completed, unblock the engine
|
||||
b.blockEngineUpdates(false)
|
||||
b.authReconfig()
|
||||
b.send(Notify{LoginFinished: &empty.Message{}})
|
||||
}
|
||||
if newSt.Persist != nil {
|
||||
persist := *newSt.Persist // copy
|
||||
|
||||
b.mu.Lock()
|
||||
b.prefs.Persist = &persist
|
||||
prefs := b.prefs.Clone()
|
||||
stateKey := b.stateKey
|
||||
b.mu.Unlock()
|
||||
|
||||
if stateKey != "" {
|
||||
if err := b.store.WriteState(stateKey, prefs.ToBytes()); err != nil {
|
||||
b.logf("Failed to save new controlclient state: %v", err)
|
||||
}
|
||||
}
|
||||
b.send(Notify{Prefs: prefs})
|
||||
}
|
||||
if newSt.NetMap != nil {
|
||||
b.mu.Lock()
|
||||
if b.netMapCache != nil {
|
||||
diff := newSt.NetMap.ConciseDiffFrom(b.netMapCache)
|
||||
if strings.TrimSpace(diff) == "" {
|
||||
b.logf("netmap diff: (none)")
|
||||
} else {
|
||||
b.logf("netmap diff:\n%v", diff)
|
||||
}
|
||||
}
|
||||
disableDERP := b.prefs != nil && b.prefs.DisableDERP
|
||||
b.netMapCache = newSt.NetMap
|
||||
b.mu.Unlock()
|
||||
|
||||
b.send(Notify{NetMap: newSt.NetMap})
|
||||
b.updateFilter(newSt.NetMap)
|
||||
if disableDERP {
|
||||
b.e.SetDERPMap(nil)
|
||||
} else {
|
||||
b.e.SetDERPMap(newSt.NetMap.DERPMap)
|
||||
}
|
||||
}
|
||||
if newSt.URL != "" {
|
||||
b.logf("Received auth URL: %.20v...", newSt.URL)
|
||||
|
||||
b.mu.Lock()
|
||||
interact := b.interact
|
||||
b.authURL = newSt.URL
|
||||
b.mu.Unlock()
|
||||
|
||||
if interact > 0 {
|
||||
b.popBrowserAuthNow()
|
||||
}
|
||||
}
|
||||
if newSt.Err != "" {
|
||||
// TODO(crawshaw): display in the UI.
|
||||
b.logf("Received error: %v", newSt.Err)
|
||||
return
|
||||
}
|
||||
if newSt.NetMap != nil {
|
||||
b.mu.Lock()
|
||||
if b.state == NeedsLogin {
|
||||
b.prefs.WantRunning = true
|
||||
}
|
||||
prefs := b.prefs
|
||||
b.mu.Unlock()
|
||||
|
||||
b.SetPrefs(prefs)
|
||||
}
|
||||
b.stateMachine()
|
||||
})
|
||||
|
||||
b.e.SetStatusCallback(func(s *wgengine.Status, err error) {
|
||||
if err != nil {
|
||||
b.logf("wgengine status error: %#v", err)
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
b.logf("weird: non-error wgengine update with status=nil: %v", s)
|
||||
return
|
||||
}
|
||||
|
||||
es := b.parseWgStatus(s)
|
||||
|
||||
b.mu.Lock()
|
||||
c := b.c
|
||||
b.engineStatus = es
|
||||
b.endpoints = append([]string{}, s.LocalAddrs...)
|
||||
b.mu.Unlock()
|
||||
|
||||
if c != nil {
|
||||
c.UpdateEndpoints(0, s.LocalAddrs)
|
||||
}
|
||||
b.stateMachine()
|
||||
|
||||
b.statusLock.Lock()
|
||||
b.statusChanged.Broadcast()
|
||||
b.statusLock.Unlock()
|
||||
|
||||
b.send(Notify{Engine: &es})
|
||||
})
|
||||
|
||||
cli.SetStatusFunc(b.setClientStatus)
|
||||
b.e.SetStatusCallback(b.setWgengineStatus)
|
||||
b.e.SetNetInfoCallback(b.setNetInfo)
|
||||
|
||||
b.mu.Lock()
|
||||
@@ -427,6 +457,34 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap) {
|
||||
b.e.SetFilter(filter.New(netMap.PacketFilter, localNets, b.e.GetFilter(), b.logf))
|
||||
}
|
||||
|
||||
// updateDNSMap updates the domain map in the DNS resolver in wgengine
|
||||
// based on the given netMap and user preferences.
|
||||
func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) {
|
||||
if netMap == nil {
|
||||
return
|
||||
}
|
||||
|
||||
domainToIP := make(map[string]netaddr.IP)
|
||||
set := func(hostname string, addrs []wgcfg.CIDR) {
|
||||
if len(addrs) == 0 {
|
||||
return
|
||||
}
|
||||
domain := hostname
|
||||
// Like PeerStatus.SimpleHostName()
|
||||
domain = strings.TrimSuffix(domain, ".local")
|
||||
domain = strings.TrimSuffix(domain, ".localdomain")
|
||||
domain = domain + ".b.tailscale.net"
|
||||
domainToIP[domain] = netaddr.IPFrom16(addrs[0].IP.Addr)
|
||||
}
|
||||
|
||||
for _, peer := range netMap.Peers {
|
||||
set(peer.Hostinfo.Hostname, peer.Addresses)
|
||||
}
|
||||
set(netMap.Hostinfo.Hostname, netMap.Addresses)
|
||||
|
||||
b.e.SetDNSMap(tsdns.NewMap(domainToIP))
|
||||
}
|
||||
|
||||
// readPoller is a goroutine that receives service lists from
|
||||
// b.portpoll and propagates them into the controlclient's HostInfo.
|
||||
func (b *LocalBackend) readPoller() {
|
||||
@@ -448,13 +506,11 @@ func (b *LocalBackend) readPoller() {
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
if b.hiCache == nil {
|
||||
// TODO(bradfitz): it's a little weird that this port poller
|
||||
// is started (by NewLocalBackend) before the Start call.
|
||||
b.hiCache = new(tailcfg.Hostinfo)
|
||||
if b.hostinfo == nil {
|
||||
b.hostinfo = new(tailcfg.Hostinfo)
|
||||
}
|
||||
b.hiCache.Services = sl
|
||||
hi := b.hiCache
|
||||
b.hostinfo.Services = sl
|
||||
hi := b.hostinfo
|
||||
b.mu.Unlock()
|
||||
|
||||
b.doSetHostinfoFilterServices(hi)
|
||||
@@ -471,6 +527,8 @@ func (b *LocalBackend) send(n Notify) {
|
||||
if notify != nil {
|
||||
n.Version = version.LONG
|
||||
notify(n)
|
||||
} else {
|
||||
b.logf("nil notify callback; dropping %+v", n)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,7 +581,7 @@ func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath st
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrStateNotExist) {
|
||||
if legacyPath != "" {
|
||||
b.prefs, err = LoadPrefs(legacyPath, true)
|
||||
b.prefs, err = LoadPrefs(legacyPath)
|
||||
if err != nil {
|
||||
b.logf("Failed to load legacy prefs: %v", err)
|
||||
b.prefs = NewPrefs()
|
||||
@@ -565,6 +623,16 @@ func (b *LocalBackend) getEngineStatus() EngineStatus {
|
||||
return b.engineStatus
|
||||
}
|
||||
|
||||
// Login implements Backend.
|
||||
func (b *LocalBackend) Login(token *oauth2.Token) {
|
||||
b.mu.Lock()
|
||||
b.assertClientLocked()
|
||||
c := b.c
|
||||
b.mu.Unlock()
|
||||
|
||||
c.Login(token, controlclient.LoginInteractive)
|
||||
}
|
||||
|
||||
// StartLoginInteractive implements Backend. It requests a new
|
||||
// interactive login from controlclient, unless such a flow is already
|
||||
// in progress, in which case StartLoginInteractive attempts to pick
|
||||
@@ -588,13 +656,23 @@ func (b *LocalBackend) StartLoginInteractive() {
|
||||
// FakeExpireAfter implements Backend.
|
||||
func (b *LocalBackend) FakeExpireAfter(x time.Duration) {
|
||||
b.logf("FakeExpireAfter: %v", x)
|
||||
if b.netMapCache != nil {
|
||||
e := b.netMapCache.Expiry
|
||||
if e.IsZero() || time.Until(e) > x {
|
||||
b.netMapCache.Expiry = time.Now().Add(x)
|
||||
}
|
||||
b.send(Notify{NetMap: b.netMapCache})
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.netMap == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// This function is called very rarely,
|
||||
// so we prefer to fully copy the netmap over introducing in-place modification here.
|
||||
mapCopy := *b.netMap
|
||||
e := mapCopy.Expiry
|
||||
if e.IsZero() || time.Until(e) > x {
|
||||
mapCopy.Expiry = time.Now().Add(x)
|
||||
}
|
||||
b.netMap = &mapCopy
|
||||
b.send(Notify{NetMap: b.netMap})
|
||||
}
|
||||
|
||||
func (b *LocalBackend) parseWgStatus(s *wgengine.Status) (ret EngineStatus) {
|
||||
@@ -651,22 +729,30 @@ func (b *LocalBackend) SetPrefs(new *Prefs) {
|
||||
b.logf("Failed to save new controlclient state: %v", err)
|
||||
}
|
||||
}
|
||||
oldHi := b.hiCache
|
||||
oldHi := b.hostinfo
|
||||
newHi := oldHi.Clone()
|
||||
newHi.RoutableIPs = append([]wgcfg.CIDR(nil), b.prefs.AdvertiseRoutes...)
|
||||
if h := new.Hostname; h != "" {
|
||||
newHi.Hostname = h
|
||||
}
|
||||
b.hiCache = newHi
|
||||
applyPrefsToHostinfo(newHi, new)
|
||||
b.hostinfo = newHi
|
||||
hostInfoChanged := !oldHi.Equal(newHi)
|
||||
b.mu.Unlock()
|
||||
|
||||
b.logf("SetPrefs: %v", new.Pretty())
|
||||
|
||||
if old.ShieldsUp != new.ShieldsUp || !oldHi.Equal(newHi) {
|
||||
if old.ShieldsUp != new.ShieldsUp || hostInfoChanged {
|
||||
b.doSetHostinfoFilterServices(newHi)
|
||||
}
|
||||
|
||||
b.updateFilter(b.netMapCache)
|
||||
b.updateFilter(b.netMap)
|
||||
// TODO(dmytro): when Prefs gain an EnableTailscaleDNS toggle, updateDNSMap here.
|
||||
|
||||
turnDERPOff := new.DisableDERP && !old.DisableDERP
|
||||
turnDERPOn := !new.DisableDERP && old.DisableDERP
|
||||
if turnDERPOff {
|
||||
b.e.SetDERPMap(nil)
|
||||
} else if turnDERPOn && b.netMap != nil {
|
||||
b.e.SetDERPMap(b.netMap.DERPMap)
|
||||
}
|
||||
|
||||
if old.WantRunning != new.WantRunning {
|
||||
b.stateMachine()
|
||||
@@ -703,7 +789,9 @@ func (b *LocalBackend) doSetHostinfoFilterServices(hi *tailcfg.Hostinfo) {
|
||||
// NetMap returns the latest cached network map received from
|
||||
// controlclient, or nil if no network map was received yet.
|
||||
func (b *LocalBackend) NetMap() *controlclient.NetworkMap {
|
||||
return b.netMapCache
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.netMap
|
||||
}
|
||||
|
||||
// blockEngineUpdate sets b.blocked to block, while holding b.mu. Its
|
||||
@@ -724,7 +812,7 @@ func (b *LocalBackend) authReconfig() {
|
||||
b.mu.Lock()
|
||||
blocked := b.blocked
|
||||
uc := b.prefs
|
||||
nm := b.netMapCache
|
||||
nm := b.netMap
|
||||
b.mu.Unlock()
|
||||
|
||||
if blocked {
|
||||
@@ -740,20 +828,20 @@ func (b *LocalBackend) authReconfig() {
|
||||
return
|
||||
}
|
||||
|
||||
uflags := controlclient.UDefault
|
||||
var flags controlclient.WGConfigFlags
|
||||
if uc.RouteAll {
|
||||
uflags |= controlclient.UAllowDefaultRoute
|
||||
flags |= controlclient.AllowDefaultRoute
|
||||
// TODO(apenwarr): Make subnet routes a different pref?
|
||||
uflags |= controlclient.UAllowSubnetRoutes
|
||||
flags |= controlclient.AllowSubnetRoutes
|
||||
// TODO(apenwarr): Remove this once we sort out subnet routes.
|
||||
// Right now default routes are broken in Windows, but
|
||||
// controlclient doesn't properly send subnet routes. So
|
||||
// let's convert a default route into a subnet route in order
|
||||
// to allow experimentation.
|
||||
uflags |= controlclient.UHackDefaultRoute
|
||||
flags |= controlclient.HackDefaultRoute
|
||||
}
|
||||
if uc.AllowSingleHosts {
|
||||
uflags |= controlclient.UAllowSingleHosts
|
||||
flags |= controlclient.AllowSingleHosts
|
||||
}
|
||||
|
||||
dns := nm.DNS
|
||||
@@ -762,7 +850,7 @@ func (b *LocalBackend) authReconfig() {
|
||||
dns = []wgcfg.IP{}
|
||||
dom = []string{}
|
||||
}
|
||||
cfg, err := nm.WGCfg(uflags, dns)
|
||||
cfg, err := nm.WGCfg(b.logf, flags, dns)
|
||||
if err != nil {
|
||||
b.logf("wgcfg: %v", err)
|
||||
return
|
||||
@@ -772,7 +860,7 @@ func (b *LocalBackend) authReconfig() {
|
||||
if err == wgengine.ErrNoChanges {
|
||||
return
|
||||
}
|
||||
b.logf("authReconfig: ra=%v dns=%v 0x%02x: %v", uc.RouteAll, uc.CorpDNS, uflags, err)
|
||||
b.logf("authReconfig: ra=%v dns=%v 0x%02x: %v", uc.RouteAll, uc.CorpDNS, flags, err)
|
||||
}
|
||||
|
||||
// routerConfig produces a router.Config from a wireguard config,
|
||||
@@ -788,17 +876,26 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs, dnsDomains []string) *router.
|
||||
|
||||
rs := &router.Config{
|
||||
LocalAddrs: wgCIDRToNetaddr(addrs),
|
||||
DNS: wgIPToNetaddr(cfg.DNS),
|
||||
DNSDomains: dnsDomains,
|
||||
SubnetRoutes: wgCIDRToNetaddr(prefs.AdvertiseRoutes),
|
||||
SNATSubnetRoutes: !prefs.NoSNAT,
|
||||
NetfilterMode: prefs.NetfilterMode,
|
||||
DNSConfig: router.DNSConfig{
|
||||
Nameservers: wgIPToNetaddr(cfg.DNS),
|
||||
Domains: dnsDomains,
|
||||
},
|
||||
}
|
||||
|
||||
for _, peer := range cfg.Peers {
|
||||
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
|
||||
}
|
||||
|
||||
// The Tailscale DNS IP.
|
||||
// TODO(dmytro): make this configurable.
|
||||
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
|
||||
IP: netaddr.IPv4(100, 100, 100, 100),
|
||||
Bits: 32,
|
||||
})
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
@@ -842,6 +939,18 @@ func wgCIDRToNetaddr(cidrs []wgcfg.CIDR) (ret []netaddr.IPPrefix) {
|
||||
return ret
|
||||
}
|
||||
|
||||
func applyPrefsToHostinfo(hi *tailcfg.Hostinfo, prefs *Prefs) {
|
||||
if h := prefs.Hostname; h != "" {
|
||||
hi.Hostname = h
|
||||
}
|
||||
if v := prefs.OSVersion; v != "" {
|
||||
hi.OSVersion = v
|
||||
}
|
||||
if m := prefs.DeviceModel; m != "" {
|
||||
hi.DeviceModel = m
|
||||
}
|
||||
}
|
||||
|
||||
// enterState transitions the backend into newState, updating internal
|
||||
// state and propagating events out as needed.
|
||||
//
|
||||
@@ -852,6 +961,7 @@ func wgCIDRToNetaddr(cidrs []wgcfg.CIDR) (ret []netaddr.IPPrefix) {
|
||||
func (b *LocalBackend) enterState(newState State) {
|
||||
b.mu.Lock()
|
||||
state := b.state
|
||||
b.state = newState
|
||||
prefs := b.prefs
|
||||
notify := b.notify
|
||||
b.mu.Unlock()
|
||||
@@ -865,7 +975,6 @@ func (b *LocalBackend) enterState(newState State) {
|
||||
b.send(Notify{State: &newState})
|
||||
}
|
||||
|
||||
b.state = newState
|
||||
switch newState {
|
||||
case NeedsLogin:
|
||||
b.blockEngineUpdates(true)
|
||||
@@ -894,7 +1003,7 @@ func (b *LocalBackend) nextState() State {
|
||||
b.assertClientLocked()
|
||||
var (
|
||||
c = b.c
|
||||
netMap = b.netMapCache
|
||||
netMap = b.netMap
|
||||
state = b.state
|
||||
wantRunning = b.prefs.WantRunning
|
||||
)
|
||||
@@ -941,7 +1050,7 @@ func (b *LocalBackend) RequestEngineStatus() {
|
||||
// RequestStatus implements Backend.
|
||||
func (b *LocalBackend) RequestStatus() {
|
||||
st := b.Status()
|
||||
b.notify(Notify{Status: st})
|
||||
b.send(Notify{Status: st})
|
||||
}
|
||||
|
||||
// stateMachine updates the state machine state based on other things
|
||||
@@ -992,13 +1101,13 @@ func (b *LocalBackend) Logout() {
|
||||
b.mu.Lock()
|
||||
b.assertClientLocked()
|
||||
c := b.c
|
||||
b.netMapCache = nil
|
||||
b.netMap = nil
|
||||
b.mu.Unlock()
|
||||
|
||||
c.Logout()
|
||||
|
||||
b.mu.Lock()
|
||||
b.netMapCache = nil
|
||||
b.netMap = nil
|
||||
b.mu.Unlock()
|
||||
|
||||
b.stateMachine()
|
||||
@@ -1011,13 +1120,13 @@ func (b *LocalBackend) assertClientLocked() {
|
||||
}
|
||||
}
|
||||
|
||||
// setNetInfo sets b.hiCache.NetInfo to ni, and passes ni along to the
|
||||
// setNetInfo sets b.hostinfo.NetInfo to ni, and passes ni along to the
|
||||
// controlclient, if one exists.
|
||||
func (b *LocalBackend) setNetInfo(ni *tailcfg.NetInfo) {
|
||||
b.mu.Lock()
|
||||
c := b.c
|
||||
if b.hiCache != nil {
|
||||
b.hiCache.NetInfo = ni.Clone()
|
||||
if b.hostinfo != nil {
|
||||
b.hostinfo.NetInfo = ni.Clone()
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/structs"
|
||||
"tailscale.com/version"
|
||||
@@ -49,6 +50,7 @@ type Command struct {
|
||||
Quit *NoArgs
|
||||
Start *StartArgs
|
||||
StartLoginInteractive *NoArgs
|
||||
Login *oauth2.Token
|
||||
Logout *NoArgs
|
||||
SetPrefs *SetPrefsArgs
|
||||
RequestEngineStatus *NoArgs
|
||||
@@ -80,6 +82,10 @@ func (bs *BackendServer) send(n Notify) {
|
||||
bs.sendNotifyMsg(b)
|
||||
}
|
||||
|
||||
func (bs *BackendServer) SendErrorMessage(msg string) {
|
||||
bs.send(Notify{ErrMessage: &msg})
|
||||
}
|
||||
|
||||
// GotCommandMsg parses the incoming message b as a JSON Command and
|
||||
// calls GotCommand with it.
|
||||
func (bs *BackendServer) GotCommandMsg(b []byte) error {
|
||||
@@ -124,6 +130,9 @@ func (bs *BackendServer) GotCommand(cmd *Command) error {
|
||||
} else if c := cmd.StartLoginInteractive; c != nil {
|
||||
bs.b.StartLoginInteractive()
|
||||
return nil
|
||||
} else if c := cmd.Login; c != nil {
|
||||
bs.b.Login(c)
|
||||
return nil
|
||||
} else if c := cmd.Logout; c != nil {
|
||||
bs.b.Logout()
|
||||
return nil
|
||||
@@ -221,6 +230,10 @@ func (bc *BackendClient) StartLoginInteractive() {
|
||||
bc.send(Command{StartLoginInteractive: &NoArgs{}})
|
||||
}
|
||||
|
||||
func (bc *BackendClient) Login(token *oauth2.Token) {
|
||||
bc.send(Command{Login: token})
|
||||
}
|
||||
|
||||
func (bc *BackendClient) Logout() {
|
||||
bc.send(Command{Logout: &NoArgs{}})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/tstest"
|
||||
)
|
||||
|
||||
@@ -177,4 +178,10 @@ func TestClientServer(t *testing.T) {
|
||||
|
||||
h.Logout()
|
||||
flushUntil(NeedsLogin)
|
||||
|
||||
h.Login(&oauth2.Token{
|
||||
AccessToken: "google_id_token",
|
||||
TokenType: GoogleIDTokenType,
|
||||
})
|
||||
flushUntil(Running)
|
||||
}
|
||||
|
||||
13
ipn/prefs.go
13
ipn/prefs.go
@@ -54,6 +54,10 @@ type Prefs struct {
|
||||
// Hostname is the hostname to use for identifying the node. If
|
||||
// not set, os.Hostname is used.
|
||||
Hostname string
|
||||
// OSVersion overrides tailcfg.Hostinfo's OSVersion.
|
||||
OSVersion string
|
||||
// DeviceModel overrides tailcfg.Hostinfo's DeviceModel.
|
||||
DeviceModel string
|
||||
|
||||
// NotepadURLs is a debugging setting that opens OAuth URLs in
|
||||
// notepad.exe on Windows, rather than loading them in a browser.
|
||||
@@ -138,6 +142,8 @@ func (p *Prefs) Equals(p2 *Prefs) bool {
|
||||
p.NoSNAT == p2.NoSNAT &&
|
||||
p.NetfilterMode == p2.NetfilterMode &&
|
||||
p.Hostname == p2.Hostname &&
|
||||
p.OSVersion == p2.OSVersion &&
|
||||
p.DeviceModel == p2.DeviceModel &&
|
||||
compareIPNets(p.AdvertiseRoutes, p2.AdvertiseRoutes) &&
|
||||
compareStrings(p.AdvertiseTags, p2.AdvertiseTags) &&
|
||||
p.Persist.Equals(p2.Persist)
|
||||
@@ -217,10 +223,9 @@ func (p *Prefs) Clone() *Prefs {
|
||||
return p2
|
||||
}
|
||||
|
||||
// LoadLegacyPrefs loads a legacy relaynode config file into Prefs
|
||||
// with sensible migration defaults set. If enforceDefaults is true,
|
||||
// Prefs.RouteAll and Prefs.AllowSingleHosts are forced on.
|
||||
func LoadPrefs(filename string, enforceDefaults bool) (*Prefs, error) {
|
||||
// LoadPrefs loads a legacy relaynode config file into Prefs
|
||||
// with sensible migration defaults set.
|
||||
func LoadPrefs(filename string) (*Prefs, error) {
|
||||
data, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading prefs from %q: %v", filename, err)
|
||||
|
||||
@@ -24,7 +24,7 @@ func fieldsOf(t reflect.Type) (fields []string) {
|
||||
func TestPrefsEqual(t *testing.T) {
|
||||
tstest.PanicOnLog()
|
||||
|
||||
prefsHandles := []string{"ControlURL", "RouteAll", "AllowSingleHosts", "CorpDNS", "WantRunning", "ShieldsUp", "AdvertiseTags", "Hostname", "NotepadURLs", "DisableDERP", "AdvertiseRoutes", "NoSNAT", "NetfilterMode", "Persist"}
|
||||
prefsHandles := []string{"ControlURL", "RouteAll", "AllowSingleHosts", "CorpDNS", "WantRunning", "ShieldsUp", "AdvertiseTags", "Hostname", "OSVersion", "DeviceModel", "NotepadURLs", "DisableDERP", "AdvertiseRoutes", "NoSNAT", "NetfilterMode", "Persist"}
|
||||
if have := fieldsOf(reflect.TypeOf(Prefs{})); !reflect.DeepEqual(have, prefsHandles) {
|
||||
t.Errorf("Prefs.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
||||
have, prefsHandles)
|
||||
|
||||
@@ -7,9 +7,9 @@ package logheap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"time"
|
||||
@@ -17,29 +17,25 @@ import (
|
||||
|
||||
// LogHeap writes a JSON logtail record with the base64 heap pprof to
|
||||
// os.Stderr.
|
||||
func LogHeap() {
|
||||
logHeap(os.Stderr)
|
||||
}
|
||||
|
||||
type logTail struct {
|
||||
ClientTime string `json:"client_time"`
|
||||
}
|
||||
|
||||
type pprofRec struct {
|
||||
Heap []byte `json:"heap,omitempty"`
|
||||
}
|
||||
|
||||
type logLine struct {
|
||||
LogTail logTail `json:"logtail"`
|
||||
Pprof pprofRec `json:"pprof"`
|
||||
}
|
||||
|
||||
func logHeap(w io.Writer) error {
|
||||
func LogHeap(postURL string) {
|
||||
if postURL == "" {
|
||||
return
|
||||
}
|
||||
runtime.GC()
|
||||
buf := new(bytes.Buffer)
|
||||
pprof.WriteHeapProfile(buf)
|
||||
return json.NewEncoder(w).Encode(logLine{
|
||||
LogTail: logTail{ClientTime: time.Now().Format(time.RFC3339Nano)},
|
||||
Pprof: pprofRec{Heap: buf.Bytes()},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", postURL, buf)
|
||||
if err != nil {
|
||||
log.Printf("LogHeap: %v", err)
|
||||
return
|
||||
}
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("LogHeap: %v", err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
}
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package logheap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLogHeap(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
if err := logHeap(&buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Got line: %s", buf.Bytes())
|
||||
|
||||
var ll logLine
|
||||
if err := json.Unmarshal(buf.Bytes(), &ll); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
zr, err := gzip.NewReader(bytes.NewReader(ll.Pprof.Heap))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rawProto, err := ioutil.ReadAll(zr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Just sanity check it. Too lazy to properly decode the protobuf. But see that
|
||||
// it contains an expected sample name.
|
||||
if !bytes.Contains(rawProto, []byte("alloc_objects")) {
|
||||
t.Errorf("raw proto didn't contain `alloc_objects`: %q", rawProto)
|
||||
}
|
||||
}
|
||||
@@ -25,13 +25,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
"tailscale.com/atomicfile"
|
||||
"tailscale.com/logtail"
|
||||
"tailscale.com/logtail/filch"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/net/tlsdial"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
@@ -305,7 +305,9 @@ func New(collection string) *Policy {
|
||||
|
||||
dir := logsDir()
|
||||
|
||||
tryFixLogStateLocation(dir, version.CmdName())
|
||||
if runtime.GOOS != "windows" { // version.CmdName call was blowing some Windows stack limit via goversion DLL loading
|
||||
tryFixLogStateLocation(dir, version.CmdName())
|
||||
}
|
||||
|
||||
cfgPath := filepath.Join(dir, fmt.Sprintf("%s.log.conf", version.CmdName()))
|
||||
var oldc *Config
|
||||
@@ -348,11 +350,7 @@ func New(collection string) *Policy {
|
||||
PrivateID: newc.PrivateID,
|
||||
Stderr: logWriter{console},
|
||||
NewZstdEncoder: func() logtail.Encoder {
|
||||
w, err := zstd.NewWriter(nil,
|
||||
zstd.WithEncoderLevel(zstd.SpeedFastest),
|
||||
zstd.WithEncoderConcurrency(1),
|
||||
zstd.WithWindowSize(8192),
|
||||
)
|
||||
w, err := smallzstd.NewEncoder(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -273,10 +273,10 @@ func (l *logger) uploading(ctx context.Context) {
|
||||
if err != nil {
|
||||
fmt.Fprintf(l.stderr, "logtail: upload: %v\n", err)
|
||||
}
|
||||
l.bo.BackOff(ctx, err)
|
||||
if uploaded {
|
||||
break
|
||||
}
|
||||
l.bo.BackOff(ctx, err)
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
@@ -40,3 +40,10 @@ func (m *LabelMap) Get(key string) *expvar.Int {
|
||||
m.Add(key, 0)
|
||||
return m.Map.Get(key).(*expvar.Int)
|
||||
}
|
||||
|
||||
// GetFloat returns a direct pointer to the expvar.Float for key, creating it
|
||||
// if necessary.
|
||||
func (m *LabelMap) GetFloat(key string) *expvar.Float {
|
||||
m.AddFloat(key, 0.0)
|
||||
return m.Map.Get(key).(*expvar.Float)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
// Tailscale returns the current machine's Tailscale interface, if any.
|
||||
@@ -52,7 +53,7 @@ func maybeTailscaleInterfaceName(s string) bool {
|
||||
// Tailscale virtual network interfaces.
|
||||
func IsTailscaleIP(ip net.IP) bool {
|
||||
nip, _ := netaddr.FromStdIP(ip) // TODO: push this up to caller, change func signature
|
||||
return cgNAT.Contains(nip)
|
||||
return tsaddr.IsTailscaleIP(nip)
|
||||
}
|
||||
|
||||
func isUp(nif *net.Interface) bool { return nif.Flags&net.FlagUp != 0 }
|
||||
@@ -95,7 +96,7 @@ func LocalAddresses() (regular, loopback []string, err error) {
|
||||
// very well be something we can route to
|
||||
// directly, because both nodes are
|
||||
// behind the same CGNAT router.
|
||||
if cgNAT.Contains(ip) {
|
||||
if tsaddr.IsTailscaleIP(ip) {
|
||||
continue
|
||||
}
|
||||
if linkLocalIPv4.Contains(ip) {
|
||||
@@ -230,6 +231,38 @@ func HTTPOfListener(ln net.Listener) string {
|
||||
|
||||
}
|
||||
|
||||
var likelyHomeRouterIP func() (netaddr.IP, bool)
|
||||
|
||||
// LikelyHomeRouterIP returns the likely IP of the residential router,
|
||||
// which will always be an IPv4 private address, if found.
|
||||
// In addition, it returns the IP address of the current machine on
|
||||
// the LAN using that gateway.
|
||||
// This is used as the destination for UPnP, NAT-PMP, PCP, etc queries.
|
||||
func LikelyHomeRouterIP() (gateway, myIP netaddr.IP, ok bool) {
|
||||
if likelyHomeRouterIP != nil {
|
||||
gateway, ok = likelyHomeRouterIP()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ForeachInterfaceAddress(func(i Interface, ip netaddr.IP) {
|
||||
if !i.IsUp() || ip.IsZero() || !myIP.IsZero() {
|
||||
return
|
||||
}
|
||||
for _, prefix := range privatev4s {
|
||||
if prefix.Contains(gateway) && prefix.Contains(ip) {
|
||||
myIP = ip
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return gateway, myIP, !myIP.IsZero()
|
||||
}
|
||||
|
||||
func isPrivateIP(ip netaddr.IP) bool {
|
||||
return private1.Contains(ip) || private2.Contains(ip) || private3.Contains(ip)
|
||||
}
|
||||
@@ -250,7 +283,7 @@ var (
|
||||
private1 = mustCIDR("10.0.0.0/8")
|
||||
private2 = mustCIDR("172.16.0.0/12")
|
||||
private3 = mustCIDR("192.168.0.0/16")
|
||||
cgNAT = mustCIDR("100.64.0.0/10")
|
||||
privatev4s = []netaddr.IPPrefix{private1, private2, private3}
|
||||
linkLocalIPv4 = mustCIDR("169.254.0.0/16")
|
||||
v6Global1 = mustCIDR("2000::/3")
|
||||
)
|
||||
|
||||
66
net/interfaces/interfaces_darwin.go
Normal file
66
net/interfaces/interfaces_darwin.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
|
||||
"go4.org/mem"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/util/lineread"
|
||||
)
|
||||
|
||||
func init() {
|
||||
likelyHomeRouterIP = likelyHomeRouterIPDarwin
|
||||
}
|
||||
|
||||
/*
|
||||
Parse out 10.0.0.1 from:
|
||||
|
||||
$ netstat -r -n -f inet
|
||||
Routing tables
|
||||
|
||||
Internet:
|
||||
Destination Gateway Flags Netif Expire
|
||||
default 10.0.0.1 UGSc en0
|
||||
default link#14 UCSI utun2
|
||||
10/16 link#4 UCS en0 !
|
||||
10.0.0.1/32 link#4 UCS en0 !
|
||||
...
|
||||
|
||||
*/
|
||||
func likelyHomeRouterIPDarwin() (ret netaddr.IP, ok bool) {
|
||||
cmd := exec.Command("/usr/sbin/netstat", "-r", "-n", "-f", "inet")
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return
|
||||
}
|
||||
defer cmd.Wait()
|
||||
|
||||
var f []mem.RO
|
||||
lineread.Reader(stdout, func(lineb []byte) error {
|
||||
line := mem.B(lineb)
|
||||
if !mem.Contains(line, mem.S("default")) {
|
||||
return nil
|
||||
}
|
||||
f = mem.AppendFields(f[:0], line)
|
||||
if len(f) < 3 || !f[0].EqualString("default") {
|
||||
return nil
|
||||
}
|
||||
ipm, flagsm := f[1], f[2]
|
||||
if !mem.Contains(flagsm, mem.S("G")) {
|
||||
return nil
|
||||
}
|
||||
ip, err := netaddr.ParseIP(string(mem.Append(nil, ipm)))
|
||||
if err == nil && isPrivateIP(ip) {
|
||||
ret = ip
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return ret, !ret.IsZero()
|
||||
}
|
||||
59
net/interfaces/interfaces_linux.go
Normal file
59
net/interfaces/interfaces_linux.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"go4.org/mem"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/util/lineread"
|
||||
)
|
||||
|
||||
func init() {
|
||||
likelyHomeRouterIP = likelyHomeRouterIPLinux
|
||||
}
|
||||
|
||||
/*
|
||||
Parse 10.0.0.1 out of:
|
||||
|
||||
$ cat /proc/net/route
|
||||
Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
ens18 00000000 0100000A 0003 0 0 0 00000000 0 0 0
|
||||
ens18 0000000A 00000000 0001 0 0 0 0000FFFF 0 0 0
|
||||
*/
|
||||
func likelyHomeRouterIPLinux() (ret netaddr.IP, ok bool) {
|
||||
lineNum := 0
|
||||
var f []mem.RO
|
||||
lineread.File("/proc/net/route", func(line []byte) error {
|
||||
lineNum++
|
||||
if lineNum == 1 {
|
||||
// Skip header line.
|
||||
return nil
|
||||
}
|
||||
f = mem.AppendFields(f[:0], mem.B(line))
|
||||
if len(f) < 4 {
|
||||
return nil
|
||||
}
|
||||
gwHex, flagsHex := f[2], f[3]
|
||||
flags, err := mem.ParseUint(flagsHex, 16, 16)
|
||||
if err != nil {
|
||||
return nil // ignore error, skip line and keep going
|
||||
}
|
||||
const RTF_UP = 0x0001
|
||||
const RTF_GATEWAY = 0x0002
|
||||
if flags&(RTF_UP|RTF_GATEWAY) != RTF_UP|RTF_GATEWAY {
|
||||
return nil
|
||||
}
|
||||
ipu32, err := mem.ParseUint(gwHex, 16, 32)
|
||||
if err != nil {
|
||||
return nil // ignore error, skip line and keep going
|
||||
}
|
||||
ip := netaddr.IPv4(byte(ipu32), byte(ipu32>>8), byte(ipu32>>16), byte(ipu32>>24))
|
||||
if isPrivateIP(ip) {
|
||||
ret = ip
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return ret, !ret.IsZero()
|
||||
}
|
||||
@@ -47,3 +47,12 @@ func TestGetState(t *testing.T) {
|
||||
t.Fatal("two States back-to-back were not equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLikelyHomeRouterIP(t *testing.T) {
|
||||
gw, my, ok := LikelyHomeRouterIP()
|
||||
if !ok {
|
||||
t.Logf("no result")
|
||||
return
|
||||
}
|
||||
t.Logf("myIP = %v; gw = %v", my, gw)
|
||||
}
|
||||
|
||||
73
net/interfaces/interfaces_windows.go
Normal file
73
net/interfaces/interfaces_windows.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"go4.org/mem"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/util/lineread"
|
||||
)
|
||||
|
||||
func init() {
|
||||
likelyHomeRouterIP = likelyHomeRouterIPWindows
|
||||
}
|
||||
|
||||
/*
|
||||
Parse out 10.0.0.1 from:
|
||||
|
||||
Z:\>route print -4
|
||||
===========================================================================
|
||||
Interface List
|
||||
15...aa 15 48 ff 1c 72 ......Red Hat VirtIO Ethernet Adapter
|
||||
5...........................Tailscale Tunnel
|
||||
1...........................Software Loopback Interface 1
|
||||
===========================================================================
|
||||
|
||||
IPv4 Route Table
|
||||
===========================================================================
|
||||
Active Routes:
|
||||
Network Destination Netmask Gateway Interface Metric
|
||||
0.0.0.0 0.0.0.0 10.0.0.1 10.0.28.63 5
|
||||
10.0.0.0 255.255.0.0 On-link 10.0.28.63 261
|
||||
10.0.28.63 255.255.255.255 On-link 10.0.28.63 261
|
||||
10.0.42.0 255.255.255.0 100.103.42.106 100.103.42.106 5
|
||||
10.0.255.255 255.255.255.255 On-link 10.0.28.63 261
|
||||
34.193.248.174 255.255.255.255 100.103.42.106 100.103.42.106 5
|
||||
|
||||
*/
|
||||
func likelyHomeRouterIPWindows() (ret netaddr.IP, ok bool) {
|
||||
cmd := exec.Command("route", "print", "-4")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return
|
||||
}
|
||||
defer cmd.Wait()
|
||||
|
||||
var f []mem.RO
|
||||
lineread.Reader(stdout, func(lineb []byte) error {
|
||||
line := mem.B(lineb)
|
||||
if !mem.Contains(line, mem.S("0.0.0.0")) {
|
||||
return nil
|
||||
}
|
||||
f = mem.AppendFields(f[:0], line)
|
||||
if len(f) < 3 || !f[0].EqualString("0.0.0.0") || !f[1].EqualString("0.0.0.0") {
|
||||
return nil
|
||||
}
|
||||
ipm := f[2]
|
||||
ip, err := netaddr.ParseIP(string(mem.Append(nil, ipm)))
|
||||
if err == nil && isPrivateIP(ip) {
|
||||
ret = ip
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return ret, !ret.IsZero()
|
||||
}
|
||||
@@ -6,9 +6,11 @@
|
||||
package netcheck
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/tcnksm/go-httpstat"
|
||||
"go4.org/mem"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/net/dnscache"
|
||||
@@ -34,15 +37,26 @@ import (
|
||||
)
|
||||
|
||||
type Report struct {
|
||||
UDP bool // UDP works
|
||||
IPv6 bool // IPv6 works
|
||||
IPv4 bool // IPv4 works
|
||||
MappingVariesByDestIP opt.Bool // for IPv4
|
||||
HairPinning opt.Bool // for IPv4
|
||||
PreferredDERP int // or 0 for unknown
|
||||
RegionLatency map[int]time.Duration // keyed by DERP Region ID
|
||||
RegionV4Latency map[int]time.Duration // keyed by DERP Region ID
|
||||
RegionV6Latency map[int]time.Duration // keyed by DERP Region ID
|
||||
UDP bool // UDP works
|
||||
IPv6 bool // IPv6 works
|
||||
IPv4 bool // IPv4 works
|
||||
MappingVariesByDestIP opt.Bool // for IPv4
|
||||
HairPinning opt.Bool // for IPv4
|
||||
|
||||
// UPnP is whether UPnP appears present on the LAN.
|
||||
// Empty means not checked.
|
||||
UPnP opt.Bool
|
||||
// PMP is whether NAT-PMP appears present on the LAN.
|
||||
// Empty means not checked.
|
||||
PMP opt.Bool
|
||||
// PCP is whether PCP appears present on the LAN.
|
||||
// Empty means not checked.
|
||||
PCP opt.Bool
|
||||
|
||||
PreferredDERP int // or 0 for unknown
|
||||
RegionLatency map[int]time.Duration // keyed by DERP Region ID
|
||||
RegionV4Latency map[int]time.Duration // keyed by DERP Region ID
|
||||
RegionV6Latency map[int]time.Duration // keyed by DERP Region ID
|
||||
|
||||
GlobalV4 string // ip:port of global IPv4
|
||||
GlobalV6 string // [ip]:port of global IPv6
|
||||
@@ -50,6 +64,11 @@ type Report struct {
|
||||
// TODO: update Clone when adding new fields
|
||||
}
|
||||
|
||||
// AnyPortMappingChecked reports whether any of UPnP, PMP, or PCP are non-empty.
|
||||
func (r *Report) AnyPortMappingChecked() bool {
|
||||
return r.UPnP != "" || r.PMP != "" || r.PCP != ""
|
||||
}
|
||||
|
||||
func (r *Report) Clone() *Report {
|
||||
if r == nil {
|
||||
return nil
|
||||
@@ -127,7 +146,7 @@ func (c *Client) vlogf(format string, a ...interface{}) {
|
||||
|
||||
// handleHairSTUN reports whether pkt (from src) was our magic hairpin
|
||||
// probe packet that we sent to ourselves.
|
||||
func (c *Client) handleHairSTUNLocked(pkt []byte, src *net.UDPAddr) bool {
|
||||
func (c *Client) handleHairSTUNLocked(pkt []byte, src netaddr.IPPort) bool {
|
||||
rs := c.curState
|
||||
if rs == nil {
|
||||
return false
|
||||
@@ -150,11 +169,7 @@ func (c *Client) MakeNextReportFull() {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) ReceiveSTUNPacket(pkt []byte, src *net.UDPAddr) {
|
||||
if src == nil || src.IP == nil {
|
||||
panic("bogus src")
|
||||
}
|
||||
|
||||
func (c *Client) ReceiveSTUNPacket(pkt []byte, src netaddr.IPPort) {
|
||||
c.mu.Lock()
|
||||
if c.handleHairSTUNLocked(pkt, src) {
|
||||
c.mu.Unlock()
|
||||
@@ -421,7 +436,9 @@ func (c *Client) readPackets(ctx context.Context, pc net.PacketConn) {
|
||||
if !stun.Is(pkt) {
|
||||
continue
|
||||
}
|
||||
c.ReceiveSTUNPacket(pkt, ua)
|
||||
if ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone); ok {
|
||||
c.ReceiveSTUNPacket(pkt, ipp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -429,13 +446,14 @@ func (c *Client) readPackets(ctx context.Context, pc net.PacketConn) {
|
||||
type reportState struct {
|
||||
c *Client
|
||||
hairTX stun.TxID
|
||||
gotHairSTUN chan *net.UDPAddr
|
||||
gotHairSTUN chan netaddr.IPPort
|
||||
hairTimeout chan struct{} // closed on timeout
|
||||
pc4 STUNConn
|
||||
pc6 STUNConn
|
||||
pc4Hair net.PacketConn
|
||||
incremental bool // doing a lite, follow-up netcheck
|
||||
stopProbeCh chan struct{}
|
||||
waitPortMap sync.WaitGroup
|
||||
|
||||
mu sync.Mutex
|
||||
sentHairCheck bool
|
||||
@@ -601,6 +619,102 @@ func (rs *reportState) stopProbes() {
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *reportState) setOptBool(b *opt.Bool, v bool) {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
b.Set(v)
|
||||
}
|
||||
|
||||
func (rs *reportState) probePortMapServices() {
|
||||
defer rs.waitPortMap.Done()
|
||||
gw, myIP, ok := interfaces.LikelyHomeRouterIP()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
rs.setOptBool(&rs.report.UPnP, false)
|
||||
rs.setOptBool(&rs.report.PMP, false)
|
||||
rs.setOptBool(&rs.report.PCP, false)
|
||||
|
||||
port1900 := netaddr.IPPort{IP: gw, Port: 1900}.UDPAddr()
|
||||
port5351 := netaddr.IPPort{IP: gw, Port: 5351}.UDPAddr()
|
||||
|
||||
rs.c.logf("probePortMapServices: me %v -> gw %v", myIP, gw)
|
||||
|
||||
// Create a UDP4 socket used just for querying for UPnP, NAT-PMP, and PCP.
|
||||
uc, err := netns.Listener().ListenPacket(context.Background(), "udp4", ":0")
|
||||
if err != nil {
|
||||
rs.c.logf("probePortMapServices: %v", err)
|
||||
return
|
||||
}
|
||||
defer uc.Close()
|
||||
tempPort := uc.LocalAddr().(*net.UDPAddr).Port
|
||||
uc.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
|
||||
// Send request packets for all three protocols.
|
||||
uc.WriteTo(uPnPPacket, port1900)
|
||||
uc.WriteTo(pmpPacket, port5351)
|
||||
uc.WriteTo(pcpPacket(myIP, tempPort, false), port5351)
|
||||
|
||||
res := make([]byte, 1500)
|
||||
for {
|
||||
n, addr, err := uc.ReadFrom(res)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch addr.(*net.UDPAddr).Port {
|
||||
case 1900:
|
||||
if mem.Contains(mem.B(res[:n]), mem.S(":InternetGatewayDevice:")) {
|
||||
rs.setOptBool(&rs.report.UPnP, true)
|
||||
}
|
||||
case 5351:
|
||||
if n == 12 && res[0] == 0x00 { // right length and version 0
|
||||
rs.setOptBool(&rs.report.PMP, true)
|
||||
}
|
||||
if n == 60 && res[0] == 0x02 { // right length and version 2
|
||||
rs.setOptBool(&rs.report.PCP, true)
|
||||
|
||||
// And now delete the mapping.
|
||||
// (PCP is the only protocol of the three that requires
|
||||
// we cause a side effect to detect whether it's present,
|
||||
// so we need to redo that side effect now.)
|
||||
uc.WriteTo(pcpPacket(myIP, tempPort, true), port5351)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pmpPacket = []byte{0, 0} // version 0, opcode 0 = "Public address request"
|
||||
|
||||
var uPnPPacket = []byte("M-SEARCH * HTTP/1.1\r\n" +
|
||||
"HOST: 239.255.255.250:1900\r\n" +
|
||||
"ST: ssdp:all\r\n" +
|
||||
"MAN: \"ssdp:discover\"\r\n" +
|
||||
"MX: 2\r\n\r\n")
|
||||
|
||||
var v4unspec, _ = netaddr.ParseIP("0.0.0.0")
|
||||
|
||||
func pcpPacket(myIP netaddr.IP, mapToLocalPort int, delete bool) []byte {
|
||||
const udpProtoNumber = 17
|
||||
lifetimeSeconds := uint32(1)
|
||||
if delete {
|
||||
lifetimeSeconds = 0
|
||||
}
|
||||
const opMap = 1
|
||||
pkt := make([]byte, (32+32+128)/8+(96+8+24+16+16+128)/8)
|
||||
pkt[0] = 2 // version
|
||||
pkt[1] = opMap
|
||||
binary.BigEndian.PutUint32(pkt[4:8], lifetimeSeconds)
|
||||
myIP16 := myIP.As16()
|
||||
copy(pkt[8:], myIP16[:])
|
||||
rand.Read(pkt[24 : 24+12])
|
||||
pkt[36] = udpProtoNumber
|
||||
binary.BigEndian.PutUint16(pkt[40:], uint16(mapToLocalPort))
|
||||
v4unspec16 := v4unspec.As16()
|
||||
copy(pkt[40:], v4unspec16[:])
|
||||
return pkt
|
||||
}
|
||||
|
||||
func newReport() *Report {
|
||||
return &Report{
|
||||
RegionLatency: make(map[int]time.Duration),
|
||||
@@ -638,7 +752,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, e
|
||||
report: newReport(),
|
||||
inFlight: map[stun.TxID]func(netaddr.IPPort){},
|
||||
hairTX: stun.NewTxID(), // random payload
|
||||
gotHairSTUN: make(chan *net.UDPAddr, 1),
|
||||
gotHairSTUN: make(chan netaddr.IPPort, 1),
|
||||
hairTimeout: make(chan struct{}),
|
||||
stopProbeCh: make(chan struct{}, 1),
|
||||
}
|
||||
@@ -673,6 +787,22 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, e
|
||||
}
|
||||
defer rs.pc4Hair.Close()
|
||||
|
||||
rs.waitPortMap.Add(1)
|
||||
go rs.probePortMapServices()
|
||||
|
||||
// At least the Apple Airport Extreme doesn't allow hairpin
|
||||
// sends from a private socket until it's seen traffic from
|
||||
// that src IP:port to something else out on the internet.
|
||||
//
|
||||
// See https://github.com/tailscale/tailscale/issues/188#issuecomment-600728643
|
||||
//
|
||||
// And it seems that even sending to a likely-filtered RFC 5737
|
||||
// documentation-only IPv4 range is enough to set up the mapping.
|
||||
// So do that for now. In the future we might want to classify networks
|
||||
// that do and don't require this separately. But for now help it.
|
||||
const documentationIP = "203.0.113.1"
|
||||
rs.pc4Hair.WriteTo([]byte("tailscale netcheck; see https://github.com/tailscale/tailscale/issues/188"), &net.UDPAddr{IP: net.ParseIP(documentationIP), Port: 12345})
|
||||
|
||||
if f := c.GetSTUNConn4; f != nil {
|
||||
rs.pc4 = f()
|
||||
} else {
|
||||
@@ -727,6 +857,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, e
|
||||
}
|
||||
|
||||
rs.waitHairCheck(ctx)
|
||||
rs.waitPortMap.Wait()
|
||||
rs.stopTimers()
|
||||
|
||||
// Try HTTPS latency check if all STUN probes failed due to UDP presumably being blocked.
|
||||
@@ -841,42 +972,48 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio
|
||||
}
|
||||
|
||||
func (c *Client) logConciseReport(r *Report, dm *tailcfg.DERPMap) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 256)) // empirically: 5 DERPs + IPv6 == ~233 bytes
|
||||
fmt.Fprintf(buf, "udp=%v", r.UDP)
|
||||
if !r.IPv4 {
|
||||
fmt.Fprintf(buf, " v4=%v", r.IPv4)
|
||||
}
|
||||
fmt.Fprintf(buf, " v6=%v", r.IPv6)
|
||||
fmt.Fprintf(buf, " mapvarydest=%v", r.MappingVariesByDestIP)
|
||||
fmt.Fprintf(buf, " hair=%v", r.HairPinning)
|
||||
if r.GlobalV4 != "" {
|
||||
fmt.Fprintf(buf, " v4a=%v", r.GlobalV4)
|
||||
}
|
||||
if r.GlobalV6 != "" {
|
||||
fmt.Fprintf(buf, " v6a=%v", r.GlobalV6)
|
||||
}
|
||||
fmt.Fprintf(buf, " derp=%v", r.PreferredDERP)
|
||||
if r.PreferredDERP != 0 {
|
||||
fmt.Fprintf(buf, " derpdist=")
|
||||
for i, rid := range dm.RegionIDs() {
|
||||
if i != 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
c.logf("%v", logger.ArgWriter(func(w *bufio.Writer) {
|
||||
fmt.Fprintf(w, "udp=%v", r.UDP)
|
||||
if !r.IPv4 {
|
||||
fmt.Fprintf(w, " v4=%v", r.IPv4)
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, " v6=%v", r.IPv6)
|
||||
fmt.Fprintf(w, " mapvarydest=%v", r.MappingVariesByDestIP)
|
||||
fmt.Fprintf(w, " hair=%v", r.HairPinning)
|
||||
if r.AnyPortMappingChecked() {
|
||||
fmt.Fprintf(w, " portmap=%v%v%v", conciseOptBool(r.UPnP, "U"), conciseOptBool(r.PMP, "M"), conciseOptBool(r.PCP, "C"))
|
||||
} else {
|
||||
fmt.Fprintf(w, " portmap=?")
|
||||
}
|
||||
if r.GlobalV4 != "" {
|
||||
fmt.Fprintf(w, " v4a=%v", r.GlobalV4)
|
||||
}
|
||||
if r.GlobalV6 != "" {
|
||||
fmt.Fprintf(w, " v6a=%v", r.GlobalV6)
|
||||
}
|
||||
fmt.Fprintf(w, " derp=%v", r.PreferredDERP)
|
||||
if r.PreferredDERP != 0 {
|
||||
fmt.Fprintf(w, " derpdist=")
|
||||
needComma := false
|
||||
if d := r.RegionV4Latency[rid]; d != 0 {
|
||||
fmt.Fprintf(buf, "%dv4:%v", rid, d.Round(time.Millisecond))
|
||||
needComma = true
|
||||
}
|
||||
if d := r.RegionV6Latency[rid]; d != 0 {
|
||||
if needComma {
|
||||
buf.WriteByte(',')
|
||||
for _, rid := range dm.RegionIDs() {
|
||||
if d := r.RegionV4Latency[rid]; d != 0 {
|
||||
if needComma {
|
||||
w.WriteByte(',')
|
||||
}
|
||||
fmt.Fprintf(w, "%dv4:%v", rid, d.Round(time.Millisecond))
|
||||
needComma = true
|
||||
}
|
||||
if d := r.RegionV6Latency[rid]; d != 0 {
|
||||
if needComma {
|
||||
w.WriteByte(',')
|
||||
}
|
||||
fmt.Fprintf(w, "%dv6:%v", rid, d.Round(time.Millisecond))
|
||||
needComma = true
|
||||
}
|
||||
fmt.Fprintf(buf, "%dv6:%v", rid, d.Round(time.Millisecond))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.logf("%s", buf.Bytes())
|
||||
}))
|
||||
}
|
||||
|
||||
func (c *Client) timeNow() time.Time {
|
||||
@@ -1009,6 +1146,20 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP
|
||||
if port < 0 || port > 1<<16-1 {
|
||||
return nil
|
||||
}
|
||||
if n.STUNTestIP != "" {
|
||||
ip, err := netaddr.ParseIP(n.STUNTestIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if proto == probeIPv4 && ip.Is6() {
|
||||
return nil
|
||||
}
|
||||
if proto == probeIPv6 && ip.Is4() {
|
||||
return nil
|
||||
}
|
||||
return netaddr.IPPort{IP: ip, Port: uint16(port)}.UDPAddr()
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case probeIPv4:
|
||||
if n.IPv4 != "" {
|
||||
@@ -1016,7 +1167,7 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP
|
||||
if !ip.Is4() {
|
||||
return nil
|
||||
}
|
||||
return netaddr.IPPort{ip, uint16(port)}.UDPAddr()
|
||||
return netaddr.IPPort{IP: ip, Port: uint16(port)}.UDPAddr()
|
||||
}
|
||||
case probeIPv6:
|
||||
if n.IPv6 != "" {
|
||||
@@ -1024,7 +1175,7 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP
|
||||
if !ip.Is6() {
|
||||
return nil
|
||||
}
|
||||
return netaddr.IPPort{ip, uint16(port)}.UDPAddr()
|
||||
return netaddr.IPPort{IP: ip, Port: uint16(port)}.UDPAddr()
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
@@ -1057,3 +1208,17 @@ func maxDurationValue(m map[int]time.Duration) (max time.Duration) {
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func conciseOptBool(b opt.Bool, trueVal string) string {
|
||||
if b == "" {
|
||||
return "_"
|
||||
}
|
||||
v, ok := b.Get()
|
||||
if !ok {
|
||||
return "x"
|
||||
}
|
||||
if v {
|
||||
return trueVal
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package netcheck
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/stun"
|
||||
"tailscale.com/net/stun/stuntest"
|
||||
@@ -26,14 +28,14 @@ func TestHairpinSTUN(t *testing.T) {
|
||||
c := &Client{
|
||||
curState: &reportState{
|
||||
hairTX: tx,
|
||||
gotHairSTUN: make(chan *net.UDPAddr, 1),
|
||||
gotHairSTUN: make(chan netaddr.IPPort, 1),
|
||||
},
|
||||
}
|
||||
req := stun.Request(tx)
|
||||
if !stun.Is(req) {
|
||||
t.Fatal("expected STUN message")
|
||||
}
|
||||
if !c.handleHairSTUNLocked(req, nil) {
|
||||
if !c.handleHairSTUNLocked(req, netaddr.IPPort{}) {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
select {
|
||||
@@ -98,6 +100,9 @@ func TestWorksWhenUDPBlocked(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := newReport()
|
||||
r.UPnP = ""
|
||||
r.PMP = ""
|
||||
r.PCP = ""
|
||||
|
||||
if !reflect.DeepEqual(r, want) {
|
||||
t.Errorf("mismatch\n got: %+v\nwant: %+v\n", r, want)
|
||||
@@ -443,3 +448,114 @@ func (p probeProto) String() string {
|
||||
}
|
||||
return "?"
|
||||
}
|
||||
|
||||
func TestLogConciseReport(t *testing.T) {
|
||||
dm := &tailcfg.DERPMap{
|
||||
Regions: map[int]*tailcfg.DERPRegion{
|
||||
1: nil,
|
||||
2: nil,
|
||||
3: nil,
|
||||
},
|
||||
}
|
||||
const ms = time.Millisecond
|
||||
tests := []struct {
|
||||
name string
|
||||
r *Report
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no_udp",
|
||||
r: &Report{},
|
||||
want: "udp=false v4=false v6=false mapvarydest= hair= portmap=? derp=0",
|
||||
},
|
||||
{
|
||||
name: "ipv4_one_region",
|
||||
r: &Report{
|
||||
UDP: true,
|
||||
IPv4: true,
|
||||
PreferredDERP: 1,
|
||||
RegionLatency: map[int]time.Duration{
|
||||
1: 10 * ms,
|
||||
},
|
||||
RegionV4Latency: map[int]time.Duration{
|
||||
1: 10 * ms,
|
||||
},
|
||||
},
|
||||
want: "udp=true v6=false mapvarydest= hair= portmap=? derp=1 derpdist=1v4:10ms",
|
||||
},
|
||||
{
|
||||
name: "ipv4_all_region",
|
||||
r: &Report{
|
||||
UDP: true,
|
||||
IPv4: true,
|
||||
PreferredDERP: 1,
|
||||
RegionLatency: map[int]time.Duration{
|
||||
1: 10 * ms,
|
||||
2: 20 * ms,
|
||||
3: 30 * ms,
|
||||
},
|
||||
RegionV4Latency: map[int]time.Duration{
|
||||
1: 10 * ms,
|
||||
2: 20 * ms,
|
||||
3: 30 * ms,
|
||||
},
|
||||
},
|
||||
want: "udp=true v6=false mapvarydest= hair= portmap=? derp=1 derpdist=1v4:10ms,2v4:20ms,3v4:30ms",
|
||||
},
|
||||
{
|
||||
name: "ipboth_all_region",
|
||||
r: &Report{
|
||||
UDP: true,
|
||||
IPv4: true,
|
||||
IPv6: true,
|
||||
PreferredDERP: 1,
|
||||
RegionLatency: map[int]time.Duration{
|
||||
1: 10 * ms,
|
||||
2: 20 * ms,
|
||||
3: 30 * ms,
|
||||
},
|
||||
RegionV4Latency: map[int]time.Duration{
|
||||
1: 10 * ms,
|
||||
2: 20 * ms,
|
||||
3: 30 * ms,
|
||||
},
|
||||
RegionV6Latency: map[int]time.Duration{
|
||||
1: 10 * ms,
|
||||
2: 20 * ms,
|
||||
3: 30 * ms,
|
||||
},
|
||||
},
|
||||
want: "udp=true v6=true mapvarydest= hair= portmap=? derp=1 derpdist=1v4:10ms,1v6:10ms,2v4:20ms,2v6:20ms,3v4:30ms,3v6:30ms",
|
||||
},
|
||||
{
|
||||
name: "portmap_all",
|
||||
r: &Report{
|
||||
UDP: true,
|
||||
UPnP: "true",
|
||||
PMP: "true",
|
||||
PCP: "true",
|
||||
},
|
||||
want: "udp=true v4=false v6=false mapvarydest= hair= portmap=UMC derp=0",
|
||||
},
|
||||
{
|
||||
name: "portmap_some",
|
||||
r: &Report{
|
||||
UDP: true,
|
||||
UPnP: "true",
|
||||
PMP: "false",
|
||||
PCP: "true",
|
||||
},
|
||||
want: "udp=true v4=false v6=false mapvarydest= hair= portmap=UC derp=0",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := &Client{Logf: func(f string, a ...interface{}) { fmt.Fprintf(&buf, f, a...) }}
|
||||
c.logConciseReport(tt.r, dm)
|
||||
if got := buf.String(); got != tt.want {
|
||||
t.Errorf("unexpected result.\n got: %#q\nwant: %#q\n", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
//
|
||||
// Keep this in sync with tailscaleBypassMark in
|
||||
// wgengine/router/router_linux.go.
|
||||
const tailscaleBypassMark = 0x20000
|
||||
const tailscaleBypassMark = 0x80000
|
||||
|
||||
// ipRuleOnce is the sync.Once & cached value for ipRuleAvailable.
|
||||
var ipRuleOnce struct {
|
||||
|
||||
@@ -29,8 +29,6 @@ const (
|
||||
bindingRequest = "\x00\x01"
|
||||
magicCookie = "\x21\x12\xa4\x42"
|
||||
lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32
|
||||
ipv4Len = 4
|
||||
ipv6Len = 16
|
||||
headerLen = 20
|
||||
)
|
||||
|
||||
@@ -135,7 +133,6 @@ var (
|
||||
func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error {
|
||||
for len(b) > 0 {
|
||||
if len(b) < 4 {
|
||||
return errors.New("effed-f1")
|
||||
return ErrMalformedAttrs
|
||||
}
|
||||
attrType := binary.BigEndian.Uint16(b[:2])
|
||||
@@ -143,7 +140,6 @@ func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error {
|
||||
attrLenPad := attrLen % 4
|
||||
b = b[4:]
|
||||
if attrLen+attrLenPad > len(b) {
|
||||
return errors.New("effed-f2")
|
||||
return ErrMalformedAttrs
|
||||
}
|
||||
if err := fn(attrType, b[:attrLen]); err != nil {
|
||||
@@ -161,9 +157,9 @@ func Response(txID TxID, ip net.IP, port uint16) []byte {
|
||||
}
|
||||
var fam byte
|
||||
switch len(ip) {
|
||||
case 4:
|
||||
case net.IPv4len:
|
||||
fam = 1
|
||||
case 16:
|
||||
case net.IPv6len:
|
||||
fam = 2
|
||||
default:
|
||||
return nil
|
||||
@@ -194,8 +190,6 @@ func Response(txID TxID, ip net.IP, port uint16) []byte {
|
||||
return b
|
||||
}
|
||||
|
||||
func beu16(b []byte) uint16 { return binary.BigEndian.Uint16(b) }
|
||||
|
||||
// ParseResponse parses a successful binding response STUN packet.
|
||||
// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute.
|
||||
// The returned addr slice is owned by the caller and does not alias b.
|
||||
@@ -207,7 +201,7 @@ func ParseResponse(b []byte) (tID TxID, addr []byte, port uint16, err error) {
|
||||
if b[0] != 0x01 || b[1] != 0x01 {
|
||||
return tID, nil, 0, ErrNotSuccessResponse
|
||||
}
|
||||
attrsLen := int(beu16(b[2:4]))
|
||||
attrsLen := int(binary.BigEndian.Uint16(b[2:4]))
|
||||
b = b[headerLen:] // remove STUN header
|
||||
if attrsLen > len(b) {
|
||||
return tID, nil, 0, ErrMalformedAttrs
|
||||
@@ -272,7 +266,7 @@ func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error)
|
||||
if len(b) < 4 {
|
||||
return nil, 0, ErrMalformedAttrs
|
||||
}
|
||||
xorPort := beu16(b[2:4])
|
||||
xorPort := binary.BigEndian.Uint16(b[2:4])
|
||||
addrField := b[4:]
|
||||
port = xorPort ^ 0x2112 // first half of magicCookie
|
||||
|
||||
@@ -298,9 +292,9 @@ func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error)
|
||||
func familyAddrLen(fam byte) int {
|
||||
switch fam {
|
||||
case 0x01: // IPv4
|
||||
return ipv4Len
|
||||
return net.IPv4len
|
||||
case 0x02: // IPv6
|
||||
return ipv6Len
|
||||
return net.IPv6len
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package stuntest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/stun"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/nettype"
|
||||
)
|
||||
|
||||
type stunStats struct {
|
||||
@@ -25,18 +27,22 @@ type stunStats struct {
|
||||
}
|
||||
|
||||
func Serve(t *testing.T) (addr *net.UDPAddr, cleanupFn func()) {
|
||||
return ServeWithPacketListener(t, nettype.Std{})
|
||||
}
|
||||
|
||||
func ServeWithPacketListener(t *testing.T, ln nettype.PacketListener) (addr *net.UDPAddr, cleanupFn func()) {
|
||||
t.Helper()
|
||||
|
||||
// TODO(crawshaw): use stats to test re-STUN logic
|
||||
var stats stunStats
|
||||
|
||||
pc, err := net.ListenPacket("udp4", ":0")
|
||||
pc, err := ln.ListenPacket(context.Background(), "udp4", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open STUN listener: %v", err)
|
||||
}
|
||||
addr = &net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: pc.LocalAddr().(*net.UDPAddr).Port,
|
||||
addr = pc.LocalAddr().(*net.UDPAddr)
|
||||
if len(addr.IP) == 0 || addr.IP.IsUnspecified() {
|
||||
addr.IP = net.ParseIP("127.0.0.1")
|
||||
}
|
||||
doneCh := make(chan struct{})
|
||||
go runSTUN(t, pc, &stats, doneCh)
|
||||
|
||||
52
net/tsaddr/tsaddr.go
Normal file
52
net/tsaddr/tsaddr.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tsaddr handles Tailscale-specific IPs and ranges.
|
||||
package tsaddr
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// ChromeOSVMRange returns the subset of the CGNAT IPv4 range used by
|
||||
// ChromeOS to interconnect the host OS to containers and VMs. We
|
||||
// avoid allocating Tailscale IPs from it, to avoid conflicts.
|
||||
func ChromeOSVMRange() netaddr.IPPrefix {
|
||||
chromeOSRange.Do(func() { mustPrefix(&chromeOSRange.v, "100.115.92.0/23") })
|
||||
return chromeOSRange.v
|
||||
}
|
||||
|
||||
var chromeOSRange oncePrefix
|
||||
|
||||
// CGNATRange returns the Carrier Grade NAT address range that
|
||||
// is the superset range that Tailscale assigns out of.
|
||||
// See https://tailscale.com/kb/1015/100.x-addresses.
|
||||
// Note that Tailscale does not assign out of the ChromeOSVMRange.
|
||||
func CGNATRange() netaddr.IPPrefix {
|
||||
cgnatRange.Do(func() { mustPrefix(&cgnatRange.v, "100.64.0.0/10") })
|
||||
return cgnatRange.v
|
||||
}
|
||||
|
||||
var cgnatRange oncePrefix
|
||||
|
||||
// IsTailscaleIP reports whether ip is an IP address in a range that
|
||||
// Tailscale assigns from.
|
||||
func IsTailscaleIP(ip netaddr.IP) bool {
|
||||
return CGNATRange().Contains(ip) && !ChromeOSVMRange().Contains(ip)
|
||||
}
|
||||
|
||||
func mustPrefix(v *netaddr.IPPrefix, prefix string) {
|
||||
var err error
|
||||
*v, err = netaddr.ParseIPPrefix(prefix)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type oncePrefix struct {
|
||||
sync.Once
|
||||
v netaddr.IPPrefix
|
||||
}
|
||||
19
net/tsaddr/tsaddr_test.go
Normal file
19
net/tsaddr/tsaddr_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tsaddr
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestChromeOSVMRange(t *testing.T) {
|
||||
if got, want := ChromeOSVMRange().String(), "100.115.92.0/23"; got != want {
|
||||
t.Errorf("got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCGNATRange(t *testing.T) {
|
||||
if got, want := CGNATRange().String(), "100.64.0.0/10"; got != want {
|
||||
t.Errorf("got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -11,9 +11,15 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// LegacyConfigPath is the path used by the pre-tailscaled "relaynode"
|
||||
// daemon's config file.
|
||||
const LegacyConfigPath = "/var/lib/tailscale/relay.conf"
|
||||
// LegacyConfigPath returns the path used by the pre-tailscaled
|
||||
// "relaynode" daemon's config file. It returns the empty string for
|
||||
// platforms where relaynode never ran.
|
||||
func LegacyConfigPath() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return ""
|
||||
}
|
||||
return "/var/lib/tailscale/relay.conf"
|
||||
}
|
||||
|
||||
// DefaultTailscaledSocket returns the path to the tailscaled Unix socket
|
||||
// or the empty string if there's no reasonable default.
|
||||
|
||||
@@ -13,12 +13,22 @@ import (
|
||||
exec "tailscale.com/tempfork/osexec"
|
||||
)
|
||||
|
||||
var osHideWindow func(*exec.Cmd) // non-nil on Windows; see portlist_windows.go
|
||||
|
||||
// hideWindow returns c. On Windows it first sets SysProcAttr.HideWindow.
|
||||
func hideWindow(c *exec.Cmd) *exec.Cmd {
|
||||
if osHideWindow != nil {
|
||||
osHideWindow(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func listPortsNetstat(arg string) (List, error) {
|
||||
exe, err := exec.LookPath("netstat")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("netstat: lookup: %v", err)
|
||||
}
|
||||
output, err := exec.Command(exe, arg).Output()
|
||||
output, err := hideWindow(exec.Command(exe, arg)).Output()
|
||||
if err != nil {
|
||||
xe, ok := err.(*exec.ExitError)
|
||||
stderr := ""
|
||||
|
||||
@@ -5,7 +5,10 @@
|
||||
package portlist
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
exec "tailscale.com/tempfork/osexec"
|
||||
)
|
||||
|
||||
// Forking on Windows is insanely expensive, so don't do it too often.
|
||||
@@ -18,3 +21,9 @@ func listPorts() (List, error) {
|
||||
func addProcesses(pl []Port) ([]Port, error) {
|
||||
return listPortsNetstat("-nab")
|
||||
}
|
||||
|
||||
func init() {
|
||||
osHideWindow = func(c *exec.Cmd) {
|
||||
c.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
@@ -77,6 +78,16 @@ func listen(path string, port uint16) (ln net.Listener, _ uint16, err error) {
|
||||
// * it also picks a random hex string that acts as an auth token
|
||||
// * it then creates a file named "sameuserproof-$PORT-$TOKEN" and leaves
|
||||
// that file descriptor open forever.
|
||||
//
|
||||
// Then, we do different things depending on whether the user is
|
||||
// running cmd/tailscale that they built themselves (running as
|
||||
// themselves, outside the App Sandbox), or whether the user is
|
||||
// running the CLI via the GUI binary
|
||||
// (e.g. /Applications/Tailscale.app/Contents/MacOS/Tailscale <args>),
|
||||
// in which case we're running within the App Sandbox.
|
||||
//
|
||||
// If we're outside the App Sandbox:
|
||||
//
|
||||
// * then we come along here, running as the same UID, but outside
|
||||
// of the sandbox, and look for it. We can run lsof on our own processes,
|
||||
// but other users on the system can't.
|
||||
@@ -86,7 +97,38 @@ func listen(path string, port uint16) (ln net.Listener, _ uint16, err error) {
|
||||
// * server verifies $TOKEN, sends "#IPN\n" if okay.
|
||||
// * server is now protocol switched
|
||||
// * we return the net.Conn and the caller speaks the normal protocol
|
||||
//
|
||||
// If we're inside the App Sandbox, then TS_MACOS_CLI_SHARED_DIR has
|
||||
// been set to our shared directory. We now have to find the most
|
||||
// recent "sameuserproof" file (there should only be 1, but previous
|
||||
// versions of the macOS app didn't clean them up).
|
||||
func connectMacOSAppSandbox() (net.Conn, error) {
|
||||
// Are we running the Tailscale.app GUI binary as a CLI, running within the App Sandbox?
|
||||
if d := os.Getenv("TS_MACOS_CLI_SHARED_DIR"); d != "" {
|
||||
fis, err := ioutil.ReadDir(d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading TS_MACOS_CLI_SHARED_DIR: %w", err)
|
||||
}
|
||||
var best os.FileInfo
|
||||
for _, fi := range fis {
|
||||
if !strings.HasPrefix(fi.Name(), "sameuserproof-") || strings.Count(fi.Name(), "-") != 2 {
|
||||
continue
|
||||
}
|
||||
if best == nil || fi.ModTime().After(best.ModTime()) {
|
||||
best = fi
|
||||
}
|
||||
}
|
||||
if best == nil {
|
||||
return nil, fmt.Errorf("no sameuserproof token found in TS_MACOS_CLI_SHARED_DIR %q", d)
|
||||
}
|
||||
f := strings.SplitN(best.Name(), "-", 3)
|
||||
portStr, token := f[1], f[2]
|
||||
return connectMacTCP(portStr, token)
|
||||
}
|
||||
|
||||
// Otherwise, assume we're running the cmd/tailscale binary from outside the
|
||||
// App Sandbox.
|
||||
|
||||
out, err := exec.Command("lsof",
|
||||
"-n", // numeric sockets; don't do DNS lookups, etc
|
||||
"-a", // logical AND remaining options
|
||||
@@ -110,22 +152,26 @@ func connectMacOSAppSandbox() (net.Conn, error) {
|
||||
continue
|
||||
}
|
||||
portStr, token := f[0], f[1]
|
||||
c, err := net.Dial("tcp", "localhost:"+portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error dialing IPNExtension: %w", err)
|
||||
}
|
||||
if _, err := io.WriteString(c, token+"\n"); err != nil {
|
||||
return nil, fmt.Errorf("error writing auth token: %w", err)
|
||||
}
|
||||
buf := make([]byte, 5)
|
||||
const authOK = "#IPN\n"
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
return nil, fmt.Errorf("error reading from IPNExtension post-auth: %w", err)
|
||||
}
|
||||
if string(buf) != authOK {
|
||||
return nil, fmt.Errorf("invalid response reading from IPNExtension post-auth")
|
||||
}
|
||||
return c, nil
|
||||
return connectMacTCP(portStr, token)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to find Tailscale's IPNExtension process")
|
||||
}
|
||||
|
||||
func connectMacTCP(portStr, token string) (net.Conn, error) {
|
||||
c, err := net.Dial("tcp", "localhost:"+portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error dialing IPNExtension: %w", err)
|
||||
}
|
||||
if _, err := io.WriteString(c, token+"\n"); err != nil {
|
||||
return nil, fmt.Errorf("error writing auth token: %w", err)
|
||||
}
|
||||
buf := make([]byte, 5)
|
||||
const authOK = "#IPN\n"
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
return nil, fmt.Errorf("error reading from IPNExtension post-auth: %w", err)
|
||||
}
|
||||
if string(buf) != authOK {
|
||||
return nil, fmt.Errorf("invalid response reading from IPNExtension post-auth")
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
14
smallzstd/testdata
Normal file
14
smallzstd/testdata
Normal file
@@ -0,0 +1,14 @@
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:40.196597018-07:00","server_time":"2020-07-01T21:49:40.198371511Z"},"text":"9.8M/25.6M magicsock: starting endpoint update (periodic)\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:40.345925455-07:00","server_time":"2020-07-01T21:49:40.347904717Z"},"text":"9.9M/25.6M netcheck: udp=true v6=false mapvarydest=false hair=false v4a=202.188.7.1:41641 derp=2 derpdist=1v4:7ms,2v4:3ms,4v4:18ms\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:43.347155742-07:00","server_time":"2020-07-01T21:49:43.34828658Z"},"text":"9.9M/25.6M control: map response long-poll timed out!\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:43.347539333-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.9M/25.6M control: PollNetMap: context canceled\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:43.347767812-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M control: sendStatus: mapRoutine1: state:authenticated\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:43.347817165-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M blockEngineUpdates(false)\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:43.347989028-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M wgcfg: [SViTM] skipping subnet route\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:43.349997554-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M Received error: PollNetMap: context canceled\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:43.350072606-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M control: mapRoutine: backoff: 30136 msec\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:47.998364646-07:00","server_time":"2020-07-01T21:49:47.999333754Z"},"text":"9.5M/25.6M [W1NbE] - [UcppE] Send handshake init [127.3.3.40:1, 6.1.1.6:37388*, 10.3.2.6:41641]\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:47.99881914-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: adding connection to derp-1 for [W1NbE]\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:47.998904932-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: 2 active derp conns: derp-1=cr0s,wr0s derp-2=cr16h0m0s,wr14h38m0s\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:47.999045606-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M derphttp.Client.Recv: connecting to derp-1 (nyc)\n"}
|
||||
{"logtail":{"client_time":"2020-07-01T14:49:48.091104119-07:00","server_time":"2020-07-01T21:49:48.09280535Z"},"text":"9.6M/25.6M magicsock: rx [W1NbE] from 6.1.1.6:37388 (1/3), set as new priority\n"}
|
||||
79
smallzstd/zstd.go
Normal file
79
smallzstd/zstd.go
Normal file
@@ -0,0 +1,79 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package smallzstd produces zstd encoders and decoders optimized for
|
||||
// low memory usage, at the expense of compression efficiency.
|
||||
//
|
||||
// This package is optimized primarily for the memory cost of
|
||||
// compressing and decompressing data. We reduce this cost in two
|
||||
// major ways: disable parallelism within the library (i.e. don't use
|
||||
// multiple CPU cores to decompress), and drop the compression window
|
||||
// down from the defaults of 4-16MiB, to 8kiB.
|
||||
//
|
||||
// Decompressors cost 2x the window size in RAM to run, so by using an
|
||||
// 8kiB window, we can run ~1000 more decompressors per unit of memory
|
||||
// than with the defaults.
|
||||
//
|
||||
// Depending on context, the benefit is either being able to run more
|
||||
// decoders (e.g. in our logs processing system), or having a lower
|
||||
// memory footprint when using compression in network protocols
|
||||
// (e.g. in tailscaled, which should have a minimal RAM cost).
|
||||
package smallzstd
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
// WindowSize is the window size used for zstd compression. Decoder
|
||||
// memory usage scales linearly with WindowSize.
|
||||
const WindowSize = 8 << 10 // 8kiB
|
||||
|
||||
// NewDecoder returns a zstd.Decoder configured for low memory usage,
|
||||
// at the expense of decompression performance.
|
||||
func NewDecoder(r io.Reader, options ...zstd.DOption) (*zstd.Decoder, error) {
|
||||
defaults := []zstd.DOption{
|
||||
// Default is GOMAXPROCS, which costs many KiB in stacks.
|
||||
zstd.WithDecoderConcurrency(1),
|
||||
// Default is to allocate more upfront for performance. We
|
||||
// prefer lower memory use and a bit of GC load.
|
||||
zstd.WithDecoderLowmem(true),
|
||||
// You might expect to see zstd.WithDecoderMaxMemory
|
||||
// here. However, it's not terribly safe to use if you're
|
||||
// doing stateless decoding, because it sets the maximum
|
||||
// amount of memory the decompressed data can occupy, rather
|
||||
// than the window size of the zstd stream. This means a very
|
||||
// compressible piece of data might violate the max memory
|
||||
// limit here, even if the window size (and thus total memory
|
||||
// required to decompress the data) is small.
|
||||
//
|
||||
// As a result, we don't set a decoder limit here, and rely on
|
||||
// the encoder below producing "cheap" streams. Callers are
|
||||
// welcome to set their own max memory setting, if
|
||||
// contextually there is a clearly correct value (e.g. it's
|
||||
// known from the upper layer protocol that the decoded data
|
||||
// can never be more than 1MiB).
|
||||
}
|
||||
|
||||
return zstd.NewReader(r, append(defaults, options...)...)
|
||||
}
|
||||
|
||||
// NewEncoder returns a zstd.Encoder configured for low memory usage,
|
||||
// both during compression and at decompression time, at the expense
|
||||
// of performance and compression efficiency.
|
||||
func NewEncoder(w io.Writer, options ...zstd.EOption) (*zstd.Encoder, error) {
|
||||
defaults := []zstd.EOption{
|
||||
// Default is GOMAXPROCS, which costs many KiB in stacks.
|
||||
zstd.WithEncoderConcurrency(1),
|
||||
// Default is several MiB, which bloats both encoders and
|
||||
// their corresponding decoders.
|
||||
zstd.WithWindowSize(WindowSize),
|
||||
// Encode zero-length inputs in a way that the `zstd` utility
|
||||
// can read, because interoperability is handy.
|
||||
zstd.WithZeroFrames(true),
|
||||
}
|
||||
|
||||
return zstd.NewWriter(w, append(defaults, options...)...)
|
||||
}
|
||||
131
smallzstd/zstd_test.go
Normal file
131
smallzstd/zstd_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package smallzstd
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
func BenchmarkSmallEncoder(b *testing.B) {
|
||||
benchEncoder(b, func() (*zstd.Encoder, error) { return NewEncoder(nil) })
|
||||
}
|
||||
|
||||
func BenchmarkSmallEncoderWithBuild(b *testing.B) {
|
||||
benchEncoderWithConstruction(b, func() (*zstd.Encoder, error) { return NewEncoder(nil) })
|
||||
}
|
||||
|
||||
func BenchmarkStockEncoder(b *testing.B) {
|
||||
benchEncoder(b, func() (*zstd.Encoder, error) { return zstd.NewWriter(nil) })
|
||||
}
|
||||
|
||||
func BenchmarkStockEncoderWithBuild(b *testing.B) {
|
||||
benchEncoderWithConstruction(b, func() (*zstd.Encoder, error) { return zstd.NewWriter(nil) })
|
||||
}
|
||||
|
||||
func BenchmarkSmallDecoder(b *testing.B) {
|
||||
benchDecoder(b, func() (*zstd.Decoder, error) { return NewDecoder(nil) })
|
||||
}
|
||||
|
||||
func BenchmarkSmallDecoderWithBuild(b *testing.B) {
|
||||
benchDecoderWithConstruction(b, func() (*zstd.Decoder, error) { return NewDecoder(nil) })
|
||||
}
|
||||
|
||||
func BenchmarkStockDecoder(b *testing.B) {
|
||||
benchDecoder(b, func() (*zstd.Decoder, error) { return zstd.NewReader(nil) })
|
||||
}
|
||||
|
||||
func BenchmarkStockDecoderWithBuild(b *testing.B) {
|
||||
benchDecoderWithConstruction(b, func() (*zstd.Decoder, error) { return zstd.NewReader(nil) })
|
||||
}
|
||||
|
||||
func benchEncoder(b *testing.B, mk func() (*zstd.Encoder, error)) {
|
||||
b.ReportAllocs()
|
||||
|
||||
in := testdata(b)
|
||||
out := make([]byte, 0, 10<<10) // 10kiB
|
||||
|
||||
e, err := mk()
|
||||
if err != nil {
|
||||
b.Fatalf("making encoder: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e.EncodeAll(in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func benchEncoderWithConstruction(b *testing.B, mk func() (*zstd.Encoder, error)) {
|
||||
b.ReportAllocs()
|
||||
|
||||
in := testdata(b)
|
||||
out := make([]byte, 0, 10<<10) // 10kiB
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
e, err := mk()
|
||||
if err != nil {
|
||||
b.Fatalf("making encoder: %v", err)
|
||||
}
|
||||
|
||||
e.EncodeAll(in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func benchDecoder(b *testing.B, mk func() (*zstd.Decoder, error)) {
|
||||
b.ReportAllocs()
|
||||
|
||||
in := compressedTestdata(b)
|
||||
out := make([]byte, 0, 10<<10)
|
||||
|
||||
d, err := mk()
|
||||
if err != nil {
|
||||
b.Fatalf("creating decoder: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
d.DecodeAll(in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func benchDecoderWithConstruction(b *testing.B, mk func() (*zstd.Decoder, error)) {
|
||||
b.ReportAllocs()
|
||||
|
||||
in := compressedTestdata(b)
|
||||
out := make([]byte, 0, 10<<10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
d, err := mk()
|
||||
if err != nil {
|
||||
b.Fatalf("creating decoder: %v", err)
|
||||
}
|
||||
|
||||
d.DecodeAll(in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func testdata(b *testing.B) []byte {
|
||||
b.Helper()
|
||||
in, err := ioutil.ReadFile("testdata")
|
||||
if err != nil {
|
||||
b.Fatalf("reading testdata: %v", err)
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
func compressedTestdata(b *testing.B) []byte {
|
||||
b.Helper()
|
||||
uncomp := testdata(b)
|
||||
e, err := NewEncoder(nil)
|
||||
if err != nil {
|
||||
b.Fatalf("creating encoder: %v", err)
|
||||
}
|
||||
return e.EncodeAll(uncomp, nil)
|
||||
}
|
||||
@@ -117,4 +117,8 @@ type DERPNode struct {
|
||||
// of using the default port of 443. If non-zero, TLS
|
||||
// verification is skipped.
|
||||
DERPTestPort int `json:",omitempty"`
|
||||
|
||||
// STUNTestIP is used in tests to override the STUN server's IP.
|
||||
// If empty, it's assumed to be the same as the DERP server.
|
||||
STUNTestIP string `json:",omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
package tailcfg
|
||||
|
||||
//go:generate go run tailscale.com/cmd/cloner -type=User,Node,Hostinfo,NetInfo -output=tailcfg_clone.go
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
@@ -13,7 +15,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"go4.org/mem"
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/opt"
|
||||
"tailscale.com/types/structs"
|
||||
)
|
||||
@@ -38,6 +42,10 @@ type MachineKey [32]byte
|
||||
// NodeKey is the curve25519 public key for a node.
|
||||
type NodeKey [32]byte
|
||||
|
||||
// DiscoKey is the curve25519 public key for path discovery key.
|
||||
// It's never written to disk or reused between network start-ups.
|
||||
type DiscoKey [32]byte
|
||||
|
||||
type Group struct {
|
||||
ID GroupID
|
||||
Name string
|
||||
@@ -87,18 +95,6 @@ type User struct {
|
||||
// Note: be sure to update Clone when adding new fields
|
||||
}
|
||||
|
||||
// Clone returns a copy of u that aliases no memory with the original.
|
||||
func (u *User) Clone() *User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u2 := new(User)
|
||||
*u2 = *u
|
||||
u2.Logins = append([]LoginID(nil), u.Logins...)
|
||||
u2.Roles = append([]RoleID(nil), u.Roles...)
|
||||
return u2
|
||||
}
|
||||
|
||||
type Login struct {
|
||||
_ structs.Incomparable
|
||||
ID LoginID
|
||||
@@ -114,8 +110,8 @@ type Login struct {
|
||||
// It also includes derived data from one of the user's logins.
|
||||
type UserProfile struct {
|
||||
ID UserID
|
||||
LoginName string // for display purposes only (provider is not listed)
|
||||
DisplayName string
|
||||
LoginName string // "alice@smith.com"; for display purposes only (provider is not listed)
|
||||
DisplayName string // "Alice Smith"
|
||||
ProfilePicURL string
|
||||
Roles []RoleID
|
||||
}
|
||||
@@ -127,6 +123,7 @@ type Node struct {
|
||||
Key NodeKey
|
||||
KeyExpiry time.Time
|
||||
Machine MachineKey
|
||||
DiscoKey DiscoKey
|
||||
Addresses []wgcfg.CIDR // IP addresses of this Node directly
|
||||
AllowedIPs []wgcfg.CIDR // range of IP addresses to route to this node
|
||||
Endpoints []string `json:",omitempty"` // IP+port (public via STUN, and local LANs)
|
||||
@@ -143,23 +140,6 @@ type Node struct {
|
||||
// require changes to Node.Clone.
|
||||
}
|
||||
|
||||
// Clone makes a deep copy of Node.
|
||||
// The result aliases no memory with the original.
|
||||
func (n *Node) Clone() (res *Node) {
|
||||
res = new(Node)
|
||||
*res = *n
|
||||
|
||||
res.Addresses = append([]wgcfg.CIDR{}, res.Addresses...)
|
||||
res.AllowedIPs = append([]wgcfg.CIDR{}, res.AllowedIPs...)
|
||||
res.Endpoints = append([]string{}, res.Endpoints...)
|
||||
if res.LastSeen != nil {
|
||||
lastSeen := *res.LastSeen
|
||||
res.LastSeen = &lastSeen
|
||||
}
|
||||
res.Hostinfo = *res.Hostinfo.Clone()
|
||||
return res
|
||||
}
|
||||
|
||||
type MachineStatus int
|
||||
|
||||
const (
|
||||
@@ -277,6 +257,8 @@ type Hostinfo struct {
|
||||
FrontendLogID string // logtail ID of frontend instance
|
||||
BackendLogID string // logtail ID of backend instance
|
||||
OS string // operating system the client runs on (a version.OS value)
|
||||
OSVersion string // operating system version, with optional distro prefix ("Debian 10.4", "Windows 10 Pro 10.0.19041")
|
||||
DeviceModel string // mobile phone model ("Pixel 3a", "iPhone 11 Pro")
|
||||
Hostname string // name of the host the client runs on
|
||||
RoutableIPs []wgcfg.CIDR `json:",omitempty"` // set of IP ranges this client can route
|
||||
RequestTags []string `json:",omitempty"` // set of ACL tags this node wants to claim
|
||||
@@ -303,6 +285,18 @@ type NetInfo struct {
|
||||
// WorkingUDP is whether UDP works.
|
||||
WorkingUDP opt.Bool
|
||||
|
||||
// UPnP is whether UPnP appears present on the LAN.
|
||||
// Empty means not checked.
|
||||
UPnP opt.Bool
|
||||
|
||||
// PMP is whether NAT-PMP appears present on the LAN.
|
||||
// Empty means not checked.
|
||||
PMP opt.Bool
|
||||
|
||||
// PCP is whether PCP appears present on the LAN.
|
||||
// Empty means not checked.
|
||||
PCP opt.Bool
|
||||
|
||||
// PreferredDERP is this node's preferred DERP server
|
||||
// for incoming traffic. The node might be be temporarily
|
||||
// connected to multiple DERP servers (to send to other nodes)
|
||||
@@ -331,9 +325,32 @@ func (ni *NetInfo) String() string {
|
||||
if ni == nil {
|
||||
return "NetInfo(nil)"
|
||||
}
|
||||
return fmt.Sprintf("NetInfo{varies=%v hairpin=%v ipv6=%v udp=%v derp=#%v link=%q}",
|
||||
return fmt.Sprintf("NetInfo{varies=%v hairpin=%v ipv6=%v udp=%v derp=#%v portmap=%v link=%q}",
|
||||
ni.MappingVariesByDestIP, ni.HairPinning, ni.WorkingIPv6,
|
||||
ni.WorkingUDP, ni.PreferredDERP, ni.LinkType)
|
||||
ni.WorkingUDP, ni.PreferredDERP,
|
||||
ni.portMapSummary(),
|
||||
ni.LinkType)
|
||||
}
|
||||
|
||||
func (ni *NetInfo) portMapSummary() string {
|
||||
if ni.UPnP == "" && ni.PMP == "" && ni.PCP == "" {
|
||||
return "?"
|
||||
}
|
||||
return conciseOptBool(ni.UPnP, "U") + conciseOptBool(ni.PMP, "M") + conciseOptBool(ni.PCP, "C")
|
||||
}
|
||||
|
||||
func conciseOptBool(b opt.Bool, trueVal string) string {
|
||||
if b == "" {
|
||||
return "_"
|
||||
}
|
||||
v, ok := b.Get()
|
||||
if !ok {
|
||||
return "x"
|
||||
}
|
||||
if v {
|
||||
return trueVal
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BasicallyEqual reports whether ni and ni2 are basically equal, ignoring
|
||||
@@ -349,39 +366,21 @@ func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool {
|
||||
ni.HairPinning == ni2.HairPinning &&
|
||||
ni.WorkingIPv6 == ni2.WorkingIPv6 &&
|
||||
ni.WorkingUDP == ni2.WorkingUDP &&
|
||||
ni.UPnP == ni2.UPnP &&
|
||||
ni.PMP == ni2.PMP &&
|
||||
ni.PCP == ni2.PCP &&
|
||||
ni.PreferredDERP == ni2.PreferredDERP &&
|
||||
ni.LinkType == ni2.LinkType
|
||||
}
|
||||
|
||||
func (ni *NetInfo) Clone() (res *NetInfo) {
|
||||
if ni == nil {
|
||||
return nil
|
||||
}
|
||||
res = new(NetInfo)
|
||||
*res = *ni
|
||||
if ni.DERPLatency != nil {
|
||||
res.DERPLatency = map[string]float64{}
|
||||
for k, v := range ni.DERPLatency {
|
||||
res.DERPLatency[k] = v
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Clone makes a deep copy of Hostinfo.
|
||||
// The result aliases no memory with the original.
|
||||
func (h *Hostinfo) Clone() (res *Hostinfo) {
|
||||
res = new(Hostinfo)
|
||||
*res = *h
|
||||
|
||||
res.RoutableIPs = append([]wgcfg.CIDR{}, h.RoutableIPs...)
|
||||
res.Services = append([]Service{}, h.Services...)
|
||||
res.NetInfo = h.NetInfo.Clone()
|
||||
return res
|
||||
}
|
||||
|
||||
// Equal reports whether h and h2 are equal.
|
||||
func (h *Hostinfo) Equal(h2 *Hostinfo) bool {
|
||||
if h == nil && h2 == nil {
|
||||
return true
|
||||
}
|
||||
if (h == nil) != (h2 == nil) {
|
||||
return false
|
||||
}
|
||||
return reflect.DeepEqual(h, h2)
|
||||
}
|
||||
|
||||
@@ -408,6 +407,8 @@ type RegisterRequest struct {
|
||||
|
||||
// Clone makes a deep copy of RegisterRequest.
|
||||
// The result aliases no memory with the original.
|
||||
//
|
||||
// TODO: extend cmd/cloner to generate this method.
|
||||
func (req *RegisterRequest) Clone() *RegisterRequest {
|
||||
res := new(RegisterRequest)
|
||||
*res = *req
|
||||
@@ -442,10 +443,17 @@ type MapRequest struct {
|
||||
Compress string // "zstd" or "" (no compression)
|
||||
KeepAlive bool // server sends keep-alives
|
||||
NodeKey NodeKey
|
||||
DiscoKey DiscoKey
|
||||
Endpoints []string // caller's endpoints (IPv4 or IPv6)
|
||||
IncludeIPv6 bool // include IPv6 endpoints in returned Node Endpoints
|
||||
Stream bool // if true, multiple MapResponse objects are returned
|
||||
Hostinfo *Hostinfo
|
||||
|
||||
// DebugForceDisco is a temporary flag during the deployment
|
||||
// of magicsock active discovery. It says that that the client
|
||||
// has environment variables explicitly turning discovery on,
|
||||
// so control should not disable it.
|
||||
DebugForceDisco bool `json:"debugForceDisco,omitempty"`
|
||||
}
|
||||
|
||||
// PortRange represents a range of UDP or TCP port numbers.
|
||||
@@ -509,65 +517,59 @@ type MapResponse struct {
|
||||
// Debug are instructions from the control server to the client
|
||||
// to adjust debug settings.
|
||||
type Debug struct {
|
||||
// LogHeapPprof controls whether the client should logs
|
||||
// LogHeapPprof controls whether the client should log
|
||||
// its heap pprof data. Each true value sent from the server
|
||||
// means that client should do one more log.
|
||||
LogHeapPprof bool `json:",omitempty"`
|
||||
|
||||
// LogHeapURL is the URL to POST its heap pprof to.
|
||||
// Empty means to not log.
|
||||
LogHeapURL string `json:",omitempty"`
|
||||
|
||||
// ForceBackgroundSTUN controls whether magicsock should
|
||||
// always do its background STUN queries (see magicsock's
|
||||
// periodicReSTUN), regardless of inactivity.
|
||||
ForceBackgroundSTUN bool `json:",omitempty"`
|
||||
}
|
||||
|
||||
func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
|
||||
func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
|
||||
func (k MachineKey) MarshalText() ([]byte, error) { return keyMarshalText("mkey:", k), nil }
|
||||
func (k *MachineKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "mkey:", text) }
|
||||
|
||||
func (k MachineKey) MarshalText() ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
fmt.Fprintf(buf, "mkey:%x", k[:])
|
||||
return buf.Bytes(), nil
|
||||
func keyMarshalText(prefix string, k [32]byte) []byte {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(prefix)+64))
|
||||
fmt.Fprintf(buf, "%s%x", prefix, k[:])
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (k *MachineKey) UnmarshalText(text []byte) error {
|
||||
s := string(text)
|
||||
if !strings.HasPrefix(s, "mkey:") {
|
||||
return errors.New(`MachineKey.UnmarshalText: missing prefix`)
|
||||
func keyUnmarshalText(dst []byte, prefix string, text []byte) error {
|
||||
if len(text) < len(prefix) || string(text[:len(prefix)]) != prefix {
|
||||
return fmt.Errorf("UnmarshalText: missing %q prefix", prefix)
|
||||
}
|
||||
s = strings.TrimPrefix(s, `mkey:`)
|
||||
key, err := wgcfg.ParseHexKey(s)
|
||||
pub, err := key.NewPublicFromHexMem(mem.B(text[len(prefix):]))
|
||||
if err != nil {
|
||||
return fmt.Errorf("MachineKey.UnmarhsalText: %v", err)
|
||||
return fmt.Errorf("UnmarshalText: after %q: %v", prefix, err)
|
||||
}
|
||||
copy(k[:], key[:])
|
||||
copy(dst[:], pub[:])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) }
|
||||
func (k NodeKey) ShortString() string { return (key.Public(k)).ShortString() }
|
||||
|
||||
func (k NodeKey) ShortString() string {
|
||||
pk := wgcfg.Key(k)
|
||||
return pk.ShortString()
|
||||
}
|
||||
func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) }
|
||||
func (k NodeKey) MarshalText() ([]byte, error) { return keyMarshalText("nodekey:", k), nil }
|
||||
func (k *NodeKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "nodekey:", text) }
|
||||
|
||||
func (k NodeKey) MarshalText() ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
fmt.Fprintf(buf, "nodekey:%x", k[:])
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
// IsZero reports whether k is the zero value.
|
||||
func (k NodeKey) IsZero() bool { return k == NodeKey{} }
|
||||
|
||||
func (k *NodeKey) UnmarshalText(text []byte) error {
|
||||
s := string(text)
|
||||
if !strings.HasPrefix(s, "nodekey:") {
|
||||
return errors.New(`Nodekey.UnmarshalText: missing prefix`)
|
||||
}
|
||||
s = strings.TrimPrefix(s, "nodekey:")
|
||||
key, err := wgcfg.ParseHexKey(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tailcfg.Ukey.UnmarhsalText: %v", err)
|
||||
}
|
||||
copy(k[:], key[:])
|
||||
return nil
|
||||
}
|
||||
func (k DiscoKey) String() string { return fmt.Sprintf("discokey:%x", k[:]) }
|
||||
func (k DiscoKey) MarshalText() ([]byte, error) { return keyMarshalText("discokey:", k), nil }
|
||||
func (k *DiscoKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "discokey:", text) }
|
||||
func (k DiscoKey) ShortString() string { return fmt.Sprintf("d:%x", k[:8]) }
|
||||
|
||||
// IsZero reports whether k is the NodeKey zero value.
|
||||
func (k NodeKey) IsZero() bool {
|
||||
return k == NodeKey{}
|
||||
}
|
||||
// IsZero reports whether k is the zero value.
|
||||
func (k DiscoKey) IsZero() bool { return k == DiscoKey{} }
|
||||
|
||||
func (id ID) String() string { return fmt.Sprintf("id:%x", int64(id)) }
|
||||
func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) }
|
||||
@@ -589,11 +591,40 @@ func (n *Node) Equal(n2 *Node) bool {
|
||||
n.Key == n2.Key &&
|
||||
n.KeyExpiry.Equal(n2.KeyExpiry) &&
|
||||
n.Machine == n2.Machine &&
|
||||
reflect.DeepEqual(n.Addresses, n2.Addresses) &&
|
||||
reflect.DeepEqual(n.AllowedIPs, n2.AllowedIPs) &&
|
||||
reflect.DeepEqual(n.Endpoints, n2.Endpoints) &&
|
||||
reflect.DeepEqual(n.Hostinfo, n2.Hostinfo) &&
|
||||
n.DiscoKey == n2.DiscoKey &&
|
||||
eqCIDRs(n.Addresses, n2.Addresses) &&
|
||||
eqCIDRs(n.AllowedIPs, n2.AllowedIPs) &&
|
||||
eqStrings(n.Endpoints, n2.Endpoints) &&
|
||||
n.Hostinfo.Equal(&n2.Hostinfo) &&
|
||||
n.Created.Equal(n2.Created) &&
|
||||
reflect.DeepEqual(n.LastSeen, n2.LastSeen) &&
|
||||
eqTimePtr(n.LastSeen, n2.LastSeen) &&
|
||||
n.MachineAuthorized == n2.MachineAuthorized
|
||||
}
|
||||
|
||||
func eqStrings(a, b []string) bool {
|
||||
if len(a) != len(b) || ((a == nil) != (b == nil)) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func eqCIDRs(a, b []wgcfg.CIDR) bool {
|
||||
if len(a) != len(b) || ((a == nil) != (b == nil)) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func eqTimePtr(a, b *time.Time) bool {
|
||||
return ((a == nil) == (b == nil)) && (a == nil || a.Equal(*b))
|
||||
}
|
||||
|
||||
75
tailcfg/tailcfg_clone.go
Normal file
75
tailcfg/tailcfg_clone.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Code generated by tailscale.com/cmd/cloner -type User,Node,Hostinfo,NetInfo; DO NOT EDIT.
|
||||
|
||||
package tailcfg
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Clone makes a deep copy of User.
|
||||
// The result aliases no memory with the original.
|
||||
func (src *User) Clone() *User {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := new(User)
|
||||
*dst = *src
|
||||
dst.Logins = append(src.Logins[:0:0], src.Logins...)
|
||||
dst.Roles = append(src.Roles[:0:0], src.Roles...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// Clone makes a deep copy of Node.
|
||||
// The result aliases no memory with the original.
|
||||
func (src *Node) Clone() *Node {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := new(Node)
|
||||
*dst = *src
|
||||
dst.Addresses = append(src.Addresses[:0:0], src.Addresses...)
|
||||
dst.AllowedIPs = append(src.AllowedIPs[:0:0], src.AllowedIPs...)
|
||||
dst.Endpoints = append(src.Endpoints[:0:0], src.Endpoints...)
|
||||
dst.Hostinfo = *src.Hostinfo.Clone()
|
||||
if dst.LastSeen != nil {
|
||||
dst.LastSeen = new(time.Time)
|
||||
*dst.LastSeen = *src.LastSeen
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// Clone makes a deep copy of Hostinfo.
|
||||
// The result aliases no memory with the original.
|
||||
func (src *Hostinfo) Clone() *Hostinfo {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := new(Hostinfo)
|
||||
*dst = *src
|
||||
dst.RoutableIPs = append(src.RoutableIPs[:0:0], src.RoutableIPs...)
|
||||
dst.RequestTags = append(src.RequestTags[:0:0], src.RequestTags...)
|
||||
dst.Services = append(src.Services[:0:0], src.Services...)
|
||||
dst.NetInfo = src.NetInfo.Clone()
|
||||
return dst
|
||||
}
|
||||
|
||||
// Clone makes a deep copy of NetInfo.
|
||||
// The result aliases no memory with the original.
|
||||
func (src *NetInfo) Clone() *NetInfo {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
dst := new(NetInfo)
|
||||
*dst = *src
|
||||
if dst.DERPLatency != nil {
|
||||
dst.DERPLatency = map[string]float64{}
|
||||
for k, v := range src.DERPLatency {
|
||||
dst.DERPLatency[k] = v
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
@@ -5,7 +5,9 @@
|
||||
package tailcfg
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -21,7 +23,8 @@ func fieldsOf(t reflect.Type) (fields []string) {
|
||||
|
||||
func TestHostinfoEqual(t *testing.T) {
|
||||
hiHandles := []string{
|
||||
"IPNVersion", "FrontendLogID", "BackendLogID", "OS", "Hostname", "RoutableIPs", "RequestTags", "Services",
|
||||
"IPNVersion", "FrontendLogID", "BackendLogID", "OS", "OSVersion",
|
||||
"DeviceModel", "Hostname", "RoutableIPs", "RequestTags", "Services",
|
||||
"NetInfo",
|
||||
}
|
||||
if have := fieldsOf(reflect.TypeOf(Hostinfo{})); !reflect.DeepEqual(have, hiHandles) {
|
||||
@@ -176,7 +179,7 @@ func TestHostinfoEqual(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNodeEqual(t *testing.T) {
|
||||
nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", "Created", "LastSeen", "KeepAlive", "MachineAuthorized"}
|
||||
nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "DiscoKey", "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", "Created", "LastSeen", "KeepAlive", "MachineAuthorized"}
|
||||
if have := fieldsOf(reflect.TypeOf(Node{})); !reflect.DeepEqual(have, nodeHandles) {
|
||||
t.Errorf("Node.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
||||
have, nodeHandles)
|
||||
@@ -327,6 +330,9 @@ func TestNetInfoFields(t *testing.T) {
|
||||
"HairPinning",
|
||||
"WorkingIPv6",
|
||||
"WorkingUDP",
|
||||
"UPnP",
|
||||
"PMP",
|
||||
"PCP",
|
||||
"PreferredDERP",
|
||||
"LinkType",
|
||||
"DERPLatency",
|
||||
@@ -336,3 +342,91 @@ func TestNetInfoFields(t *testing.T) {
|
||||
have, handled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMachineKeyMarshal(t *testing.T) {
|
||||
var k1, k2 MachineKey
|
||||
for i := range k1 {
|
||||
k1[i] = byte(i)
|
||||
}
|
||||
testKey(t, "mkey:", k1, &k2)
|
||||
}
|
||||
|
||||
func TestNodeKeyMarshal(t *testing.T) {
|
||||
var k1, k2 NodeKey
|
||||
for i := range k1 {
|
||||
k1[i] = byte(i)
|
||||
}
|
||||
testKey(t, "nodekey:", k1, &k2)
|
||||
}
|
||||
|
||||
func TestDiscoKeyMarshal(t *testing.T) {
|
||||
var k1, k2 DiscoKey
|
||||
for i := range k1 {
|
||||
k1[i] = byte(i)
|
||||
}
|
||||
testKey(t, "discokey:", k1, &k2)
|
||||
}
|
||||
|
||||
type keyIn interface {
|
||||
String() string
|
||||
MarshalText() ([]byte, error)
|
||||
}
|
||||
|
||||
func testKey(t *testing.T, prefix string, in keyIn, out encoding.TextUnmarshaler) {
|
||||
got, err := in.MarshalText()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := out.UnmarshalText(got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if s := in.String(); string(got) != s {
|
||||
t.Errorf("MarshalText = %q != String %q", got, s)
|
||||
}
|
||||
if !strings.HasPrefix(string(got), prefix) {
|
||||
t.Errorf("%q didn't start with prefix %q", got, prefix)
|
||||
}
|
||||
if reflect.ValueOf(out).Elem().Interface() != in {
|
||||
t.Errorf("mismatch after unmarshal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
u *User
|
||||
}{
|
||||
{"nil_logins", &User{}},
|
||||
{"zero_logins", &User{Logins: make([]LoginID, 0)}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u2 := tt.u.Clone()
|
||||
if !reflect.DeepEqual(tt.u, u2) {
|
||||
t.Errorf("not equal")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v *Node
|
||||
}{
|
||||
{"nil_fields", &Node{}},
|
||||
{"zero_fields", &Node{
|
||||
Addresses: make([]wgcfg.CIDR, 0),
|
||||
AllowedIPs: make([]wgcfg.CIDR, 0),
|
||||
Endpoints: make([]string, 0),
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v2 := tt.v.Clone()
|
||||
if !reflect.DeepEqual(tt.v, v2) {
|
||||
t.Errorf("not equal")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
4
tempfork/pprof/README.md
Normal file
4
tempfork/pprof/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
This is a fork of net/http/pprof that doesn't use init side effects
|
||||
and doesn't use html/template (which ends up calling
|
||||
reflect.Value.MethodByName, which disables some linker deadcode
|
||||
optimizations).
|
||||
301
tempfork/pprof/pprof.go
Normal file
301
tempfork/pprof/pprof.go
Normal file
@@ -0,0 +1,301 @@
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package pprof serves via its HTTP server runtime profiling data
|
||||
// in the format expected by the pprof visualization tool.
|
||||
//
|
||||
// See Go's net/http/pprof for docs.
|
||||
//
|
||||
// This is a fork of net/http/pprof that doesn't use init side effects
|
||||
// and doesn't use html/template (which ends up calling
|
||||
// reflect.Value.MethodByName, which disables some linker deadcode
|
||||
// optimizations).
|
||||
package pprof
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"runtime/trace"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func AddHandlers(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/debug/pprof/", Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", Trace)
|
||||
}
|
||||
|
||||
// Cmdline responds with the running program's
|
||||
// command line, with arguments separated by NUL bytes.
|
||||
// The package initialization registers it as /debug/pprof/cmdline.
|
||||
func Cmdline(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprintf(w, strings.Join(os.Args, "\x00"))
|
||||
}
|
||||
|
||||
func sleep(w http.ResponseWriter, d time.Duration) {
|
||||
var clientGone <-chan bool
|
||||
if cn, ok := w.(http.CloseNotifier); ok {
|
||||
clientGone = cn.CloseNotify()
|
||||
}
|
||||
select {
|
||||
case <-time.After(d):
|
||||
case <-clientGone:
|
||||
}
|
||||
}
|
||||
|
||||
func durationExceedsWriteTimeout(r *http.Request, seconds float64) bool {
|
||||
srv, ok := r.Context().Value(http.ServerContextKey).(*http.Server)
|
||||
return ok && srv.WriteTimeout != 0 && seconds >= srv.WriteTimeout.Seconds()
|
||||
}
|
||||
|
||||
func serveError(w http.ResponseWriter, status int, txt string) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("X-Go-Pprof", "1")
|
||||
w.Header().Del("Content-Disposition")
|
||||
w.WriteHeader(status)
|
||||
fmt.Fprintln(w, txt)
|
||||
}
|
||||
|
||||
// Profile responds with the pprof-formatted cpu profile.
|
||||
// Profiling lasts for duration specified in seconds GET parameter, or for 30 seconds if not specified.
|
||||
// The package initialization registers it as /debug/pprof/profile.
|
||||
func Profile(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
sec, err := strconv.ParseInt(r.FormValue("seconds"), 10, 64)
|
||||
if sec <= 0 || err != nil {
|
||||
sec = 30
|
||||
}
|
||||
|
||||
if durationExceedsWriteTimeout(r, float64(sec)) {
|
||||
serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout")
|
||||
return
|
||||
}
|
||||
|
||||
// Set Content Type assuming StartCPUProfile will work,
|
||||
// because if it does it starts writing.
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="profile"`)
|
||||
if err := pprof.StartCPUProfile(w); err != nil {
|
||||
// StartCPUProfile failed, so no writes yet.
|
||||
serveError(w, http.StatusInternalServerError,
|
||||
fmt.Sprintf("Could not enable CPU profiling: %s", err))
|
||||
return
|
||||
}
|
||||
sleep(w, time.Duration(sec)*time.Second)
|
||||
pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
// Trace responds with the execution trace in binary form.
|
||||
// Tracing lasts for duration specified in seconds GET parameter, or for 1 second if not specified.
|
||||
// The package initialization registers it as /debug/pprof/trace.
|
||||
func Trace(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
sec, err := strconv.ParseFloat(r.FormValue("seconds"), 64)
|
||||
if sec <= 0 || err != nil {
|
||||
sec = 1
|
||||
}
|
||||
|
||||
if durationExceedsWriteTimeout(r, sec) {
|
||||
serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout")
|
||||
return
|
||||
}
|
||||
|
||||
// Set Content Type assuming trace.Start will work,
|
||||
// because if it does it starts writing.
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Disposition", `attachment; filename="trace"`)
|
||||
if err := trace.Start(w); err != nil {
|
||||
// trace.Start failed, so no writes yet.
|
||||
serveError(w, http.StatusInternalServerError,
|
||||
fmt.Sprintf("Could not enable tracing: %s", err))
|
||||
return
|
||||
}
|
||||
sleep(w, time.Duration(sec*float64(time.Second)))
|
||||
trace.Stop()
|
||||
}
|
||||
|
||||
// Symbol looks up the program counters listed in the request,
|
||||
// responding with a table mapping program counters to function names.
|
||||
// The package initialization registers it as /debug/pprof/symbol.
|
||||
func Symbol(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
|
||||
// We have to read the whole POST body before
|
||||
// writing any output. Buffer the output here.
|
||||
var buf bytes.Buffer
|
||||
|
||||
// We don't know how many symbols we have, but we
|
||||
// do have symbol information. Pprof only cares whether
|
||||
// this number is 0 (no symbols available) or > 0.
|
||||
fmt.Fprintf(&buf, "num_symbols: 1\n")
|
||||
|
||||
var b *bufio.Reader
|
||||
if r.Method == "POST" {
|
||||
b = bufio.NewReader(r.Body)
|
||||
} else {
|
||||
b = bufio.NewReader(strings.NewReader(r.URL.RawQuery))
|
||||
}
|
||||
|
||||
for {
|
||||
word, err := b.ReadSlice('+')
|
||||
if err == nil {
|
||||
word = word[0 : len(word)-1] // trim +
|
||||
}
|
||||
pc, _ := strconv.ParseUint(string(word), 0, 64)
|
||||
if pc != 0 {
|
||||
f := runtime.FuncForPC(uintptr(pc))
|
||||
if f != nil {
|
||||
fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until here to check for err; the last
|
||||
// symbol will have an err because it doesn't end in +.
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
fmt.Fprintf(&buf, "reading request: %v\n", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
// Handler returns an HTTP handler that serves the named profile.
|
||||
func Handler(name string) http.Handler {
|
||||
return handler(name)
|
||||
}
|
||||
|
||||
type handler string
|
||||
|
||||
func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
p := pprof.Lookup(string(name))
|
||||
if p == nil {
|
||||
serveError(w, http.StatusNotFound, "Unknown profile")
|
||||
return
|
||||
}
|
||||
gc, _ := strconv.Atoi(r.FormValue("gc"))
|
||||
if name == "heap" && gc > 0 {
|
||||
runtime.GC()
|
||||
}
|
||||
debug, _ := strconv.Atoi(r.FormValue("debug"))
|
||||
if debug != 0 {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, name))
|
||||
}
|
||||
p.WriteTo(w, debug)
|
||||
}
|
||||
|
||||
var profileDescriptions = map[string]string{
|
||||
"allocs": "A sampling of all past memory allocations",
|
||||
"block": "Stack traces that led to blocking on synchronization primitives",
|
||||
"cmdline": "The command line invocation of the current program",
|
||||
"goroutine": "Stack traces of all current goroutines",
|
||||
"heap": "A sampling of memory allocations of live objects. You can specify the gc GET parameter to run GC before taking the heap sample.",
|
||||
"mutex": "Stack traces of holders of contended mutexes",
|
||||
"profile": "CPU profile. You can specify the duration in the seconds GET parameter. After you get the profile file, use the go tool pprof command to investigate the profile.",
|
||||
"threadcreate": "Stack traces that led to the creation of new OS threads",
|
||||
"trace": "A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.",
|
||||
}
|
||||
|
||||
// Index responds with the pprof-formatted profile named by the request.
|
||||
// For example, "/debug/pprof/heap" serves the "heap" profile.
|
||||
// Index responds to a request for "/debug/pprof/" with an HTML page
|
||||
// listing the available profiles.
|
||||
func Index(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/debug/pprof/") {
|
||||
name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/")
|
||||
if name != "" {
|
||||
handler(name).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type profile struct {
|
||||
Name string
|
||||
Href string
|
||||
Desc string
|
||||
Count int
|
||||
}
|
||||
var profiles []profile
|
||||
for _, p := range pprof.Profiles() {
|
||||
profiles = append(profiles, profile{
|
||||
Name: p.Name(),
|
||||
Href: p.Name() + "?debug=1",
|
||||
Desc: profileDescriptions[p.Name()],
|
||||
Count: p.Count(),
|
||||
})
|
||||
}
|
||||
|
||||
// Adding other profiles exposed from within this package
|
||||
for _, p := range []string{"cmdline", "profile", "trace"} {
|
||||
profiles = append(profiles, profile{
|
||||
Name: p,
|
||||
Href: p,
|
||||
Desc: profileDescriptions[p],
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(profiles, func(i, j int) bool {
|
||||
return profiles[i].Name < profiles[j].Name
|
||||
})
|
||||
|
||||
io.WriteString(w, `<html>
|
||||
<head>
|
||||
<title>/debug/pprof/</title>
|
||||
<style>
|
||||
.profile-name{
|
||||
display:inline-block;
|
||||
width:6rem;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
/debug/pprof/<br>
|
||||
<br>
|
||||
Types of profiles available:
|
||||
<table>
|
||||
<thead><td>Count</td><td>Profile</td></thead>
|
||||
`)
|
||||
for _, p := range profiles {
|
||||
fmt.Fprintf(w, "<tr><td>%d</td><td><a href=%v>%v</td></tr>\n",
|
||||
p.Count, html.EscapeString(p.Href), html.EscapeString(p.Name))
|
||||
}
|
||||
io.WriteString(w, `</table>
|
||||
<a href="goroutine?debug=2">full goroutine stack dump</a>
|
||||
<br/>
|
||||
<p>
|
||||
Profile Descriptions:
|
||||
<ul>
|
||||
`)
|
||||
for _, p := range profiles {
|
||||
fmt.Fprintf(w, "<li><div class=profile-name>%s:</div> %s</li>\n",
|
||||
html.EscapeString(p.Name), html.EscapeString(p.Desc))
|
||||
}
|
||||
io.WriteString(w, `
|
||||
</ul>
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
||||
`)
|
||||
}
|
||||
81
tempfork/pprof/pprof_test.go
Normal file
81
tempfork/pprof/pprof_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package pprof
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime/pprof"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestDescriptions checks that the profile names under runtime/pprof package
|
||||
// have a key in the description map.
|
||||
func TestDescriptions(t *testing.T) {
|
||||
for _, p := range pprof.Profiles() {
|
||||
_, ok := profileDescriptions[p.Name()]
|
||||
if ok != true {
|
||||
t.Errorf("%s does not exist in profileDescriptions map\n", p.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
path string
|
||||
handler http.HandlerFunc
|
||||
statusCode int
|
||||
contentType string
|
||||
contentDisposition string
|
||||
resp []byte
|
||||
}{
|
||||
{"/debug/pprof/<script>scripty<script>", Index, http.StatusNotFound, "text/plain; charset=utf-8", "", []byte("Unknown profile\n")},
|
||||
{"/debug/pprof/heap", Index, http.StatusOK, "application/octet-stream", `attachment; filename="heap"`, nil},
|
||||
{"/debug/pprof/heap?debug=1", Index, http.StatusOK, "text/plain; charset=utf-8", "", nil},
|
||||
{"/debug/pprof/cmdline", Cmdline, http.StatusOK, "text/plain; charset=utf-8", "", nil},
|
||||
{"/debug/pprof/profile?seconds=1", Profile, http.StatusOK, "application/octet-stream", `attachment; filename="profile"`, nil},
|
||||
{"/debug/pprof/symbol", Symbol, http.StatusOK, "text/plain; charset=utf-8", "", nil},
|
||||
{"/debug/pprof/trace", Trace, http.StatusOK, "application/octet-stream", `attachment; filename="trace"`, nil},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com"+tc.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
tc.handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
if got, want := resp.StatusCode, tc.statusCode; got != want {
|
||||
t.Errorf("status code: got %d; want %d", got, want)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("when reading response body, expected non-nil err; got %v", err)
|
||||
}
|
||||
if got, want := resp.Header.Get("X-Content-Type-Options"), "nosniff"; got != want {
|
||||
t.Errorf("X-Content-Type-Options: got %q; want %q", got, want)
|
||||
}
|
||||
if got, want := resp.Header.Get("Content-Type"), tc.contentType; got != want {
|
||||
t.Errorf("Content-Type: got %q; want %q", got, want)
|
||||
}
|
||||
if got, want := resp.Header.Get("Content-Disposition"), tc.contentDisposition; got != want {
|
||||
t.Errorf("Content-Disposition: got %q; want %q", got, want)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return
|
||||
}
|
||||
if got, want := resp.Header.Get("X-Go-Pprof"), "1"; got != want {
|
||||
t.Errorf("X-Go-Pprof: got %q; want %q", got, want)
|
||||
}
|
||||
if !bytes.Equal(body, tc.resp) {
|
||||
t.Errorf("response: got %q; want %q", body, tc.resp)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
11
tempfork/registry/export_test.go
Normal file
11
tempfork/registry/export_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package registry
|
||||
|
||||
func (k Key) SetValue(name string, valtype uint32, data []byte) error {
|
||||
return k.setValue(name, valtype, data)
|
||||
}
|
||||
8
tempfork/registry/fix_corp_redo_build_test.go
Normal file
8
tempfork/registry/fix_corp_redo_build_test.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package registry_test
|
||||
|
||||
// Tailscale's redo-based build system doesn't know how to skip running Go tests
|
||||
// in directories that don't contain files for the current OS.
|
||||
//
|
||||
// https://github.com/tailscale/corp/issues/293
|
||||
//
|
||||
// So this is a dummy file for now to make it happy.
|
||||
204
tempfork/registry/key.go
Normal file
204
tempfork/registry/key.go
Normal file
@@ -0,0 +1,204 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
// Package registry provides access to the Windows registry.
|
||||
//
|
||||
// Here is a simple example, opening a registry key and reading a string value from it.
|
||||
//
|
||||
// k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// defer k.Close()
|
||||
//
|
||||
// s, _, err := k.GetStringValue("SystemRoot")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// fmt.Printf("Windows system root is %q\n", s)
|
||||
//
|
||||
package registry
|
||||
|
||||
import (
|
||||
"io"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Registry key security and access rights.
|
||||
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms724878.aspx
|
||||
// for details.
|
||||
ALL_ACCESS = 0xf003f
|
||||
CREATE_LINK = 0x00020
|
||||
CREATE_SUB_KEY = 0x00004
|
||||
ENUMERATE_SUB_KEYS = 0x00008
|
||||
EXECUTE = 0x20019
|
||||
NOTIFY = 0x00010
|
||||
QUERY_VALUE = 0x00001
|
||||
READ = 0x20019
|
||||
SET_VALUE = 0x00002
|
||||
WOW64_32KEY = 0x00200
|
||||
WOW64_64KEY = 0x00100
|
||||
WRITE = 0x20006
|
||||
)
|
||||
|
||||
// Key is a handle to an open Windows registry key.
|
||||
// Keys can be obtained by calling OpenKey; there are
|
||||
// also some predefined root keys such as CURRENT_USER.
|
||||
// Keys can be used directly in the Windows API.
|
||||
type Key syscall.Handle
|
||||
|
||||
const (
|
||||
// Windows defines some predefined root keys that are always open.
|
||||
// An application can use these keys as entry points to the registry.
|
||||
// Normally these keys are used in OpenKey to open new keys,
|
||||
// but they can also be used anywhere a Key is required.
|
||||
CLASSES_ROOT = Key(syscall.HKEY_CLASSES_ROOT)
|
||||
CURRENT_USER = Key(syscall.HKEY_CURRENT_USER)
|
||||
LOCAL_MACHINE = Key(syscall.HKEY_LOCAL_MACHINE)
|
||||
USERS = Key(syscall.HKEY_USERS)
|
||||
CURRENT_CONFIG = Key(syscall.HKEY_CURRENT_CONFIG)
|
||||
PERFORMANCE_DATA = Key(syscall.HKEY_PERFORMANCE_DATA)
|
||||
)
|
||||
|
||||
// Close closes open key k.
|
||||
func (k Key) Close() error {
|
||||
return syscall.RegCloseKey(syscall.Handle(k))
|
||||
}
|
||||
|
||||
// WaitChange waits for k to change using RegNotifyChangeKeyValue.
|
||||
// The subtree parameter is whether subtrees should also be watched.
|
||||
func (k Key) WaitChange(subtree bool) error {
|
||||
return regNotifyChangeKeyValue(syscall.Handle(k), subtree, 0, 0, false)
|
||||
}
|
||||
|
||||
// OpenKey opens a new key with path name relative to key k.
|
||||
// It accepts any open key, including CURRENT_USER and others,
|
||||
// and returns the new key and an error.
|
||||
// The access parameter specifies desired access rights to the
|
||||
// key to be opened.
|
||||
func OpenKey(k Key, path string, access uint32) (Key, error) {
|
||||
p, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var subkey syscall.Handle
|
||||
err = syscall.RegOpenKeyEx(syscall.Handle(k), p, 0, access, &subkey)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return Key(subkey), nil
|
||||
}
|
||||
|
||||
// OpenRemoteKey opens a predefined registry key on another
|
||||
// computer pcname. The key to be opened is specified by k, but
|
||||
// can only be one of LOCAL_MACHINE, PERFORMANCE_DATA or USERS.
|
||||
// If pcname is "", OpenRemoteKey returns local computer key.
|
||||
func OpenRemoteKey(pcname string, k Key) (Key, error) {
|
||||
var err error
|
||||
var p *uint16
|
||||
if pcname != "" {
|
||||
p, err = syscall.UTF16PtrFromString(`\\` + pcname)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
var remoteKey syscall.Handle
|
||||
err = regConnectRegistry(p, syscall.Handle(k), &remoteKey)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return Key(remoteKey), nil
|
||||
}
|
||||
|
||||
// ReadSubKeyNames returns the names of subkeys of key k.
|
||||
// The parameter n controls the number of returned names,
|
||||
// analogous to the way os.File.Readdirnames works.
|
||||
func (k Key) ReadSubKeyNames(n int) ([]string, error) {
|
||||
names := make([]string, 0)
|
||||
// Registry key size limit is 255 bytes and described there:
|
||||
// https://msdn.microsoft.com/library/windows/desktop/ms724872.aspx
|
||||
buf := make([]uint16, 256) //plus extra room for terminating zero byte
|
||||
loopItems:
|
||||
for i := uint32(0); ; i++ {
|
||||
if n > 0 {
|
||||
if len(names) == n {
|
||||
return names, nil
|
||||
}
|
||||
}
|
||||
l := uint32(len(buf))
|
||||
for {
|
||||
err := syscall.RegEnumKeyEx(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err == syscall.ERROR_MORE_DATA {
|
||||
// Double buffer size and try again.
|
||||
l = uint32(2 * len(buf))
|
||||
buf = make([]uint16, l)
|
||||
continue
|
||||
}
|
||||
if err == _ERROR_NO_MORE_ITEMS {
|
||||
break loopItems
|
||||
}
|
||||
return names, err
|
||||
}
|
||||
names = append(names, syscall.UTF16ToString(buf[:l]))
|
||||
}
|
||||
if n > len(names) {
|
||||
return names, io.EOF
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// CreateKey creates a key named path under open key k.
|
||||
// CreateKey returns the new key and a boolean flag that reports
|
||||
// whether the key already existed.
|
||||
// The access parameter specifies the access rights for the key
|
||||
// to be created.
|
||||
func CreateKey(k Key, path string, access uint32) (newk Key, openedExisting bool, err error) {
|
||||
var h syscall.Handle
|
||||
var d uint32
|
||||
err = regCreateKeyEx(syscall.Handle(k), syscall.StringToUTF16Ptr(path),
|
||||
0, nil, _REG_OPTION_NON_VOLATILE, access, nil, &h, &d)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return Key(h), d == _REG_OPENED_EXISTING_KEY, nil
|
||||
}
|
||||
|
||||
// DeleteKey deletes the subkey path of key k and its values.
|
||||
func DeleteKey(k Key, path string) error {
|
||||
return regDeleteKey(syscall.Handle(k), syscall.StringToUTF16Ptr(path))
|
||||
}
|
||||
|
||||
// A KeyInfo describes the statistics of a key. It is returned by Stat.
|
||||
type KeyInfo struct {
|
||||
SubKeyCount uint32
|
||||
MaxSubKeyLen uint32 // size of the key's subkey with the longest name, in Unicode characters, not including the terminating zero byte
|
||||
ValueCount uint32
|
||||
MaxValueNameLen uint32 // size of the key's longest value name, in Unicode characters, not including the terminating zero byte
|
||||
MaxValueLen uint32 // longest data component among the key's values, in bytes
|
||||
lastWriteTime syscall.Filetime
|
||||
}
|
||||
|
||||
// ModTime returns the key's last write time.
|
||||
func (ki *KeyInfo) ModTime() time.Time {
|
||||
return time.Unix(0, ki.lastWriteTime.Nanoseconds())
|
||||
}
|
||||
|
||||
// Stat retrieves information about the open key k.
|
||||
func (k Key) Stat() (*KeyInfo, error) {
|
||||
var ki KeyInfo
|
||||
err := syscall.RegQueryInfoKey(syscall.Handle(k), nil, nil, nil,
|
||||
&ki.SubKeyCount, &ki.MaxSubKeyLen, nil, &ki.ValueCount,
|
||||
&ki.MaxValueNameLen, &ki.MaxValueLen, nil, &ki.lastWriteTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ki, nil
|
||||
}
|
||||
9
tempfork/registry/mksyscall.go
Normal file
9
tempfork/registry/mksyscall.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build generate
|
||||
|
||||
package registry
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall.go
|
||||
701
tempfork/registry/registry_test.go
Normal file
701
tempfork/registry/registry_test.go
Normal file
@@ -0,0 +1,701 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package registry_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"tailscale.com/tempfork/registry"
|
||||
)
|
||||
|
||||
func randKeyName(prefix string) string {
|
||||
const numbers = "0123456789"
|
||||
buf := make([]byte, 10)
|
||||
rand.Read(buf)
|
||||
for i, b := range buf {
|
||||
buf[i] = numbers[b%byte(len(numbers))]
|
||||
}
|
||||
return prefix + string(buf)
|
||||
}
|
||||
|
||||
func TestReadSubKeyNames(t *testing.T) {
|
||||
k, err := registry.OpenKey(registry.CLASSES_ROOT, "TypeLib", registry.ENUMERATE_SUB_KEYS)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
names, err := k.ReadSubKeyNames(-1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var foundStdOle bool
|
||||
for _, name := range names {
|
||||
// Every PC has "stdole 2.0 OLE Automation" library installed.
|
||||
if name == "{00020430-0000-0000-C000-000000000046}" {
|
||||
foundStdOle = true
|
||||
}
|
||||
}
|
||||
if !foundStdOle {
|
||||
t.Fatal("could not find stdole 2.0 OLE Automation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOpenDeleteKey(t *testing.T) {
|
||||
k, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
testKName := randKeyName("TestCreateOpenDeleteKey_")
|
||||
|
||||
testK, exist, err := registry.CreateKey(k, testKName, registry.CREATE_SUB_KEY)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testK.Close()
|
||||
|
||||
if exist {
|
||||
t.Fatalf("key %q already exists", testKName)
|
||||
}
|
||||
|
||||
testKAgain, exist, err := registry.CreateKey(k, testKName, registry.CREATE_SUB_KEY)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testKAgain.Close()
|
||||
|
||||
if !exist {
|
||||
t.Fatalf("key %q should already exist", testKName)
|
||||
}
|
||||
|
||||
testKOpened, err := registry.OpenKey(k, testKName, registry.ENUMERATE_SUB_KEYS)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testKOpened.Close()
|
||||
|
||||
err = registry.DeleteKey(k, testKName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testKOpenedAgain, err := registry.OpenKey(k, testKName, registry.ENUMERATE_SUB_KEYS)
|
||||
if err == nil {
|
||||
defer testKOpenedAgain.Close()
|
||||
t.Fatalf("key %q should already been deleted", testKName)
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
t.Fatalf(`unexpected error ("not exist" expected): %v`, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatch(t *testing.T) {
|
||||
k, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE|registry.WRITE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
testKName := randKeyName("TestWatch_")
|
||||
testK, _, err := registry.CreateKey(k, testKName, registry.CREATE_SUB_KEY|registry.NOTIFY|registry.WRITE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testK.Close()
|
||||
|
||||
timer := time.AfterFunc(100*time.Millisecond, func() {
|
||||
err := registry.DeleteKey(k, testKName)
|
||||
t.Logf("DeleteKey: %v", err)
|
||||
})
|
||||
defer timer.Stop()
|
||||
t.Logf("pre-wait")
|
||||
t0 := time.Now()
|
||||
err = testK.WaitChange(true)
|
||||
t.Logf("WaitChange after %v: %v", time.Since(t0).Round(time.Millisecond), err)
|
||||
}
|
||||
|
||||
func equalStringSlice(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if a == nil {
|
||||
return true
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type ValueTest struct {
|
||||
Type uint32
|
||||
Name string
|
||||
Value interface{}
|
||||
WillFail bool
|
||||
}
|
||||
|
||||
var ValueTests = []ValueTest{
|
||||
{Type: registry.SZ, Name: "String1", Value: ""},
|
||||
{Type: registry.SZ, Name: "String2", Value: "\000", WillFail: true},
|
||||
{Type: registry.SZ, Name: "String3", Value: "Hello World"},
|
||||
{Type: registry.SZ, Name: "String4", Value: "Hello World\000", WillFail: true},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString1", Value: ""},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString2", Value: "\000", WillFail: true},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString3", Value: "Hello World"},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString4", Value: "Hello\000World", WillFail: true},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString5", Value: "%PATH%"},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString6", Value: "%NO_SUCH_VARIABLE%"},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString7", Value: "%PATH%;."},
|
||||
{Type: registry.BINARY, Name: "Binary1", Value: []byte{}},
|
||||
{Type: registry.BINARY, Name: "Binary2", Value: []byte{1, 2, 3}},
|
||||
{Type: registry.BINARY, Name: "Binary3", Value: []byte{3, 2, 1, 0, 1, 2, 3}},
|
||||
{Type: registry.DWORD, Name: "Dword1", Value: uint64(0)},
|
||||
{Type: registry.DWORD, Name: "Dword2", Value: uint64(1)},
|
||||
{Type: registry.DWORD, Name: "Dword3", Value: uint64(0xff)},
|
||||
{Type: registry.DWORD, Name: "Dword4", Value: uint64(0xffff)},
|
||||
{Type: registry.QWORD, Name: "Qword1", Value: uint64(0)},
|
||||
{Type: registry.QWORD, Name: "Qword2", Value: uint64(1)},
|
||||
{Type: registry.QWORD, Name: "Qword3", Value: uint64(0xff)},
|
||||
{Type: registry.QWORD, Name: "Qword4", Value: uint64(0xffff)},
|
||||
{Type: registry.QWORD, Name: "Qword5", Value: uint64(0xffffff)},
|
||||
{Type: registry.QWORD, Name: "Qword6", Value: uint64(0xffffffff)},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString1", Value: []string{"a", "b", "c"}},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString2", Value: []string{"abc", "", "cba"}},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString3", Value: []string{""}},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString4", Value: []string{"abcdef"}},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString5", Value: []string{"\000"}, WillFail: true},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString6", Value: []string{"a\000b"}, WillFail: true},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString7", Value: []string{"ab", "\000", "cd"}, WillFail: true},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString8", Value: []string{"\000", "cd"}, WillFail: true},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString9", Value: []string{"ab", "\000"}, WillFail: true},
|
||||
}
|
||||
|
||||
func setValues(t *testing.T, k registry.Key) {
|
||||
for _, test := range ValueTests {
|
||||
var err error
|
||||
switch test.Type {
|
||||
case registry.SZ:
|
||||
err = k.SetStringValue(test.Name, test.Value.(string))
|
||||
case registry.EXPAND_SZ:
|
||||
err = k.SetExpandStringValue(test.Name, test.Value.(string))
|
||||
case registry.MULTI_SZ:
|
||||
err = k.SetStringsValue(test.Name, test.Value.([]string))
|
||||
case registry.BINARY:
|
||||
err = k.SetBinaryValue(test.Name, test.Value.([]byte))
|
||||
case registry.DWORD:
|
||||
err = k.SetDWordValue(test.Name, uint32(test.Value.(uint64)))
|
||||
case registry.QWORD:
|
||||
err = k.SetQWordValue(test.Name, test.Value.(uint64))
|
||||
default:
|
||||
t.Fatalf("unsupported type %d for %s value", test.Type, test.Name)
|
||||
}
|
||||
if test.WillFail {
|
||||
if err == nil {
|
||||
t.Fatalf("setting %s value %q should fail, but succeeded", test.Name, test.Value)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enumerateValues(t *testing.T, k registry.Key) {
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
haveNames := make(map[string]bool)
|
||||
for _, n := range names {
|
||||
haveNames[n] = false
|
||||
}
|
||||
for _, test := range ValueTests {
|
||||
wantFound := !test.WillFail
|
||||
_, haveFound := haveNames[test.Name]
|
||||
if wantFound && !haveFound {
|
||||
t.Errorf("value %s is not found while enumerating", test.Name)
|
||||
}
|
||||
if haveFound && !wantFound {
|
||||
t.Errorf("value %s is found while enumerating, but expected to fail", test.Name)
|
||||
}
|
||||
if haveFound {
|
||||
delete(haveNames, test.Name)
|
||||
}
|
||||
}
|
||||
for n, v := range haveNames {
|
||||
t.Errorf("value %s (%v) is found while enumerating, but has not been cretaed", n, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testErrNotExist(t *testing.T, name string, err error) {
|
||||
if err == nil {
|
||||
t.Errorf("%s value should not exist", name)
|
||||
return
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
t.Errorf("reading %s value should return 'not exist' error, but got: %s", name, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testErrUnexpectedType(t *testing.T, test ValueTest, gottype uint32, err error) {
|
||||
if err == nil {
|
||||
t.Errorf("GetXValue(%q) should not succeed", test.Name)
|
||||
return
|
||||
}
|
||||
if err != registry.ErrUnexpectedType {
|
||||
t.Errorf("reading %s value should return 'unexpected key value type' error, but got: %s", test.Name, err)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testGetStringValue(t *testing.T, k registry.Key, test ValueTest) {
|
||||
got, gottype, err := k.GetStringValue(test.Name)
|
||||
if err != nil {
|
||||
t.Errorf("GetStringValue(%s) failed: %v", test.Name, err)
|
||||
return
|
||||
}
|
||||
if got != test.Value {
|
||||
t.Errorf("want %s value %q, got %q", test.Name, test.Value, got)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
if gottype == registry.EXPAND_SZ {
|
||||
_, err = registry.ExpandString(got)
|
||||
if err != nil {
|
||||
t.Errorf("ExpandString(%s) failed: %v", got, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testGetIntegerValue(t *testing.T, k registry.Key, test ValueTest) {
|
||||
got, gottype, err := k.GetIntegerValue(test.Name)
|
||||
if err != nil {
|
||||
t.Errorf("GetIntegerValue(%s) failed: %v", test.Name, err)
|
||||
return
|
||||
}
|
||||
if got != test.Value.(uint64) {
|
||||
t.Errorf("want %s value %v, got %v", test.Name, test.Value, got)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testGetBinaryValue(t *testing.T, k registry.Key, test ValueTest) {
|
||||
got, gottype, err := k.GetBinaryValue(test.Name)
|
||||
if err != nil {
|
||||
t.Errorf("GetBinaryValue(%s) failed: %v", test.Name, err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(got, test.Value.([]byte)) {
|
||||
t.Errorf("want %s value %v, got %v", test.Name, test.Value, got)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testGetStringsValue(t *testing.T, k registry.Key, test ValueTest) {
|
||||
got, gottype, err := k.GetStringsValue(test.Name)
|
||||
if err != nil {
|
||||
t.Errorf("GetStringsValue(%s) failed: %v", test.Name, err)
|
||||
return
|
||||
}
|
||||
if !equalStringSlice(got, test.Value.([]string)) {
|
||||
t.Errorf("want %s value %#v, got %#v", test.Name, test.Value, got)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testGetValue(t *testing.T, k registry.Key, test ValueTest, size int) {
|
||||
if size <= 0 {
|
||||
return
|
||||
}
|
||||
// read data with no buffer
|
||||
gotsize, gottype, err := k.GetValue(test.Name, nil)
|
||||
if err != nil {
|
||||
t.Errorf("GetValue(%s, [%d]byte) failed: %v", test.Name, size, err)
|
||||
return
|
||||
}
|
||||
if gotsize != size {
|
||||
t.Errorf("want %s value size of %d, got %v", test.Name, size, gotsize)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
// read data with short buffer
|
||||
gotsize, gottype, err = k.GetValue(test.Name, make([]byte, size-1))
|
||||
if err == nil {
|
||||
t.Errorf("GetValue(%s, [%d]byte) should fail, but succeeded", test.Name, size-1)
|
||||
return
|
||||
}
|
||||
if err != registry.ErrShortBuffer {
|
||||
t.Errorf("reading %s value should return 'short buffer' error, but got: %s", test.Name, err)
|
||||
return
|
||||
}
|
||||
if gotsize != size {
|
||||
t.Errorf("want %s value size of %d, got %v", test.Name, size, gotsize)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
// read full data
|
||||
gotsize, gottype, err = k.GetValue(test.Name, make([]byte, size))
|
||||
if err != nil {
|
||||
t.Errorf("GetValue(%s, [%d]byte) failed: %v", test.Name, size, err)
|
||||
return
|
||||
}
|
||||
if gotsize != size {
|
||||
t.Errorf("want %s value size of %d, got %v", test.Name, size, gotsize)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
// check GetValue returns ErrNotExist as required
|
||||
_, _, err = k.GetValue(test.Name+"_not_there", make([]byte, size))
|
||||
if err == nil {
|
||||
t.Errorf("GetValue(%q) should not succeed", test.Name)
|
||||
return
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
t.Errorf("GetValue(%q) should return 'not exist' error, but got: %s", test.Name, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testValues(t *testing.T, k registry.Key) {
|
||||
for _, test := range ValueTests {
|
||||
switch test.Type {
|
||||
case registry.SZ, registry.EXPAND_SZ:
|
||||
if test.WillFail {
|
||||
_, _, err := k.GetStringValue(test.Name)
|
||||
testErrNotExist(t, test.Name, err)
|
||||
} else {
|
||||
testGetStringValue(t, k, test)
|
||||
_, gottype, err := k.GetIntegerValue(test.Name)
|
||||
testErrUnexpectedType(t, test, gottype, err)
|
||||
// Size of utf16 string in bytes is not perfect,
|
||||
// but correct for current test values.
|
||||
// Size also includes terminating 0.
|
||||
testGetValue(t, k, test, (len(test.Value.(string))+1)*2)
|
||||
}
|
||||
_, _, err := k.GetStringValue(test.Name + "_string_not_created")
|
||||
testErrNotExist(t, test.Name+"_string_not_created", err)
|
||||
case registry.DWORD, registry.QWORD:
|
||||
testGetIntegerValue(t, k, test)
|
||||
_, gottype, err := k.GetBinaryValue(test.Name)
|
||||
testErrUnexpectedType(t, test, gottype, err)
|
||||
_, _, err = k.GetIntegerValue(test.Name + "_int_not_created")
|
||||
testErrNotExist(t, test.Name+"_int_not_created", err)
|
||||
size := 8
|
||||
if test.Type == registry.DWORD {
|
||||
size = 4
|
||||
}
|
||||
testGetValue(t, k, test, size)
|
||||
case registry.BINARY:
|
||||
testGetBinaryValue(t, k, test)
|
||||
_, gottype, err := k.GetStringsValue(test.Name)
|
||||
testErrUnexpectedType(t, test, gottype, err)
|
||||
_, _, err = k.GetBinaryValue(test.Name + "_byte_not_created")
|
||||
testErrNotExist(t, test.Name+"_byte_not_created", err)
|
||||
testGetValue(t, k, test, len(test.Value.([]byte)))
|
||||
case registry.MULTI_SZ:
|
||||
if test.WillFail {
|
||||
_, _, err := k.GetStringsValue(test.Name)
|
||||
testErrNotExist(t, test.Name, err)
|
||||
} else {
|
||||
testGetStringsValue(t, k, test)
|
||||
_, gottype, err := k.GetStringValue(test.Name)
|
||||
testErrUnexpectedType(t, test, gottype, err)
|
||||
size := 0
|
||||
for _, s := range test.Value.([]string) {
|
||||
size += len(s) + 1 // nil terminated
|
||||
}
|
||||
size += 1 // extra nil at the end
|
||||
size *= 2 // count bytes, not uint16
|
||||
testGetValue(t, k, test, size)
|
||||
}
|
||||
_, _, err := k.GetStringsValue(test.Name + "_strings_not_created")
|
||||
testErrNotExist(t, test.Name+"_strings_not_created", err)
|
||||
default:
|
||||
t.Errorf("unsupported type %d for %s value", test.Type, test.Name)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testStat(t *testing.T, k registry.Key) {
|
||||
subk, _, err := registry.CreateKey(k, "subkey", registry.CREATE_SUB_KEY)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer subk.Close()
|
||||
|
||||
defer registry.DeleteKey(k, "subkey")
|
||||
|
||||
ki, err := k.Stat()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if ki.SubKeyCount != 1 {
|
||||
t.Error("key must have 1 subkey")
|
||||
}
|
||||
if ki.MaxSubKeyLen != 6 {
|
||||
t.Error("key max subkey name length must be 6")
|
||||
}
|
||||
if ki.ValueCount != 24 {
|
||||
t.Errorf("key must have 24 values, but is %d", ki.ValueCount)
|
||||
}
|
||||
if ki.MaxValueNameLen != 12 {
|
||||
t.Errorf("key max value name length must be 10, but is %d", ki.MaxValueNameLen)
|
||||
}
|
||||
if ki.MaxValueLen != 38 {
|
||||
t.Errorf("key max value length must be 38, but is %d", ki.MaxValueLen)
|
||||
}
|
||||
if mt, ct := ki.ModTime(), time.Now(); ct.Sub(mt) > 100*time.Millisecond {
|
||||
t.Errorf("key mod time is not close to current time: mtime=%v current=%v delta=%v", mt, ct, ct.Sub(mt))
|
||||
}
|
||||
}
|
||||
|
||||
func deleteValues(t *testing.T, k registry.Key) {
|
||||
for _, test := range ValueTests {
|
||||
if test.WillFail {
|
||||
continue
|
||||
}
|
||||
err := k.DeleteValue(test.Name)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(names) != 0 {
|
||||
t.Errorf("some values remain after deletion: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
softwareK, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer softwareK.Close()
|
||||
|
||||
testKName := randKeyName("TestValues_")
|
||||
|
||||
k, exist, err := registry.CreateKey(softwareK, testKName, registry.CREATE_SUB_KEY|registry.QUERY_VALUE|registry.SET_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
if exist {
|
||||
t.Fatalf("key %q already exists", testKName)
|
||||
}
|
||||
|
||||
defer registry.DeleteKey(softwareK, testKName)
|
||||
|
||||
setValues(t, k)
|
||||
|
||||
enumerateValues(t, k)
|
||||
|
||||
testValues(t, k)
|
||||
|
||||
testStat(t, k)
|
||||
|
||||
deleteValues(t, k)
|
||||
}
|
||||
|
||||
func TestExpandString(t *testing.T) {
|
||||
got, err := registry.ExpandString("%PATH%")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := os.Getenv("PATH")
|
||||
if got != want {
|
||||
t.Errorf("want %q string expanded, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidValues(t *testing.T) {
|
||||
softwareK, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer softwareK.Close()
|
||||
|
||||
testKName := randKeyName("TestInvalidValues_")
|
||||
|
||||
k, exist, err := registry.CreateKey(softwareK, testKName, registry.CREATE_SUB_KEY|registry.QUERY_VALUE|registry.SET_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
if exist {
|
||||
t.Fatalf("key %q already exists", testKName)
|
||||
}
|
||||
|
||||
defer registry.DeleteKey(softwareK, testKName)
|
||||
|
||||
var tests = []struct {
|
||||
Type uint32
|
||||
Name string
|
||||
Data []byte
|
||||
}{
|
||||
{registry.DWORD, "Dword1", nil},
|
||||
{registry.DWORD, "Dword2", []byte{1, 2, 3}},
|
||||
{registry.QWORD, "Qword1", nil},
|
||||
{registry.QWORD, "Qword2", []byte{1, 2, 3}},
|
||||
{registry.QWORD, "Qword3", []byte{1, 2, 3, 4, 5, 6, 7}},
|
||||
{registry.MULTI_SZ, "MultiString1", nil},
|
||||
{registry.MULTI_SZ, "MultiString2", []byte{0}},
|
||||
{registry.MULTI_SZ, "MultiString3", []byte{'a', 'b', 0}},
|
||||
{registry.MULTI_SZ, "MultiString4", []byte{'a', 0, 0, 'b', 0}},
|
||||
{registry.MULTI_SZ, "MultiString5", []byte{'a', 0, 0}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
err := k.SetValue(test.Name, test.Type, test.Data)
|
||||
if err != nil {
|
||||
t.Fatalf("SetValue for %q failed: %v", test.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
switch test.Type {
|
||||
case registry.DWORD, registry.QWORD:
|
||||
value, valType, err := k.GetIntegerValue(test.Name)
|
||||
if err == nil {
|
||||
t.Errorf("GetIntegerValue(%q) succeeded. Returns type=%d value=%v", test.Name, valType, value)
|
||||
}
|
||||
case registry.MULTI_SZ:
|
||||
value, valType, err := k.GetStringsValue(test.Name)
|
||||
if err == nil {
|
||||
if len(value) != 0 {
|
||||
t.Errorf("GetStringsValue(%q) succeeded. Returns type=%d value=%v", test.Name, valType, value)
|
||||
}
|
||||
}
|
||||
default:
|
||||
t.Errorf("unsupported type %d for %s value", test.Type, test.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMUIStringValue(t *testing.T) {
|
||||
if err := registry.LoadRegLoadMUIString(); err != nil {
|
||||
t.Skip("regLoadMUIString not supported; skipping")
|
||||
}
|
||||
if err := procGetDynamicTimeZoneInformation.Find(); err != nil {
|
||||
t.Skipf("%s not supported; skipping", procGetDynamicTimeZoneInformation.Name)
|
||||
}
|
||||
var dtzi DynamicTimezoneinformation
|
||||
if _, err := GetDynamicTimeZoneInformation(&dtzi); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tzKeyName := syscall.UTF16ToString(dtzi.TimeZoneKeyName[:])
|
||||
timezoneK, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones\`+tzKeyName, registry.READ)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer timezoneK.Close()
|
||||
|
||||
type testType struct {
|
||||
name string
|
||||
want string
|
||||
}
|
||||
var tests = []testType{
|
||||
{"MUI_Std", syscall.UTF16ToString(dtzi.StandardName[:])},
|
||||
}
|
||||
if dtzi.DynamicDaylightTimeDisabled == 0 {
|
||||
tests = append(tests, testType{"MUI_Dlt", syscall.UTF16ToString(dtzi.DaylightName[:])})
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
got, err := timezoneK.GetMUIStringValue(test.name)
|
||||
if err != nil {
|
||||
t.Error("GetMUIStringValue:", err)
|
||||
}
|
||||
|
||||
if got != test.want {
|
||||
t.Errorf("GetMUIStringValue: %s: Got %q, want %q", test.name, got, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type DynamicTimezoneinformation struct {
|
||||
Bias int32
|
||||
StandardName [32]uint16
|
||||
StandardDate syscall.Systemtime
|
||||
StandardBias int32
|
||||
DaylightName [32]uint16
|
||||
DaylightDate syscall.Systemtime
|
||||
DaylightBias int32
|
||||
TimeZoneKeyName [128]uint16
|
||||
DynamicDaylightTimeDisabled uint8
|
||||
}
|
||||
|
||||
var (
|
||||
kernel32DLL = syscall.NewLazyDLL("kernel32")
|
||||
|
||||
procGetDynamicTimeZoneInformation = kernel32DLL.NewProc("GetDynamicTimeZoneInformation")
|
||||
)
|
||||
|
||||
func GetDynamicTimeZoneInformation(dtzi *DynamicTimezoneinformation) (rc uint32, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procGetDynamicTimeZoneInformation.Addr(), 1, uintptr(unsafe.Pointer(dtzi)), 0, 0)
|
||||
rc = uint32(r0)
|
||||
if rc == 0xffffffff {
|
||||
if e1 != 0 {
|
||||
err = error(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
33
tempfork/registry/syscall.go
Normal file
33
tempfork/registry/syscall.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package registry
|
||||
|
||||
import "syscall"
|
||||
|
||||
const (
|
||||
_REG_OPTION_NON_VOLATILE = 0
|
||||
|
||||
_REG_CREATED_NEW_KEY = 1
|
||||
_REG_OPENED_EXISTING_KEY = 2
|
||||
|
||||
_ERROR_NO_MORE_ITEMS syscall.Errno = 259
|
||||
)
|
||||
|
||||
func LoadRegLoadMUIString() error {
|
||||
return procRegLoadMUIStringW.Find()
|
||||
}
|
||||
|
||||
//sys regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) = advapi32.RegCreateKeyExW
|
||||
//sys regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) = advapi32.RegDeleteKeyW
|
||||
//sys regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) = advapi32.RegSetValueExW
|
||||
//sys regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) = advapi32.RegEnumValueW
|
||||
//sys regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) = advapi32.RegDeleteValueW
|
||||
//sys regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) = advapi32.RegLoadMUIStringW
|
||||
//sys regConnectRegistry(machinename *uint16, key syscall.Handle, result *syscall.Handle) (regerrno error) = advapi32.RegConnectRegistryW
|
||||
//sys regNotifyChangeKeyValue(key syscall.Handle, watchSubtree bool, notifyFilter uint32, event syscall.Handle, async bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue
|
||||
|
||||
//sys expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) = kernel32.ExpandEnvironmentStringsW
|
||||
386
tempfork/registry/value.go
Normal file
386
tempfork/registry/value.go
Normal file
@@ -0,0 +1,386 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package registry
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"syscall"
|
||||
"unicode/utf16"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Registry value types.
|
||||
NONE = 0
|
||||
SZ = 1
|
||||
EXPAND_SZ = 2
|
||||
BINARY = 3
|
||||
DWORD = 4
|
||||
DWORD_BIG_ENDIAN = 5
|
||||
LINK = 6
|
||||
MULTI_SZ = 7
|
||||
RESOURCE_LIST = 8
|
||||
FULL_RESOURCE_DESCRIPTOR = 9
|
||||
RESOURCE_REQUIREMENTS_LIST = 10
|
||||
QWORD = 11
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrShortBuffer is returned when the buffer was too short for the operation.
|
||||
ErrShortBuffer = syscall.ERROR_MORE_DATA
|
||||
|
||||
// ErrNotExist is returned when a registry key or value does not exist.
|
||||
ErrNotExist = syscall.ERROR_FILE_NOT_FOUND
|
||||
|
||||
// ErrUnexpectedType is returned by Get*Value when the value's type was unexpected.
|
||||
ErrUnexpectedType = errors.New("unexpected key value type")
|
||||
)
|
||||
|
||||
// GetValue retrieves the type and data for the specified value associated
|
||||
// with an open key k. It fills up buffer buf and returns the retrieved
|
||||
// byte count n. If buf is too small to fit the stored value it returns
|
||||
// ErrShortBuffer error along with the required buffer size n.
|
||||
// If no buffer is provided, it returns true and actual buffer size n.
|
||||
// If no buffer is provided, GetValue returns the value's type only.
|
||||
// If the value does not exist, the error returned is ErrNotExist.
|
||||
//
|
||||
// GetValue is a low level function. If value's type is known, use the appropriate
|
||||
// Get*Value function instead.
|
||||
func (k Key) GetValue(name string, buf []byte) (n int, valtype uint32, err error) {
|
||||
pname, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
var pbuf *byte
|
||||
if len(buf) > 0 {
|
||||
pbuf = (*byte)(unsafe.Pointer(&buf[0]))
|
||||
}
|
||||
l := uint32(len(buf))
|
||||
err = syscall.RegQueryValueEx(syscall.Handle(k), pname, nil, &valtype, pbuf, &l)
|
||||
if err != nil {
|
||||
return int(l), valtype, err
|
||||
}
|
||||
return int(l), valtype, nil
|
||||
}
|
||||
|
||||
func (k Key) getValue(name string, buf []byte) (data []byte, valtype uint32, err error) {
|
||||
p, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var t uint32
|
||||
n := uint32(len(buf))
|
||||
for {
|
||||
err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n)
|
||||
if err == nil {
|
||||
return buf[:n], t, nil
|
||||
}
|
||||
if err != syscall.ERROR_MORE_DATA {
|
||||
return nil, 0, err
|
||||
}
|
||||
if n <= uint32(len(buf)) {
|
||||
return nil, 0, err
|
||||
}
|
||||
buf = make([]byte, n)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStringValue retrieves the string value for the specified
|
||||
// value name associated with an open key k. It also returns the value's type.
|
||||
// If value does not exist, GetStringValue returns ErrNotExist.
|
||||
// If value is not SZ or EXPAND_SZ, it will return the correct value
|
||||
// type and ErrUnexpectedType.
|
||||
func (k Key) GetStringValue(name string) (val string, valtype uint32, err error) {
|
||||
data, typ, err2 := k.getValue(name, make([]byte, 64))
|
||||
if err2 != nil {
|
||||
return "", typ, err2
|
||||
}
|
||||
switch typ {
|
||||
case SZ, EXPAND_SZ:
|
||||
default:
|
||||
return "", typ, ErrUnexpectedType
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return "", typ, nil
|
||||
}
|
||||
u := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[: len(data)/2 : len(data)/2]
|
||||
return syscall.UTF16ToString(u), typ, nil
|
||||
}
|
||||
|
||||
// GetMUIStringValue retrieves the localized string value for
|
||||
// the specified value name associated with an open key k.
|
||||
// If the value name doesn't exist or the localized string value
|
||||
// can't be resolved, GetMUIStringValue returns ErrNotExist.
|
||||
// GetMUIStringValue panics if the system doesn't support
|
||||
// regLoadMUIString; use LoadRegLoadMUIString to check if
|
||||
// regLoadMUIString is supported before calling this function.
|
||||
func (k Key) GetMUIStringValue(name string) (string, error) {
|
||||
pname, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
buf := make([]uint16, 1024)
|
||||
var buflen uint32
|
||||
var pdir *uint16
|
||||
|
||||
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
|
||||
if err == syscall.ERROR_FILE_NOT_FOUND { // Try fallback path
|
||||
|
||||
// Try to resolve the string value using the system directory as
|
||||
// a DLL search path; this assumes the string value is of the form
|
||||
// @[path]\dllname,-strID but with no path given, e.g. @tzres.dll,-320.
|
||||
|
||||
// This approach works with tzres.dll but may have to be revised
|
||||
// in the future to allow callers to provide custom search paths.
|
||||
|
||||
var s string
|
||||
s, err = ExpandString("%SystemRoot%\\system32\\")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pdir, err = syscall.UTF16PtrFromString(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
|
||||
}
|
||||
|
||||
for err == syscall.ERROR_MORE_DATA { // Grow buffer if needed
|
||||
if buflen <= uint32(len(buf)) {
|
||||
break // Buffer not growing, assume race; break
|
||||
}
|
||||
buf = make([]uint16, buflen)
|
||||
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return syscall.UTF16ToString(buf), nil
|
||||
}
|
||||
|
||||
// ExpandString expands environment-variable strings and replaces
|
||||
// them with the values defined for the current user.
|
||||
// Use ExpandString to expand EXPAND_SZ strings.
|
||||
func ExpandString(value string) (string, error) {
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
p, err := syscall.UTF16PtrFromString(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
r := make([]uint16, 100)
|
||||
for {
|
||||
n, err := expandEnvironmentStrings(p, &r[0], uint32(len(r)))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n <= uint32(len(r)) {
|
||||
return syscall.UTF16ToString(r[:n]), nil
|
||||
}
|
||||
r = make([]uint16, n)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStringsValue retrieves the []string value for the specified
|
||||
// value name associated with an open key k. It also returns the value's type.
|
||||
// If value does not exist, GetStringsValue returns ErrNotExist.
|
||||
// If value is not MULTI_SZ, it will return the correct value
|
||||
// type and ErrUnexpectedType.
|
||||
func (k Key) GetStringsValue(name string) (val []string, valtype uint32, err error) {
|
||||
data, typ, err2 := k.getValue(name, make([]byte, 64))
|
||||
if err2 != nil {
|
||||
return nil, typ, err2
|
||||
}
|
||||
if typ != MULTI_SZ {
|
||||
return nil, typ, ErrUnexpectedType
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, typ, nil
|
||||
}
|
||||
p := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[: len(data)/2 : len(data)/2]
|
||||
if len(p) == 0 {
|
||||
return nil, typ, nil
|
||||
}
|
||||
if p[len(p)-1] == 0 {
|
||||
p = p[:len(p)-1] // remove terminating null
|
||||
}
|
||||
val = make([]string, 0, 5)
|
||||
from := 0
|
||||
for i, c := range p {
|
||||
if c == 0 {
|
||||
val = append(val, string(utf16.Decode(p[from:i])))
|
||||
from = i + 1
|
||||
}
|
||||
}
|
||||
return val, typ, nil
|
||||
}
|
||||
|
||||
// GetIntegerValue retrieves the integer value for the specified
|
||||
// value name associated with an open key k. It also returns the value's type.
|
||||
// If value does not exist, GetIntegerValue returns ErrNotExist.
|
||||
// If value is not DWORD or QWORD, it will return the correct value
|
||||
// type and ErrUnexpectedType.
|
||||
func (k Key) GetIntegerValue(name string) (val uint64, valtype uint32, err error) {
|
||||
data, typ, err2 := k.getValue(name, make([]byte, 8))
|
||||
if err2 != nil {
|
||||
return 0, typ, err2
|
||||
}
|
||||
switch typ {
|
||||
case DWORD:
|
||||
if len(data) != 4 {
|
||||
return 0, typ, errors.New("DWORD value is not 4 bytes long")
|
||||
}
|
||||
var val32 uint32
|
||||
copy((*[4]byte)(unsafe.Pointer(&val32))[:], data)
|
||||
return uint64(val32), DWORD, nil
|
||||
case QWORD:
|
||||
if len(data) != 8 {
|
||||
return 0, typ, errors.New("QWORD value is not 8 bytes long")
|
||||
}
|
||||
copy((*[8]byte)(unsafe.Pointer(&val))[:], data)
|
||||
return val, QWORD, nil
|
||||
default:
|
||||
return 0, typ, ErrUnexpectedType
|
||||
}
|
||||
}
|
||||
|
||||
// GetBinaryValue retrieves the binary value for the specified
|
||||
// value name associated with an open key k. It also returns the value's type.
|
||||
// If value does not exist, GetBinaryValue returns ErrNotExist.
|
||||
// If value is not BINARY, it will return the correct value
|
||||
// type and ErrUnexpectedType.
|
||||
func (k Key) GetBinaryValue(name string) (val []byte, valtype uint32, err error) {
|
||||
data, typ, err2 := k.getValue(name, make([]byte, 64))
|
||||
if err2 != nil {
|
||||
return nil, typ, err2
|
||||
}
|
||||
if typ != BINARY {
|
||||
return nil, typ, ErrUnexpectedType
|
||||
}
|
||||
return data, typ, nil
|
||||
}
|
||||
|
||||
func (k Key) setValue(name string, valtype uint32, data []byte) error {
|
||||
p, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return regSetValueEx(syscall.Handle(k), p, 0, valtype, nil, 0)
|
||||
}
|
||||
return regSetValueEx(syscall.Handle(k), p, 0, valtype, &data[0], uint32(len(data)))
|
||||
}
|
||||
|
||||
// SetDWordValue sets the data and type of a name value
|
||||
// under key k to value and DWORD.
|
||||
func (k Key) SetDWordValue(name string, value uint32) error {
|
||||
return k.setValue(name, DWORD, (*[4]byte)(unsafe.Pointer(&value))[:])
|
||||
}
|
||||
|
||||
// SetQWordValue sets the data and type of a name value
|
||||
// under key k to value and QWORD.
|
||||
func (k Key) SetQWordValue(name string, value uint64) error {
|
||||
return k.setValue(name, QWORD, (*[8]byte)(unsafe.Pointer(&value))[:])
|
||||
}
|
||||
|
||||
func (k Key) setStringValue(name string, valtype uint32, value string) error {
|
||||
v, err := syscall.UTF16FromString(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[: len(v)*2 : len(v)*2]
|
||||
return k.setValue(name, valtype, buf)
|
||||
}
|
||||
|
||||
// SetStringValue sets the data and type of a name value
|
||||
// under key k to value and SZ. The value must not contain a zero byte.
|
||||
func (k Key) SetStringValue(name, value string) error {
|
||||
return k.setStringValue(name, SZ, value)
|
||||
}
|
||||
|
||||
// SetExpandStringValue sets the data and type of a name value
|
||||
// under key k to value and EXPAND_SZ. The value must not contain a zero byte.
|
||||
func (k Key) SetExpandStringValue(name, value string) error {
|
||||
return k.setStringValue(name, EXPAND_SZ, value)
|
||||
}
|
||||
|
||||
// SetStringsValue sets the data and type of a name value
|
||||
// under key k to value and MULTI_SZ. The value strings
|
||||
// must not contain a zero byte.
|
||||
func (k Key) SetStringsValue(name string, value []string) error {
|
||||
ss := ""
|
||||
for _, s := range value {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == 0 {
|
||||
return errors.New("string cannot have 0 inside")
|
||||
}
|
||||
}
|
||||
ss += s + "\x00"
|
||||
}
|
||||
v := utf16.Encode([]rune(ss + "\x00"))
|
||||
buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[: len(v)*2 : len(v)*2]
|
||||
return k.setValue(name, MULTI_SZ, buf)
|
||||
}
|
||||
|
||||
// SetBinaryValue sets the data and type of a name value
|
||||
// under key k to value and BINARY.
|
||||
func (k Key) SetBinaryValue(name string, value []byte) error {
|
||||
return k.setValue(name, BINARY, value)
|
||||
}
|
||||
|
||||
// DeleteValue removes a named value from the key k.
|
||||
func (k Key) DeleteValue(name string) error {
|
||||
return regDeleteValue(syscall.Handle(k), syscall.StringToUTF16Ptr(name))
|
||||
}
|
||||
|
||||
// ReadValueNames returns the value names of key k.
|
||||
// The parameter n controls the number of returned names,
|
||||
// analogous to the way os.File.Readdirnames works.
|
||||
func (k Key) ReadValueNames(n int) ([]string, error) {
|
||||
ki, err := k.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
names := make([]string, 0, ki.ValueCount)
|
||||
buf := make([]uint16, ki.MaxValueNameLen+1) // extra room for terminating null character
|
||||
loopItems:
|
||||
for i := uint32(0); ; i++ {
|
||||
if n > 0 {
|
||||
if len(names) == n {
|
||||
return names, nil
|
||||
}
|
||||
}
|
||||
l := uint32(len(buf))
|
||||
for {
|
||||
err := regEnumValue(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err == syscall.ERROR_MORE_DATA {
|
||||
// Double buffer size and try again.
|
||||
l = uint32(2 * len(buf))
|
||||
buf = make([]uint16, l)
|
||||
continue
|
||||
}
|
||||
if err == _ERROR_NO_MORE_ITEMS {
|
||||
break loopItems
|
||||
}
|
||||
return names, err
|
||||
}
|
||||
names = append(names, syscall.UTF16ToString(buf[:l]))
|
||||
}
|
||||
if n > len(names) {
|
||||
return names, io.EOF
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
141
tempfork/registry/zsyscall_windows.go
Normal file
141
tempfork/registry/zsyscall_windows.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package registry
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return nil
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
|
||||
procRegCreateKeyExW = modadvapi32.NewProc("RegCreateKeyExW")
|
||||
procRegDeleteKeyW = modadvapi32.NewProc("RegDeleteKeyW")
|
||||
procRegSetValueExW = modadvapi32.NewProc("RegSetValueExW")
|
||||
procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW")
|
||||
procRegDeleteValueW = modadvapi32.NewProc("RegDeleteValueW")
|
||||
procRegLoadMUIStringW = modadvapi32.NewProc("RegLoadMUIStringW")
|
||||
procRegConnectRegistryW = modadvapi32.NewProc("RegConnectRegistryW")
|
||||
procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue")
|
||||
procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW")
|
||||
)
|
||||
|
||||
func regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall9(procRegCreateKeyExW.Addr(), 9, uintptr(key), uintptr(unsafe.Pointer(subkey)), uintptr(reserved), uintptr(unsafe.Pointer(class)), uintptr(options), uintptr(desired), uintptr(unsafe.Pointer(sa)), uintptr(unsafe.Pointer(result)), uintptr(unsafe.Pointer(disposition)))
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall(procRegDeleteKeyW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(subkey)), 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall6(procRegSetValueExW.Addr(), 6, uintptr(key), uintptr(unsafe.Pointer(valueName)), uintptr(reserved), uintptr(vtype), uintptr(unsafe.Pointer(buf)), uintptr(bufsize))
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valtype)), uintptr(unsafe.Pointer(buf)), uintptr(unsafe.Pointer(buflen)), 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall(procRegDeleteValueW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(name)), 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall9(procRegLoadMUIStringW.Addr(), 7, uintptr(key), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(unsafe.Pointer(buflenCopied)), uintptr(flags), uintptr(unsafe.Pointer(dir)), 0, 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regConnectRegistry(machinename *uint16, key syscall.Handle, result *syscall.Handle) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall(procRegConnectRegistryW.Addr(), 3, uintptr(unsafe.Pointer(machinename)), uintptr(key), uintptr(unsafe.Pointer(result)))
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regNotifyChangeKeyValue(key syscall.Handle, watchSubtree bool, notifyFilter uint32, event syscall.Handle, async bool) (regerrno error) {
|
||||
var _p0 uint32
|
||||
if watchSubtree {
|
||||
_p0 = 1
|
||||
} else {
|
||||
_p0 = 0
|
||||
}
|
||||
var _p1 uint32
|
||||
if async {
|
||||
_p1 = 1
|
||||
} else {
|
||||
_p1 = 0
|
||||
}
|
||||
r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procExpandEnvironmentStringsW.Addr(), 3, uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(size))
|
||||
n = uint32(r0)
|
||||
if n == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -30,16 +30,27 @@ type Clock struct {
|
||||
func (c *Clock) Now() time.Time {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.initLocked()
|
||||
step := c.Step
|
||||
ret := c.Present
|
||||
c.Present = c.Present.Add(step)
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *Clock) Advance(d time.Duration) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.initLocked()
|
||||
c.Present = c.Present.Add(d)
|
||||
}
|
||||
|
||||
func (c *Clock) initLocked() {
|
||||
if c.Start.IsZero() {
|
||||
c.Start = time.Now()
|
||||
}
|
||||
if c.Present.Before(c.Start) {
|
||||
c.Present = c.Start
|
||||
}
|
||||
step := c.Step
|
||||
ret := c.Present
|
||||
c.Present = c.Present.Add(step)
|
||||
return ret
|
||||
}
|
||||
|
||||
// Reset rewinds the virtual clock to its start time.
|
||||
|
||||
157
tstest/natlab/firewall.go
Normal file
157
tstest/natlab/firewall.go
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package natlab
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// FirewallType is the type of filtering a stateful firewall
|
||||
// does. Values express different modes defined by RFC 4787.
|
||||
type FirewallType int
|
||||
|
||||
const (
|
||||
// AddressAndPortDependentFirewall specifies a destination
|
||||
// address-and-port dependent firewall. Outbound traffic to an
|
||||
// ip:port authorizes traffic from that ip:port exactly, and
|
||||
// nothing else.
|
||||
AddressAndPortDependentFirewall FirewallType = iota
|
||||
// AddressDependentFirewall specifies a destination address
|
||||
// dependent firewall. Once outbound traffic has been seen to an
|
||||
// IP address, that IP address can talk back from any port.
|
||||
AddressDependentFirewall
|
||||
// EndpointIndependentFirewall specifies a destination endpoint
|
||||
// independent firewall. Once outbound traffic has been seen from
|
||||
// a source, anyone can talk back to that source.
|
||||
EndpointIndependentFirewall
|
||||
)
|
||||
|
||||
// fwKey is the lookup key for a firewall session. While it contains a
|
||||
// 4-tuple ({src,dst} {ip,port}), some FirewallTypes will zero out
|
||||
// some fields, so in practice the key is either a 2-tuple (src only),
|
||||
// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port).
|
||||
type fwKey struct {
|
||||
src netaddr.IPPort
|
||||
dst netaddr.IPPort
|
||||
}
|
||||
|
||||
// key returns an fwKey for the given src and dst, trimmed according
|
||||
// to the FirewallType. fwKeys are always constructed from the
|
||||
// "outbound" point of view (i.e. src is the "trusted" side of the
|
||||
// world), it's the caller's responsibility to swap src and dst in the
|
||||
// call to key when processing packets inbound from the "untrusted"
|
||||
// world.
|
||||
func (s FirewallType) key(src, dst netaddr.IPPort) fwKey {
|
||||
k := fwKey{src: src}
|
||||
switch s {
|
||||
case EndpointIndependentFirewall:
|
||||
case AddressDependentFirewall:
|
||||
k.dst.IP = dst.IP
|
||||
case AddressAndPortDependentFirewall:
|
||||
k.dst = dst
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown firewall selectivity %v", s))
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
// DefaultSessionTimeout is the default timeout for a firewall
|
||||
// session.
|
||||
const DefaultSessionTimeout = 30 * time.Second
|
||||
|
||||
// Firewall is a simple stateful firewall that allows all outbound
|
||||
// traffic and filters inbound traffic based on recently seen outbound
|
||||
// traffic. Its HandlePacket method should be attached to a Machine to
|
||||
// give it a stateful firewall.
|
||||
type Firewall struct {
|
||||
// SessionTimeout is the lifetime of idle sessions in the firewall
|
||||
// state. Packets transiting from the TrustedInterface reset the
|
||||
// session lifetime to SessionTimeout. If zero,
|
||||
// DefaultSessionTimeout is used.
|
||||
SessionTimeout time.Duration
|
||||
// Type specifies how precisely return traffic must match
|
||||
// previously seen outbound traffic to be allowed. Defaults to
|
||||
// AddressAndPortDependentFirewall.
|
||||
Type FirewallType
|
||||
// TrustedInterface is an optional interface that is considered
|
||||
// trusted in addition to PacketConns local to the Machine. All
|
||||
// other interfaces can only respond to traffic from
|
||||
// TrustedInterface or the local host.
|
||||
TrustedInterface *Interface
|
||||
// TimeNow is a function returning the current time. If nil,
|
||||
// time.Now is used.
|
||||
TimeNow func() time.Time
|
||||
|
||||
// TODO: refresh directionality: outbound-only, both
|
||||
|
||||
mu sync.Mutex
|
||||
seen map[fwKey]time.Time // session -> deadline
|
||||
}
|
||||
|
||||
func (f *Firewall) timeNow() time.Time {
|
||||
if f.TimeNow != nil {
|
||||
return f.TimeNow()
|
||||
}
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func (f *Firewall) init() {
|
||||
if f.seen == nil {
|
||||
f.seen = map[fwKey]time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.init()
|
||||
|
||||
k := f.Type.key(p.Src, p.Dst)
|
||||
f.seen[k] = f.timeNow().Add(f.sessionTimeoutLocked())
|
||||
p.Trace("firewall out ok")
|
||||
return p
|
||||
}
|
||||
|
||||
func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.init()
|
||||
|
||||
// reverse src and dst because the session table is from the POV
|
||||
// of outbound packets.
|
||||
k := f.Type.key(p.Dst, p.Src)
|
||||
now := f.timeNow()
|
||||
if now.After(f.seen[k]) {
|
||||
p.Trace("firewall drop")
|
||||
return nil
|
||||
}
|
||||
p.Trace("firewall in ok")
|
||||
return p
|
||||
}
|
||||
|
||||
func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet {
|
||||
if iif == f.TrustedInterface {
|
||||
// Treat just like a locally originated packet
|
||||
return f.HandleOut(p, oif)
|
||||
}
|
||||
if oif != f.TrustedInterface {
|
||||
// Not a possible return packet from our trusted interface, drop.
|
||||
p.Trace("firewall drop, unexpected oif")
|
||||
return nil
|
||||
}
|
||||
// Otherwise, a session must exist, same as HandleIn.
|
||||
return f.HandleIn(p, iif)
|
||||
}
|
||||
|
||||
func (f *Firewall) sessionTimeoutLocked() time.Duration {
|
||||
if f.SessionTimeout == 0 {
|
||||
return DefaultSessionTimeout
|
||||
}
|
||||
return f.SessionTimeout
|
||||
}
|
||||
257
tstest/natlab/nat.go
Normal file
257
tstest/natlab/nat.go
Normal file
@@ -0,0 +1,257 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package natlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// mapping is the state of an allocated NAT session.
|
||||
type mapping struct {
|
||||
lanSrc netaddr.IPPort
|
||||
lanDst netaddr.IPPort
|
||||
wanSrc netaddr.IPPort
|
||||
deadline time.Time
|
||||
|
||||
// pc is a PacketConn that reserves an outbound port on the NAT's
|
||||
// WAN interface. We do this because ListenPacket already has
|
||||
// random port selection logic built in. Additionally this means
|
||||
// that concurrent use of ListenPacket for connections originating
|
||||
// from the NAT box won't conflict with NAT mappings, since both
|
||||
// use PacketConn to reserve ports on the machine.
|
||||
pc net.PacketConn
|
||||
}
|
||||
|
||||
// NATType is the mapping behavior of a NAT device. Values express
|
||||
// different modes defined by RFC 4787.
|
||||
type NATType int
|
||||
|
||||
const (
|
||||
// EndpointIndependentNAT specifies a destination endpoint
|
||||
// independent NAT. All traffic from a source ip:port gets mapped
|
||||
// to a single WAN ip:port.
|
||||
EndpointIndependentNAT NATType = iota
|
||||
// AddressDependentNAT specifies a destination address dependent
|
||||
// NAT. Every distinct destination IP gets its own WAN ip:port
|
||||
// allocation.
|
||||
AddressDependentNAT
|
||||
// AddressAndPortDependentNAT specifies a destination
|
||||
// address-and-port dependent NAT. Every distinct destination
|
||||
// ip:port gets its own WAN ip:port allocation.
|
||||
AddressAndPortDependentNAT
|
||||
)
|
||||
|
||||
// natKey is the lookup key for a NAT session. While it contains a
|
||||
// 4-tuple ({src,dst} {ip,port}), some NATTypes will zero out some
|
||||
// fields, so in practice the key is either a 2-tuple (src only),
|
||||
// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port).
|
||||
type natKey struct {
|
||||
src, dst netaddr.IPPort
|
||||
}
|
||||
|
||||
func (t NATType) key(src, dst netaddr.IPPort) natKey {
|
||||
k := natKey{src: src}
|
||||
switch t {
|
||||
case EndpointIndependentNAT:
|
||||
case AddressDependentNAT:
|
||||
k.dst.IP = dst.IP
|
||||
case AddressAndPortDependentNAT:
|
||||
k.dst = dst
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown NAT type %v", t))
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
// DefaultMappingTimeout is the default timeout for a NAT mapping.
|
||||
const DefaultMappingTimeout = 30 * time.Second
|
||||
|
||||
// SNAT44 implements an IPv4-to-IPv4 source NAT (SNAT) translator, with
|
||||
// optional builtin firewall.
|
||||
type SNAT44 struct {
|
||||
// Machine is the machine to which this NAT is attached. Altered
|
||||
// packets are injected back into this Machine for processing.
|
||||
Machine *Machine
|
||||
// ExternalInterface is the "WAN" interface of Machine. Packets
|
||||
// from other sources get NATed onto this interface.
|
||||
ExternalInterface *Interface
|
||||
// Type specifies the mapping allocation behavior for this NAT.
|
||||
Type NATType
|
||||
// MappingTimeout is the lifetime of individual NAT sessions. Once
|
||||
// a session expires, the mapped port effectively "closes" to new
|
||||
// traffic. If MappingTimeout is 0, DefaultMappingTimeout is used.
|
||||
MappingTimeout time.Duration
|
||||
// Firewall is an optional packet handler that will be invoked as
|
||||
// a firewall during NAT translation. The firewall always sees
|
||||
// packets in their "LAN form", i.e. before translation in the
|
||||
// outbound direction and after translation in the inbound
|
||||
// direction.
|
||||
Firewall PacketHandler
|
||||
// TimeNow is a function that returns the current time. If
|
||||
// nil, time.Now is used.
|
||||
TimeNow func() time.Time
|
||||
|
||||
mu sync.Mutex
|
||||
byLAN map[natKey]*mapping // lookup by outbound packet tuple
|
||||
byWAN map[netaddr.IPPort]*mapping // lookup by wan ip:port only
|
||||
}
|
||||
|
||||
func (n *SNAT44) timeNow() time.Time {
|
||||
if n.TimeNow != nil {
|
||||
return n.TimeNow()
|
||||
}
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func (n *SNAT44) mappingTimeout() time.Duration {
|
||||
if n.MappingTimeout == 0 {
|
||||
return DefaultMappingTimeout
|
||||
}
|
||||
return n.MappingTimeout
|
||||
}
|
||||
|
||||
func (n *SNAT44) initLocked() {
|
||||
if n.byLAN == nil {
|
||||
n.byLAN = map[natKey]*mapping{}
|
||||
n.byWAN = map[netaddr.IPPort]*mapping{}
|
||||
}
|
||||
if n.ExternalInterface.Machine() != n.Machine {
|
||||
panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet {
|
||||
// NATs don't affect locally originated packets.
|
||||
if n.Firewall != nil {
|
||||
return n.Firewall.HandleOut(p, oif)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet {
|
||||
if iif != n.ExternalInterface {
|
||||
// NAT can't apply, defer to firewall.
|
||||
if n.Firewall != nil {
|
||||
return n.Firewall.HandleIn(p, iif)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
n.initLocked()
|
||||
|
||||
now := n.timeNow()
|
||||
mapping := n.byWAN[p.Dst]
|
||||
if mapping == nil || now.After(mapping.deadline) {
|
||||
// NAT didn't hit, defer to firewall or allow in for local
|
||||
// socket handling.
|
||||
if n.Firewall != nil {
|
||||
return n.Firewall.HandleIn(p, iif)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
p.Dst = mapping.lanSrc
|
||||
p.Trace("dnat to %v", p.Dst)
|
||||
// Don't process firewall here. We mutated the packet such that
|
||||
// it's no longer destined locally, so we'll get reinvoked as
|
||||
// HandleForward and need to process the altered packet there.
|
||||
return p
|
||||
}
|
||||
|
||||
func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet {
|
||||
switch {
|
||||
case oif == n.ExternalInterface:
|
||||
if p.Src.IP == oif.V4() {
|
||||
// Packet already NATed and is just retraversing Forward,
|
||||
// don't touch it again.
|
||||
return p
|
||||
}
|
||||
|
||||
if n.Firewall != nil {
|
||||
p2 := n.Firewall.HandleForward(p, iif, oif)
|
||||
if p2 == nil {
|
||||
// firewall dropped, done
|
||||
return nil
|
||||
}
|
||||
if !p.Equivalent(p2) {
|
||||
// firewall mutated packet? Weird, but okay.
|
||||
return p2
|
||||
}
|
||||
}
|
||||
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
n.initLocked()
|
||||
|
||||
k := n.Type.key(p.Src, p.Dst)
|
||||
now := n.timeNow()
|
||||
m := n.byLAN[k]
|
||||
if m == nil || now.After(m.deadline) {
|
||||
pc, wanAddr := n.allocateMappedPort()
|
||||
m = &mapping{
|
||||
lanSrc: p.Src,
|
||||
lanDst: p.Dst,
|
||||
wanSrc: wanAddr,
|
||||
pc: pc,
|
||||
}
|
||||
n.byLAN[k] = m
|
||||
n.byWAN[wanAddr] = m
|
||||
}
|
||||
m.deadline = now.Add(n.mappingTimeout())
|
||||
p.Src = m.wanSrc
|
||||
p.Trace("snat from %v", p.Src)
|
||||
return p
|
||||
case iif == n.ExternalInterface:
|
||||
// Packet was already un-NAT-ed, we just need to either
|
||||
// firewall it or let it through.
|
||||
if n.Firewall != nil {
|
||||
return n.Firewall.HandleForward(p, iif, oif)
|
||||
}
|
||||
return p
|
||||
default:
|
||||
// No NAT applies, invoke firewall or drop.
|
||||
if n.Firewall != nil {
|
||||
return n.Firewall.HandleForward(p, iif, oif)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (n *SNAT44) allocateMappedPort() (net.PacketConn, netaddr.IPPort) {
|
||||
// Clean up old entries before trying to allocate, to free up any
|
||||
// expired ports.
|
||||
n.gc()
|
||||
|
||||
ip := n.ExternalInterface.V4()
|
||||
pc, err := n.Machine.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), "0"))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ran out of NAT ports: %v", err))
|
||||
}
|
||||
addr := netaddr.IPPort{
|
||||
IP: ip,
|
||||
Port: uint16(pc.LocalAddr().(*net.UDPAddr).Port),
|
||||
}
|
||||
return pc, addr
|
||||
}
|
||||
|
||||
func (n *SNAT44) gc() {
|
||||
now := n.timeNow()
|
||||
for _, m := range n.byLAN {
|
||||
if !now.After(m.deadline) {
|
||||
continue
|
||||
}
|
||||
m.pc.Close()
|
||||
delete(n.byLAN, n.Type.key(m.lanSrc, m.lanDst))
|
||||
delete(n.byWAN, m.wanSrc)
|
||||
}
|
||||
}
|
||||
860
tstest/natlab/natlab.go
Normal file
860
tstest/natlab/natlab.go
Normal file
@@ -0,0 +1,860 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//lint:file-ignore U1000 in development
|
||||
//lint:file-ignore S1000 in development
|
||||
|
||||
// Package natlab lets us simulate different types of networks all
|
||||
// in-memory without running VMs or requiring root, etc. Despite the
|
||||
// name, it does more than just NATs. But NATs are the most
|
||||
// interesting.
|
||||
package natlab
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE"))
|
||||
|
||||
// Packet represents a UDP packet flowing through the virtual network.
|
||||
type Packet struct {
|
||||
Src, Dst netaddr.IPPort
|
||||
Payload []byte
|
||||
|
||||
// Prefix set by various internal methods of natlab, to locate
|
||||
// where in the network a trace occured.
|
||||
locator string
|
||||
}
|
||||
|
||||
// Equivalent returns true if Src, Dst and Payload are the same in p
|
||||
// and p2.
|
||||
func (p *Packet) Equivalent(p2 *Packet) bool {
|
||||
return p.Src == p2.Src && p.Dst == p2.Dst && bytes.Equal(p.Payload, p2.Payload)
|
||||
}
|
||||
|
||||
// Clone returns a copy of p that shares nothing with p.
|
||||
func (p *Packet) Clone() *Packet {
|
||||
return &Packet{
|
||||
Src: p.Src,
|
||||
Dst: p.Dst,
|
||||
Payload: append([]byte(nil), p.Payload...),
|
||||
locator: p.locator,
|
||||
}
|
||||
}
|
||||
|
||||
// short returns a short identifier for a packet payload,
|
||||
// suitable for printing trace information.
|
||||
func (p *Packet) short() string {
|
||||
s := sha256.Sum256(p.Payload)
|
||||
payload := base64.RawStdEncoding.EncodeToString(s[:])[:2]
|
||||
|
||||
s = sha256.Sum256([]byte(p.Src.String() + "_" + p.Dst.String()))
|
||||
tuple := base64.RawStdEncoding.EncodeToString(s[:])[:2]
|
||||
|
||||
return fmt.Sprintf("%s/%s", payload, tuple)
|
||||
}
|
||||
|
||||
func (p *Packet) Trace(msg string, args ...interface{}) {
|
||||
if !traceOn {
|
||||
return
|
||||
}
|
||||
allArgs := []interface{}{p.short(), p.locator, p.Src, p.Dst}
|
||||
allArgs = append(allArgs, args...)
|
||||
fmt.Fprintf(os.Stderr, "[%s]%s src=%s dst=%s "+msg+"\n", allArgs...)
|
||||
}
|
||||
|
||||
func (p *Packet) setLocator(msg string, args ...interface{}) {
|
||||
p.locator = fmt.Sprintf(" "+msg, args...)
|
||||
}
|
||||
|
||||
func mustPrefix(s string) netaddr.IPPrefix {
|
||||
ipp, err := netaddr.ParseIPPrefix(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ipp
|
||||
}
|
||||
|
||||
// NewInternet returns a network that simulates the internet.
|
||||
func NewInternet() *Network {
|
||||
return &Network{
|
||||
Name: "internet",
|
||||
Prefix4: mustPrefix("203.0.113.0/24"), // documentation netblock that looks Internet-y
|
||||
Prefix6: mustPrefix("fc00:52::/64"),
|
||||
}
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
Name string
|
||||
Prefix4 netaddr.IPPrefix
|
||||
Prefix6 netaddr.IPPrefix
|
||||
|
||||
mu sync.Mutex
|
||||
machine map[netaddr.IP]*Interface
|
||||
defaultGW *Interface // optional
|
||||
lastV4 netaddr.IP
|
||||
lastV6 netaddr.IP
|
||||
}
|
||||
|
||||
func (n *Network) SetDefaultGateway(gwIf *Interface) {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
if gwIf.net != n {
|
||||
panic(fmt.Sprintf("can't set if=%s as net=%s's default gw, if not connected to net", gwIf.name, gwIf.net.Name))
|
||||
}
|
||||
n.defaultGW = gwIf
|
||||
}
|
||||
|
||||
func (n *Network) addMachineLocked(ip netaddr.IP, iface *Interface) {
|
||||
if iface == nil {
|
||||
return // for tests
|
||||
}
|
||||
if n.machine == nil {
|
||||
n.machine = map[netaddr.IP]*Interface{}
|
||||
}
|
||||
n.machine[ip] = iface
|
||||
}
|
||||
|
||||
func (n *Network) allocIPv4(iface *Interface) netaddr.IP {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
if n.Prefix4.IsZero() {
|
||||
return netaddr.IP{}
|
||||
}
|
||||
if n.lastV4.IsZero() {
|
||||
n.lastV4 = n.Prefix4.IP
|
||||
}
|
||||
a := n.lastV4.As16()
|
||||
addOne(&a, 15)
|
||||
n.lastV4 = netaddr.IPFrom16(a)
|
||||
if !n.Prefix4.Contains(n.lastV4) {
|
||||
panic("pool exhausted")
|
||||
}
|
||||
n.addMachineLocked(n.lastV4, iface)
|
||||
return n.lastV4
|
||||
}
|
||||
|
||||
func (n *Network) allocIPv6(iface *Interface) netaddr.IP {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
if n.Prefix6.IsZero() {
|
||||
return netaddr.IP{}
|
||||
}
|
||||
if n.lastV6.IsZero() {
|
||||
n.lastV6 = n.Prefix6.IP
|
||||
}
|
||||
a := n.lastV6.As16()
|
||||
addOne(&a, 15)
|
||||
n.lastV6 = netaddr.IPFrom16(a)
|
||||
if !n.Prefix6.Contains(n.lastV6) {
|
||||
panic("pool exhausted")
|
||||
}
|
||||
n.addMachineLocked(n.lastV6, iface)
|
||||
return n.lastV6
|
||||
}
|
||||
|
||||
func addOne(a *[16]byte, index int) {
|
||||
if v := a[index]; v < 255 {
|
||||
a[index]++
|
||||
} else {
|
||||
a[index] = 0
|
||||
addOne(a, index-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Network) write(p *Packet) (num int, err error) {
|
||||
p.setLocator("net=%s", n.Name)
|
||||
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
iface, ok := n.machine[p.Dst.IP]
|
||||
if !ok {
|
||||
if n.defaultGW == nil {
|
||||
p.Trace("no route to %v", p.Dst.IP)
|
||||
return len(p.Payload), nil
|
||||
}
|
||||
iface = n.defaultGW
|
||||
}
|
||||
|
||||
// Pretend it went across the network. Make a copy so nobody
|
||||
// can later mess with caller's memory.
|
||||
p.Trace("-> mach=%s if=%s", iface.machine.Name, iface.name)
|
||||
go iface.machine.deliverIncomingPacket(p, iface)
|
||||
return len(p.Payload), nil
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
machine *Machine
|
||||
net *Network
|
||||
name string // optional
|
||||
ips []netaddr.IP // static; not mutated once created
|
||||
}
|
||||
|
||||
func (f *Interface) Machine() *Machine {
|
||||
return f.machine
|
||||
}
|
||||
|
||||
func (f *Interface) Network() *Network {
|
||||
return f.net
|
||||
}
|
||||
|
||||
// V4 returns the machine's first IPv4 address, or the zero value if none.
|
||||
func (f *Interface) V4() netaddr.IP { return f.pickIP(netaddr.IP.Is4) }
|
||||
|
||||
// V6 returns the machine's first IPv6 address, or the zero value if none.
|
||||
func (f *Interface) V6() netaddr.IP { return f.pickIP(netaddr.IP.Is6) }
|
||||
|
||||
func (f *Interface) pickIP(pred func(netaddr.IP) bool) netaddr.IP {
|
||||
for _, ip := range f.ips {
|
||||
if pred(ip) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
return netaddr.IP{}
|
||||
}
|
||||
|
||||
func (f *Interface) String() string {
|
||||
// TODO: make this all better
|
||||
if f.name != "" {
|
||||
return f.name
|
||||
}
|
||||
return fmt.Sprintf("unamed-interface-on-network-%p", f.net)
|
||||
}
|
||||
|
||||
// Contains reports whether f contains ip as an IP.
|
||||
func (f *Interface) Contains(ip netaddr.IP) bool {
|
||||
for _, v := range f.ips {
|
||||
if ip == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type routeEntry struct {
|
||||
prefix netaddr.IPPrefix
|
||||
iface *Interface
|
||||
}
|
||||
|
||||
// A PacketVerdict is a decision of what to do with a packet.
|
||||
type PacketVerdict int
|
||||
|
||||
const (
|
||||
// Continue means the packet should be processed by the "local
|
||||
// sockets" logic of the Machine.
|
||||
Continue PacketVerdict = iota
|
||||
// Drop means the packet should not be handled further.
|
||||
Drop
|
||||
)
|
||||
|
||||
func (v PacketVerdict) String() string {
|
||||
switch v {
|
||||
case Continue:
|
||||
return "Continue"
|
||||
case Drop:
|
||||
return "Drop"
|
||||
default:
|
||||
return fmt.Sprintf("<unknown verdict %d>", v)
|
||||
}
|
||||
}
|
||||
|
||||
// A PacketHandler can look at packets arriving at, departing, and
|
||||
// transiting a Machine, and filter or mutate them.
|
||||
//
|
||||
// Each method is invoked with a Packet that natlab would like to keep
|
||||
// processing. Handlers can return that same Packet to allow
|
||||
// processing to continue; nil to drop the Packet; or a different
|
||||
// Packet that should be processed instead of the original.
|
||||
//
|
||||
// Packets passed to handlers share no state with anything else, and
|
||||
// are therefore safe to mutate. It's safe to return the original
|
||||
// packet mutated in-place, or a brand new packet initialized from
|
||||
// scratch.
|
||||
//
|
||||
// Packets mutated by a PacketHandler are processed anew by the
|
||||
// associated Machine, as if the packet had always been the mutated
|
||||
// one. For example, if HandleForward is invoked with a Packet, and
|
||||
// the handler changes the destination IP address to one of the
|
||||
// Machine's own IPs, the Machine restarts delivery, but this time
|
||||
// going to a local PacketConn (which in turn will invoke HandleIn,
|
||||
// since the packet is now destined for local delivery).
|
||||
type PacketHandler interface {
|
||||
// HandleIn processes a packet arriving on iif, whose destination
|
||||
// is an IP address owned by the attached Machine. If p is
|
||||
// returned unmodified, the Machine will go on to deliver the
|
||||
// Packet to the appropriate listening PacketConn, if one exists.
|
||||
HandleIn(p *Packet, iif *Interface) *Packet
|
||||
// HandleOut processes a packet about to depart on oif from a
|
||||
// local PacketConn. If p is returned unmodified, the Machine will
|
||||
// transmit the Packet on oif.
|
||||
HandleOut(p *Packet, oif *Interface) *Packet
|
||||
// HandleForward is called when the Machine wants to forward a
|
||||
// packet from iif to oif. If p is returned unmodified, the
|
||||
// Machine will transmit the packet on oif.
|
||||
HandleForward(p *Packet, iif, oif *Interface) *Packet
|
||||
}
|
||||
|
||||
// A Machine is a representation of an operating system's network
|
||||
// stack. It has a network routing table and can have multiple
|
||||
// attached networks. The zero value is valid, but lacks any
|
||||
// networking capability until Attach is called.
|
||||
type Machine struct {
|
||||
// Name is a pretty name for debugging and packet tracing. It need
|
||||
// not be globally unique.
|
||||
Name string
|
||||
|
||||
// PacketHandler, if not nil, is a PacketHandler implementation
|
||||
// that inspects all packets arriving, departing, or transiting
|
||||
// the Machine. See the definition of the PacketHandler interface
|
||||
// for semantics.
|
||||
//
|
||||
// If PacketHandler is nil, the machine allows all inbound
|
||||
// traffic, all outbound traffic, and drops forwarded packets.
|
||||
PacketHandler PacketHandler
|
||||
|
||||
mu sync.Mutex
|
||||
interfaces []*Interface
|
||||
routes []routeEntry // sorted by longest prefix to shortest
|
||||
|
||||
conns4 map[netaddr.IPPort]*conn // conns that want IPv4 packets
|
||||
conns6 map[netaddr.IPPort]*conn // conns that want IPv6 packets
|
||||
}
|
||||
|
||||
func (m *Machine) isLocalIP(ip netaddr.IP) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, intf := range m.interfaces {
|
||||
for _, iip := range intf.ips {
|
||||
if ip == iip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) {
|
||||
p.setLocator("mach=%s if=%s", m.Name, iface.name)
|
||||
|
||||
if m.isLocalIP(p.Dst.IP) {
|
||||
m.deliverLocalPacket(p, iface)
|
||||
} else {
|
||||
m.forwardPacket(p, iface)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Machine) deliverLocalPacket(p *Packet, iface *Interface) {
|
||||
// TODO: can't hold lock while handling packet. This is safe as
|
||||
// long as you set HandlePacket before traffic starts flowing.
|
||||
if m.PacketHandler != nil {
|
||||
p2 := m.PacketHandler.HandleIn(p.Clone(), iface)
|
||||
if p2 == nil {
|
||||
// Packet dropped, nothing left to do.
|
||||
return
|
||||
}
|
||||
if !p.Equivalent(p2) {
|
||||
// Restart delivery, this packet might be a forward packet
|
||||
// now.
|
||||
m.deliverIncomingPacket(p2, iface)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
conns := m.conns4
|
||||
if p.Dst.IP.Is6() {
|
||||
conns = m.conns6
|
||||
}
|
||||
possibleDsts := []netaddr.IPPort{
|
||||
p.Dst,
|
||||
netaddr.IPPort{IP: v6unspec, Port: p.Dst.Port},
|
||||
netaddr.IPPort{IP: v4unspec, Port: p.Dst.Port},
|
||||
}
|
||||
for _, dest := range possibleDsts {
|
||||
c, ok := conns[dest]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case c.in <- p:
|
||||
p.Trace("queued to conn")
|
||||
default:
|
||||
p.Trace("dropped, queue overflow")
|
||||
// Queue overflow. Just drop it.
|
||||
}
|
||||
return
|
||||
}
|
||||
p.Trace("dropped, no listening conn")
|
||||
}
|
||||
|
||||
func (m *Machine) forwardPacket(p *Packet, iif *Interface) {
|
||||
oif, err := m.interfaceForIP(p.Dst.IP)
|
||||
if err != nil {
|
||||
p.Trace("%v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if m.PacketHandler == nil {
|
||||
// Forwarding not allowed by default
|
||||
p.Trace("drop, forwarding not allowed")
|
||||
return
|
||||
}
|
||||
p2 := m.PacketHandler.HandleForward(p.Clone(), iif, oif)
|
||||
if p2 == nil {
|
||||
p.Trace("drop")
|
||||
// Packet dropped, done.
|
||||
return
|
||||
}
|
||||
if !p.Equivalent(p2) {
|
||||
// Packet changed, restart delivery.
|
||||
p2.Trace("PacketHandler mutated packet")
|
||||
m.deliverIncomingPacket(p2, iif)
|
||||
return
|
||||
}
|
||||
|
||||
p.Trace("-> net=%s oif=%s", oif.net.Name, oif)
|
||||
oif.net.write(p)
|
||||
}
|
||||
|
||||
func unspecOf(ip netaddr.IP) netaddr.IP {
|
||||
if ip.Is4() {
|
||||
return v4unspec
|
||||
}
|
||||
if ip.Is6() {
|
||||
return v6unspec
|
||||
}
|
||||
panic(fmt.Sprintf("bogus IP %#v", ip))
|
||||
}
|
||||
|
||||
// Attach adds an interface to a machine.
|
||||
//
|
||||
// The first interface added to a Machine becomes that machine's
|
||||
// default route.
|
||||
func (m *Machine) Attach(interfaceName string, n *Network) *Interface {
|
||||
f := &Interface{
|
||||
machine: m,
|
||||
net: n,
|
||||
name: interfaceName,
|
||||
}
|
||||
if ip := n.allocIPv4(f); !ip.IsZero() {
|
||||
f.ips = append(f.ips, ip)
|
||||
}
|
||||
if ip := n.allocIPv6(f); !ip.IsZero() {
|
||||
f.ips = append(f.ips, ip)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.interfaces = append(m.interfaces, f)
|
||||
if len(m.interfaces) == 1 {
|
||||
m.routes = append(m.routes,
|
||||
routeEntry{
|
||||
prefix: mustPrefix("0.0.0.0/0"),
|
||||
iface: f,
|
||||
},
|
||||
routeEntry{
|
||||
prefix: mustPrefix("::/0"),
|
||||
iface: f,
|
||||
})
|
||||
} else {
|
||||
if !n.Prefix4.IsZero() {
|
||||
m.routes = append(m.routes, routeEntry{
|
||||
prefix: n.Prefix4,
|
||||
iface: f,
|
||||
})
|
||||
}
|
||||
if !n.Prefix6.IsZero() {
|
||||
m.routes = append(m.routes, routeEntry{
|
||||
prefix: n.Prefix6,
|
||||
iface: f,
|
||||
})
|
||||
}
|
||||
}
|
||||
sort.Slice(m.routes, func(i, j int) bool {
|
||||
return m.routes[i].prefix.Bits > m.routes[j].prefix.Bits
|
||||
})
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
var (
|
||||
v4unspec = netaddr.IPv4(0, 0, 0, 0)
|
||||
v6unspec = netaddr.IPv6Unspecified()
|
||||
)
|
||||
|
||||
func (m *Machine) writePacket(p *Packet) (n int, err error) {
|
||||
p.setLocator("mach=%s", m.Name)
|
||||
|
||||
iface, err := m.interfaceForIP(p.Dst.IP)
|
||||
if err != nil {
|
||||
p.Trace("%v", err)
|
||||
return 0, err
|
||||
}
|
||||
origSrcIP := p.Src.IP
|
||||
switch {
|
||||
case p.Src.IP == v4unspec:
|
||||
p.Trace("assigning srcIP=%s", iface.V4())
|
||||
p.Src.IP = iface.V4()
|
||||
case p.Src.IP == v6unspec:
|
||||
// v6unspec in Go means "any src, but match address families"
|
||||
if p.Dst.IP.Is6() {
|
||||
p.Trace("assigning srcIP=%s", iface.V6())
|
||||
p.Src.IP = iface.V6()
|
||||
} else if p.Dst.IP.Is4() {
|
||||
p.Trace("assigning srcIP=%s", iface.V4())
|
||||
p.Src.IP = iface.V4()
|
||||
}
|
||||
default:
|
||||
if !iface.Contains(p.Src.IP) {
|
||||
err := fmt.Errorf("can't send to %v with src %v on interface %v", p.Dst.IP, p.Src.IP, iface)
|
||||
p.Trace("%v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if p.Src.IP.IsZero() {
|
||||
err := fmt.Errorf("no matching address for address family for %v", origSrcIP)
|
||||
p.Trace("%v", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if m.PacketHandler != nil {
|
||||
p2 := m.PacketHandler.HandleOut(p.Clone(), iface)
|
||||
if p2 == nil {
|
||||
// Packet dropped, done.
|
||||
return len(p.Payload), nil
|
||||
}
|
||||
if !p.Equivalent(p2) {
|
||||
// Restart transmission, src may have changed weirdly
|
||||
m.writePacket(p2)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.Trace("-> net=%s if=%s", iface.net.Name, iface)
|
||||
return iface.net.write(p)
|
||||
}
|
||||
|
||||
func (m *Machine) interfaceForIP(ip netaddr.IP) (*Interface, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, re := range m.routes {
|
||||
if re.prefix.Contains(ip) {
|
||||
return re.iface, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no route found to %v", ip)
|
||||
}
|
||||
|
||||
func (m *Machine) hasv6() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, f := range m.interfaces {
|
||||
for _, ip := range f.ips {
|
||||
if ip.Is6() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Machine) pickEphemPort() (port uint16, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for tries := 0; tries < 500; tries++ {
|
||||
port := uint16(rand.Intn(32<<10) + 32<<10)
|
||||
if !m.portInUseLocked(port) {
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
return 0, errors.New("failed to find an ephemeral port")
|
||||
}
|
||||
|
||||
func (m *Machine) portInUseLocked(port uint16) bool {
|
||||
for ipp := range m.conns4 {
|
||||
if ipp.Port == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for ipp := range m.conns6 {
|
||||
if ipp.Port == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Machine) registerConn4(c *conn) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if c.ipp.IP.Is6() && c.ipp.IP != v6unspec {
|
||||
return fmt.Errorf("registerConn4 got IPv6 %s", c.ipp)
|
||||
}
|
||||
return registerConn(&m.conns4, c)
|
||||
}
|
||||
|
||||
func (m *Machine) unregisterConn4(c *conn) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.conns4, c.ipp)
|
||||
}
|
||||
|
||||
func (m *Machine) registerConn6(c *conn) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if c.ipp.IP.Is4() {
|
||||
return fmt.Errorf("registerConn6 got IPv4 %s", c.ipp)
|
||||
}
|
||||
return registerConn(&m.conns6, c)
|
||||
}
|
||||
|
||||
func (m *Machine) unregisterConn6(c *conn) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.conns6, c.ipp)
|
||||
}
|
||||
|
||||
func registerConn(conns *map[netaddr.IPPort]*conn, c *conn) error {
|
||||
if _, ok := (*conns)[c.ipp]; ok {
|
||||
return fmt.Errorf("duplicate conn listening on %v", c.ipp)
|
||||
}
|
||||
if *conns == nil {
|
||||
*conns = map[netaddr.IPPort]*conn{}
|
||||
}
|
||||
(*conns)[c.ipp] = c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Machine) AddNetwork(n *Network) {}
|
||||
|
||||
func (m *Machine) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||
// if udp4, udp6, etc... look at address IP vs unspec
|
||||
var (
|
||||
fam uint8
|
||||
ip netaddr.IP
|
||||
)
|
||||
switch network {
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported network type %q", network)
|
||||
case "udp":
|
||||
fam = 0
|
||||
ip = v6unspec
|
||||
case "udp4":
|
||||
fam = 4
|
||||
ip = v4unspec
|
||||
case "udp6":
|
||||
fam = 6
|
||||
ip = v6unspec
|
||||
}
|
||||
|
||||
host, portStr, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if host != "" {
|
||||
ip, err = netaddr.ParseIP(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fam == 0 && (ip != v4unspec && ip != v6unspec) {
|
||||
// We got an explicit IP address, need to switch the
|
||||
// family to the right one.
|
||||
if ip.Is4() {
|
||||
fam = 4
|
||||
} else {
|
||||
fam = 6
|
||||
}
|
||||
}
|
||||
}
|
||||
porti, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
port := uint16(porti)
|
||||
if port == 0 {
|
||||
port, err = m.pickEphemPort()
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
ipp := netaddr.IPPort{IP: ip, Port: port}
|
||||
|
||||
c := &conn{
|
||||
m: m,
|
||||
fam: fam,
|
||||
ipp: ipp,
|
||||
in: make(chan *Packet, 100), // arbitrary
|
||||
}
|
||||
switch c.fam {
|
||||
case 0:
|
||||
if err := m.registerConn4(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.registerConn6(c); err != nil {
|
||||
m.unregisterConn4(c)
|
||||
return nil, err
|
||||
}
|
||||
case 4:
|
||||
if err := m.registerConn4(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case 6:
|
||||
if err := m.registerConn6(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// conn is our net.PacketConn implementation
|
||||
type conn struct {
|
||||
m *Machine
|
||||
fam uint8 // 0, 4, or 6
|
||||
ipp netaddr.IPPort
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
readDeadline time.Time
|
||||
activeReads map[*activeRead]bool
|
||||
in chan *Packet
|
||||
}
|
||||
|
||||
type activeRead struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// canRead reports whether we can do a read.
|
||||
func (c *conn) canRead() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return errors.New("closed network connection") // sadface: magic string used by other; don't change
|
||||
}
|
||||
if !c.readDeadline.IsZero() && c.readDeadline.Before(time.Now()) {
|
||||
return errors.New("read deadline exceeded")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) registerActiveRead(ar *activeRead, active bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.activeReads == nil {
|
||||
c.activeReads = make(map[*activeRead]bool)
|
||||
}
|
||||
if active {
|
||||
c.activeReads[ar] = true
|
||||
} else {
|
||||
delete(c.activeReads, ar)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
switch c.fam {
|
||||
case 0:
|
||||
c.m.unregisterConn4(c)
|
||||
c.m.unregisterConn6(c)
|
||||
case 4:
|
||||
c.m.unregisterConn4(c)
|
||||
case 6:
|
||||
c.m.unregisterConn6(c)
|
||||
}
|
||||
c.breakActiveReadsLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) breakActiveReadsLocked() {
|
||||
for ar := range c.activeReads {
|
||||
ar.cancel()
|
||||
}
|
||||
c.activeReads = nil
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.ipp.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ar := &activeRead{cancel: cancel}
|
||||
|
||||
if err := c.canRead(); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
c.registerActiveRead(ar, true)
|
||||
defer c.registerActiveRead(ar, false)
|
||||
|
||||
select {
|
||||
case pkt := <-c.in:
|
||||
n = copy(p, pkt.Payload)
|
||||
pkt.Trace("PacketConn.ReadFrom")
|
||||
return n, pkt.Src.UDPAddr(), nil
|
||||
case <-ctx.Done():
|
||||
return 0, nil, context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
ipp, err := netaddr.ParseIPPort(addr.String())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String())
|
||||
}
|
||||
pkt := &Packet{
|
||||
Src: c.ipp,
|
||||
Dst: ipp,
|
||||
Payload: append([]byte(nil), p...),
|
||||
}
|
||||
pkt.setLocator("mach=%s", c.m.Name)
|
||||
pkt.Trace("PacketConn.WriteTo")
|
||||
return c.m.writePacket(pkt)
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
panic("SetWriteDeadline unsupported; TODO when needed")
|
||||
}
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
panic("SetWriteDeadline unsupported; TODO when needed")
|
||||
}
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if t.After(now) {
|
||||
panic("SetReadDeadline in the future not yet supported; TODO?")
|
||||
}
|
||||
|
||||
if !t.IsZero() && t.Before(now) {
|
||||
c.breakActiveReadsLocked()
|
||||
}
|
||||
c.readDeadline = t
|
||||
|
||||
return nil
|
||||
}
|
||||
509
tstest/natlab/natlab_test.go
Normal file
509
tstest/natlab/natlab_test.go
Normal file
@@ -0,0 +1,509 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package natlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tstest"
|
||||
)
|
||||
|
||||
func TestAllocIPs(t *testing.T) {
|
||||
n := NewInternet()
|
||||
saw := map[netaddr.IP]bool{}
|
||||
for i := 0; i < 255; i++ {
|
||||
for _, f := range []func(*Interface) netaddr.IP{n.allocIPv4, n.allocIPv6} {
|
||||
ip := f(nil)
|
||||
if saw[ip] {
|
||||
t.Fatalf("got duplicate %v", ip)
|
||||
}
|
||||
saw[ip] = true
|
||||
}
|
||||
}
|
||||
|
||||
// This should work:
|
||||
n.allocIPv6(nil)
|
||||
|
||||
// But allocating another IPv4 should panic, exhausting the
|
||||
// limited /24 range:
|
||||
defer func() {
|
||||
if e := recover(); fmt.Sprint(e) != "pool exhausted" {
|
||||
t.Errorf("unexpected panic: %v", e)
|
||||
}
|
||||
}()
|
||||
n.allocIPv4(nil)
|
||||
t.Fatalf("expected panic from IPv4")
|
||||
}
|
||||
|
||||
func TestSendPacket(t *testing.T) {
|
||||
internet := NewInternet()
|
||||
|
||||
foo := &Machine{Name: "foo"}
|
||||
bar := &Machine{Name: "bar"}
|
||||
ifFoo := foo.Attach("eth0", internet)
|
||||
ifBar := bar.Attach("enp0s1", internet)
|
||||
|
||||
fooAddr := netaddr.IPPort{IP: ifFoo.V4(), Port: 123}
|
||||
barAddr := netaddr.IPPort{IP: ifBar.V4(), Port: 456}
|
||||
|
||||
ctx := context.Background()
|
||||
fooPC, err := foo.ListenPacket(ctx, "udp4", fooAddr.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
barPC, err := bar.ListenPacket(ctx, "udp4", barAddr.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const msg = "some message"
|
||||
if _, err := fooPC.WriteTo([]byte(msg), barAddr.UDPAddr()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1500) // TODO: care about MTUs in the natlab package somewhere
|
||||
n, addr, err := barPC.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf = buf[:n]
|
||||
if string(buf) != msg {
|
||||
t.Errorf("read %q; want %q", buf, msg)
|
||||
}
|
||||
if addr.String() != fooAddr.String() {
|
||||
t.Errorf("addr = %q; want %q", addr, fooAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiNetwork(t *testing.T) {
|
||||
lan := &Network{
|
||||
Name: "lan",
|
||||
Prefix4: mustPrefix("192.168.0.0/24"),
|
||||
}
|
||||
internet := NewInternet()
|
||||
|
||||
client := &Machine{Name: "client"}
|
||||
nat := &Machine{Name: "nat"}
|
||||
server := &Machine{Name: "server"}
|
||||
|
||||
ifClient := client.Attach("eth0", lan)
|
||||
ifNATWAN := nat.Attach("ethwan", internet)
|
||||
ifNATLAN := nat.Attach("ethlan", lan)
|
||||
ifServer := server.Attach("eth0", internet)
|
||||
|
||||
ctx := context.Background()
|
||||
clientPC, err := client.ListenPacket(ctx, "udp", ":123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
natPC, err := nat.ListenPacket(ctx, "udp", ":456")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverPC, err := server.ListenPacket(ctx, "udp", ":789")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
clientAddr := netaddr.IPPort{IP: ifClient.V4(), Port: 123}
|
||||
natLANAddr := netaddr.IPPort{IP: ifNATLAN.V4(), Port: 456}
|
||||
natWANAddr := netaddr.IPPort{IP: ifNATWAN.V4(), Port: 456}
|
||||
serverAddr := netaddr.IPPort{IP: ifServer.V4(), Port: 789}
|
||||
|
||||
const msg1, msg2 = "hello", "world"
|
||||
if _, err := natPC.WriteTo([]byte(msg1), clientAddr.UDPAddr()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := natPC.WriteTo([]byte(msg2), serverAddr.UDPAddr()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
n, addr, err := clientPC.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(buf[:n]) != msg1 {
|
||||
t.Errorf("read %q; want %q", buf[:n], msg1)
|
||||
}
|
||||
if addr.String() != natLANAddr.String() {
|
||||
t.Errorf("addr = %q; want %q", addr, natLANAddr)
|
||||
}
|
||||
|
||||
n, addr, err = serverPC.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(buf[:n]) != msg2 {
|
||||
t.Errorf("read %q; want %q", buf[:n], msg2)
|
||||
}
|
||||
if addr.String() != natWANAddr.String() {
|
||||
t.Errorf("addr = %q; want %q", addr, natLANAddr)
|
||||
}
|
||||
}
|
||||
|
||||
type trivialNAT struct {
|
||||
clientIP netaddr.IP
|
||||
lanIf, wanIf *Interface
|
||||
}
|
||||
|
||||
func (n *trivialNAT) HandleIn(p *Packet, iface *Interface) *Packet {
|
||||
if iface == n.wanIf && p.Dst.IP == n.wanIf.V4() {
|
||||
p.Dst.IP = n.clientIP
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (n trivialNAT) HandleOut(p *Packet, iface *Interface) *Packet {
|
||||
return p
|
||||
}
|
||||
|
||||
func (n *trivialNAT) HandleForward(p *Packet, iif, oif *Interface) *Packet {
|
||||
// Outbound from LAN -> apply NAT, continue
|
||||
if iif == n.lanIf && oif == n.wanIf {
|
||||
if p.Src.IP == n.clientIP {
|
||||
p.Src.IP = n.wanIf.V4()
|
||||
}
|
||||
return p
|
||||
}
|
||||
// Return traffic to LAN, allow if right dst.
|
||||
if iif == n.wanIf && oif == n.lanIf && p.Dst.IP == n.clientIP {
|
||||
return p
|
||||
}
|
||||
// Else drop.
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPacketHandler(t *testing.T) {
|
||||
lan := &Network{
|
||||
Name: "lan",
|
||||
Prefix4: mustPrefix("192.168.0.0/24"),
|
||||
Prefix6: mustPrefix("fd00:916::/64"),
|
||||
}
|
||||
internet := NewInternet()
|
||||
|
||||
client := &Machine{Name: "client"}
|
||||
nat := &Machine{Name: "nat"}
|
||||
server := &Machine{Name: "server"}
|
||||
|
||||
ifClient := client.Attach("eth0", lan)
|
||||
ifNATWAN := nat.Attach("wan", internet)
|
||||
ifNATLAN := nat.Attach("lan", lan)
|
||||
ifServer := server.Attach("server", internet)
|
||||
|
||||
lan.SetDefaultGateway(ifNATLAN)
|
||||
|
||||
nat.PacketHandler = &trivialNAT{
|
||||
clientIP: ifClient.V4(),
|
||||
lanIf: ifNATLAN,
|
||||
wanIf: ifNATWAN,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
clientPC, err := client.ListenPacket(ctx, "udp4", ":123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serverPC, err := server.ListenPacket(ctx, "udp4", ":456")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const msg = "some message"
|
||||
serverAddr := netaddr.IPPort{IP: ifServer.V4(), Port: 456}
|
||||
if _, err := clientPC.WriteTo([]byte(msg), serverAddr.UDPAddr()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1500) // TODO: care about MTUs in the natlab package somewhere
|
||||
n, addr, err := serverPC.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf = buf[:n]
|
||||
if string(buf) != msg {
|
||||
t.Errorf("read %q; want %q", buf, msg)
|
||||
}
|
||||
mappedAddr := netaddr.IPPort{IP: ifNATWAN.V4(), Port: 123}
|
||||
if addr.String() != mappedAddr.String() {
|
||||
t.Errorf("addr = %q; want %q", addr, mappedAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewall(t *testing.T) {
|
||||
wan := NewInternet()
|
||||
lan := &Network{
|
||||
Name: "lan",
|
||||
Prefix4: mustPrefix("10.0.0.0/8"),
|
||||
}
|
||||
m := &Machine{Name: "test"}
|
||||
trust := m.Attach("trust", lan)
|
||||
untrust := m.Attach("untrust", wan)
|
||||
|
||||
client := ipp("192.168.0.2:1234")
|
||||
serverA := ipp("2.2.2.2:5678")
|
||||
serverB1 := ipp("7.7.7.7:9012")
|
||||
serverB2 := ipp("7.7.7.7:3456")
|
||||
|
||||
t.Run("ip_port_dependent", func(t *testing.T) {
|
||||
f := &Firewall{
|
||||
TrustedInterface: trust,
|
||||
SessionTimeout: 30 * time.Second,
|
||||
Type: AddressAndPortDependentFirewall,
|
||||
}
|
||||
testFirewall(t, f, []fwTest{
|
||||
// client -> A authorizes A -> client
|
||||
{trust, untrust, client, serverA, true},
|
||||
{untrust, trust, serverA, client, true},
|
||||
{untrust, trust, serverA, client, true},
|
||||
|
||||
// B1 -> client fails until client -> B1
|
||||
{untrust, trust, serverB1, client, false},
|
||||
{trust, untrust, client, serverB1, true},
|
||||
{untrust, trust, serverB1, client, true},
|
||||
|
||||
// B2 -> client still fails
|
||||
{untrust, trust, serverB2, client, false},
|
||||
})
|
||||
})
|
||||
t.Run("ip_dependent", func(t *testing.T) {
|
||||
f := &Firewall{
|
||||
TrustedInterface: trust,
|
||||
SessionTimeout: 30 * time.Second,
|
||||
Type: AddressDependentFirewall,
|
||||
}
|
||||
testFirewall(t, f, []fwTest{
|
||||
// client -> A authorizes A -> client
|
||||
{trust, untrust, client, serverA, true},
|
||||
{untrust, trust, serverA, client, true},
|
||||
{untrust, trust, serverA, client, true},
|
||||
|
||||
// B1 -> client fails until client -> B1
|
||||
{untrust, trust, serverB1, client, false},
|
||||
{trust, untrust, client, serverB1, true},
|
||||
{untrust, trust, serverB1, client, true},
|
||||
|
||||
// B2 -> client also works now
|
||||
{untrust, trust, serverB2, client, true},
|
||||
})
|
||||
})
|
||||
t.Run("endpoint_independent", func(t *testing.T) {
|
||||
f := &Firewall{
|
||||
TrustedInterface: trust,
|
||||
SessionTimeout: 30 * time.Second,
|
||||
Type: EndpointIndependentFirewall,
|
||||
}
|
||||
testFirewall(t, f, []fwTest{
|
||||
// client -> A authorizes A -> client
|
||||
{trust, untrust, client, serverA, true},
|
||||
{untrust, trust, serverA, client, true},
|
||||
{untrust, trust, serverA, client, true},
|
||||
|
||||
// B1 -> client also works
|
||||
{untrust, trust, serverB1, client, true},
|
||||
|
||||
// B2 -> client also works
|
||||
{untrust, trust, serverB2, client, true},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type fwTest struct {
|
||||
iif, oif *Interface
|
||||
src, dst netaddr.IPPort
|
||||
ok bool
|
||||
}
|
||||
|
||||
func testFirewall(t *testing.T, f *Firewall, tests []fwTest) {
|
||||
t.Helper()
|
||||
clock := &tstest.Clock{}
|
||||
f.TimeNow = clock.Now
|
||||
for _, test := range tests {
|
||||
clock.Advance(time.Second)
|
||||
p := &Packet{
|
||||
Src: test.src,
|
||||
Dst: test.dst,
|
||||
Payload: []byte{},
|
||||
}
|
||||
got := f.HandleForward(p, test.iif, test.oif)
|
||||
gotOK := got != nil
|
||||
if gotOK != test.ok {
|
||||
t.Errorf("iif=%s oif=%s src=%s dst=%s got ok=%v, want ok=%v", test.iif, test.oif, test.src, test.dst, gotOK, test.ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ipp(str string) netaddr.IPPort {
|
||||
ipp, err := netaddr.ParseIPPort(str)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ipp
|
||||
}
|
||||
|
||||
func TestNAT(t *testing.T) {
|
||||
internet := NewInternet()
|
||||
lan := &Network{
|
||||
Name: "LAN",
|
||||
Prefix4: mustPrefix("192.168.0.0/24"),
|
||||
}
|
||||
m := &Machine{Name: "NAT"}
|
||||
wanIf := m.Attach("wan", internet)
|
||||
lanIf := m.Attach("lan", lan)
|
||||
|
||||
t.Run("endpoint_independent_mapping", func(t *testing.T) {
|
||||
n := &SNAT44{
|
||||
Machine: m,
|
||||
ExternalInterface: wanIf,
|
||||
Type: EndpointIndependentNAT,
|
||||
Firewall: &Firewall{
|
||||
TrustedInterface: lanIf,
|
||||
},
|
||||
}
|
||||
testNAT(t, n, lanIf, wanIf, []natTest{
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("2.2.2.2:5678"),
|
||||
wantNewMapping: true,
|
||||
},
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("7.7.7.7:9012"),
|
||||
wantNewMapping: false,
|
||||
},
|
||||
{
|
||||
src: ipp("192.168.0.20:2345"),
|
||||
dst: ipp("7.7.7.7:9012"),
|
||||
wantNewMapping: true,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("address_dependent_mapping", func(t *testing.T) {
|
||||
n := &SNAT44{
|
||||
Machine: m,
|
||||
ExternalInterface: wanIf,
|
||||
Type: AddressDependentNAT,
|
||||
Firewall: &Firewall{
|
||||
TrustedInterface: lanIf,
|
||||
},
|
||||
}
|
||||
testNAT(t, n, lanIf, wanIf, []natTest{
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("2.2.2.2:5678"),
|
||||
wantNewMapping: true,
|
||||
},
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("2.2.2.2:9012"),
|
||||
wantNewMapping: false,
|
||||
},
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("7.7.7.7:9012"),
|
||||
wantNewMapping: true,
|
||||
},
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("7.7.7.7:1234"),
|
||||
wantNewMapping: false,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("address_and_port_dependent_mapping", func(t *testing.T) {
|
||||
n := &SNAT44{
|
||||
Machine: m,
|
||||
ExternalInterface: wanIf,
|
||||
Type: AddressAndPortDependentNAT,
|
||||
Firewall: &Firewall{
|
||||
TrustedInterface: lanIf,
|
||||
},
|
||||
}
|
||||
testNAT(t, n, lanIf, wanIf, []natTest{
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("2.2.2.2:5678"),
|
||||
wantNewMapping: true,
|
||||
},
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("2.2.2.2:9012"),
|
||||
wantNewMapping: true,
|
||||
},
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("7.7.7.7:9012"),
|
||||
wantNewMapping: true,
|
||||
},
|
||||
{
|
||||
src: ipp("192.168.0.20:1234"),
|
||||
dst: ipp("7.7.7.7:1234"),
|
||||
wantNewMapping: true,
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type natTest struct {
|
||||
src, dst netaddr.IPPort
|
||||
wantNewMapping bool
|
||||
}
|
||||
|
||||
func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) {
|
||||
clock := &tstest.Clock{}
|
||||
n.TimeNow = clock.Now
|
||||
|
||||
mappings := map[netaddr.IPPort]bool{}
|
||||
for _, test := range tests {
|
||||
clock.Advance(time.Second)
|
||||
p := &Packet{
|
||||
Src: test.src,
|
||||
Dst: test.dst,
|
||||
Payload: []byte("foo"),
|
||||
}
|
||||
gotPacket := n.HandleForward(p.Clone(), lanIf, wanIf)
|
||||
if gotPacket == nil {
|
||||
t.Errorf("n.HandleForward(%v) dropped packet", p)
|
||||
continue
|
||||
}
|
||||
|
||||
if gotPacket.Dst != p.Dst {
|
||||
t.Errorf("n.HandleForward(%v) mutated dest ip:port, got %v", p, gotPacket.Dst)
|
||||
}
|
||||
gotNewMapping := !mappings[gotPacket.Src]
|
||||
if gotNewMapping != test.wantNewMapping {
|
||||
t.Errorf("n.HandleForward(%v) mapping was new=%v, want %v", p, gotNewMapping, test.wantNewMapping)
|
||||
}
|
||||
mappings[gotPacket.Src] = true
|
||||
|
||||
// Check that the return path works and translates back
|
||||
// correctly.
|
||||
clock.Advance(time.Second)
|
||||
p2 := &Packet{
|
||||
Src: test.dst,
|
||||
Dst: gotPacket.Src,
|
||||
Payload: []byte("bar"),
|
||||
}
|
||||
gotPacket2 := n.HandleIn(p2.Clone(), wanIf)
|
||||
|
||||
if gotPacket2 == nil {
|
||||
t.Errorf("return packet was dropped")
|
||||
continue
|
||||
}
|
||||
|
||||
if gotPacket2.Src != test.dst {
|
||||
t.Errorf("return packet has src=%v, want %v", gotPacket2.Src, test.dst)
|
||||
}
|
||||
if gotPacket2.Dst != test.src {
|
||||
t.Errorf("return packet has dst=%v, want %v", gotPacket2.Dst, test.src)
|
||||
}
|
||||
}
|
||||
}
|
||||
12
tstest/staticcheck/staticcheck.go
Normal file
12
tstest/staticcheck/staticcheck.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This file exists just so go mod tidy won't remove
|
||||
// staticcheck's module from our go.mod.
|
||||
|
||||
package tstest
|
||||
|
||||
import (
|
||||
_ "honnef.co/go/tools/staticcheck"
|
||||
)
|
||||
134
tsweb/jsonhandler.go
Normal file
134
tsweb/jsonhandler.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tsweb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type response struct {
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func responseSuccess(data interface{}) *response {
|
||||
return &response{
|
||||
Status: "success",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func responseError(e string) *response {
|
||||
return &response{
|
||||
Status: "error",
|
||||
Error: e,
|
||||
}
|
||||
}
|
||||
|
||||
func writeResponse(w http.ResponseWriter, s int, resp *response) {
|
||||
b, _ := json.Marshal(resp)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(s)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func checkFn(t reflect.Type) {
|
||||
h := reflect.TypeOf(http.HandlerFunc(nil))
|
||||
switch t.NumIn() {
|
||||
case 2, 3:
|
||||
if !t.In(0).AssignableTo(h.In(0)) {
|
||||
panic("first argument must be http.ResponseWriter")
|
||||
}
|
||||
if !t.In(1).AssignableTo(h.In(1)) {
|
||||
panic("second argument must be *http.Request")
|
||||
}
|
||||
default:
|
||||
panic("JSONHandler: number of input parameter should be 2 or 3")
|
||||
}
|
||||
|
||||
switch t.NumOut() {
|
||||
case 1:
|
||||
if !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
panic("return value must be error")
|
||||
}
|
||||
case 2:
|
||||
if !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
panic("second return value must be error")
|
||||
}
|
||||
default:
|
||||
panic("JSONHandler: number of return values should be 1 or 2")
|
||||
}
|
||||
}
|
||||
|
||||
// JSONHandler wraps an HTTP handler function with a version that automatically
|
||||
// unmarshals and marshals requests and responses respectively into fn's arguments
|
||||
// and results.
|
||||
//
|
||||
// The fn parameter is a function. It must take two or three input arguments.
|
||||
// The first two arguments must be http.ResponseWriter and *http.Request.
|
||||
// The optional third argument can be of any type representing the JSON input.
|
||||
// The function's results can be either (error) or (T, error), where T is the
|
||||
// JSON-marshalled result type.
|
||||
//
|
||||
// For example:
|
||||
// fn := func(w http.ResponseWriter, r *http.Request, in *Req) (*Res, error) { ... }
|
||||
func JSONHandler(fn interface{}) http.Handler {
|
||||
v := reflect.ValueOf(fn)
|
||||
t := v.Type()
|
||||
checkFn(t)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
wv := reflect.ValueOf(w)
|
||||
rv := reflect.ValueOf(r)
|
||||
var vs []reflect.Value
|
||||
|
||||
switch t.NumIn() {
|
||||
case 2:
|
||||
vs = v.Call([]reflect.Value{wv, rv})
|
||||
case 3:
|
||||
dv := reflect.New(t.In(2))
|
||||
err := json.NewDecoder(r.Body).Decode(dv.Interface())
|
||||
if err != nil {
|
||||
writeResponse(w, http.StatusBadRequest, responseError("bad json"))
|
||||
return
|
||||
}
|
||||
vs = v.Call([]reflect.Value{wv, rv, dv.Elem()})
|
||||
default:
|
||||
panic("JSONHandler: number of input parameter should be 2 or 3")
|
||||
}
|
||||
|
||||
var e reflect.Value
|
||||
switch len(vs) {
|
||||
case 1:
|
||||
// todo support other error types
|
||||
if vs[0].IsZero() {
|
||||
writeResponse(w, http.StatusOK, responseSuccess(nil))
|
||||
return
|
||||
}
|
||||
e = vs[0]
|
||||
case 2:
|
||||
if vs[1].IsZero() {
|
||||
if !vs[0].IsZero() {
|
||||
writeResponse(w, http.StatusOK, responseSuccess(vs[0].Interface()))
|
||||
}
|
||||
return
|
||||
}
|
||||
e = vs[1]
|
||||
default:
|
||||
panic("JSONHandler: number of return values should be 1 or 2")
|
||||
}
|
||||
|
||||
if e.Type().AssignableTo(reflect.TypeOf(HTTPError{})) {
|
||||
err := e.Interface().(HTTPError)
|
||||
writeResponse(w, err.Code, responseError(err.Error()))
|
||||
} else {
|
||||
err := e.Interface().(error)
|
||||
writeResponse(w, http.StatusBadRequest, responseError(err.Error()))
|
||||
}
|
||||
})
|
||||
}
|
||||
213
tsweb/jsonhandler_test.go
Normal file
213
tsweb/jsonhandler_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tsweb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type Data struct {
|
||||
Name string
|
||||
Price int
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Status string
|
||||
Error string
|
||||
Data *Data
|
||||
}
|
||||
|
||||
func TestNewJSONHandler(t *testing.T) {
|
||||
checkStatus := func(w *httptest.ResponseRecorder, status string) *Response {
|
||||
d := &Response{
|
||||
Data: &Data{},
|
||||
}
|
||||
|
||||
t.Logf("%s", w.Body.Bytes())
|
||||
err := json.Unmarshal(w.Body.Bytes(), d)
|
||||
if err != nil {
|
||||
t.Logf(err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.Status == status {
|
||||
t.Logf("ok: %s", d.Status)
|
||||
} else {
|
||||
t.Fatalf("wrong status: %s %s", d.Status, status)
|
||||
}
|
||||
|
||||
if w.Header().Get("Content-Type") != "application/json" {
|
||||
t.Fatalf("wrong content type: %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
// 2 1
|
||||
h21 := JSONHandler(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
t.Run("2 1 simple", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
h21.ServeHTTP(w, r)
|
||||
checkStatus(w, "success")
|
||||
})
|
||||
|
||||
t.Run("2 1 HTTPError", func(t *testing.T) {
|
||||
h := JSONHandler(func(w http.ResponseWriter, r *http.Request) HTTPError {
|
||||
return Error(http.StatusForbidden, "forbidden", nil)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("wrong code: %d %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
|
||||
// 2 2
|
||||
h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) {
|
||||
return &Data{Name: "tailscale"}, nil
|
||||
})
|
||||
t.Run("2 2 get data", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
h22.ServeHTTP(w, r)
|
||||
checkStatus(w, "success")
|
||||
})
|
||||
|
||||
// 3 1
|
||||
h31 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) error {
|
||||
if d.Name == "" {
|
||||
return errors.New("name is empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
t.Run("3 1 post data", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`))
|
||||
h31.ServeHTTP(w, r)
|
||||
checkStatus(w, "success")
|
||||
})
|
||||
|
||||
t.Run("3 1 bad json", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`))
|
||||
h31.ServeHTTP(w, r)
|
||||
checkStatus(w, "error")
|
||||
})
|
||||
|
||||
t.Run("3 1 post data error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||
h31.ServeHTTP(w, r)
|
||||
resp := checkStatus(w, "error")
|
||||
if resp.Error != "name is empty" {
|
||||
t.Fatalf("wrong error")
|
||||
}
|
||||
})
|
||||
|
||||
// 3 2
|
||||
h32 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) (*Data, error) {
|
||||
if d.Price == 0 {
|
||||
return nil, errors.New("price is empty")
|
||||
}
|
||||
|
||||
return &Data{Price: d.Price * 2}, nil
|
||||
})
|
||||
t.Run("3 2 post data", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
|
||||
h32.ServeHTTP(w, r)
|
||||
resp := checkStatus(w, "success")
|
||||
t.Log(resp.Data)
|
||||
if resp.Data.Price != 20 {
|
||||
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("3 2 post data error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||
h32.ServeHTTP(w, r)
|
||||
resp := checkStatus(w, "error")
|
||||
if resp.Error != "price is empty" {
|
||||
t.Fatalf("wrong error")
|
||||
}
|
||||
})
|
||||
|
||||
// fn check
|
||||
shouldPanic := func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatalf("should panic")
|
||||
}
|
||||
t.Log(r)
|
||||
}
|
||||
|
||||
t.Run("2 0 panic", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
JSONHandler(func(w http.ResponseWriter, r *http.Request) {})
|
||||
})
|
||||
|
||||
t.Run("2 1 panic return value", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
JSONHandler(func(w http.ResponseWriter, r *http.Request) string {
|
||||
return ""
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("2 1 panic arguments", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
JSONHandler(func(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("3 1 panic arguments", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("3 2 panic return value", func(t *testing.T) {
|
||||
defer shouldPanic()
|
||||
//lint:ignore ST1008 intentional
|
||||
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) (error, string) {
|
||||
return nil, "panic"
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("2 2 forbidden", func(t *testing.T) {
|
||||
code := http.StatusForbidden
|
||||
body := []byte("forbidden")
|
||||
h := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) {
|
||||
w.WriteHeader(code)
|
||||
w.Write(body)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("wrong code: %d %d", w.Code, code)
|
||||
}
|
||||
if !bytes.Equal(w.Body.Bytes(), []byte("forbidden")) {
|
||||
t.Fatalf("wrong body: %s %s", w.Body.Bytes(), body)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -6,9 +6,11 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"go4.org/mem"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
@@ -20,6 +22,15 @@ type Private [32]byte
|
||||
// Private reports whether p is the zero value.
|
||||
func (p Private) IsZero() bool { return p == Private{} }
|
||||
|
||||
// NewPrivate returns a new private key.
|
||||
func NewPrivate() Private {
|
||||
var p Private
|
||||
if _, err := io.ReadFull(crand.Reader, p[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// B32 returns k as the *[32]byte type that's used by the
|
||||
// golang.org/x/crypto packages. This allocates; it might
|
||||
// not be appropriate for performance-sensitive paths.
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -63,6 +64,13 @@ type limitData struct {
|
||||
|
||||
var disableRateLimit = os.Getenv("TS_DEBUG_LOG_RATE") == "all"
|
||||
|
||||
// rateFreePrefix are format string prefixes that are exempt from rate limiting.
|
||||
// Things should not be added to this unless they're already limited otherwise.
|
||||
var rateFreePrefix = []string{
|
||||
"magicsock: disco: ",
|
||||
"magicsock: CreateEndpoint:",
|
||||
}
|
||||
|
||||
// RateLimitedFn returns a rate-limiting Logf wrapping the given logf.
|
||||
// Messages are allowed through at a maximum of one message every f (where f is a time.Duration), in
|
||||
// bursts of up to burst messages at a time. Up to maxCache strings will be held at a time.
|
||||
@@ -85,6 +93,12 @@ func RateLimitedFn(logf Logf, f time.Duration, burst int, maxCache int) Logf {
|
||||
)
|
||||
|
||||
judge := func(format string) verdict {
|
||||
for _, pfx := range rateFreePrefix {
|
||||
if strings.HasPrefix(format, pfx) {
|
||||
return allow
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
rl, ok := msgLim[format]
|
||||
@@ -127,18 +141,23 @@ func RateLimitedFn(logf Logf, f time.Duration, burst int, maxCache int) Logf {
|
||||
// since the last time this identical line was logged.
|
||||
func LogOnChange(logf Logf, maxInterval time.Duration, timeNow func() time.Time) Logf {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
sLastLogged string
|
||||
tLastLogged = timeNow()
|
||||
)
|
||||
|
||||
return func(format string, args ...interface{}) {
|
||||
s := fmt.Sprintf(format, args...)
|
||||
|
||||
mu.Lock()
|
||||
if s == sLastLogged && timeNow().Sub(tLastLogged) < maxInterval {
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
sLastLogged = s
|
||||
tLastLogged = timeNow()
|
||||
mu.Unlock()
|
||||
|
||||
logf(s)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -59,7 +60,7 @@ func TestRateLimiter(t *testing.T) {
|
||||
lg("templated format string no. %d", i)
|
||||
if i == 4 {
|
||||
lg("Make sure this string makes it through the rest (that are blocked) %d", i)
|
||||
prefixed = WithPrefix(lg, string('0'+i))
|
||||
prefixed = WithPrefix(lg, string(rune('0'+i)))
|
||||
prefixed(" shouldn't get filtered.")
|
||||
}
|
||||
}
|
||||
@@ -117,3 +118,31 @@ func TestArgWriter(t *testing.T) {
|
||||
t.Errorf("got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynchronization(t *testing.T) {
|
||||
timeNow := testTimer(1 * time.Second)
|
||||
tests := []struct {
|
||||
name string
|
||||
logf Logf
|
||||
}{
|
||||
{"RateLimitedFn", RateLimitedFn(t.Logf, 1*time.Minute, 2, 50)},
|
||||
{"LogOnChange", LogOnChange(t.Logf, 5*time.Second, timeNow)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
f := func() {
|
||||
tt.logf("1 2 3 4 5")
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
go f()
|
||||
go f()
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user