Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c03eaba1a6 | ||
|
|
abaedc675b | ||
|
|
70512da940 | ||
|
|
0710fca0cd | ||
|
|
aa9d7f4665 | ||
|
|
a5dd0bcb09 | ||
|
|
b65eee0745 | ||
|
|
1ebbaaaebb | ||
|
|
eccc167733 | ||
|
|
8f76548fd9 | ||
|
|
5b338bf011 | ||
|
|
acade77c86 | ||
|
|
5d96ecd5e6 | ||
|
|
c8939ab7c7 | ||
|
|
883a11f2a8 | ||
|
|
d9e2edb5ae | ||
|
|
3c508a58cc | ||
|
|
51c8fd1dfc | ||
|
|
ff50ddf1ee | ||
|
|
fc8bc76e58 | ||
|
|
7a01cd27ca | ||
|
|
45d96788b5 | ||
|
|
000347d4cf | ||
|
|
b0526e8284 | ||
|
|
efad55cf86 | ||
|
|
cccdd81441 | ||
|
|
2eb474dd8d | ||
|
|
ce45f4f3ff | ||
|
|
3fdae12f0c | ||
|
|
47380ebcfb | ||
|
|
5062131aad | ||
|
|
2d604b3791 | ||
|
|
04ff3c91ee | ||
|
|
fac2b30eff | ||
|
|
a664aac877 | ||
|
|
a2d78b4d3e | ||
|
|
97e82c6cc0 | ||
|
|
19b0cfe89e | ||
|
|
258b680bc5 | ||
|
|
563d43b2a5 | ||
|
|
b246810377 | ||
|
|
c03543dbe2 | ||
|
|
0050070493 | ||
|
|
f99f6608ff | ||
|
|
a38e28da07 | ||
|
|
c2cc3acbaf | ||
|
|
d7ee3096dd | ||
|
|
9ef39af2f2 | ||
|
|
22bf48f37c | ||
|
|
55b1221db2 | ||
|
|
89894c6930 | ||
|
|
d192bd0f86 | ||
|
|
d21956436a | ||
|
|
450cfedeba | ||
|
|
e7ac9a4b90 | ||
|
|
6e52633c53 | ||
|
|
093431f5dd | ||
|
|
c48253e63b | ||
|
|
7a54910990 | ||
|
|
76d99cf01a | ||
|
|
b950bd60bf | ||
|
|
a8589636a8 | ||
|
|
b3634f020d | ||
|
|
7988f75b87 | ||
|
|
427bf2134f | ||
|
|
19df6a2ee2 | ||
|
|
ebd96bf4a9 | ||
|
|
e9bca0c00b | ||
|
|
b1de2020d7 | ||
|
|
b4e19b95ed | ||
|
|
8f30fa67aa | ||
|
|
3aa68cd397 | ||
|
|
119101962c | ||
|
|
bda53897b5 | ||
|
|
782e07c0ae | ||
|
|
4f4e84236a | ||
|
|
6bcb466096 | ||
|
|
696e160cfc | ||
|
|
946c1edb42 | ||
|
|
fb9f80cd61 | ||
|
|
ed17f5ddae | ||
|
|
39bbb86b09 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
# Binaries for programs and plugins
|
||||
*~
|
||||
*.tmp
|
||||
*.exe
|
||||
*.dll
|
||||
*.so
|
||||
@@ -18,3 +19,7 @@ cmd/tailscaled/tailscaled
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
# direnv config, this may be different for other people so it's probably safer
|
||||
# to make this nonspecific.
|
||||
.envrc
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
# $ docker exec tailscaled tailscale status
|
||||
|
||||
|
||||
FROM golang:1.14-alpine AS build-env
|
||||
FROM golang:1.15-alpine AS build-env
|
||||
|
||||
WORKDIR /go/src/tailscale
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
1.1.0 f81233524fddeec450940af8dc1a0dd8841bf28c
|
||||
1.3.0
|
||||
|
||||
@@ -9,12 +9,8 @@
|
||||
# this script, or executing equivalent commands in your
|
||||
# distro-specific build system.
|
||||
|
||||
set -euo pipefail
|
||||
set -eu
|
||||
|
||||
describe=$(./version/describe.sh)
|
||||
commit=$(git rev-parse --verify --quiet HEAD)
|
||||
eval $(./version/version.sh)
|
||||
|
||||
long=$(./version/mkversion.sh long "$describe" "")
|
||||
short=$(./version/mkversion.sh short "$describe" "")
|
||||
|
||||
exec go build -tags xversion -ldflags "-X tailscale.com/version.Long=${long} -X tailscale.com/version.Short=${short} -X tailscale.com/version.GitCommit=${commit}" "$@"
|
||||
exec go build -tags xversion -ldflags "-X tailscale.com/version.Long=${VERSION_LONG} -X tailscale.com/version.Short=${VERSION_SHORT} -X tailscale.com/version.GitCommit=${VERSION_GIT_HASH}" "$@"
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
|
||||
"github.com/peterbourgon/ff/v2/ffcli"
|
||||
"tailscale.com/derp/derpmap"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/netcheck"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/logger"
|
||||
@@ -44,9 +43,7 @@ var netcheckArgs struct {
|
||||
}
|
||||
|
||||
func runNetcheck(ctx context.Context, args []string) error {
|
||||
c := &netcheck.Client{
|
||||
DNSCache: dnscache.Get(),
|
||||
}
|
||||
c := &netcheck.Client{}
|
||||
if netcheckArgs.verbose {
|
||||
c.Logf = logger.WithPrefix(log.Printf, "netcheck: ")
|
||||
c.Verbose = true
|
||||
|
||||
@@ -182,18 +182,11 @@ func runUp(ctx context.Context, args []string) error {
|
||||
var tags []string
|
||||
if upArgs.advertiseTags != "" {
|
||||
tags = strings.Split(upArgs.advertiseTags, ",")
|
||||
for i, tag := range tags {
|
||||
if strings.HasPrefix(tag, "tag:") {
|
||||
// Accept fully-qualified tags (starting with
|
||||
// "tag:"), as we do in the ACL file.
|
||||
err := tailcfg.CheckTag(tag)
|
||||
if err != nil {
|
||||
fatalf("tag: %q: %v", tag, err)
|
||||
}
|
||||
} else if err := tailcfg.CheckTagSuffix(tag); err != nil {
|
||||
fatalf("tag: %q: %v", tag, err)
|
||||
for _, tag := range tags {
|
||||
err := tailcfg.CheckTag(tag)
|
||||
if err != nil {
|
||||
fatalf("tag: %q: %s", tag, err)
|
||||
}
|
||||
tags[i] = "tag:" + tag
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/iphlpapi from github.com/tailscale/wireguard-go/tun/wintun
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/namespaceapi from github.com/tailscale/wireguard-go/tun/wintun
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/nci from github.com/tailscale/wireguard-go/tun/wintun
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/registry from github.com/tailscale/wireguard-go/tun/wintun
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/registry from github.com/tailscale/wireguard-go/tun/wintun+
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/setupapi from github.com/tailscale/wireguard-go/tun/wintun
|
||||
github.com/tailscale/wireguard-go/wgcfg from github.com/tailscale/wireguard-go/conn+
|
||||
github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck
|
||||
@@ -53,10 +53,11 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
tailscale.com/log/logheap from tailscale.com/control/controlclient
|
||||
tailscale.com/logtail/backoff from tailscale.com/control/controlclient+
|
||||
tailscale.com/metrics from tailscale.com/derp
|
||||
tailscale.com/net/dnscache from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/net/dnscache from tailscale.com/control/controlclient+
|
||||
💣 tailscale.com/net/interfaces from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/net/netcheck from tailscale.com/cmd/tailscale/cli+
|
||||
💣 tailscale.com/net/netns from tailscale.com/control/controlclient+
|
||||
tailscale.com/net/netns from tailscale.com/control/controlclient+
|
||||
tailscale.com/net/packet from tailscale.com/wgengine+
|
||||
tailscale.com/net/stun from tailscale.com/net/netcheck+
|
||||
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
|
||||
tailscale.com/net/tsaddr from tailscale.com/ipn+
|
||||
@@ -66,15 +67,15 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
tailscale.com/safesocket from tailscale.com/cmd/tailscale/cli
|
||||
💣 tailscale.com/syncs from tailscale.com/net/interfaces+
|
||||
tailscale.com/tailcfg from tailscale.com/cmd/tailscale/cli+
|
||||
DW tailscale.com/tempfork/osexec from tailscale.com/portlist
|
||||
W tailscale.com/tsconst from tailscale.com/net/interfaces
|
||||
tailscale.com/types/empty from tailscale.com/control/controlclient+
|
||||
tailscale.com/types/key from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/types/logger from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/types/nettype from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/types/opt from tailscale.com/control/controlclient+
|
||||
tailscale.com/types/strbuilder from tailscale.com/wgengine/packet
|
||||
tailscale.com/types/strbuilder from tailscale.com/net/packet
|
||||
tailscale.com/types/structs from tailscale.com/control/controlclient+
|
||||
W tailscale.com/util/endian from tailscale.com/net/netns
|
||||
tailscale.com/util/lineread from tailscale.com/control/controlclient+
|
||||
tailscale.com/version from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/version/distro from tailscale.com/cmd/tailscale/cli+
|
||||
@@ -82,7 +83,6 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
tailscale.com/wgengine/filter from tailscale.com/control/controlclient+
|
||||
tailscale.com/wgengine/magicsock from tailscale.com/wgengine
|
||||
💣 tailscale.com/wgengine/monitor from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/wgengine/packet from tailscale.com/wgengine+
|
||||
tailscale.com/wgengine/router from tailscale.com/cmd/tailscale/cli+
|
||||
💣 tailscale.com/wgengine/router/dns from tailscale.com/ipn+
|
||||
tailscale.com/wgengine/tsdns from tailscale.com/ipn+
|
||||
@@ -170,6 +170,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
hash/adler32 from compress/zlib
|
||||
hash/crc32 from compress/gzip+
|
||||
hash/fnv from tailscale.com/wgengine/magicsock
|
||||
hash/maphash from go4.org/mem
|
||||
html from tailscale.com/ipn/ipnstate
|
||||
io from bufio+
|
||||
io/ioutil from crypto/tls+
|
||||
|
||||
@@ -33,7 +33,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/iphlpapi from github.com/tailscale/wireguard-go/tun/wintun
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/namespaceapi from github.com/tailscale/wireguard-go/tun/wintun
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/nci from github.com/tailscale/wireguard-go/tun/wintun
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/registry from github.com/tailscale/wireguard-go/tun/wintun
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/registry from github.com/tailscale/wireguard-go/tun/wintun+
|
||||
W 💣 github.com/tailscale/wireguard-go/tun/wintun/setupapi from github.com/tailscale/wireguard-go/tun/wintun
|
||||
github.com/tailscale/wireguard-go/wgcfg from github.com/tailscale/wireguard-go/conn+
|
||||
github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck
|
||||
@@ -58,11 +58,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/logtail/backoff from tailscale.com/control/controlclient+
|
||||
tailscale.com/logtail/filch from tailscale.com/logpolicy
|
||||
tailscale.com/metrics from tailscale.com/derp
|
||||
tailscale.com/net/dnscache from tailscale.com/derp/derphttp+
|
||||
tailscale.com/net/dnscache from tailscale.com/control/controlclient+
|
||||
💣 tailscale.com/net/interfaces from tailscale.com/ipn+
|
||||
tailscale.com/net/netcheck from tailscale.com/wgengine/magicsock
|
||||
💣 tailscale.com/net/netns from tailscale.com/control/controlclient+
|
||||
tailscale.com/net/netns from tailscale.com/control/controlclient+
|
||||
💣 tailscale.com/net/netstat from tailscale.com/ipn/ipnserver
|
||||
tailscale.com/net/packet from tailscale.com/wgengine+
|
||||
tailscale.com/net/stun from tailscale.com/net/netcheck+
|
||||
tailscale.com/net/tlsdial from tailscale.com/control/controlclient+
|
||||
tailscale.com/net/tsaddr from tailscale.com/ipn+
|
||||
@@ -73,7 +74,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/smallzstd from tailscale.com/ipn/ipnserver+
|
||||
💣 tailscale.com/syncs from tailscale.com/net/interfaces+
|
||||
tailscale.com/tailcfg from tailscale.com/control/controlclient+
|
||||
DW tailscale.com/tempfork/osexec from tailscale.com/portlist
|
||||
W tailscale.com/tsconst from tailscale.com/net/interfaces
|
||||
tailscale.com/types/empty from tailscale.com/control/controlclient+
|
||||
tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled
|
||||
@@ -81,17 +81,18 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/types/logger from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/types/nettype from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/types/opt from tailscale.com/control/controlclient+
|
||||
tailscale.com/types/strbuilder from tailscale.com/wgengine/packet
|
||||
tailscale.com/types/strbuilder from tailscale.com/net/packet
|
||||
tailscale.com/types/structs from tailscale.com/control/controlclient+
|
||||
W tailscale.com/util/endian from tailscale.com/net/netns+
|
||||
tailscale.com/util/lineread from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/pidowner from tailscale.com/ipn/ipnserver
|
||||
tailscale.com/util/racebuild from tailscale.com/logpolicy
|
||||
tailscale.com/version from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/version/distro from tailscale.com/control/controlclient+
|
||||
tailscale.com/wgengine from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/wgengine/filter from tailscale.com/control/controlclient+
|
||||
tailscale.com/wgengine/magicsock from tailscale.com/cmd/tailscaled+
|
||||
💣 tailscale.com/wgengine/monitor from tailscale.com/wgengine
|
||||
tailscale.com/wgengine/packet from tailscale.com/wgengine+
|
||||
tailscale.com/wgengine/router from tailscale.com/cmd/tailscaled+
|
||||
💣 tailscale.com/wgengine/router/dns from tailscale.com/ipn+
|
||||
tailscale.com/wgengine/tsdns from tailscale.com/ipn+
|
||||
@@ -180,6 +181,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
hash/adler32 from compress/zlib
|
||||
hash/crc32 from compress/gzip+
|
||||
hash/fnv from tailscale.com/wgengine/magicsock
|
||||
hash/maphash from go4.org/mem
|
||||
html from html/template+
|
||||
html/template from net/http/pprof
|
||||
io from bufio+
|
||||
|
||||
@@ -6,6 +6,7 @@ After=network-pre.target
|
||||
|
||||
[Service]
|
||||
EnvironmentFile=/etc/default/tailscaled
|
||||
ExecStartPre=/usr/sbin/tailscaled --cleanup
|
||||
ExecStart=/usr/sbin/tailscaled --state=/var/lib/tailscale/tailscaled.state --socket=/run/tailscale/tailscaled.sock --port $PORT $FLAGS
|
||||
ExecStopPost=/usr/sbin/tailscaled --cleanup
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sort"
|
||||
@@ -35,6 +36,7 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/log/logheap"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/net/tlsdial"
|
||||
"tailscale.com/net/tshttpproxy"
|
||||
@@ -91,9 +93,14 @@ func (p *Persist) Pretty() string {
|
||||
if !p.PrivateNodeKey.IsZero() {
|
||||
nk = p.PrivateNodeKey.Public()
|
||||
}
|
||||
ss := func(k wgcfg.Key) string {
|
||||
if k.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return k.ShortString()
|
||||
}
|
||||
return fmt.Sprintf("Persist{lm=%v, o=%v, n=%v u=%#v}",
|
||||
mk.ShortString(), ok.ShortString(), nk.ShortString(),
|
||||
p.LoginName)
|
||||
ss(mk), ss(ok), ss(nk), p.LoginName)
|
||||
}
|
||||
|
||||
// Direct is the client that connects to a tailcontrol server for a node.
|
||||
@@ -166,11 +173,15 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
|
||||
httpc := opts.HTTPTestClient
|
||||
if httpc == nil {
|
||||
dnsCache := &dnscache.Resolver{
|
||||
Forward: dnscache.Get().Forward, // use default cache's forwarder
|
||||
UseLastGood: true,
|
||||
}
|
||||
dialer := netns.NewDialer()
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.Proxy = tshttpproxy.ProxyFromEnvironment
|
||||
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
|
||||
tr.DialContext = dialer.DialContext
|
||||
tr.DialContext = dnscache.Dialer(dialer.DialContext, dnsCache)
|
||||
tr.ForceAttemptHTTP2 = true
|
||||
tr.TLSClientConfig = tlsdial.Config(serverURL.Host, tr.TLSClientConfig)
|
||||
httpc = &http.Client{Transport: tr}
|
||||
@@ -539,6 +550,10 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
|
||||
Hostinfo: hostinfo,
|
||||
DebugFlags: c.debugFlags,
|
||||
}
|
||||
if hostinfo != nil && ipForwardingBroken(hostinfo.RoutableIPs) {
|
||||
old := request.DebugFlags
|
||||
request.DebugFlags = append(old[:len(old):len(old)], "warn-ip-forwarding-off")
|
||||
}
|
||||
if c.newDecompressor != nil {
|
||||
request.Compress = "zstd"
|
||||
}
|
||||
@@ -781,6 +796,8 @@ func decode(res *http.Response, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg
|
||||
|
||||
var debugMap, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_MAP"))
|
||||
|
||||
var jsonEscapedZero = []byte(`\u0000`)
|
||||
|
||||
func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
|
||||
c.mu.Lock()
|
||||
serverKey := c.serverKey
|
||||
@@ -809,6 +826,10 @@ func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
|
||||
json.Indent(&buf, b, "", " ")
|
||||
log.Printf("MapResponse: %s", buf.Bytes())
|
||||
}
|
||||
|
||||
if bytes.Contains(b, jsonEscapedZero) {
|
||||
log.Printf("[unexpected] zero byte in controlclient.Direct.decodeMsg into %T: %q", v, b)
|
||||
}
|
||||
if err := json.Unmarshal(b, v); err != nil {
|
||||
return fmt.Errorf("response: %v", err)
|
||||
}
|
||||
@@ -821,6 +842,9 @@ func decodeMsg(msg []byte, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.Priv
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if bytes.Contains(decrypted, jsonEscapedZero) {
|
||||
log.Printf("[unexpected] zero byte in controlclient decodeMsg into %T: %q", v, decrypted)
|
||||
}
|
||||
if err := json.Unmarshal(decrypted, v); err != nil {
|
||||
return fmt.Errorf("response: %v", err)
|
||||
}
|
||||
@@ -1051,3 +1075,34 @@ func TrimWGConfig() opt.Bool {
|
||||
v, _ := controlTrimWGConfig.Load().(opt.Bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// ipForwardingBroken reports whether the system's IP forwarding is disabled
|
||||
// and will definitely not work for the routes provided.
|
||||
//
|
||||
// It should not return false positives.
|
||||
func ipForwardingBroken(routes []wgcfg.CIDR) bool {
|
||||
if len(routes) == 0 {
|
||||
// Nothing to route, so no need to warn.
|
||||
return false
|
||||
}
|
||||
if runtime.GOOS != "linux" {
|
||||
// We only do subnet routing on Linux for now.
|
||||
// It might work on darwin/macOS when building from source, so
|
||||
// don't return true for other OSes. We can OS-based warnings
|
||||
// already in the admin panel.
|
||||
return false
|
||||
}
|
||||
out, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward")
|
||||
if err != nil {
|
||||
// Try another way.
|
||||
out, err = exec.Command("sysctl", "-n", "net.ipv4.ip_forward").Output()
|
||||
}
|
||||
if err != nil {
|
||||
// Oh well, we tried. This is just for debugging.
|
||||
// We don't want false positives.
|
||||
// TODO: maybe we want a different warning for inability to check?
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(string(out)) == "0"
|
||||
// TODO: also check IPv6 if 'routes' contains any IPv6 routes
|
||||
}
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
// Parse a backward-compatible FilterRule used by control's wire format,
|
||||
// producing the most current filter.Matches format.
|
||||
func (c *Direct) parsePacketFilter(pf []tailcfg.FilterRule) filter.Matches {
|
||||
// Parse a backward-compatible FilterRule used by control's wire
|
||||
// format, producing the most current filter format.
|
||||
func (c *Direct) parsePacketFilter(pf []tailcfg.FilterRule) []filter.Match {
|
||||
mm, err := filter.MatchesFromFilterRules(pf)
|
||||
if err != nil {
|
||||
c.logf("parsePacketFilter: %s\n", err)
|
||||
|
||||
@@ -34,7 +34,7 @@ type NetworkMap struct {
|
||||
Peers []*tailcfg.Node // sorted by Node.ID
|
||||
DNS tailcfg.DNSConfig
|
||||
Hostinfo tailcfg.Hostinfo
|
||||
PacketFilter filter.Matches
|
||||
PacketFilter []filter.Match
|
||||
|
||||
// DERPMap is the last DERP server map received. It's reused
|
||||
// between updates and should not be modified.
|
||||
|
||||
10
go.mod
10
go.mod
@@ -26,15 +26,15 @@ require (
|
||||
github.com/tailscale/wireguard-go v0.0.0-20201021041318-a6168fd06b3f
|
||||
github.com/tcnksm/go-httpstat v0.2.0
|
||||
github.com/toqueteos/webbrowser v1.2.0
|
||||
go4.org/mem v0.0.0-20200706164138-185c595c3ecc
|
||||
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202
|
||||
go4.org/mem v0.0.0-20201119185036-c04c5a6ff174
|
||||
golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208
|
||||
golang.org/x/sys v0.0.0-20200812155832-6a926be9bd1d
|
||||
golang.org/x/sys v0.0.0-20201112073958-5cba982894dd
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
|
||||
golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d
|
||||
golang.zx2c4.com/wireguard/windows v0.1.2-0.20201004085714-dd60d0447f81
|
||||
golang.zx2c4.com/wireguard/windows v0.1.2-0.20201113162609-9b85be97fdf8
|
||||
honnef.co/go/tools v0.0.1-2020.1.4
|
||||
inet.af/netaddr v0.0.0-20200810144936-56928fe48a98
|
||||
rsc.io/goversion v1.2.0
|
||||
|
||||
40
go.sum
40
go.sum
@@ -1,8 +1,12 @@
|
||||
cloud.google.com/go v0.0.0-20170206221025-ce650573d812/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20190129172621-c8b1d7a94ddf/go.mod h1:aJ4qN3TfrelA6NZ6AXsXRfmEVaYin3EDbSPJrKS8OXo=
|
||||
github.com/Masterminds/semver/v3 v3.0.3 h1:znjIyLfpXEDQjOIEWh+ehwpTU14UzUPub3c3sm36u14=
|
||||
github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
|
||||
github.com/aclements/go-gg v0.0.0-20170118225347-6dbb4e4fefb0/go.mod h1:55qNq4vcpkIuHowELi5C8e+1yUHtoLoOUR9QU5j7Tes=
|
||||
github.com/aclements/go-moremath v0.0.0-20161014184102-0ff62e0875ff/go.mod h1:idZL3yvz4kzx1dsBOAC+oYv6L92P1oFEhUXUB1A/lwQ=
|
||||
github.com/alecthomas/kingpin v2.2.6+incompatible/go.mod h1:59OFYbFVLKQKq+mqrL6Rw5bR0c3ACQaawgXx0QYndlE=
|
||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
|
||||
@@ -31,19 +35,27 @@ github.com/go-multierror/multierror v1.0.2 h1:AwsKbEXkmf49ajdFJgcFXqSG0aLo0HEyAE
|
||||
github.com/go-multierror/multierror v1.0.2/go.mod h1:U7SZR/D9jHgt2nkSj8XcbCWdmVM2igraCHQ3HC1HiKY=
|
||||
github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI=
|
||||
github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM=
|
||||
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
|
||||
github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME=
|
||||
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/gonum/blas v0.0.0-20181208220705-f22b278b28ac/go.mod h1:P32wAyui1PQ58Oce/KYkOqQv8cVw1zAapXOl+dRFGbc=
|
||||
github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82/go.mod h1:PxC8OnwL11+aosOB5+iEPoV3picfs8tUpkVd0pDo+Kg=
|
||||
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029/go.mod h1:Pu4dmpkhSyOzRwuXkOgAvijx4o+4YMUJJo9OvPYMkks=
|
||||
github.com/gonum/lapack v0.0.0-20181123203213-e4cdc5a0bff9/go.mod h1:XA3DeT6rxh2EAE789SSiSJNqxPaC0aE9J8NTOI0Jo/A=
|
||||
github.com/gonum/matrix v0.0.0-20181209220409-c518dec07be9/go.mod h1:0EXg4mc1CNP0HCqCz+K4ts155PXIlUywf0wqN+GfPZw=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/rpmpack v0.0.0-20191226140753-aa36bfddb3a0 h1:BW6OvS3kpT5UEPbCZ+KyX/OB4Ks9/MNMhWjqPPkZxsE=
|
||||
github.com/google/rpmpack v0.0.0-20191226140753-aa36bfddb3a0/go.mod h1:RaTPr0KUf2K7fnZYLNDrr8rxAamWs3iNywJLtQ2AzBg=
|
||||
github.com/googleapis/gax-go v0.0.0-20161107002406-da06d194a00e/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY=
|
||||
github.com/goreleaser/nfpm v1.1.10 h1:0nwzKUJTcygNxTzVKq2Dh9wpVP1W2biUH6SNKmoxR3w=
|
||||
github.com/goreleaser/nfpm v1.1.10/go.mod h1:oOcoGRVwvKIODz57NUfiRwFWGfn00NXdgnn6MrYtO5k=
|
||||
github.com/imdario/mergo v0.3.8 h1:CGgOkSJeqMRmt0D9XLWExdT4m4F1vd3FV3VPt+0VxkQ=
|
||||
@@ -61,7 +73,10 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lxn/walk v0.0.0-20191128110447-55ccb3a9f5c1/go.mod h1:E23UucZGqpuUANJooIbHWCufXvOcT6E7Stq81gU+CSQ=
|
||||
github.com/lxn/walk v0.0.0-20201110160827-18ea5e372cdb/go.mod h1:E23UucZGqpuUANJooIbHWCufXvOcT6E7Stq81gU+CSQ=
|
||||
github.com/lxn/win v0.0.0-20191128105842-2da648fda5b4/go.mod h1:ouWl4wViUNh8tPSIwxTVMuS014WakR1hqvBc2I0bMoA=
|
||||
github.com/lxn/win v0.0.0-20201111105847-2a20daff6a55/go.mod h1:KxxjdtRkfNoYDCUP5ryK7XJJNTnpC8atvtmTheChOtk=
|
||||
github.com/mattn/go-sqlite3 v0.0.0-20161215041557-2d44decb4941/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mattn/go-zglob v0.0.1 h1:xsEx/XUoVlI6yXjqBK062zYhRTZltCNmYPx6v+8DNaY=
|
||||
github.com/mattn/go-zglob v0.0.1/go.mod h1:9fxibJccNxU2cnpIKLRRFA7zX7qhkJIQWBb449FYHOo=
|
||||
github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA=
|
||||
@@ -110,6 +125,8 @@ github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMx
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go4.org/mem v0.0.0-20200706164138-185c595c3ecc h1:paujszgN6SpsO/UsXC7xax3gQAKz/XQKCYZLQdU34Tw=
|
||||
go4.org/mem v0.0.0-20200706164138-185c595c3ecc/go.mod h1:NEYvpHWemiG/E5UWfaN5QAIGZeT1sa0Z2UNk6oeMb/k=
|
||||
go4.org/mem v0.0.0-20201119185036-c04c5a6ff174 h1:vSug/WNOi2+4jrKdivxayTN/zd8EA1UrStjpWvvo1jk=
|
||||
go4.org/mem v0.0.0-20201119185036-c04c5a6ff174/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
@@ -119,6 +136,9 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnk
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 h1:DZhuSZLsGlFL4CmhA8BcRA0mnthyA/nZ00AqCUo7vHg=
|
||||
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9 h1:umElSU9WZirRdgu2yFHY0ayQkEnKiOC1TtM3fWXFnoU=
|
||||
golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
|
||||
@@ -137,9 +157,16 @@ golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/
|
||||
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/oauth2 v0.0.0-20170207211851-4464e7848382/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/perf v0.0.0-20200918155509-d949658356f9 h1:yVBHF5pcQLKR9B+y+dOJ6y68nqJBDWaZ9DhB1Ohg0qE=
|
||||
golang.org/x/perf v0.0.0-20200918155509-d949658356f9/go.mod h1:FrqOtQDO3iMDVUtw5nNTDFpR1HUCGh00M3kj2wiSzLQ=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA=
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -159,10 +186,17 @@ golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200812155832-6a926be9bd1d h1:QQrM/CCYEzTs91GZylDCQjGHudbPTxF/1fvXdVh5lMo=
|
||||
golang.org/x/sys v0.0.0-20200812155832-6a926be9bd1d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201107080550-4d91cf3a1aaf/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201112073958-5cba982894dd h1:5CtCZbICpIOFdgO940moixOPjc0178IU44m4EjOO5IY=
|
||||
golang.org/x/sys v0.0.0-20201112073958-5cba982894dd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4 h1:0YWbFKbhXG/wIiuHDSKpS0Iy7FSA+u45VtBMfQcFTTc=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -179,10 +213,16 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1N
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wireguard v0.0.20200321-0.20200715051853-507f148e1c42 h1:SrR1hmxGKKarHEEDvaHxatwnqE3uT+7jvMcin6SHOkw=
|
||||
golang.zx2c4.com/wireguard v0.0.20200321-0.20200715051853-507f148e1c42/go.mod h1:GJvYs5O24/ASlwPiRklVnjMx2xQzrOic0DuU6GvYJL4=
|
||||
golang.zx2c4.com/wireguard v0.0.20200321-0.20201111175144-60b3766b89b9 h1:qowcZ56hhpeoESmWzI4Exhx4Y78TpCyXUJur4/c0CoE=
|
||||
golang.zx2c4.com/wireguard v0.0.20200321-0.20201111175144-60b3766b89b9/go.mod h1:LMeNfjlcPZTrBC1juwgbQyA4Zy2XVcsrdO/fIJxwyuA=
|
||||
golang.zx2c4.com/wireguard/windows v0.1.2-0.20201004085714-dd60d0447f81 h1:cT2oWlz8v9g7bjFZclT362akxJJfGv9d7ccKu6GQUbA=
|
||||
golang.zx2c4.com/wireguard/windows v0.1.2-0.20201004085714-dd60d0447f81/go.mod h1:GaK5zcgr5XE98WaRzIDilumDBp5/yP8j2kG/LCDnvAM=
|
||||
golang.zx2c4.com/wireguard/windows v0.1.2-0.20201113162609-9b85be97fdf8 h1:nlXPqGA98n+qcq1pwZ28KjM5EsFQvamKS00A+VUeVjs=
|
||||
golang.zx2c4.com/wireguard/windows v0.1.2-0.20201113162609-9b85be97fdf8/go.mod h1:psva4yDnAHLuh7lUzOK7J7bLYxNFfo0iKWz+mi9gzkA=
|
||||
google.golang.org/api v0.0.0-20170206182103-3d017632ea10/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/grpc v0.0.0-20170208002647-2a6bf6142e96/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
|
||||
@@ -724,6 +724,10 @@ func BabysitProc(ctx context.Context, args []string, logf logger.Logf) {
|
||||
// pipe. We'll make a new one when we restart the subproc.
|
||||
wStdin.Close()
|
||||
|
||||
if os.Getenv("TS_DEBUG_RESTART_CRASHED") == "0" {
|
||||
log.Fatalf("Process ended.")
|
||||
}
|
||||
|
||||
if time.Since(startTime) < 60*time.Second {
|
||||
bo.BackOff(ctx, fmt.Errorf("subproc early exit: %v", err))
|
||||
} else {
|
||||
|
||||
89
ipn/local.go
89
ipn/local.go
@@ -419,7 +419,9 @@ func (b *LocalBackend) Start(opts Options) error {
|
||||
b.serverURL = b.prefs.ControlURL
|
||||
hostinfo.RoutableIPs = append(hostinfo.RoutableIPs, b.prefs.AdvertiseRoutes...)
|
||||
hostinfo.RequestTags = append(hostinfo.RequestTags, b.prefs.AdvertiseTags...)
|
||||
b.logf("Start: serverMode=%v; stateKey=%q; tags=%q; routes=%v; url=%v", b.inServerMode, b.stateKey, b.prefs.AdvertiseTags, b.prefs.AdvertiseRoutes, b.prefs.ControlURL)
|
||||
if b.inServerMode || runtime.GOOS == "windows" {
|
||||
b.logf("Start: serverMode=%v", b.inServerMode)
|
||||
}
|
||||
applyPrefsToHostinfo(hostinfo, b.prefs)
|
||||
|
||||
b.notify = opts.Notify
|
||||
@@ -521,7 +523,7 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap, prefs *Pre
|
||||
var (
|
||||
haveNetmap = netMap != nil
|
||||
addrs []wgcfg.CIDR
|
||||
packetFilter filter.Matches
|
||||
packetFilter []filter.Match
|
||||
advRoutes []wgcfg.CIDR
|
||||
shieldsUp = prefs == nil || prefs.ShieldsUp // Be conservative when not ready
|
||||
)
|
||||
@@ -544,12 +546,12 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap, prefs *Pre
|
||||
return
|
||||
}
|
||||
|
||||
localNets := wgCIDRsToFilter(netMap.Addresses, advRoutes)
|
||||
localNets := wgCIDRsToNetaddr(netMap.Addresses, advRoutes)
|
||||
|
||||
if shieldsUp {
|
||||
b.logf("netmap packet filter: (shields up)")
|
||||
var prevFilter *filter.Filter // don't reuse old filter state
|
||||
b.e.SetFilter(filter.New(filter.Matches{}, localNets, prevFilter, b.logf))
|
||||
b.e.SetFilter(filter.New(nil, localNets, prevFilter, b.logf))
|
||||
} else {
|
||||
b.logf("netmap packet filter: %v", packetFilter)
|
||||
b.e.SetFilter(filter.New(packetFilter, localNets, b.e.GetFilter(), b.logf))
|
||||
@@ -704,6 +706,7 @@ func (b *LocalBackend) popBrowserAuthNow() {
|
||||
// initMachineKeyLocked is called to initialize b.machinePrivKey.
|
||||
//
|
||||
// b.prefs must already be initialized.
|
||||
// b.stateKey should be set too, but just for nicer log messages.
|
||||
// b.mu must be held.
|
||||
func (b *LocalBackend) initMachineKeyLocked() (err error) {
|
||||
if temporarilySetMachineKeyInPersist() {
|
||||
@@ -748,7 +751,11 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) {
|
||||
// have a legacy machine key, use that. Otherwise generate a
|
||||
// new one.
|
||||
if !legacyMachineKey.IsZero() {
|
||||
b.logf("using frontend-provided legacy machine key")
|
||||
if b.stateKey == "" {
|
||||
b.logf("using frontend-provided legacy machine key")
|
||||
} else {
|
||||
b.logf("using legacy machine key from state key %q", b.stateKey)
|
||||
}
|
||||
b.machinePrivKey = legacyMachineKey
|
||||
} else {
|
||||
b.logf("generating new machine key")
|
||||
@@ -801,23 +808,32 @@ func (b *LocalBackend) writeServerModeStartState(userID string, prefs *Prefs) {
|
||||
// loadStateLocked sets b.prefs and b.stateKey based on a complex
|
||||
// combination of key, prefs, and legacyPath. b.mu must be held when
|
||||
// calling.
|
||||
func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath string) error {
|
||||
func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath string) (err error) {
|
||||
if prefs == nil && key == "" {
|
||||
panic("state key and prefs are both unset")
|
||||
}
|
||||
|
||||
// Optimistically set stateKey (for initMachineKeyLocked's
|
||||
// logging), but revert it if we return an error so a later SetPrefs
|
||||
// call can't pick it up if it's bogus.
|
||||
b.stateKey = key
|
||||
defer func() {
|
||||
if err != nil {
|
||||
b.stateKey = ""
|
||||
}
|
||||
}()
|
||||
|
||||
if key == "" {
|
||||
// Frontend owns the state, we just need to obey it.
|
||||
//
|
||||
// If the frontend (e.g. on Windows) supplied the
|
||||
// optional/legacy machine key then it's used as the
|
||||
// value instead of making up a new one.
|
||||
b.logf("Using frontend prefs")
|
||||
b.logf("using frontend prefs: %s", prefs.Pretty())
|
||||
b.prefs = prefs.Clone()
|
||||
if err := b.initMachineKeyLocked(); err != nil {
|
||||
return fmt.Errorf("initMachineKeyLocked: %w", err)
|
||||
}
|
||||
b.stateKey = ""
|
||||
b.writeServerModeStartState(b.userID, b.prefs)
|
||||
return nil
|
||||
}
|
||||
@@ -825,13 +841,13 @@ func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath st
|
||||
if prefs != nil {
|
||||
// Backend owns the state, but frontend is trying to migrate
|
||||
// state into the backend.
|
||||
b.logf("Importing frontend prefs into backend store")
|
||||
b.logf("importing frontend prefs into backend store; frontend prefs: %s", prefs.Pretty())
|
||||
if err := b.store.WriteState(key, prefs.ToBytes()); err != nil {
|
||||
return fmt.Errorf("store.WriteState: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.logf("Using backend prefs")
|
||||
b.logf("using backend prefs")
|
||||
bs, err := b.store.ReadState(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrStateNotExist) {
|
||||
@@ -843,16 +859,15 @@ func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath st
|
||||
}
|
||||
b.prefs = NewPrefs()
|
||||
} else {
|
||||
b.logf("Imported state from relaynode for %q", key)
|
||||
b.logf("imported prefs from relaynode for %q: %v", key, b.prefs.Pretty())
|
||||
}
|
||||
} else {
|
||||
b.prefs = NewPrefs()
|
||||
b.logf("Created empty state for %q", key)
|
||||
b.logf("created empty state for %q: %s", key, b.prefs.Pretty())
|
||||
}
|
||||
if err := b.initMachineKeyLocked(); err != nil {
|
||||
return fmt.Errorf("initMachineKeyLocked: %w", err)
|
||||
}
|
||||
b.stateKey = key
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("store.ReadState(%q): %v", key, err)
|
||||
@@ -861,7 +876,7 @@ func (b *LocalBackend) loadStateLocked(key StateKey, prefs *Prefs, legacyPath st
|
||||
if err != nil {
|
||||
return fmt.Errorf("PrefsFromBytes: %v", err)
|
||||
}
|
||||
b.stateKey = key
|
||||
b.logf("backend prefs for %q: %s", key, b.prefs.Pretty())
|
||||
if err := b.initMachineKeyLocked(); err != nil {
|
||||
return fmt.Errorf("initMachineKeyLocked: %w", err)
|
||||
}
|
||||
@@ -1139,6 +1154,7 @@ func (b *LocalBackend) authReconfig() {
|
||||
uc := b.prefs
|
||||
nm := b.netMap
|
||||
hasPAC := b.prevIfState.HasPAC()
|
||||
disableSubnetsIfPAC := nm != nil && nm.Debug != nil && nm.Debug.DisableSubnetsIfPAC.EqualBool(true)
|
||||
b.mu.Unlock()
|
||||
|
||||
if blocked {
|
||||
@@ -1163,13 +1179,7 @@ func (b *LocalBackend) authReconfig() {
|
||||
if uc.AllowSingleHosts {
|
||||
flags |= controlclient.AllowSingleHosts
|
||||
}
|
||||
if hasPAC {
|
||||
// TODO(bradfitz): make this policy configurable per
|
||||
// domain, flesh out all the edge cases where subnet
|
||||
// routes might shadow corp HTTP proxies, DNS servers,
|
||||
// domain controllers, etc. For now we just want
|
||||
// Tailscale to stay enabled while laptops roam
|
||||
// between corp & non-corp networks.
|
||||
if hasPAC && disableSubnetsIfPAC {
|
||||
if flags&controlclient.AllowSubnetRoutes != 0 {
|
||||
b.logf("authReconfig: have PAC; disabling subnet routes")
|
||||
flags &^= controlclient.AllowSubnetRoutes
|
||||
@@ -1251,14 +1261,14 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
|
||||
}
|
||||
|
||||
rs := &router.Config{
|
||||
LocalAddrs: wgCIDRToNetaddr(addrs),
|
||||
SubnetRoutes: wgCIDRToNetaddr(prefs.AdvertiseRoutes),
|
||||
LocalAddrs: wgCIDRsToNetaddr(addrs),
|
||||
SubnetRoutes: wgCIDRsToNetaddr(prefs.AdvertiseRoutes),
|
||||
SNATSubnetRoutes: !prefs.NoSNAT,
|
||||
NetfilterMode: prefs.NetfilterMode,
|
||||
}
|
||||
|
||||
for _, peer := range cfg.Peers {
|
||||
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
|
||||
rs.Routes = append(rs.Routes, wgCIDRsToNetaddr(peer.AllowedIPs)...)
|
||||
}
|
||||
|
||||
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
|
||||
@@ -1269,35 +1279,20 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
|
||||
return rs
|
||||
}
|
||||
|
||||
// wgCIDRsToFilter converts lists of wgcfg.CIDR into a single list of
|
||||
// filter.Net.
|
||||
func wgCIDRsToFilter(cidrLists ...[]wgcfg.CIDR) (ret []filter.Net) {
|
||||
func wgCIDRsToNetaddr(cidrLists ...[]wgcfg.CIDR) (ret []netaddr.IPPrefix) {
|
||||
for _, cidrs := range cidrLists {
|
||||
for _, cidr := range cidrs {
|
||||
if !cidr.IP.Is4() {
|
||||
continue
|
||||
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
|
||||
}
|
||||
ret = append(ret, filter.Net{
|
||||
IP: filter.NewIP(cidr.IP.IP()),
|
||||
Mask: filter.Netmask(int(cidr.Mask)),
|
||||
})
|
||||
ncidr.IP = ncidr.IP.Unmap()
|
||||
ret = append(ret, ncidr)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func wgCIDRToNetaddr(cidrs []wgcfg.CIDR) (ret []netaddr.IPPrefix) {
|
||||
for _, cidr := range cidrs {
|
||||
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
|
||||
}
|
||||
ncidr.IP = ncidr.IP.Unmap()
|
||||
ret = append(ret, ncidr)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func applyPrefsToHostinfo(hi *tailcfg.Hostinfo, prefs *Prefs) {
|
||||
if h := prefs.Hostname; h != "" {
|
||||
hi.Hostname = h
|
||||
@@ -1308,6 +1303,7 @@ func applyPrefsToHostinfo(hi *tailcfg.Hostinfo, prefs *Prefs) {
|
||||
if m := prefs.DeviceModel; m != "" {
|
||||
hi.DeviceModel = m
|
||||
}
|
||||
hi.ShieldsUp = prefs.ShieldsUp
|
||||
}
|
||||
|
||||
// enterState transitions the backend into newState, updating internal
|
||||
@@ -1554,7 +1550,8 @@ func (b *LocalBackend) TestOnlyPublicKeys() (machineKey tailcfg.MachineKey, node
|
||||
// clients. We can't do that until 1.0.x is no longer supported.
|
||||
func temporarilySetMachineKeyInPersist() bool {
|
||||
//lint:ignore S1008 for comments
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "android" {
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "ios", "android":
|
||||
// iOS, macOS, Android users can't downgrade anyway.
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package ipn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -19,6 +20,8 @@ import (
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
var jsonEscapedZero = []byte(`\u0000`)
|
||||
|
||||
type NoArgs struct{}
|
||||
|
||||
type StartArgs struct {
|
||||
@@ -85,6 +88,9 @@ func (bs *BackendServer) send(n Notify) {
|
||||
if err != nil {
|
||||
log.Fatalf("Failed json.Marshal(notify): %v\n%#v", err, n)
|
||||
}
|
||||
if bytes.Contains(b, jsonEscapedZero) {
|
||||
log.Printf("[unexpected] zero byte in BackendServer.send notify message: %q", b)
|
||||
}
|
||||
bs.sendNotifyMsg(b)
|
||||
}
|
||||
|
||||
@@ -204,6 +210,9 @@ func (bc *BackendClient) GotNotifyMsg(b []byte) {
|
||||
// not interesting
|
||||
return
|
||||
}
|
||||
if bytes.Contains(b, jsonEscapedZero) {
|
||||
log.Printf("[unexpected] zero byte in BackendClient.GotNotifyMsg message: %q", b)
|
||||
}
|
||||
n := Notify{}
|
||||
if err := json.Unmarshal(b, &n); err != nil {
|
||||
log.Fatalf("BackendClient.Notify: cannot decode message (length=%d)\n%#v", len(b), string(b))
|
||||
@@ -230,6 +239,9 @@ func (bc *BackendClient) send(cmd Command) {
|
||||
if err != nil {
|
||||
log.Fatalf("Failed json.Marshal(cmd): %v\n%#v\n", err, cmd)
|
||||
}
|
||||
if bytes.Contains(b, jsonEscapedZero) {
|
||||
log.Printf("[unexpected] zero byte in BackendClient.send command: %q", b)
|
||||
}
|
||||
bc.sendCommandMsg(b)
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ type Prefs struct {
|
||||
// ShieldsUp indicates whether to block all incoming connections,
|
||||
// regardless of the control-provided packet filter. If false, we
|
||||
// use the packet filter as provided. If true, we block incoming
|
||||
// connections.
|
||||
// connections. This overrides tailcfg.Hostinfo's ShieldsUp.
|
||||
ShieldsUp bool
|
||||
|
||||
// AdvertiseTags specifies groups that this node wants to join, for
|
||||
@@ -152,9 +152,15 @@ func (p *Prefs) pretty(goos string) string {
|
||||
if len(p.AdvertiseRoutes) > 0 || p.NoSNAT {
|
||||
fmt.Fprintf(&sb, "snat=%v ", !p.NoSNAT)
|
||||
}
|
||||
if len(p.AdvertiseTags) > 0 {
|
||||
fmt.Fprintf(&sb, "tags=%s ", strings.Join(p.AdvertiseTags, ","))
|
||||
}
|
||||
if goos == "linux" {
|
||||
fmt.Fprintf(&sb, "nf=%v ", p.NetfilterMode)
|
||||
}
|
||||
if p.ControlURL != "" && p.ControlURL != "https://login.tailscale.com" {
|
||||
fmt.Fprintf(&sb, "url=%q ", p.ControlURL)
|
||||
}
|
||||
if p.Persist != nil {
|
||||
sb.WriteString(p.Persist.Pretty())
|
||||
} else {
|
||||
|
||||
@@ -326,6 +326,32 @@ func TestPrefsPretty(t *testing.T) {
|
||||
"windows",
|
||||
"Prefs{ra=false dns=false want=true server=true Persist=nil}",
|
||||
},
|
||||
{
|
||||
Prefs{
|
||||
AllowSingleHosts: true,
|
||||
WantRunning: true,
|
||||
ControlURL: "http://localhost:1234",
|
||||
AdvertiseTags: []string{"tag:foo", "tag:bar"},
|
||||
},
|
||||
"darwin",
|
||||
`Prefs{ra=false dns=false want=true tags=tag:foo,tag:bar url="http://localhost:1234" Persist=nil}`,
|
||||
},
|
||||
{
|
||||
Prefs{
|
||||
Persist: &controlclient.Persist{},
|
||||
},
|
||||
"linux",
|
||||
`Prefs{ra=false mesh=false dns=false want=false routes=[] nf=off Persist{lm=, o=, n= u=""}}`,
|
||||
},
|
||||
{
|
||||
Prefs{
|
||||
Persist: &controlclient.Persist{
|
||||
PrivateNodeKey: wgcfg.PrivateKey{1: 1},
|
||||
},
|
||||
},
|
||||
"linux",
|
||||
`Prefs{ra=false mesh=false dns=false want=false routes=[] nf=off Persist{lm=, o=, n=[B1VKl] u=""}}`,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
got := tt.p.pretty(tt.os)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
@@ -100,6 +101,14 @@ func (s *FileStore) String() string { return fmt.Sprintf("FileStore(%q)", s.path
|
||||
// NewFileStore returns a new file store that persists to path.
|
||||
func NewFileStore(path string) (*FileStore, error) {
|
||||
bs, err := ioutil.ReadFile(path)
|
||||
|
||||
// Treat an empty file as a missing file.
|
||||
// (https://github.com/tailscale/tailscale/issues/895#issuecomment-723255589)
|
||||
if err == nil && len(bs) == 0 {
|
||||
log.Printf("ipn.NewFileStore(%q): file empty; treating it like a missing file [warning]", path)
|
||||
err = os.ErrNotExist
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Write out an initial file, to verify that we can write
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"tailscale.com/paths"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/racebuild"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
@@ -396,7 +397,7 @@ func New(collection string) *Policy {
|
||||
|
||||
log.Printf("Program starting: v%v, Go %v: %#v",
|
||||
version.Long,
|
||||
strings.TrimPrefix(runtime.Version(), "go"),
|
||||
goVersion(),
|
||||
os.Args)
|
||||
log.Printf("LogID: %v", newc.PublicID)
|
||||
if filchErr != nil {
|
||||
@@ -479,3 +480,11 @@ func newLogtailTransport(host string) *http.Transport {
|
||||
|
||||
return tr
|
||||
}
|
||||
|
||||
func goVersion() string {
|
||||
v := strings.TrimPrefix(runtime.Version(), "go")
|
||||
if racebuild.On {
|
||||
return v + "-race"
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ type Config struct {
|
||||
Buffer Buffer // temp storage, if nil a MemoryBuffer
|
||||
NewZstdEncoder func() Encoder // if set, used to compress logs for transmission
|
||||
|
||||
// DrainLogs, if non-nil, disables autmatic uploading of new logs,
|
||||
// DrainLogs, if non-nil, disables automatic uploading of new logs,
|
||||
// so that logs are only uploaded when a token is sent to DrainLogs.
|
||||
DrainLogs <-chan struct{}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,11 @@ package dnscache
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -25,7 +28,7 @@ func preferGoResolver() bool {
|
||||
// There does not appear to be a local resolver running
|
||||
// on iOS, and NetworkExtension is good at isolating DNS.
|
||||
// So do not use the Go resolver on macOS/iOS.
|
||||
if runtime.GOOS == "darwin" {
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -42,8 +45,6 @@ func preferGoResolver() bool {
|
||||
// Get returns a caching Resolver singleton.
|
||||
func Get() *Resolver { return single }
|
||||
|
||||
const fixedTTL = 10 * time.Minute
|
||||
|
||||
// Resolver is a minimal DNS caching resolver.
|
||||
//
|
||||
// The TTL is always fixed for now. It's not intended for general use.
|
||||
@@ -54,6 +55,15 @@ type Resolver struct {
|
||||
// If nil, net.DefaultResolver is used.
|
||||
Forward *net.Resolver
|
||||
|
||||
// TTL is how long to keep entries cached
|
||||
//
|
||||
// If zero, a default (currently 10 minutes) is used.
|
||||
TTL time.Duration
|
||||
|
||||
// UseLastGood controls whether a cached entry older than TTL is used
|
||||
// if a refresh fails.
|
||||
UseLastGood bool
|
||||
|
||||
sf singleflight.Group
|
||||
|
||||
mu sync.Mutex
|
||||
@@ -72,16 +82,31 @@ func (r *Resolver) fwd() *net.Resolver {
|
||||
return net.DefaultResolver
|
||||
}
|
||||
|
||||
func (r *Resolver) ttl() time.Duration {
|
||||
if r.TTL > 0 {
|
||||
return r.TTL
|
||||
}
|
||||
return 10 * time.Minute
|
||||
}
|
||||
|
||||
var debug, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_DNS_CACHE"))
|
||||
|
||||
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address.
|
||||
func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) {
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4, nil
|
||||
}
|
||||
if debug {
|
||||
log.Printf("dnscache: %q is an IP", host)
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
if ip, ok := r.lookupIPCache(host); ok {
|
||||
if debug {
|
||||
log.Printf("dnscache: %q = %v (cached)", host, ip)
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
@@ -95,10 +120,24 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) {
|
||||
select {
|
||||
case res := <-ch:
|
||||
if res.Err != nil {
|
||||
if r.UseLastGood {
|
||||
if ip, ok := r.lookupIPCacheExpired(host); ok {
|
||||
if debug {
|
||||
log.Printf("dnscache: %q using %v after error", host, ip)
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
if debug {
|
||||
log.Printf("dnscache: error resolving %q: %v", host, res.Err)
|
||||
}
|
||||
return nil, res.Err
|
||||
}
|
||||
return res.Val.(net.IP), nil
|
||||
case <-ctx.Done():
|
||||
if debug {
|
||||
log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err())
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
@@ -112,12 +151,41 @@ func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *Resolver) lookupIPCacheExpired(host string) (ip net.IP, ok bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if ent, ok := r.ipCache[host]; ok {
|
||||
return ent.ip, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *Resolver) lookupTimeoutForHost(host string) time.Duration {
|
||||
if r.UseLastGood {
|
||||
if _, ok := r.lookupIPCacheExpired(host); ok {
|
||||
// If we have some previous good value for this host,
|
||||
// don't give this DNS lookup much time. If we're in a
|
||||
// situation where the user's DNS server is unreachable
|
||||
// (e.g. their corp DNS server is behind a subnet router
|
||||
// that can't come up due to Tailscale needing to
|
||||
// connect to itself), then we want to fail fast and let
|
||||
// our caller (who set UseLastGood) fall back to using
|
||||
// the last-known-good IP address.
|
||||
return 3 * time.Second
|
||||
}
|
||||
}
|
||||
return 10 * time.Second
|
||||
}
|
||||
|
||||
func (r *Resolver) lookupIP(host string) (net.IP, error) {
|
||||
if ip, ok := r.lookupIPCache(host); ok {
|
||||
if debug {
|
||||
log.Printf("dnscache: %q found in cache as %v", host, ip)
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host))
|
||||
defer cancel()
|
||||
ips, err := r.fwd().LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
@@ -129,19 +197,26 @@ func (r *Resolver) lookupIP(host string) (net.IP, error) {
|
||||
|
||||
for _, ipa := range ips {
|
||||
if ip4 := ipa.IP.To4(); ip4 != nil {
|
||||
return r.addIPCache(host, ip4, fixedTTL), nil
|
||||
return r.addIPCache(host, ip4, r.ttl()), nil
|
||||
}
|
||||
}
|
||||
return r.addIPCache(host, ips[0].IP, fixedTTL), nil
|
||||
return r.addIPCache(host, ips[0].IP, r.ttl()), nil
|
||||
}
|
||||
|
||||
func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP {
|
||||
if isPrivateIP(ip) {
|
||||
// Don't cache obviously wrong entries from captive portals.
|
||||
// TODO: use DoH or DoT for the forwarding resolver?
|
||||
if debug {
|
||||
log.Printf("dnscache: %q resolved to private IP %v; using but not caching", host, ip)
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
if debug {
|
||||
log.Printf("dnscache: %q resolved to IP %v; caching", host, ip)
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.ipCache == nil {
|
||||
@@ -168,3 +243,26 @@ var (
|
||||
private2 = mustCIDR("172.16.0.0/12")
|
||||
private3 = mustCIDR("192.168.0.0/16")
|
||||
)
|
||||
|
||||
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
|
||||
func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// Bogus. But just let the real dialer return an error rather than
|
||||
// inventing a similar one.
|
||||
return fwd(ctx, network, address)
|
||||
}
|
||||
ip, err := dnsCache.LookupIP(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
|
||||
}
|
||||
dst := net.JoinHostPort(ip.String(), port)
|
||||
if debug {
|
||||
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
|
||||
}
|
||||
return fwd(ctx, network, dst)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/exec"
|
||||
|
||||
"go4.org/mem"
|
||||
@@ -62,8 +63,12 @@ func likelyHomeRouterIPDarwinExec() (ret netaddr.IP, ok bool) {
|
||||
ip, err := netaddr.ParseIP(string(mem.Append(nil, ipm)))
|
||||
if err == nil && isPrivateIP(ip) {
|
||||
ret = ip
|
||||
// We've found what we're looking for.
|
||||
return stopReadingNetstatTable
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return ret, !ret.IsZero()
|
||||
}
|
||||
|
||||
var stopReadingNetstatTable = errors.New("found private gateway")
|
||||
|
||||
@@ -7,6 +7,7 @@ package interfaces
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
@@ -176,7 +177,12 @@ func getPACWindows() string {
|
||||
return ""
|
||||
}
|
||||
defer globalFree.Call(uintptr(unsafe.Pointer(res)))
|
||||
return windows.UTF16PtrToString(res)
|
||||
s := windows.UTF16PtrToString(res)
|
||||
if _, err := url.Parse(s); err != nil {
|
||||
log.Printf("getPACWindows: invalid URL %q from winhttp; ignoring", s)
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
const (
|
||||
ERROR_WINHTTP_AUTODETECTION_FAILED = 12180
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"go4.org/mem"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/net/stun"
|
||||
@@ -134,10 +133,6 @@ func cloneDurationMap(m map[int]time.Duration) map[int]time.Duration {
|
||||
|
||||
// Client generates a netcheck Report.
|
||||
type Client struct {
|
||||
// DNSCache optionally specifies a DNSCache to use.
|
||||
// If nil, a DNS cache is not used.
|
||||
DNSCache *dnscache.Resolver
|
||||
|
||||
// Verbose enables verbose logging.
|
||||
Verbose bool
|
||||
|
||||
|
||||
@@ -5,14 +5,14 @@
|
||||
package netns
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/util/endian"
|
||||
)
|
||||
|
||||
func interfaceIndex(iface *winipcfg.IPAdapterAddresses) uint32 {
|
||||
@@ -114,7 +114,8 @@ func bindSocket6(c syscall.RawConn, ifidx uint32) error {
|
||||
// representation, suitable for passing to Windows APIs that require a
|
||||
// mangled uint32.
|
||||
func nativeToBigEndian(i uint32) uint32 {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], i)
|
||||
return *(*uint32)(unsafe.Pointer(&b[0]))
|
||||
if endian.Big {
|
||||
return i
|
||||
}
|
||||
return bits.ReverseBytes32(i)
|
||||
}
|
||||
|
||||
@@ -6,14 +6,15 @@
|
||||
package netstat
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/util/endian"
|
||||
)
|
||||
|
||||
// See https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable
|
||||
@@ -92,7 +93,7 @@ func (t *Table) addEntries(fam int) error {
|
||||
}
|
||||
buf = buf[:size]
|
||||
|
||||
numEntries := *(*uint32)(unsafe.Pointer(&buf[0]))
|
||||
numEntries := endian.Native.Uint32(buf[:4])
|
||||
buf = buf[4:]
|
||||
|
||||
var recSize int
|
||||
@@ -153,9 +154,11 @@ func state(v uint32) string {
|
||||
}
|
||||
|
||||
func ipport4(addr uint32, port uint16) netaddr.IPPort {
|
||||
a4 := (*[4]byte)(unsafe.Pointer(&addr))
|
||||
if !endian.Big {
|
||||
addr = bits.ReverseBytes32(addr)
|
||||
}
|
||||
return netaddr.IPPort{
|
||||
IP: netaddr.IPv4(a4[0], a4[1], a4[2], a4[3]),
|
||||
IP: netaddr.IPv4(byte(addr>>24), byte(addr>>16), byte(addr>>8), byte(addr)),
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
@@ -173,6 +176,8 @@ func ipport6(addr [16]byte, scope uint32, port uint16) netaddr.IPPort {
|
||||
}
|
||||
|
||||
func port(v *uint32) uint16 {
|
||||
p := (*[4]byte)(unsafe.Pointer(v))
|
||||
return binary.BigEndian.Uint16(p[:2])
|
||||
if !endian.Big {
|
||||
return uint16(bits.ReverseBytes32(*v) >> 16)
|
||||
}
|
||||
return uint16(*v >> 16)
|
||||
}
|
||||
|
||||
16
net/packet/doc.go
Normal file
16
net/packet/doc.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package packet contains packet parsing and marshaling utilities.
|
||||
//
|
||||
// Parsed provides allocation-free minimal packet header decoding, for
|
||||
// use in packet filtering. The other types in the package are for
|
||||
// constructing and marshaling packets into []bytes.
|
||||
//
|
||||
// To support allocation-free parsing, this package defines IPv4 and
|
||||
// IPv6 address types. You should prefer to use netaddr's types,
|
||||
// except where you absolutely need allocation-free IP handling
|
||||
// (i.e. in the tunnel datapath) and are willing to implement all
|
||||
// codepaths and data structures twice, once per IP family.
|
||||
package packet
|
||||
@@ -16,27 +16,34 @@ const tcpHeaderLength = 20
|
||||
const maxPacketLength = math.MaxUint16
|
||||
|
||||
var (
|
||||
// errSmallBuffer is returned when Marshal receives a buffer
|
||||
// too small to contain the header to marshal.
|
||||
errSmallBuffer = errors.New("buffer too small")
|
||||
// errLargePacket is returned when Marshal receives a payload
|
||||
// larger than the maximum representable size in header
|
||||
// fields.
|
||||
errLargePacket = errors.New("packet too large")
|
||||
)
|
||||
|
||||
// Header is a packet header capable of marshaling itself into a byte buffer.
|
||||
// Header is a packet header capable of marshaling itself into a byte
|
||||
// buffer.
|
||||
type Header interface {
|
||||
// Len returns the length of the header after marshaling.
|
||||
// Len returns the length of the marshaled packet.
|
||||
Len() int
|
||||
// Marshal serializes the header into buf in wire format.
|
||||
// It clobbers the header region, which is the first h.Length() bytes of buf.
|
||||
// It explicitly initializes every byte of the header region,
|
||||
// so pre-zeroing it on reuse is not required. It does not allocate memory.
|
||||
// It fails if and only if len(buf) < Length().
|
||||
// Marshal serializes the header into buf, which must be at
|
||||
// least Len() bytes long. Implementations of Marshal assume
|
||||
// that bytes after the first Len() are payload bytes for the
|
||||
// purpose of computing length and checksum fields. Marshal
|
||||
// implementations must not allocate memory.
|
||||
Marshal(buf []byte) error
|
||||
// ToResponse transforms the header into one for a response packet.
|
||||
// For instance, this swaps the source and destination IPs.
|
||||
ToResponse()
|
||||
}
|
||||
|
||||
// Generate generates a new packet with the given header and payload.
|
||||
// Unlike Header.Marshal, this does allocate memory.
|
||||
// Generate generates a new packet with the given Header and
|
||||
// payload. This function allocates memory, see Header.Marshal for an
|
||||
// allocation-free option.
|
||||
func Generate(h Header, payload []byte) []byte {
|
||||
hlen := h.Len()
|
||||
buf := make([]byte, hlen+len(payload))
|
||||
90
net/packet/icmp4.go
Normal file
90
net/packet/icmp4.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// icmp4HeaderLength is the size of the ICMPv4 packet header, not
|
||||
// including the outer IP layer or the variable "response data"
|
||||
// trailer.
|
||||
const icmp4HeaderLength = 4
|
||||
|
||||
// ICMP4Type is an ICMPv4 type, as specified in
|
||||
// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
|
||||
type ICMP4Type uint8
|
||||
|
||||
const (
|
||||
ICMP4EchoReply ICMP4Type = 0x00
|
||||
ICMP4EchoRequest ICMP4Type = 0x08
|
||||
ICMP4Unreachable ICMP4Type = 0x03
|
||||
ICMP4TimeExceeded ICMP4Type = 0x0b
|
||||
)
|
||||
|
||||
func (t ICMP4Type) String() string {
|
||||
switch t {
|
||||
case ICMP4EchoReply:
|
||||
return "EchoReply"
|
||||
case ICMP4EchoRequest:
|
||||
return "EchoRequest"
|
||||
case ICMP4Unreachable:
|
||||
return "Unreachable"
|
||||
case ICMP4TimeExceeded:
|
||||
return "TimeExceeded"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ICMP4Code is an ICMPv4 code, as specified in
|
||||
// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
|
||||
type ICMP4Code uint8
|
||||
|
||||
const (
|
||||
ICMP4NoCode ICMP4Code = 0
|
||||
)
|
||||
|
||||
// ICMP4Header is an IPv4+ICMPv4 header.
|
||||
type ICMP4Header struct {
|
||||
IP4Header
|
||||
Type ICMP4Type
|
||||
Code ICMP4Code
|
||||
}
|
||||
|
||||
// Len implements Header.
|
||||
func (h ICMP4Header) Len() int {
|
||||
return h.IP4Header.Len() + icmp4HeaderLength
|
||||
}
|
||||
|
||||
// Marshal implements Header.
|
||||
func (h ICMP4Header) Marshal(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if len(buf) > maxPacketLength {
|
||||
return errLargePacket
|
||||
}
|
||||
// The caller does not need to set this.
|
||||
h.IPProto = ICMPv4
|
||||
|
||||
buf[20] = uint8(h.Type)
|
||||
buf[21] = uint8(h.Code)
|
||||
|
||||
h.IP4Header.Marshal(buf)
|
||||
|
||||
binary.BigEndian.PutUint16(buf[22:24], ip4Checksum(buf))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToResponse implements Header. TODO: it doesn't implement it
|
||||
// correctly, instead it statically generates an ICMP Echo Reply
|
||||
// packet.
|
||||
func (h *ICMP4Header) ToResponse() {
|
||||
// TODO: this doesn't implement ToResponse correctly, as it
|
||||
// assumes the ICMP request type.
|
||||
h.Type = ICMP4EchoReply
|
||||
h.Code = ICMP4NoCode
|
||||
h.IP4Header.ToResponse()
|
||||
}
|
||||
44
net/packet/icmp6.go
Normal file
44
net/packet/icmp6.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
// icmp6HeaderLength is the size of the ICMPv6 packet header, not
|
||||
// including the outer IP layer or the variable "response data"
|
||||
// trailer.
|
||||
const icmp6HeaderLength = 4
|
||||
|
||||
// ICMP6Type is an ICMPv6 type, as specified in
|
||||
// https://www.iana.org/assignments/icmpv6-parameters/icmpv6-parameters.xhtml
|
||||
type ICMP6Type uint8
|
||||
|
||||
const (
|
||||
ICMP6Unreachable ICMP6Type = 1
|
||||
ICMP6TimeExceeded ICMP6Type = 3
|
||||
ICMP6EchoRequest ICMP6Type = 128
|
||||
ICMP6EchoReply ICMP6Type = 129
|
||||
)
|
||||
|
||||
func (t ICMP6Type) String() string {
|
||||
switch t {
|
||||
case ICMP6Unreachable:
|
||||
return "Unreachable"
|
||||
case ICMP6TimeExceeded:
|
||||
return "TimeExceeded"
|
||||
case ICMP6EchoRequest:
|
||||
return "EchoRequest"
|
||||
case ICMP6EchoReply:
|
||||
return "EchoReply"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ICMP6Code is an ICMPv6 code, as specified in
|
||||
// https://www.iana.org/assignments/icmpv6-parameters/icmpv6-parameters.xhtml
|
||||
type ICMP6Code uint8
|
||||
|
||||
const (
|
||||
ICMP6NoCode ICMP6Code = 0
|
||||
)
|
||||
53
net/packet/ip.go
Normal file
53
net/packet/ip.go
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
// IPProto is an IP subprotocol as defined by the IANA protocol
|
||||
// numbers list
|
||||
// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml),
|
||||
// or the special values Unknown or Fragment.
|
||||
type IPProto uint8
|
||||
|
||||
const (
|
||||
// Unknown represents an unknown or unsupported protocol; it's
|
||||
// deliberately the zero value. Strictly speaking the zero
|
||||
// value is IPv6 hop-by-hop extensions, but we don't support
|
||||
// those, so this is still technically correct.
|
||||
Unknown IPProto = 0x00
|
||||
|
||||
// Values from the IANA registry.
|
||||
ICMPv4 IPProto = 0x01
|
||||
IGMP IPProto = 0x02
|
||||
ICMPv6 IPProto = 0x3a
|
||||
TCP IPProto = 0x06
|
||||
UDP IPProto = 0x11
|
||||
|
||||
// Fragment represents any non-first IP fragment, for which we
|
||||
// don't have the sub-protocol header (and therefore can't
|
||||
// figure out what the sub-protocol is).
|
||||
//
|
||||
// 0xFF is reserved in the IANA registry, so we steal it for
|
||||
// internal use.
|
||||
Fragment IPProto = 0xFF
|
||||
)
|
||||
|
||||
func (p IPProto) String() string {
|
||||
switch p {
|
||||
case Fragment:
|
||||
return "Frag"
|
||||
case ICMPv4:
|
||||
return "ICMPv4"
|
||||
case IGMP:
|
||||
return "IGMP"
|
||||
case ICMPv6:
|
||||
return "ICMPv6"
|
||||
case UDP:
|
||||
return "UDP"
|
||||
case TCP:
|
||||
return "TCP"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
144
net/packet/ip4.go
Normal file
144
net/packet/ip4.go
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// IP4 is an IPv4 address.
|
||||
type IP4 uint32
|
||||
|
||||
// IPFromNetaddr converts a netaddr.IP to an IP4. Panics if !ip.Is4.
|
||||
func IP4FromNetaddr(ip netaddr.IP) IP4 {
|
||||
ipbytes := ip.As4()
|
||||
return IP4(binary.BigEndian.Uint32(ipbytes[:]))
|
||||
}
|
||||
|
||||
// Netaddr converts ip to a netaddr.IP.
|
||||
func (ip IP4) Netaddr() netaddr.IP {
|
||||
return netaddr.IPv4(byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
|
||||
}
|
||||
|
||||
func (ip IP4) String() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
|
||||
}
|
||||
|
||||
// IsMulticast returns whether ip is a multicast address.
|
||||
func (ip IP4) IsMulticast() bool {
|
||||
return byte(ip>>24)&0xf0 == 0xe0
|
||||
}
|
||||
|
||||
// IsLinkLocalUnicast returns whether ip is a link-local unicast
|
||||
// address.
|
||||
func (ip IP4) IsLinkLocalUnicast() bool {
|
||||
return byte(ip>>24) == 169 && byte(ip>>16) == 254
|
||||
}
|
||||
|
||||
// IsMostLinkLocalUnicast returns whether ip is a link-local unicast
|
||||
// address other than the magical "169.254.169.254" address used by
|
||||
// GCP DNS.
|
||||
func (ip IP4) IsMostLinkLocalUnicast() bool {
|
||||
return ip.IsLinkLocalUnicast() && ip != 0xA9FEA9FE
|
||||
}
|
||||
|
||||
// ip4HeaderLength is the length of an IPv4 header with no IP options.
|
||||
const ip4HeaderLength = 20
|
||||
|
||||
// IP4Header represents an IPv4 packet header.
|
||||
type IP4Header struct {
|
||||
IPProto IPProto
|
||||
IPID uint16
|
||||
SrcIP IP4
|
||||
DstIP IP4
|
||||
}
|
||||
|
||||
// Len implements Header.
|
||||
func (h IP4Header) Len() int {
|
||||
return ip4HeaderLength
|
||||
}
|
||||
|
||||
// Marshal implements Header.
|
||||
func (h IP4Header) Marshal(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if len(buf) > maxPacketLength {
|
||||
return errLargePacket
|
||||
}
|
||||
|
||||
buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL
|
||||
buf[1] = 0x00 // DSCP + ECN
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length
|
||||
binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID
|
||||
binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset
|
||||
buf[8] = 64 // TTL
|
||||
buf[9] = uint8(h.IPProto) // Inner protocol
|
||||
// Blank checksum. This is necessary even though we overwrite
|
||||
// it later, because the checksum computation runs over these
|
||||
// bytes and expects them to be zero.
|
||||
binary.BigEndian.PutUint16(buf[10:12], 0)
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(h.SrcIP)) // Src
|
||||
binary.BigEndian.PutUint32(buf[16:20], uint32(h.DstIP)) // Dst
|
||||
|
||||
binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToResponse implements Header.
|
||||
func (h *IP4Header) ToResponse() {
|
||||
h.SrcIP, h.DstIP = h.DstIP, h.SrcIP
|
||||
// Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
|
||||
h.IPID = ^h.IPID
|
||||
}
|
||||
|
||||
// ip4Checksum computes an IPv4 checksum, as specified in
|
||||
// https://tools.ietf.org/html/rfc1071
|
||||
func ip4Checksum(b []byte) uint16 {
|
||||
var ac uint32
|
||||
i := 0
|
||||
n := len(b)
|
||||
for n >= 2 {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[i : i+2]))
|
||||
n -= 2
|
||||
i += 2
|
||||
}
|
||||
if n == 1 {
|
||||
ac += uint32(b[i]) << 8
|
||||
}
|
||||
for (ac >> 16) > 0 {
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
}
|
||||
return uint16(^ac)
|
||||
}
|
||||
|
||||
// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP
|
||||
// pseudo-header is smaller than the real IPv4 header.
|
||||
const ip4PseudoHeaderOffset = 8
|
||||
|
||||
// marshalPseudo serializes h into buf in the "pseudo-header" form
|
||||
// required when calculating UDP checksums. The pseudo-header starts
|
||||
// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP
|
||||
// header, while leaving enough space in buf for a full IPv4 header.
|
||||
func (h IP4Header) marshalPseudo(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if len(buf) > maxPacketLength {
|
||||
return errLargePacket
|
||||
}
|
||||
|
||||
length := len(buf) - h.Len()
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(h.SrcIP))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(h.DstIP))
|
||||
buf[16] = 0x0
|
||||
buf[17] = uint8(h.IPProto)
|
||||
binary.BigEndian.PutUint16(buf[18:20], uint16(length))
|
||||
return nil
|
||||
}
|
||||
113
net/packet/ip6.go
Normal file
113
net/packet/ip6.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// IP6 is an IPv6 address.
|
||||
type IP6 struct {
|
||||
Hi, Lo uint64
|
||||
}
|
||||
|
||||
// IP6FromNetaddr converts a netaddr.IP to an IP6. Panics if !ip.Is6.
|
||||
func IP6FromNetaddr(ip netaddr.IP) IP6 {
|
||||
if !ip.Is6() {
|
||||
panic(fmt.Sprintf("IP6FromNetaddr called with non-v6 addr %q", ip))
|
||||
}
|
||||
b := ip.As16()
|
||||
return IP6{binary.BigEndian.Uint64(b[:8]), binary.BigEndian.Uint64(b[8:])}
|
||||
}
|
||||
|
||||
// Netaddr converts ip to a netaddr.IP.
|
||||
func (ip IP6) Netaddr() netaddr.IP {
|
||||
var b [16]byte
|
||||
binary.BigEndian.PutUint64(b[:8], ip.Hi)
|
||||
binary.BigEndian.PutUint64(b[8:], ip.Lo)
|
||||
return netaddr.IPFrom16(b)
|
||||
}
|
||||
|
||||
func (ip IP6) String() string {
|
||||
return ip.Netaddr().String()
|
||||
}
|
||||
|
||||
func (ip IP6) IsMulticast() bool {
|
||||
return (ip.Hi >> 56) == 0xFF
|
||||
}
|
||||
|
||||
func (ip IP6) IsLinkLocalUnicast() bool {
|
||||
return (ip.Hi >> 48) == 0xFE80
|
||||
}
|
||||
|
||||
// ip6HeaderLength is the length of an IPv6 header with no IP options.
|
||||
const ip6HeaderLength = 40
|
||||
|
||||
// IP6Header represents an IPv6 packet header.
|
||||
type IP6Header struct {
|
||||
IPProto IPProto
|
||||
IPID uint32 // only lower 20 bits used
|
||||
SrcIP IP6
|
||||
DstIP IP6
|
||||
}
|
||||
|
||||
// Len implements Header.
|
||||
func (h IP6Header) Len() int {
|
||||
return ip6HeaderLength
|
||||
}
|
||||
|
||||
// Marshal implements Header.
|
||||
func (h IP6Header) Marshal(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if len(buf) > maxPacketLength {
|
||||
return errLargePacket
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF)
|
||||
buf[0] = 0x60
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length
|
||||
buf[6] = uint8(h.IPProto) // Inner protocol
|
||||
buf[7] = 64 // TTL
|
||||
binary.BigEndian.PutUint64(buf[8:16], h.SrcIP.Hi)
|
||||
binary.BigEndian.PutUint64(buf[16:24], h.SrcIP.Lo)
|
||||
binary.BigEndian.PutUint64(buf[24:32], h.DstIP.Hi)
|
||||
binary.BigEndian.PutUint64(buf[32:40], h.DstIP.Lo)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToResponse implements Header.
|
||||
func (h *IP6Header) ToResponse() {
|
||||
h.SrcIP, h.DstIP = h.DstIP, h.SrcIP
|
||||
// Flip the bits in the IPID. If incoming IPIDs are distinct, so are these.
|
||||
h.IPID = (^h.IPID) & 0x000FFFFF
|
||||
}
|
||||
|
||||
// marshalPseudo serializes h into buf in the "pseudo-header" form
|
||||
// required when calculating UDP checksums.
|
||||
func (h IP6Header) marshalPseudo(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if len(buf) > maxPacketLength {
|
||||
return errLargePacket
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint64(buf[:8], h.SrcIP.Hi)
|
||||
binary.BigEndian.PutUint64(buf[8:16], h.SrcIP.Lo)
|
||||
binary.BigEndian.PutUint64(buf[16:24], h.DstIP.Hi)
|
||||
binary.BigEndian.PutUint64(buf[24:32], h.DstIP.Lo)
|
||||
binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len()))
|
||||
buf[36] = 0
|
||||
buf[37] = 0
|
||||
buf[38] = 0
|
||||
buf[39] = 17 // NextProto
|
||||
return nil
|
||||
}
|
||||
434
net/packet/packet.go
Normal file
434
net/packet/packet.go
Normal file
@@ -0,0 +1,434 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/types/strbuilder"
|
||||
)
|
||||
|
||||
// RFC1858: prevent overlapping fragment attacks.
|
||||
const minFrag = 60 + 20 // max IPv4 header + basic TCP header
|
||||
|
||||
const (
|
||||
TCPSyn = 0x02
|
||||
TCPAck = 0x10
|
||||
TCPSynAck = TCPSyn | TCPAck
|
||||
)
|
||||
|
||||
// Parsed is a minimal decoding of a packet suitable for use in filters.
|
||||
type Parsed struct {
|
||||
// b is the byte buffer that this decodes.
|
||||
b []byte
|
||||
// subofs is the offset of IP subprotocol.
|
||||
subofs int
|
||||
// dataofs is the offset of IP subprotocol payload.
|
||||
dataofs int
|
||||
// length is the total length of the packet.
|
||||
// This is not the same as len(b) because b can have trailing zeros.
|
||||
length int
|
||||
|
||||
// IPVersion is the IP protocol version of the packet (4 or
|
||||
// 6), or 0 if the packet doesn't look like IPv4 or IPv6.
|
||||
IPVersion uint8
|
||||
// IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0.
|
||||
IPProto IPProto
|
||||
// SrcIP4 is the IPv4 source address. Valid iff IPVersion == 4.
|
||||
SrcIP4 IP4
|
||||
// DstIP4 is the IPv4 destination address. Valid iff IPVersion == 4.
|
||||
DstIP4 IP4
|
||||
// SrcIP6 is the IPv6 source address. Valid iff IPVersion == 6.
|
||||
SrcIP6 IP6
|
||||
// DstIP6 is the IPv6 destination address. Valid iff IPVersion == 6.
|
||||
DstIP6 IP6
|
||||
// SrcPort is the TCP/UDP source port. Valid iff IPProto == TCP || IPProto == UDP.
|
||||
SrcPort uint16
|
||||
// DstPort is the TCP/UDP source port. Valid iff IPProto == TCP || IPProto == UDP.
|
||||
DstPort uint16
|
||||
// TCPFlags is the packet's TCP flag bigs. Valid iff IPProto == TCP.
|
||||
TCPFlags uint8
|
||||
}
|
||||
|
||||
func (p *Parsed) String() string {
|
||||
switch p.IPVersion {
|
||||
case 4:
|
||||
sb := strbuilder.Get()
|
||||
sb.WriteString(p.IPProto.String())
|
||||
sb.WriteByte('{')
|
||||
writeIP4Port(sb, p.SrcIP4, p.SrcPort)
|
||||
sb.WriteString(" > ")
|
||||
writeIP4Port(sb, p.DstIP4, p.DstPort)
|
||||
sb.WriteByte('}')
|
||||
return sb.String()
|
||||
case 6:
|
||||
sb := strbuilder.Get()
|
||||
sb.WriteString(p.IPProto.String())
|
||||
sb.WriteByte('{')
|
||||
writeIP6Port(sb, p.SrcIP6, p.SrcPort)
|
||||
sb.WriteString(" > ")
|
||||
writeIP6Port(sb, p.DstIP6, p.DstPort)
|
||||
sb.WriteByte('}')
|
||||
return sb.String()
|
||||
default:
|
||||
return "Unknown{???}"
|
||||
}
|
||||
}
|
||||
|
||||
func writeIP4Port(sb *strbuilder.Builder, ip IP4, port uint16) {
|
||||
sb.WriteUint(uint64(byte(ip >> 24)))
|
||||
sb.WriteByte('.')
|
||||
sb.WriteUint(uint64(byte(ip >> 16)))
|
||||
sb.WriteByte('.')
|
||||
sb.WriteUint(uint64(byte(ip >> 8)))
|
||||
sb.WriteByte('.')
|
||||
sb.WriteUint(uint64(byte(ip)))
|
||||
sb.WriteByte(':')
|
||||
sb.WriteUint(uint64(port))
|
||||
}
|
||||
|
||||
func writeIP6Port(sb *strbuilder.Builder, ip IP6, port uint16) {
|
||||
sb.WriteByte('[')
|
||||
sb.WriteString(ip.Netaddr().String()) // TODO: faster?
|
||||
sb.WriteString("]:")
|
||||
sb.WriteUint(uint64(port))
|
||||
}
|
||||
|
||||
// Decode extracts data from the packet in b into q.
|
||||
// It performs extremely simple packet decoding for basic IPv4 packet types.
|
||||
// It extracts only the subprotocol id, IP addresses, and (if any) ports,
|
||||
// and shouldn't need any memory allocation.
|
||||
func (q *Parsed) Decode(b []byte) {
|
||||
q.b = b
|
||||
|
||||
if len(b) < 1 {
|
||||
q.IPVersion = 0
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
|
||||
q.IPVersion = b[0] >> 4
|
||||
switch q.IPVersion {
|
||||
case 4:
|
||||
q.decode4(b)
|
||||
case 6:
|
||||
q.decode6(b)
|
||||
default:
|
||||
q.IPVersion = 0
|
||||
q.IPProto = Unknown
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Parsed) decode4(b []byte) {
|
||||
if len(b) < ip4HeaderLength {
|
||||
q.IPVersion = 0
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
|
||||
// Check that it's IPv4.
|
||||
q.IPProto = IPProto(b[9])
|
||||
q.length = int(binary.BigEndian.Uint16(b[2:4]))
|
||||
if len(b) < q.length {
|
||||
// Packet was cut off before full IPv4 length.
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
|
||||
// If it's valid IPv4, then the IP addresses are valid
|
||||
q.SrcIP4 = IP4(binary.BigEndian.Uint32(b[12:16]))
|
||||
q.DstIP4 = IP4(binary.BigEndian.Uint32(b[16:20]))
|
||||
|
||||
q.subofs = int((b[0] & 0x0F) << 2)
|
||||
if q.subofs > q.length {
|
||||
// next-proto starts beyond end of packet.
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
sub := b[q.subofs:]
|
||||
sub = sub[:len(sub):len(sub)] // help the compiler do bounds check elimination
|
||||
|
||||
// We don't care much about IP fragmentation, except insofar as it's
|
||||
// used for firewall bypass attacks. The trick is make the first
|
||||
// fragment of a TCP or UDP packet so short that it doesn't fit
|
||||
// the TCP or UDP header, so we can't read the port, in hope that
|
||||
// it'll sneak past. Then subsequent fragments fill it in, but we're
|
||||
// missing the first part of the header, so we can't read that either.
|
||||
//
|
||||
// A "perfectly correct" implementation would have to reassemble
|
||||
// fragments before deciding what to do. But the truth is there's
|
||||
// zero reason to send such a short first fragment, so we can treat
|
||||
// it as Unknown. We can also treat any subsequent fragment that starts
|
||||
// at such a low offset as Unknown.
|
||||
fragFlags := binary.BigEndian.Uint16(b[6:8])
|
||||
moreFrags := (fragFlags & 0x20) != 0
|
||||
fragOfs := fragFlags & 0x1FFF
|
||||
if fragOfs == 0 {
|
||||
// This is the first fragment
|
||||
if moreFrags && len(sub) < minFrag {
|
||||
// Suspiciously short first fragment, dump it.
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
// otherwise, this is either non-fragmented (the usual case)
|
||||
// or a big enough initial fragment that we can read the
|
||||
// whole subprotocol header.
|
||||
switch q.IPProto {
|
||||
case ICMPv4:
|
||||
if len(sub) < icmp4HeaderLength {
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
q.SrcPort = 0
|
||||
q.DstPort = 0
|
||||
q.dataofs = q.subofs + icmp4HeaderLength
|
||||
return
|
||||
case IGMP:
|
||||
// Keep IPProto, but don't parse anything else
|
||||
// out.
|
||||
return
|
||||
case TCP:
|
||||
if len(sub) < tcpHeaderLength {
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
|
||||
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
|
||||
q.TCPFlags = sub[13] & 0x3F
|
||||
headerLength := (sub[12] & 0xF0) >> 2
|
||||
q.dataofs = q.subofs + int(headerLength)
|
||||
return
|
||||
case UDP:
|
||||
if len(sub) < udpHeaderLength {
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
|
||||
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
|
||||
q.dataofs = q.subofs + udpHeaderLength
|
||||
return
|
||||
default:
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// This is a fragment other than the first one.
|
||||
if fragOfs < minFrag {
|
||||
// First frag was suspiciously short, so we can't
|
||||
// trust the followup either.
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
// otherwise, we have to permit the fragment to slide through.
|
||||
// Second and later fragments don't have sub-headers.
|
||||
// Ideally, we would drop fragments that we can't identify,
|
||||
// but that would require statefulness. Anyway, receivers'
|
||||
// kernels know to drop fragments where the initial fragment
|
||||
// doesn't arrive.
|
||||
q.IPProto = Fragment
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Parsed) decode6(b []byte) {
|
||||
if len(b) < ip6HeaderLength {
|
||||
q.IPVersion = 0
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
|
||||
q.IPProto = IPProto(b[6])
|
||||
q.length = int(binary.BigEndian.Uint16(b[4:6])) + ip6HeaderLength
|
||||
if len(b) < q.length {
|
||||
// Packet was cut off before the full IPv6 length.
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
|
||||
q.SrcIP6.Hi = binary.BigEndian.Uint64(b[8:16])
|
||||
q.SrcIP6.Lo = binary.BigEndian.Uint64(b[16:24])
|
||||
q.DstIP6.Hi = binary.BigEndian.Uint64(b[24:32])
|
||||
q.DstIP6.Lo = binary.BigEndian.Uint64(b[32:40])
|
||||
|
||||
// We don't support any IPv6 extension headers. Don't try to
|
||||
// be clever. Therefore, the IP subprotocol always starts at
|
||||
// byte 40.
|
||||
//
|
||||
// Note that this means we don't support fragmentation in
|
||||
// IPv6. This is fine, because IPv6 strongly mandates that you
|
||||
// should not fragment, which makes fragmentation on the open
|
||||
// internet extremely uncommon.
|
||||
//
|
||||
// This also means we don't support IPSec headers (AH/ESP), or
|
||||
// IPv6 jumbo frames. Those will get marked Unknown and
|
||||
// dropped.
|
||||
q.subofs = 40
|
||||
sub := b[q.subofs:]
|
||||
sub = sub[:len(sub):len(sub)] // help the compiler do bounds check elimination
|
||||
|
||||
switch q.IPProto {
|
||||
case ICMPv6:
|
||||
if len(sub) < icmp6HeaderLength {
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
q.SrcPort = 0
|
||||
q.DstPort = 0
|
||||
q.dataofs = q.subofs + icmp6HeaderLength
|
||||
case TCP:
|
||||
if len(sub) < tcpHeaderLength {
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
|
||||
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
|
||||
q.TCPFlags = sub[13] & 0x3F
|
||||
headerLength := (sub[12] & 0xF0) >> 2
|
||||
q.dataofs = q.subofs + int(headerLength)
|
||||
return
|
||||
case UDP:
|
||||
if len(sub) < udpHeaderLength {
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
|
||||
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
|
||||
q.dataofs = q.subofs + udpHeaderLength
|
||||
default:
|
||||
q.IPProto = Unknown
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Parsed) IP4Header() IP4Header {
|
||||
if q.IPVersion != 4 {
|
||||
panic("IP4Header called on non-IPv4 Parsed")
|
||||
}
|
||||
ipid := binary.BigEndian.Uint16(q.b[4:6])
|
||||
return IP4Header{
|
||||
IPID: ipid,
|
||||
IPProto: q.IPProto,
|
||||
SrcIP: q.SrcIP4,
|
||||
DstIP: q.DstIP4,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Parsed) ICMP4Header() ICMP4Header {
|
||||
if q.IPVersion != 4 {
|
||||
panic("IP4Header called on non-IPv4 Parsed")
|
||||
}
|
||||
return ICMP4Header{
|
||||
IP4Header: q.IP4Header(),
|
||||
Type: ICMP4Type(q.b[q.subofs+0]),
|
||||
Code: ICMP4Code(q.b[q.subofs+1]),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Parsed) UDP4Header() UDP4Header {
|
||||
if q.IPVersion != 4 {
|
||||
panic("IP4Header called on non-IPv4 Parsed")
|
||||
}
|
||||
return UDP4Header{
|
||||
IP4Header: q.IP4Header(),
|
||||
SrcPort: q.SrcPort,
|
||||
DstPort: q.DstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer returns the entire packet buffer.
|
||||
// This is a read-only view; that is, q retains the ownership of the buffer.
|
||||
func (q *Parsed) Buffer() []byte {
|
||||
return q.b
|
||||
}
|
||||
|
||||
// Payload returns the payload of the IP subprotocol section.
|
||||
// This is a read-only view; that is, q retains the ownership of the buffer.
|
||||
func (q *Parsed) Payload() []byte {
|
||||
return q.b[q.dataofs:q.length]
|
||||
}
|
||||
|
||||
// IsTCPSyn reports whether q is a TCP SYN packet
|
||||
// (i.e. the first packet in a new connection).
|
||||
func (q *Parsed) IsTCPSyn() bool {
|
||||
return (q.TCPFlags & TCPSynAck) == TCPSyn
|
||||
}
|
||||
|
||||
// IsError reports whether q is an ICMP "Error" packet.
|
||||
func (q *Parsed) IsError() bool {
|
||||
switch q.IPProto {
|
||||
case ICMPv4:
|
||||
if len(q.b) < q.subofs+8 {
|
||||
return false
|
||||
}
|
||||
t := ICMP4Type(q.b[q.subofs])
|
||||
return t == ICMP4Unreachable || t == ICMP4TimeExceeded
|
||||
case ICMPv6:
|
||||
if len(q.b) < q.subofs+8 {
|
||||
return false
|
||||
}
|
||||
t := ICMP6Type(q.b[q.subofs])
|
||||
return t == ICMP6Unreachable || t == ICMP6TimeExceeded
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsEchoRequest reports whether q is an ICMP Echo Request.
|
||||
func (q *Parsed) IsEchoRequest() bool {
|
||||
switch q.IPProto {
|
||||
case ICMPv4:
|
||||
return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoRequest && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode
|
||||
case ICMPv6:
|
||||
return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoRequest && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsEchoRequest reports whether q is an IPv4 ICMP Echo Response.
|
||||
func (q *Parsed) IsEchoResponse() bool {
|
||||
switch q.IPProto {
|
||||
case ICMPv4:
|
||||
return len(q.b) >= q.subofs+8 && ICMP4Type(q.b[q.subofs]) == ICMP4EchoReply && ICMP4Code(q.b[q.subofs+1]) == ICMP4NoCode
|
||||
case ICMPv6:
|
||||
return len(q.b) >= q.subofs+8 && ICMP6Type(q.b[q.subofs]) == ICMP6EchoReply && ICMP6Code(q.b[q.subofs+1]) == ICMP6NoCode
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func Hexdump(b []byte) string {
|
||||
out := new(strings.Builder)
|
||||
for i := 0; i < len(b); i += 16 {
|
||||
if i > 0 {
|
||||
fmt.Fprintf(out, "\n")
|
||||
}
|
||||
fmt.Fprintf(out, " %04x ", i)
|
||||
j := 0
|
||||
for ; j < 16 && i+j < len(b); j++ {
|
||||
if j == 8 {
|
||||
fmt.Fprintf(out, " ")
|
||||
}
|
||||
fmt.Fprintf(out, "%02x ", b[i+j])
|
||||
}
|
||||
for ; j < 16; j++ {
|
||||
if j == 8 {
|
||||
fmt.Fprintf(out, " ")
|
||||
}
|
||||
fmt.Fprintf(out, " ")
|
||||
}
|
||||
fmt.Fprintf(out, " ")
|
||||
for j = 0; j < 16 && i+j < len(b); j++ {
|
||||
if b[i+j] >= 32 && b[i+j] < 128 {
|
||||
fmt.Fprintf(out, "%c", b[i+j])
|
||||
} else {
|
||||
fmt.Fprintf(out, ".")
|
||||
}
|
||||
}
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
502
net/packet/packet_test.go
Normal file
502
net/packet/packet_test.go
Normal file
@@ -0,0 +1,502 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
func mustIP4(s string) IP4 {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return IP4FromNetaddr(ip)
|
||||
}
|
||||
|
||||
func mustIP6(s string) IP6 {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return IP6FromNetaddr(ip)
|
||||
}
|
||||
|
||||
func TestIP4String(t *testing.T) {
|
||||
const str = "1.2.3.4"
|
||||
ip := mustIP4(str)
|
||||
|
||||
var got string
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
got = ip.String()
|
||||
})
|
||||
|
||||
if got != str {
|
||||
t.Errorf("got %q; want %q", got, str)
|
||||
}
|
||||
if allocs != 1 {
|
||||
t.Errorf("allocs = %v; want 1", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIP6String(t *testing.T) {
|
||||
const str = "2607:f8b0:400a:809::200e"
|
||||
ip := mustIP6(str)
|
||||
|
||||
var got string
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
got = ip.String()
|
||||
})
|
||||
|
||||
if got != str {
|
||||
t.Errorf("got %q; want %q", got, str)
|
||||
}
|
||||
if allocs != 2 {
|
||||
t.Errorf("allocs = %v; want 1", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
var icmp4RequestBuffer = []byte{
|
||||
// IP header up to checksum
|
||||
0x45, 0x00, 0x00, 0x27, 0xde, 0xad, 0x00, 0x00, 0x40, 0x01, 0x8c, 0x15,
|
||||
// source ip
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
// destination ip
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
// ICMP header
|
||||
0x08, 0x00, 0x7d, 0x22,
|
||||
// "request_payload"
|
||||
0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
|
||||
}
|
||||
|
||||
var icmp4RequestDecode = Parsed{
|
||||
b: icmp4RequestBuffer,
|
||||
subofs: 20,
|
||||
dataofs: 24,
|
||||
length: len(icmp4RequestBuffer),
|
||||
|
||||
IPVersion: 4,
|
||||
IPProto: ICMPv4,
|
||||
SrcIP4: mustIP4("1.2.3.4"),
|
||||
DstIP4: mustIP4("5.6.7.8"),
|
||||
SrcPort: 0,
|
||||
DstPort: 0,
|
||||
}
|
||||
|
||||
var icmp4ReplyBuffer = []byte{
|
||||
0x45, 0x00, 0x00, 0x25, 0x21, 0x52, 0x00, 0x00, 0x40, 0x01, 0x49, 0x73,
|
||||
// source ip
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
// destination ip
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
// ICMP header
|
||||
0x00, 0x00, 0xe6, 0x9e,
|
||||
// "reply_payload"
|
||||
0x72, 0x65, 0x70, 0x6c, 0x79, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
|
||||
}
|
||||
|
||||
var icmp4ReplyDecode = Parsed{
|
||||
b: icmp4ReplyBuffer,
|
||||
subofs: 20,
|
||||
dataofs: 24,
|
||||
length: len(icmp4ReplyBuffer),
|
||||
|
||||
IPVersion: 4,
|
||||
IPProto: ICMPv4,
|
||||
SrcIP4: mustIP4("1.2.3.4"),
|
||||
DstIP4: mustIP4("5.6.7.8"),
|
||||
SrcPort: 0,
|
||||
DstPort: 0,
|
||||
}
|
||||
|
||||
// ICMPv6 Router Solicitation
|
||||
var icmp6PacketBuffer = []byte{
|
||||
0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x3a, 0xff,
|
||||
0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0xfb, 0x57, 0x1d, 0xea, 0x9c, 0x39, 0x8f, 0xb7,
|
||||
0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
|
||||
0x85, 0x00, 0x38, 0x04, 0x00, 0x00, 0x00, 0x00,
|
||||
}
|
||||
|
||||
var icmp6PacketDecode = Parsed{
|
||||
b: icmp6PacketBuffer,
|
||||
subofs: 40,
|
||||
dataofs: 44,
|
||||
length: len(icmp6PacketBuffer),
|
||||
IPVersion: 6,
|
||||
IPProto: ICMPv6,
|
||||
SrcIP6: mustIP6("fe80::fb57:1dea:9c39:8fb7"),
|
||||
DstIP6: mustIP6("ff02::2"),
|
||||
}
|
||||
|
||||
// This is a malformed IPv4 packet.
|
||||
// Namely, the string "tcp_payload" follows the first byte of the IPv4 header.
|
||||
var unknownPacketBuffer = []byte{
|
||||
0x45, 0x74, 0x63, 0x70, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
|
||||
}
|
||||
|
||||
var unknownPacketDecode = Parsed{
|
||||
b: unknownPacketBuffer,
|
||||
IPVersion: 0,
|
||||
IPProto: Unknown,
|
||||
}
|
||||
|
||||
var tcp4PacketBuffer = []byte{
|
||||
// IP header up to checksum
|
||||
0x45, 0x00, 0x00, 0x37, 0xde, 0xad, 0x00, 0x00, 0x40, 0x06, 0x49, 0x5f,
|
||||
// source ip
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
// destination ip
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
// TCP header with SYN, ACK set
|
||||
0x00, 0x7b, 0x02, 0x37, 0x00, 0x00, 0x12, 0x34, 0x00, 0x00, 0x00, 0x00,
|
||||
0x50, 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
// "request_payload"
|
||||
0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
|
||||
}
|
||||
|
||||
var tcp4PacketDecode = Parsed{
|
||||
b: tcp4PacketBuffer,
|
||||
subofs: 20,
|
||||
dataofs: 40,
|
||||
length: len(tcp4PacketBuffer),
|
||||
|
||||
IPVersion: 4,
|
||||
IPProto: TCP,
|
||||
SrcIP4: mustIP4("1.2.3.4"),
|
||||
DstIP4: mustIP4("5.6.7.8"),
|
||||
SrcPort: 123,
|
||||
DstPort: 567,
|
||||
TCPFlags: TCPSynAck,
|
||||
}
|
||||
|
||||
var tcp6RequestBuffer = []byte{
|
||||
// IPv6 header up to hop limit
|
||||
0x60, 0x06, 0xef, 0xcc, 0x00, 0x28, 0x06, 0x40,
|
||||
// Src addr
|
||||
0x20, 0x01, 0x05, 0x59, 0xbc, 0x13, 0x54, 0x00, 0x17, 0x49, 0x46, 0x28, 0x39, 0x34, 0x0e, 0x1b,
|
||||
// Dst addr
|
||||
0x26, 0x07, 0xf8, 0xb0, 0x40, 0x0a, 0x08, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x0e,
|
||||
// TCP SYN segment, no payload
|
||||
0xa4, 0x60, 0x00, 0x50, 0xf3, 0x82, 0xa1, 0x25, 0x00, 0x00, 0x00, 0x00, 0xa0, 0x02, 0xfd, 0x20,
|
||||
0xb1, 0xc6, 0x00, 0x00, 0x02, 0x04, 0x05, 0xa0, 0x04, 0x02, 0x08, 0x0a, 0xca, 0x76, 0xa6, 0x8e,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07,
|
||||
}
|
||||
|
||||
var tcp6RequestDecode = Parsed{
|
||||
b: tcp6RequestBuffer,
|
||||
subofs: 40,
|
||||
dataofs: len(tcp6RequestBuffer),
|
||||
length: len(tcp6RequestBuffer),
|
||||
|
||||
IPVersion: 6,
|
||||
IPProto: TCP,
|
||||
SrcIP6: mustIP6("2001:559:bc13:5400:1749:4628:3934:e1b"),
|
||||
DstIP6: mustIP6("2607:f8b0:400a:809::200e"),
|
||||
SrcPort: 42080,
|
||||
DstPort: 80,
|
||||
TCPFlags: TCPSyn,
|
||||
}
|
||||
|
||||
var udp4RequestBuffer = []byte{
|
||||
// IP header up to checksum
|
||||
0x45, 0x00, 0x00, 0x2b, 0xde, 0xad, 0x00, 0x00, 0x40, 0x11, 0x8c, 0x01,
|
||||
// source ip
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
// destination ip
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
// UDP header
|
||||
0x00, 0x7b, 0x02, 0x37, 0x00, 0x17, 0x72, 0x1d,
|
||||
// "request_payload"
|
||||
0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
|
||||
}
|
||||
|
||||
var udp4RequestDecode = Parsed{
|
||||
b: udp4RequestBuffer,
|
||||
subofs: 20,
|
||||
dataofs: 28,
|
||||
length: len(udp4RequestBuffer),
|
||||
|
||||
IPVersion: 4,
|
||||
IPProto: UDP,
|
||||
SrcIP4: mustIP4("1.2.3.4"),
|
||||
DstIP4: mustIP4("5.6.7.8"),
|
||||
SrcPort: 123,
|
||||
DstPort: 567,
|
||||
}
|
||||
|
||||
var invalid4RequestBuffer = []byte{
|
||||
// IP header up to checksum. IHL field points beyond end of packet.
|
||||
0x4a, 0x00, 0x00, 0x14, 0xde, 0xad, 0x00, 0x00, 0x40, 0x11, 0x8c, 0x01,
|
||||
// source ip
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
// destination ip
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
}
|
||||
|
||||
// Regression check for the IHL field pointing beyond the end of the
|
||||
// packet.
|
||||
var invalid4RequestDecode = Parsed{
|
||||
b: invalid4RequestBuffer,
|
||||
subofs: 40,
|
||||
length: len(invalid4RequestBuffer),
|
||||
|
||||
IPVersion: 4,
|
||||
IPProto: Unknown,
|
||||
SrcIP4: mustIP4("1.2.3.4"),
|
||||
DstIP4: mustIP4("5.6.7.8"),
|
||||
}
|
||||
|
||||
var udp6RequestBuffer = []byte{
|
||||
// IPv6 header up to hop limit
|
||||
0x60, 0x0e, 0xc9, 0x67, 0x00, 0x29, 0x11, 0x40,
|
||||
// Src addr
|
||||
0x20, 0x01, 0x05, 0x59, 0xbc, 0x13, 0x54, 0x00, 0x17, 0x49, 0x46, 0x28, 0x39, 0x34, 0x0e, 0x1b,
|
||||
// Dst addr
|
||||
0x26, 0x07, 0xf8, 0xb0, 0x40, 0x0a, 0x08, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x0e,
|
||||
// UDP header
|
||||
0xd4, 0x04, 0x01, 0xbb, 0x00, 0x29, 0x96, 0x84,
|
||||
// Payload
|
||||
0x5c, 0x06, 0xae, 0x85, 0x02, 0xf5, 0xdb, 0x90, 0xe0, 0xe0, 0x93, 0xed, 0x9a, 0xd9, 0x92, 0x69, 0xbe, 0x36, 0x8a, 0x7d, 0xd7, 0xce, 0xd0, 0x8a, 0xf2, 0x51, 0x95, 0xff, 0xb6, 0x92, 0x70, 0x10, 0xd7,
|
||||
}
|
||||
|
||||
var udp6RequestDecode = Parsed{
|
||||
b: udp6RequestBuffer,
|
||||
subofs: 40,
|
||||
dataofs: 48,
|
||||
length: len(udp6RequestBuffer),
|
||||
|
||||
IPVersion: 6,
|
||||
IPProto: UDP,
|
||||
SrcIP6: mustIP6("2001:559:bc13:5400:1749:4628:3934:e1b"),
|
||||
DstIP6: mustIP6("2607:f8b0:400a:809::200e"),
|
||||
SrcPort: 54276,
|
||||
DstPort: 443,
|
||||
}
|
||||
|
||||
var udp4ReplyBuffer = []byte{
|
||||
// IP header up to checksum
|
||||
0x45, 0x00, 0x00, 0x29, 0x21, 0x52, 0x00, 0x00, 0x40, 0x11, 0x49, 0x5f,
|
||||
// source ip
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
// destination ip
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
// UDP header
|
||||
0x02, 0x37, 0x00, 0x7b, 0x00, 0x15, 0xd3, 0x9d,
|
||||
// "reply_payload"
|
||||
0x72, 0x65, 0x70, 0x6c, 0x79, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
|
||||
}
|
||||
|
||||
var udp4ReplyDecode = Parsed{
|
||||
b: udp4ReplyBuffer,
|
||||
subofs: 20,
|
||||
dataofs: 28,
|
||||
length: len(udp4ReplyBuffer),
|
||||
|
||||
IPProto: UDP,
|
||||
SrcIP4: mustIP4("1.2.3.4"),
|
||||
DstIP4: mustIP4("5.6.7.8"),
|
||||
SrcPort: 567,
|
||||
DstPort: 123,
|
||||
}
|
||||
|
||||
var igmpPacketBuffer = []byte{
|
||||
// IP header up to checksum
|
||||
0x46, 0xc0, 0x00, 0x20, 0x00, 0x00, 0x40, 0x00, 0x01, 0x02, 0x41, 0x22,
|
||||
// source IP
|
||||
0xc0, 0xa8, 0x01, 0x52,
|
||||
// destination IP
|
||||
0xe0, 0x00, 0x00, 0xfb,
|
||||
// IGMP Membership Report
|
||||
0x94, 0x04, 0x00, 0x00, 0x16, 0x00, 0x09, 0x04, 0xe0, 0x00, 0x00, 0xfb,
|
||||
//0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
}
|
||||
|
||||
var igmpPacketDecode = Parsed{
|
||||
b: igmpPacketBuffer,
|
||||
subofs: 24,
|
||||
length: len(igmpPacketBuffer),
|
||||
|
||||
IPVersion: 4,
|
||||
IPProto: IGMP,
|
||||
SrcIP4: mustIP4("192.168.1.82"),
|
||||
DstIP4: mustIP4("224.0.0.251"),
|
||||
}
|
||||
|
||||
func TestParsed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
qdecode Parsed
|
||||
want string
|
||||
}{
|
||||
{"tcp4", tcp4PacketDecode, "TCP{1.2.3.4:123 > 5.6.7.8:567}"},
|
||||
{"tcp6", tcp6RequestDecode, "TCP{[2001:559:bc13:5400:1749:4628:3934:e1b]:42080 > [2607:f8b0:400a:809::200e]:80}"},
|
||||
{"udp4", udp4RequestDecode, "UDP{1.2.3.4:123 > 5.6.7.8:567}"},
|
||||
{"udp6", udp6RequestDecode, "UDP{[2001:559:bc13:5400:1749:4628:3934:e1b]:54276 > [2607:f8b0:400a:809::200e]:443}"},
|
||||
{"icmp4", icmp4RequestDecode, "ICMPv4{1.2.3.4:0 > 5.6.7.8:0}"},
|
||||
{"icmp6", icmp6PacketDecode, "ICMPv6{[fe80::fb57:1dea:9c39:8fb7]:0 > [ff02::2]:0}"},
|
||||
{"igmp", igmpPacketDecode, "IGMP{192.168.1.82:0 > 224.0.0.251:0}"},
|
||||
{"unknown", unknownPacketDecode, "Unknown{???}"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.qdecode.String()
|
||||
if got != tt.want {
|
||||
t.Errorf("got %q; want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var sink string
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
sink = tests[0].qdecode.String()
|
||||
})
|
||||
_ = sink
|
||||
if allocs != 1 {
|
||||
t.Errorf("allocs = %v; want 1", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
buf []byte
|
||||
want Parsed
|
||||
}{
|
||||
{"icmp4", icmp4RequestBuffer, icmp4RequestDecode},
|
||||
{"icmp6", icmp6PacketBuffer, icmp6PacketDecode},
|
||||
{"tcp4", tcp4PacketBuffer, tcp4PacketDecode},
|
||||
{"tcp6", tcp6RequestBuffer, tcp6RequestDecode},
|
||||
{"udp4", udp4RequestBuffer, udp4RequestDecode},
|
||||
{"udp6", udp6RequestBuffer, udp6RequestDecode},
|
||||
{"igmp", igmpPacketBuffer, igmpPacketDecode},
|
||||
{"unknown", unknownPacketBuffer, unknownPacketDecode},
|
||||
{"invalid4", invalid4RequestBuffer, invalid4RequestDecode},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Parsed
|
||||
got.Decode(tt.buf)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("mismatch\n got: %#v\nwant: %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
var got Parsed
|
||||
got.Decode(tests[0].buf)
|
||||
})
|
||||
if allocs != 0 {
|
||||
t.Errorf("allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecode(b *testing.B) {
|
||||
benches := []struct {
|
||||
name string
|
||||
buf []byte
|
||||
}{
|
||||
{"tcp4", tcp4PacketBuffer},
|
||||
{"tcp6", tcp6RequestBuffer},
|
||||
{"udp4", udp4RequestBuffer},
|
||||
{"udp6", udp6RequestBuffer},
|
||||
{"icmp4", icmp4RequestBuffer},
|
||||
{"icmp6", icmp6PacketBuffer},
|
||||
{"igmp", igmpPacketBuffer},
|
||||
{"unknown", unknownPacketBuffer},
|
||||
}
|
||||
|
||||
for _, bench := range benches {
|
||||
b.Run(bench.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var p Parsed
|
||||
p.Decode(bench.buf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalRequest(t *testing.T) {
|
||||
// Too small to hold our packets, but only barely.
|
||||
var small [20]byte
|
||||
var large [64]byte
|
||||
|
||||
icmpHeader := icmp4RequestDecode.ICMP4Header()
|
||||
udpHeader := udp4RequestDecode.UDP4Header()
|
||||
tests := []struct {
|
||||
name string
|
||||
header Header
|
||||
want []byte
|
||||
}{
|
||||
{"icmp", &icmpHeader, icmp4RequestBuffer},
|
||||
{"udp", &udpHeader, udp4RequestBuffer},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.header.Marshal(small[:])
|
||||
if err != errSmallBuffer {
|
||||
t.Errorf("got err: nil; want: %s", errSmallBuffer)
|
||||
}
|
||||
|
||||
dataOffset := tt.header.Len()
|
||||
dataLength := copy(large[dataOffset:], []byte("request_payload"))
|
||||
end := dataOffset + dataLength
|
||||
err = tt.header.Marshal(large[:end])
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("got err: %s; want nil", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(large[:end], tt.want) {
|
||||
t.Errorf("got %x; want %x", large[:end], tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalResponse(t *testing.T) {
|
||||
var buf [64]byte
|
||||
|
||||
icmpHeader := icmp4RequestDecode.ICMP4Header()
|
||||
udpHeader := udp4RequestDecode.UDP4Header()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header Header
|
||||
want []byte
|
||||
}{
|
||||
{"icmp", &icmpHeader, icmp4ReplyBuffer},
|
||||
{"udp", &udpHeader, udp4ReplyBuffer},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.header.ToResponse()
|
||||
|
||||
dataOffset := tt.header.Len()
|
||||
dataLength := copy(buf[dataOffset:], []byte("reply_payload"))
|
||||
end := dataOffset + dataLength
|
||||
err := tt.header.Marshal(buf[:end])
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("got err: %s; want nil", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(buf[:end], tt.want) {
|
||||
t.Errorf("got %x; want %x", buf[:end], tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
55
net/packet/udp4.go
Normal file
55
net/packet/udp4.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// udpHeaderLength is the size of the UDP packet header, not including
|
||||
// the outer IP header.
|
||||
const udpHeaderLength = 8
|
||||
|
||||
// UDP4Header is an IPv4+UDP header.
|
||||
type UDP4Header struct {
|
||||
IP4Header
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// Len implements Header.
|
||||
func (h UDP4Header) Len() int {
|
||||
return h.IP4Header.Len() + udpHeaderLength
|
||||
}
|
||||
|
||||
// Marshal implements Header.
|
||||
func (h UDP4Header) Marshal(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if len(buf) > maxPacketLength {
|
||||
return errLargePacket
|
||||
}
|
||||
// The caller does not need to set this.
|
||||
h.IPProto = UDP
|
||||
|
||||
length := len(buf) - h.IP4Header.Len()
|
||||
binary.BigEndian.PutUint16(buf[20:22], h.SrcPort)
|
||||
binary.BigEndian.PutUint16(buf[22:24], h.DstPort)
|
||||
binary.BigEndian.PutUint16(buf[24:26], uint16(length))
|
||||
binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum
|
||||
|
||||
// UDP checksum with IP pseudo header.
|
||||
h.IP4Header.marshalPseudo(buf)
|
||||
binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:]))
|
||||
|
||||
h.IP4Header.Marshal(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToResponse implements Header.
|
||||
func (h *UDP4Header) ToResponse() {
|
||||
h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
|
||||
h.IP4Header.ToResponse()
|
||||
}
|
||||
51
net/packet/udp6.go
Normal file
51
net/packet/udp6.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// UDP6Header is an IPv6+UDP header.
|
||||
type UDP6Header struct {
|
||||
IP6Header
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// Len implements Header.
|
||||
func (h UDP6Header) Len() int {
|
||||
return h.IP6Header.Len() + udpHeaderLength
|
||||
}
|
||||
|
||||
// Marshal implements Header.
|
||||
func (h UDP6Header) Marshal(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if len(buf) > maxPacketLength {
|
||||
return errLargePacket
|
||||
}
|
||||
// The caller does not need to set this.
|
||||
h.IPProto = UDP
|
||||
|
||||
length := len(buf) - h.IP6Header.Len()
|
||||
binary.BigEndian.PutUint16(buf[40:42], h.SrcPort)
|
||||
binary.BigEndian.PutUint16(buf[42:44], h.DstPort)
|
||||
binary.BigEndian.PutUint16(buf[44:46], uint16(length))
|
||||
binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum
|
||||
|
||||
// UDP checksum with IP pseudo header.
|
||||
h.IP6Header.marshalPseudo(buf)
|
||||
binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:]))
|
||||
|
||||
h.IP6Header.Marshal(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToResponse implements Header.
|
||||
func (h *UDP6Header) ToResponse() {
|
||||
h.SrcPort, h.DstPort = h.DstPort, h.SrcPort
|
||||
h.IP6Header.ToResponse()
|
||||
}
|
||||
@@ -8,9 +8,8 @@ package portlist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
exec "tailscale.com/tempfork/osexec"
|
||||
)
|
||||
|
||||
var osHideWindow func(*exec.Cmd) // non-nil on Windows; see portlist_windows.go
|
||||
|
||||
@@ -7,7 +7,6 @@ package portlist
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"tailscale.com/version"
|
||||
@@ -34,7 +33,7 @@ type Poller struct {
|
||||
// NewPoller returns a new portlist Poller. It returns an error
|
||||
// if the portlist couldn't be obtained.
|
||||
func NewPoller() (*Poller, error) {
|
||||
if runtime.GOOS == "darwin" && version.IsMobile() {
|
||||
if version.OS() == "iOS" {
|
||||
return nil, errors.New("not available on iOS")
|
||||
}
|
||||
p := &Poller{
|
||||
|
||||
@@ -12,11 +12,10 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
exec "tailscale.com/tempfork/osexec"
|
||||
)
|
||||
|
||||
// We have to run netstat, which is a bit expensive, so don't do it too often.
|
||||
|
||||
@@ -5,11 +5,11 @@
|
||||
package portlist
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
exec "tailscale.com/tempfork/osexec"
|
||||
)
|
||||
|
||||
// Forking on Windows is insanely expensive, so don't do it too often.
|
||||
|
||||
24
shell.nix
Normal file
24
shell.nix
Normal file
@@ -0,0 +1,24 @@
|
||||
# This is a shell.nix file used to describe the environment that tailscale needs
|
||||
# for development. This includes a lot of the basic tools that you need in order
|
||||
# to get started. We hope this file will be useful for users of Nix on macOS or
|
||||
# Linux.
|
||||
#
|
||||
# For more information about this and why this file is useful, see here:
|
||||
# https://nixos.org/guides/nix-pills/developing-with-nix-shell.html
|
||||
#
|
||||
# Also look into direnv: https://direnv.net/, this can make it so that you can
|
||||
# automatically get your environment set up when you change folders into the
|
||||
# project.
|
||||
{ pkgs ? import <nixpkgs> {} }:
|
||||
|
||||
pkgs.mkShell {
|
||||
# This specifies the tools that are needed for people to get started with
|
||||
# development. These tools include:
|
||||
# - The Go compiler toolchain (and all additional tooling with it)
|
||||
# - goimports, a robust formatting tool for Go source code
|
||||
# - gopls, the language server for Go to increase editor integration
|
||||
# - git, the version control program (used in some scripts)
|
||||
buildInputs = with pkgs; [
|
||||
go goimports gopls git
|
||||
];
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"go4.org/mem"
|
||||
@@ -211,10 +210,15 @@ func (m MachineStatus) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
func isNum(r rune) bool { return r >= '0' && r <= '9' }
|
||||
func isAlpha(r rune) bool { return (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') }
|
||||
func isNum(b byte) bool {
|
||||
return b >= '0' && b <= '9'
|
||||
}
|
||||
|
||||
// CheckTag valids whether a given string can be used as an ACL tag.
|
||||
func isAlpha(b byte) bool {
|
||||
return (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z')
|
||||
}
|
||||
|
||||
// CheckTag validates tag for use as an ACL tag.
|
||||
// For now we allow only ascii alphanumeric tags, and they need to start
|
||||
// with a letter. No unicode shenanigans allowed, and we reserve punctuation
|
||||
// marks other than '-' for a possible future URI scheme.
|
||||
@@ -227,34 +231,33 @@ func CheckTag(tag string) error {
|
||||
if !strings.HasPrefix(tag, "tag:") {
|
||||
return errors.New("tags must start with 'tag:'")
|
||||
}
|
||||
suffix := tag[len("tag:"):]
|
||||
if err := CheckTagSuffix(suffix); err != nil {
|
||||
return fmt.Errorf("invalid tag %q: %w", tag, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckTagSuffix checks whether tag is a valid tag suffix (the part
|
||||
// appearing after "tag:"). The error message does not reference
|
||||
// "tag:", so it's suitable for use by the "tailscale up" CLI tool
|
||||
// where the "tag:" isn't required. The returned error also does not
|
||||
// reference the tag itself, so the caller can wrap it as needed with
|
||||
// either the full or short form.
|
||||
func CheckTagSuffix(tag string) error {
|
||||
tag = tag[4:]
|
||||
if tag == "" {
|
||||
return errors.New("tag names must not be empty")
|
||||
}
|
||||
if i := strings.IndexFunc(tag, func(r rune) bool { return r >= utf8.RuneSelf }); i != -1 {
|
||||
return errors.New("tag names must only contain ASCII")
|
||||
if !isAlpha(tag[0]) {
|
||||
return errors.New("tag names must start with a letter, after 'tag:'")
|
||||
}
|
||||
if !isAlpha(rune(tag[0])) {
|
||||
return errors.New("tag name must start with a letter")
|
||||
}
|
||||
for _, r := range tag {
|
||||
if !isNum(r) && !isAlpha(r) && r != '-' {
|
||||
|
||||
for _, b := range []byte(tag) {
|
||||
if !isNum(b) && !isAlpha(b) && b != '-' {
|
||||
return errors.New("tag names can only contain numbers, letters, or dashes")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckRequestTags checks that all of h.RequestTags are valid.
|
||||
func (h *Hostinfo) CheckRequestTags() error {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
for _, tag := range h.RequestTags {
|
||||
if err := CheckTag(tag); err != nil {
|
||||
return fmt.Errorf("tag(%#v): %w", tag, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -291,6 +294,7 @@ type Hostinfo struct {
|
||||
OSVersion string `json:",omitempty"` // operating system version, with optional distro prefix ("Debian 10.4", "Windows 10 Pro 10.0.19041")
|
||||
DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone 11 Pro")
|
||||
Hostname string // name of the host the client runs on
|
||||
ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections
|
||||
GoArch string `json:",omitempty"` // the host's GOARCH value (of the running binary)
|
||||
RoutableIPs []wgcfg.CIDR `json:",omitempty"` // set of IP ranges this client can route
|
||||
RequestTags []string `json:",omitempty"` // set of ACL tags this node wants to claim
|
||||
@@ -509,6 +513,12 @@ type MapRequest struct {
|
||||
// added and removed all the time during development, and offer no
|
||||
// compatibility promise. To roll out semantic changes, bump
|
||||
// Version instead.
|
||||
//
|
||||
// Current DebugFlags values are:
|
||||
// * "warn-ip-forwarding-off": client is trying to be a subnet
|
||||
// router but their IP forwarding is broken.
|
||||
// * "v6-overlay": IPv6 development flag to have control send
|
||||
// v6 node addrs
|
||||
DebugFlags []string `json:",omitempty"`
|
||||
}
|
||||
|
||||
@@ -529,9 +539,28 @@ type NetPortRange struct {
|
||||
}
|
||||
|
||||
// FilterRule represents one rule in a packet filter.
|
||||
//
|
||||
// A rule is logically a set of source CIDRs to match (described by
|
||||
// SrcIPs and SrcBits), and a set of destination targets that are then
|
||||
// allowed if a source IP is mathces of those CIDRs.
|
||||
type FilterRule struct {
|
||||
SrcIPs []string // "*" means all
|
||||
SrcBits []int
|
||||
// SrcIPs are the source IPs/networks to match.
|
||||
// The special value "*" means to match all.
|
||||
SrcIPs []string
|
||||
|
||||
// SrcBits values correspond to the SrcIPs above.
|
||||
//
|
||||
// If present at the same index, it changes the SrcIP above to
|
||||
// be a network with /n CIDR bits. If the slice is nil or
|
||||
// insufficiently long, the default value (for an IPv4
|
||||
// address) for a position is 32, as if the SrcIPs above were
|
||||
// a /32 mask. For a "*" SrcIPs value, the corresponding
|
||||
// SrcBits value is ignored.
|
||||
// TODO: for IPv6, clarify default bits length.
|
||||
SrcBits []int
|
||||
|
||||
// DstPorts are the port ranges to allow once a source IP
|
||||
// matches (is in the CIDR described by SrcIPs & SrcBits).
|
||||
DstPorts []NetPortRange
|
||||
}
|
||||
|
||||
@@ -635,6 +664,10 @@ type Debug struct {
|
||||
// TrimWGConfig controls whether Tailscale does lazy, on-demand
|
||||
// wireguard configuration of peers.
|
||||
TrimWGConfig opt.Bool `json:",omitempty"`
|
||||
|
||||
// DisableSubnetsIfPAC controls whether subnet routers should be
|
||||
// disabled if WPAD is present on the network.
|
||||
DisableSubnetsIfPAC opt.Bool `json:",omitempty"`
|
||||
}
|
||||
|
||||
func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
|
||||
|
||||
@@ -105,6 +105,7 @@ var _HostinfoNeedsRegeneration = Hostinfo(struct {
|
||||
OSVersion string
|
||||
DeviceModel string
|
||||
Hostname string
|
||||
ShieldsUp bool
|
||||
GoArch string
|
||||
RoutableIPs []wgcfg.CIDR
|
||||
RequestTags []string
|
||||
|
||||
@@ -24,8 +24,8 @@ func fieldsOf(t reflect.Type) (fields []string) {
|
||||
func TestHostinfoEqual(t *testing.T) {
|
||||
hiHandles := []string{
|
||||
"IPNVersion", "FrontendLogID", "BackendLogID", "OS", "OSVersion",
|
||||
"DeviceModel", "Hostname", "GoArch", "RoutableIPs", "RequestTags", "Services",
|
||||
"NetInfo",
|
||||
"DeviceModel", "Hostname", "ShieldsUp", "GoArch", "RoutableIPs",
|
||||
"RequestTags", "Services", "NetInfo",
|
||||
}
|
||||
if have := fieldsOf(reflect.TypeOf(Hostinfo{})); !reflect.DeepEqual(have, hiHandles) {
|
||||
t.Errorf("Hostinfo.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
This is a temporary fork of Go 1.13's os/exec package,
|
||||
to work around https://github.com/golang/go/issues/36644.
|
||||
|
||||
The main modification (outside of removing some tests that require
|
||||
internal-only packages to run) is:
|
||||
|
||||
```
|
||||
commit 3c66be240f1ee1f1b5f03bed79eb0d9f8c08965a
|
||||
Author: Avery Pennarun <apenwarr@gmail.com>
|
||||
Date: Sun Jan 19 03:17:30 2020 -0500
|
||||
|
||||
Cmd.Wait(): handle EINTR return code from os.Process.Wait().
|
||||
|
||||
This is probably not actually the correct fix; most likely
|
||||
os.Process.Wait() itself should be fixed to retry on EINTR so that it
|
||||
never leaks out of that function. But if we're going to patch a
|
||||
particular module, it's safer to patch a higher-level one like os/exec
|
||||
rather than the os module itself.
|
||||
|
||||
diff --git a/exec.go b/exec.go
|
||||
index 17ef003e..5375e673 100644
|
||||
--- a/exec.go
|
||||
+++ b/exec.go
|
||||
@@ -498,7 +498,21 @@ func (c *Cmd) Wait() error {
|
||||
}
|
||||
c.finished = true
|
||||
|
||||
- state, err := c.Process.Wait()
|
||||
+ var err error
|
||||
+ var state *os.ProcessState
|
||||
+ for {
|
||||
+ state, err = c.Process.Wait()
|
||||
+ if err != nil {
|
||||
+ xe, ok := err.(*os.SyscallError)
|
||||
+ if ok {
|
||||
+ if xe.Unwrap() == syscall.EINTR {
|
||||
+ // temporary error, retry wait syscall
|
||||
+ continue
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ break
|
||||
+ }
|
||||
if c.waitDone != nil {
|
||||
close(c.waitDone)
|
||||
}
|
||||
```
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkExecHostname(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
path, err := LookPath("hostname")
|
||||
if err != nil {
|
||||
b.Fatalf("could not find hostname: %v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := Command(path).Run(); err != nil {
|
||||
b.Fatalf("hostname: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
// Copyright 2017 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDedupEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
noCase bool
|
||||
in []string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
noCase: true,
|
||||
in: []string{"k1=v1", "k2=v2", "K1=v3"},
|
||||
want: []string{"K1=v3", "k2=v2"},
|
||||
},
|
||||
{
|
||||
noCase: false,
|
||||
in: []string{"k1=v1", "K1=V2", "k1=v3"},
|
||||
want: []string{"k1=v3", "K1=V2"},
|
||||
},
|
||||
{
|
||||
in: []string{"=a", "=b", "foo", "bar"},
|
||||
want: []string{"=b", "foo", "bar"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := dedupEnvCase(tt.noCase, tt.in)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Dedup(%v, %q) = %q; want %q", tt.noCase, tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package exec_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ExampleLookPath() {
|
||||
path, err := exec.LookPath("fortune")
|
||||
if err != nil {
|
||||
log.Fatal("installing fortune is in your future")
|
||||
}
|
||||
fmt.Printf("fortune is available at %s\n", path)
|
||||
}
|
||||
|
||||
func ExampleCommand() {
|
||||
cmd := exec.Command("tr", "a-z", "A-Z")
|
||||
cmd.Stdin = strings.NewReader("some input")
|
||||
var out bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("in all caps: %q\n", out.String())
|
||||
}
|
||||
|
||||
func ExampleCommand_environment() {
|
||||
cmd := exec.Command("prog")
|
||||
cmd.Env = append(os.Environ(),
|
||||
"FOO=duplicate_value", // ignored
|
||||
"FOO=actual_value", // this value is used
|
||||
)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleCmd_Output() {
|
||||
out, err := exec.Command("date").Output()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("The date is %s\n", out)
|
||||
}
|
||||
|
||||
func ExampleCmd_Run() {
|
||||
cmd := exec.Command("sleep", "1")
|
||||
log.Printf("Running command and waiting for it to finish...")
|
||||
err := cmd.Run()
|
||||
log.Printf("Command finished with error: %v", err)
|
||||
}
|
||||
|
||||
func ExampleCmd_Start() {
|
||||
cmd := exec.Command("sleep", "5")
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("Waiting for command to finish...")
|
||||
err = cmd.Wait()
|
||||
log.Printf("Command finished with error: %v", err)
|
||||
}
|
||||
|
||||
func ExampleCmd_StdoutPipe() {
|
||||
cmd := exec.Command("echo", "-n", `{"Name": "Bob", "Age": 32}`)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
var person struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
if err := json.NewDecoder(stdout).Decode(&person); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s is %d years old\n", person.Name, person.Age)
|
||||
}
|
||||
|
||||
func ExampleCmd_StdinPipe() {
|
||||
cmd := exec.Command("cat")
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stdin.Close()
|
||||
io.WriteString(stdin, "values written to stdin are passed to cmd's standard input")
|
||||
}()
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%s\n", out)
|
||||
}
|
||||
|
||||
func ExampleCmd_StderrPipe() {
|
||||
cmd := exec.Command("sh", "-c", "echo stdout; echo 1>&2 stderr")
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
slurp, _ := ioutil.ReadAll(stderr)
|
||||
fmt.Printf("%s\n", slurp)
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleCmd_CombinedOutput() {
|
||||
cmd := exec.Command("sh", "-c", "echo stdout; echo 1>&2 stderr")
|
||||
stdoutStderr, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s\n", stdoutStderr)
|
||||
}
|
||||
|
||||
func ExampleCommandContext() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := exec.CommandContext(ctx, "sleep", "5").Run(); err != nil {
|
||||
// This will fail after 100 milliseconds. The 5 second sleep
|
||||
// will be interrupted.
|
||||
}
|
||||
}
|
||||
@@ -1,797 +0,0 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package exec runs external commands. It wraps os.StartProcess to make it
|
||||
// easier to remap stdin and stdout, connect I/O with pipes, and do other
|
||||
// adjustments.
|
||||
//
|
||||
// Unlike the "system" library call from C and other languages, the
|
||||
// os/exec package intentionally does not invoke the system shell and
|
||||
// does not expand any glob patterns or handle other expansions,
|
||||
// pipelines, or redirections typically done by shells. The package
|
||||
// behaves more like C's "exec" family of functions. To expand glob
|
||||
// patterns, either call the shell directly, taking care to escape any
|
||||
// dangerous input, or use the path/filepath package's Glob function.
|
||||
// To expand environment variables, use package os's ExpandEnv.
|
||||
//
|
||||
// Note that the examples in this package assume a Unix system.
|
||||
// They may not run on Windows, and they do not run in the Go Playground
|
||||
// used by golang.org and godoc.org.
|
||||
package exec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Error is returned by LookPath when it fails to classify a file as an
|
||||
// executable.
|
||||
type Error struct {
|
||||
// Name is the file name for which the error occurred.
|
||||
Name string
|
||||
// Err is the underlying error.
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return "exec: " + strconv.Quote(e.Name) + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *Error) Unwrap() error { return e.Err }
|
||||
|
||||
// Cmd represents an external command being prepared or run.
|
||||
//
|
||||
// A Cmd cannot be reused after calling its Run, Output or CombinedOutput
|
||||
// methods.
|
||||
type Cmd struct {
|
||||
// Path is the path of the command to run.
|
||||
//
|
||||
// This is the only field that must be set to a non-zero
|
||||
// value. If Path is relative, it is evaluated relative
|
||||
// to Dir.
|
||||
Path string
|
||||
|
||||
// Args holds command line arguments, including the command as Args[0].
|
||||
// If the Args field is empty or nil, Run uses {Path}.
|
||||
//
|
||||
// In typical use, both Path and Args are set by calling Command.
|
||||
Args []string
|
||||
|
||||
// Env specifies the environment of the process.
|
||||
// Each entry is of the form "key=value".
|
||||
// If Env is nil, the new process uses the current process's
|
||||
// environment.
|
||||
// If Env contains duplicate environment keys, only the last
|
||||
// value in the slice for each duplicate key is used.
|
||||
// As a special case on Windows, SYSTEMROOT is always added if
|
||||
// missing and not explicitly set to the empty string.
|
||||
Env []string
|
||||
|
||||
// Dir specifies the working directory of the command.
|
||||
// If Dir is the empty string, Run runs the command in the
|
||||
// calling process's current directory.
|
||||
Dir string
|
||||
|
||||
// Stdin specifies the process's standard input.
|
||||
//
|
||||
// If Stdin is nil, the process reads from the null device (os.DevNull).
|
||||
//
|
||||
// If Stdin is an *os.File, the process's standard input is connected
|
||||
// directly to that file.
|
||||
//
|
||||
// Otherwise, during the execution of the command a separate
|
||||
// goroutine reads from Stdin and delivers that data to the command
|
||||
// over a pipe. In this case, Wait does not complete until the goroutine
|
||||
// stops copying, either because it has reached the end of Stdin
|
||||
// (EOF or a read error) or because writing to the pipe returned an error.
|
||||
Stdin io.Reader
|
||||
|
||||
// Stdout and Stderr specify the process's standard output and error.
|
||||
//
|
||||
// If either is nil, Run connects the corresponding file descriptor
|
||||
// to the null device (os.DevNull).
|
||||
//
|
||||
// If either is an *os.File, the corresponding output from the process
|
||||
// is connected directly to that file.
|
||||
//
|
||||
// Otherwise, during the execution of the command a separate goroutine
|
||||
// reads from the process over a pipe and delivers that data to the
|
||||
// corresponding Writer. In this case, Wait does not complete until the
|
||||
// goroutine reaches EOF or encounters an error.
|
||||
//
|
||||
// If Stdout and Stderr are the same writer, and have a type that can
|
||||
// be compared with ==, at most one goroutine at a time will call Write.
|
||||
Stdout io.Writer
|
||||
Stderr io.Writer
|
||||
|
||||
// ExtraFiles specifies additional open files to be inherited by the
|
||||
// new process. It does not include standard input, standard output, or
|
||||
// standard error. If non-nil, entry i becomes file descriptor 3+i.
|
||||
//
|
||||
// ExtraFiles is not supported on Windows.
|
||||
ExtraFiles []*os.File
|
||||
|
||||
// SysProcAttr holds optional, operating system-specific attributes.
|
||||
// Run passes it to os.StartProcess as the os.ProcAttr's Sys field.
|
||||
SysProcAttr *syscall.SysProcAttr
|
||||
|
||||
// Process is the underlying process, once started.
|
||||
Process *os.Process
|
||||
|
||||
// ProcessState contains information about an exited process,
|
||||
// available after a call to Wait or Run.
|
||||
ProcessState *os.ProcessState
|
||||
|
||||
ctx context.Context // nil means none
|
||||
lookPathErr error // LookPath error, if any.
|
||||
finished bool // when Wait was called
|
||||
childFiles []*os.File
|
||||
closeAfterStart []io.Closer
|
||||
closeAfterWait []io.Closer
|
||||
goroutine []func() error
|
||||
errch chan error // one send per goroutine
|
||||
waitDone chan struct{}
|
||||
}
|
||||
|
||||
// Command returns the Cmd struct to execute the named program with
|
||||
// the given arguments.
|
||||
//
|
||||
// It sets only the Path and Args in the returned structure.
|
||||
//
|
||||
// If name contains no path separators, Command uses LookPath to
|
||||
// resolve name to a complete path if possible. Otherwise it uses name
|
||||
// directly as Path.
|
||||
//
|
||||
// The returned Cmd's Args field is constructed from the command name
|
||||
// followed by the elements of arg, so arg should not include the
|
||||
// command name itself. For example, Command("echo", "hello").
|
||||
// Args[0] is always name, not the possibly resolved Path.
|
||||
//
|
||||
// On Windows, processes receive the whole command line as a single string
|
||||
// and do their own parsing. Command combines and quotes Args into a command
|
||||
// line string with an algorithm compatible with applications using
|
||||
// CommandLineToArgvW (which is the most common way). Notable exceptions are
|
||||
// msiexec.exe and cmd.exe (and thus, all batch files), which have a different
|
||||
// unquoting algorithm. In these or other similar cases, you can do the
|
||||
// quoting yourself and provide the full command line in SysProcAttr.CmdLine,
|
||||
// leaving Args empty.
|
||||
func Command(name string, arg ...string) *Cmd {
|
||||
cmd := &Cmd{
|
||||
Path: name,
|
||||
Args: append([]string{name}, arg...),
|
||||
}
|
||||
if filepath.Base(name) == name {
|
||||
if lp, err := LookPath(name); err != nil {
|
||||
cmd.lookPathErr = err
|
||||
} else {
|
||||
cmd.Path = lp
|
||||
}
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// CommandContext is like Command but includes a context.
|
||||
//
|
||||
// The provided context is used to kill the process (by calling
|
||||
// os.Process.Kill) if the context becomes done before the command
|
||||
// completes on its own.
|
||||
func CommandContext(ctx context.Context, name string, arg ...string) *Cmd {
|
||||
if ctx == nil {
|
||||
panic("nil Context")
|
||||
}
|
||||
cmd := Command(name, arg...)
|
||||
cmd.ctx = ctx
|
||||
return cmd
|
||||
}
|
||||
|
||||
// String returns a human-readable description of c.
|
||||
// It is intended only for debugging.
|
||||
// In particular, it is not suitable for use as input to a shell.
|
||||
// The output of String may vary across Go releases.
|
||||
func (c *Cmd) String() string {
|
||||
if c.lookPathErr != nil {
|
||||
// failed to resolve path; report the original requested path (plus args)
|
||||
return strings.Join(c.Args, " ")
|
||||
}
|
||||
// report the exact executable path (plus args)
|
||||
b := new(strings.Builder)
|
||||
b.WriteString(c.Path)
|
||||
for _, a := range c.Args[1:] {
|
||||
b.WriteByte(' ')
|
||||
b.WriteString(a)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// interfaceEqual protects against panics from doing equality tests on
|
||||
// two interfaces with non-comparable underlying types.
|
||||
func interfaceEqual(a, b interface{}) bool {
|
||||
defer func() {
|
||||
recover()
|
||||
}()
|
||||
return a == b
|
||||
}
|
||||
|
||||
func (c *Cmd) envv() []string {
|
||||
if c.Env != nil {
|
||||
return c.Env
|
||||
}
|
||||
return os.Environ()
|
||||
}
|
||||
|
||||
func (c *Cmd) argv() []string {
|
||||
if len(c.Args) > 0 {
|
||||
return c.Args
|
||||
}
|
||||
return []string{c.Path}
|
||||
}
|
||||
|
||||
// skipStdinCopyError optionally specifies a function which reports
|
||||
// whether the provided stdin copy error should be ignored.
|
||||
// It is non-nil everywhere but Plan 9, which lacks EPIPE. See exec_posix.go.
|
||||
var skipStdinCopyError func(error) bool
|
||||
|
||||
func (c *Cmd) stdin() (f *os.File, err error) {
|
||||
if c.Stdin == nil {
|
||||
f, err = os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.closeAfterStart = append(c.closeAfterStart, f)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := c.Stdin.(*os.File); ok {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.closeAfterStart = append(c.closeAfterStart, pr)
|
||||
c.closeAfterWait = append(c.closeAfterWait, pw)
|
||||
c.goroutine = append(c.goroutine, func() error {
|
||||
_, err := io.Copy(pw, c.Stdin)
|
||||
if skip := skipStdinCopyError; skip != nil && skip(err) {
|
||||
err = nil
|
||||
}
|
||||
if err1 := pw.Close(); err == nil {
|
||||
err = err1
|
||||
}
|
||||
return err
|
||||
})
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
func (c *Cmd) stdout() (f *os.File, err error) {
|
||||
return c.writerDescriptor(c.Stdout)
|
||||
}
|
||||
|
||||
func (c *Cmd) stderr() (f *os.File, err error) {
|
||||
if c.Stderr != nil && interfaceEqual(c.Stderr, c.Stdout) {
|
||||
return c.childFiles[1], nil
|
||||
}
|
||||
return c.writerDescriptor(c.Stderr)
|
||||
}
|
||||
|
||||
func (c *Cmd) writerDescriptor(w io.Writer) (f *os.File, err error) {
|
||||
if w == nil {
|
||||
f, err = os.OpenFile(os.DevNull, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.closeAfterStart = append(c.closeAfterStart, f)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(*os.File); ok {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.closeAfterStart = append(c.closeAfterStart, pw)
|
||||
c.closeAfterWait = append(c.closeAfterWait, pr)
|
||||
c.goroutine = append(c.goroutine, func() error {
|
||||
_, err := io.Copy(w, pr)
|
||||
pr.Close() // in case io.Copy stopped due to write error
|
||||
return err
|
||||
})
|
||||
return pw, nil
|
||||
}
|
||||
|
||||
func (c *Cmd) closeDescriptors(closers []io.Closer) {
|
||||
for _, fd := range closers {
|
||||
fd.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the specified command and waits for it to complete.
|
||||
//
|
||||
// The returned error is nil if the command runs, has no problems
|
||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||
// status.
|
||||
//
|
||||
// If the command starts but does not complete successfully, the error is of
|
||||
// type *ExitError. Other error types may be returned for other situations.
|
||||
//
|
||||
// If the calling goroutine has locked the operating system thread
|
||||
// with runtime.LockOSThread and modified any inheritable OS-level
|
||||
// thread state (for example, Linux or Plan 9 name spaces), the new
|
||||
// process will inherit the caller's thread state.
|
||||
func (c *Cmd) Run() error {
|
||||
if err := c.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Wait()
|
||||
}
|
||||
|
||||
// lookExtensions finds windows executable by its dir and path.
|
||||
// It uses LookPath to try appropriate extensions.
|
||||
// lookExtensions does not search PATH, instead it converts `prog` into `.\prog`.
|
||||
func lookExtensions(path, dir string) (string, error) {
|
||||
if filepath.Base(path) == path {
|
||||
path = filepath.Join(".", path)
|
||||
}
|
||||
if dir == "" {
|
||||
return LookPath(path)
|
||||
}
|
||||
if filepath.VolumeName(path) != "" {
|
||||
return LookPath(path)
|
||||
}
|
||||
if len(path) > 1 && os.IsPathSeparator(path[0]) {
|
||||
return LookPath(path)
|
||||
}
|
||||
dirandpath := filepath.Join(dir, path)
|
||||
// We assume that LookPath will only add file extension.
|
||||
lp, err := LookPath(dirandpath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ext := strings.TrimPrefix(lp, dirandpath)
|
||||
return path + ext, nil
|
||||
}
|
||||
|
||||
// Start starts the specified command but does not wait for it to complete.
|
||||
//
|
||||
// The Wait method will return the exit code and release associated resources
|
||||
// once the command exits.
|
||||
func (c *Cmd) Start() error {
|
||||
if c.lookPathErr != nil {
|
||||
c.closeDescriptors(c.closeAfterStart)
|
||||
c.closeDescriptors(c.closeAfterWait)
|
||||
return c.lookPathErr
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
lp, err := lookExtensions(c.Path, c.Dir)
|
||||
if err != nil {
|
||||
c.closeDescriptors(c.closeAfterStart)
|
||||
c.closeDescriptors(c.closeAfterWait)
|
||||
return err
|
||||
}
|
||||
c.Path = lp
|
||||
}
|
||||
if c.Process != nil {
|
||||
return errors.New("exec: already started")
|
||||
}
|
||||
if c.ctx != nil {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
c.closeDescriptors(c.closeAfterStart)
|
||||
c.closeDescriptors(c.closeAfterWait)
|
||||
return c.ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
c.childFiles = make([]*os.File, 0, 3+len(c.ExtraFiles))
|
||||
type F func(*Cmd) (*os.File, error)
|
||||
for _, setupFd := range []F{(*Cmd).stdin, (*Cmd).stdout, (*Cmd).stderr} {
|
||||
fd, err := setupFd(c)
|
||||
if err != nil {
|
||||
c.closeDescriptors(c.closeAfterStart)
|
||||
c.closeDescriptors(c.closeAfterWait)
|
||||
return err
|
||||
}
|
||||
c.childFiles = append(c.childFiles, fd)
|
||||
}
|
||||
c.childFiles = append(c.childFiles, c.ExtraFiles...)
|
||||
|
||||
var err error
|
||||
c.Process, err = os.StartProcess(c.Path, c.argv(), &os.ProcAttr{
|
||||
Dir: c.Dir,
|
||||
Files: c.childFiles,
|
||||
Env: addCriticalEnv(dedupEnv(c.envv())),
|
||||
Sys: c.SysProcAttr,
|
||||
})
|
||||
if err != nil {
|
||||
c.closeDescriptors(c.closeAfterStart)
|
||||
c.closeDescriptors(c.closeAfterWait)
|
||||
return err
|
||||
}
|
||||
|
||||
c.closeDescriptors(c.closeAfterStart)
|
||||
|
||||
// Don't allocate the channel unless there are goroutines to fire.
|
||||
if len(c.goroutine) > 0 {
|
||||
c.errch = make(chan error, len(c.goroutine))
|
||||
for _, fn := range c.goroutine {
|
||||
go func(fn func() error) {
|
||||
c.errch <- fn()
|
||||
}(fn)
|
||||
}
|
||||
}
|
||||
|
||||
if c.ctx != nil {
|
||||
c.waitDone = make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
c.Process.Kill()
|
||||
case <-c.waitDone:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// An ExitError reports an unsuccessful exit by a command.
|
||||
type ExitError struct {
|
||||
*os.ProcessState
|
||||
|
||||
// Stderr holds a subset of the standard error output from the
|
||||
// Cmd.Output method if standard error was not otherwise being
|
||||
// collected.
|
||||
//
|
||||
// If the error output is long, Stderr may contain only a prefix
|
||||
// and suffix of the output, with the middle replaced with
|
||||
// text about the number of omitted bytes.
|
||||
//
|
||||
// Stderr is provided for debugging, for inclusion in error messages.
|
||||
// Users with other needs should redirect Cmd.Stderr as needed.
|
||||
Stderr []byte
|
||||
}
|
||||
|
||||
func (e *ExitError) Error() string {
|
||||
return e.ProcessState.String()
|
||||
}
|
||||
|
||||
// Wait waits for the command to exit and waits for any copying to
|
||||
// stdin or copying from stdout or stderr to complete.
|
||||
//
|
||||
// The command must have been started by Start.
|
||||
//
|
||||
// The returned error is nil if the command runs, has no problems
|
||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||
// status.
|
||||
//
|
||||
// If the command fails to run or doesn't complete successfully, the
|
||||
// error is of type *ExitError. Other error types may be
|
||||
// returned for I/O problems.
|
||||
//
|
||||
// If any of c.Stdin, c.Stdout or c.Stderr are not an *os.File, Wait also waits
|
||||
// for the respective I/O loop copying to or from the process to complete.
|
||||
//
|
||||
// Wait releases any resources associated with the Cmd.
|
||||
func (c *Cmd) Wait() error {
|
||||
if c.Process == nil {
|
||||
return errors.New("exec: not started")
|
||||
}
|
||||
if c.finished {
|
||||
return errors.New("exec: Wait was already called")
|
||||
}
|
||||
c.finished = true
|
||||
|
||||
var err error
|
||||
var state *os.ProcessState
|
||||
for {
|
||||
state, err = c.Process.Wait()
|
||||
if err != nil {
|
||||
xe, ok := err.(*os.SyscallError)
|
||||
if ok {
|
||||
if xe.Unwrap() == syscall.EINTR {
|
||||
// temporary error, retry wait syscall
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if c.waitDone != nil {
|
||||
close(c.waitDone)
|
||||
}
|
||||
c.ProcessState = state
|
||||
|
||||
var copyError error
|
||||
for range c.goroutine {
|
||||
if err := <-c.errch; err != nil && copyError == nil {
|
||||
copyError = err
|
||||
}
|
||||
}
|
||||
|
||||
c.closeDescriptors(c.closeAfterWait)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !state.Success() {
|
||||
return &ExitError{ProcessState: state}
|
||||
}
|
||||
|
||||
return copyError
|
||||
}
|
||||
|
||||
// Output runs the command and returns its standard output.
|
||||
// Any returned error will usually be of type *ExitError.
|
||||
// If c.Stderr was nil, Output populates ExitError.Stderr.
|
||||
func (c *Cmd) Output() ([]byte, error) {
|
||||
if c.Stdout != nil {
|
||||
return nil, errors.New("exec: Stdout already set")
|
||||
}
|
||||
var stdout bytes.Buffer
|
||||
c.Stdout = &stdout
|
||||
|
||||
captureErr := c.Stderr == nil
|
||||
if captureErr {
|
||||
c.Stderr = &prefixSuffixSaver{N: 32 << 10}
|
||||
}
|
||||
|
||||
err := c.Run()
|
||||
if err != nil && captureErr {
|
||||
if ee, ok := err.(*ExitError); ok {
|
||||
ee.Stderr = c.Stderr.(*prefixSuffixSaver).Bytes()
|
||||
}
|
||||
}
|
||||
return stdout.Bytes(), err
|
||||
}
|
||||
|
||||
// CombinedOutput runs the command and returns its combined standard
|
||||
// output and standard error.
|
||||
func (c *Cmd) CombinedOutput() ([]byte, error) {
|
||||
if c.Stdout != nil {
|
||||
return nil, errors.New("exec: Stdout already set")
|
||||
}
|
||||
if c.Stderr != nil {
|
||||
return nil, errors.New("exec: Stderr already set")
|
||||
}
|
||||
var b bytes.Buffer
|
||||
c.Stdout = &b
|
||||
c.Stderr = &b
|
||||
err := c.Run()
|
||||
return b.Bytes(), err
|
||||
}
|
||||
|
||||
// StdinPipe returns a pipe that will be connected to the command's
|
||||
// standard input when the command starts.
|
||||
// The pipe will be closed automatically after Wait sees the command exit.
|
||||
// A caller need only call Close to force the pipe to close sooner.
|
||||
// For example, if the command being run will not exit until standard input
|
||||
// is closed, the caller must close the pipe.
|
||||
func (c *Cmd) StdinPipe() (io.WriteCloser, error) {
|
||||
if c.Stdin != nil {
|
||||
return nil, errors.New("exec: Stdin already set")
|
||||
}
|
||||
if c.Process != nil {
|
||||
return nil, errors.New("exec: StdinPipe after process started")
|
||||
}
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Stdin = pr
|
||||
c.closeAfterStart = append(c.closeAfterStart, pr)
|
||||
wc := &closeOnce{File: pw}
|
||||
c.closeAfterWait = append(c.closeAfterWait, wc)
|
||||
return wc, nil
|
||||
}
|
||||
|
||||
type closeOnce struct {
|
||||
*os.File
|
||||
|
||||
once sync.Once
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *closeOnce) Close() error {
|
||||
c.once.Do(c.close)
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *closeOnce) close() {
|
||||
c.err = c.File.Close()
|
||||
}
|
||||
|
||||
// StdoutPipe returns a pipe that will be connected to the command's
|
||||
// standard output when the command starts.
|
||||
//
|
||||
// Wait will close the pipe after seeing the command exit, so most callers
|
||||
// need not close the pipe themselves; however, an implication is that
|
||||
// it is incorrect to call Wait before all reads from the pipe have completed.
|
||||
// For the same reason, it is incorrect to call Run when using StdoutPipe.
|
||||
// See the example for idiomatic usage.
|
||||
func (c *Cmd) StdoutPipe() (io.ReadCloser, error) {
|
||||
if c.Stdout != nil {
|
||||
return nil, errors.New("exec: Stdout already set")
|
||||
}
|
||||
if c.Process != nil {
|
||||
return nil, errors.New("exec: StdoutPipe after process started")
|
||||
}
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Stdout = pw
|
||||
c.closeAfterStart = append(c.closeAfterStart, pw)
|
||||
c.closeAfterWait = append(c.closeAfterWait, pr)
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
// StderrPipe returns a pipe that will be connected to the command's
|
||||
// standard error when the command starts.
|
||||
//
|
||||
// Wait will close the pipe after seeing the command exit, so most callers
|
||||
// need not close the pipe themselves; however, an implication is that
|
||||
// it is incorrect to call Wait before all reads from the pipe have completed.
|
||||
// For the same reason, it is incorrect to use Run when using StderrPipe.
|
||||
// See the StdoutPipe example for idiomatic usage.
|
||||
func (c *Cmd) StderrPipe() (io.ReadCloser, error) {
|
||||
if c.Stderr != nil {
|
||||
return nil, errors.New("exec: Stderr already set")
|
||||
}
|
||||
if c.Process != nil {
|
||||
return nil, errors.New("exec: StderrPipe after process started")
|
||||
}
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Stderr = pw
|
||||
c.closeAfterStart = append(c.closeAfterStart, pw)
|
||||
c.closeAfterWait = append(c.closeAfterWait, pr)
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
// prefixSuffixSaver is an io.Writer which retains the first N bytes
|
||||
// and the last N bytes written to it. The Bytes() methods reconstructs
|
||||
// it with a pretty error message.
|
||||
type prefixSuffixSaver struct {
|
||||
N int // max size of prefix or suffix
|
||||
prefix []byte
|
||||
suffix []byte // ring buffer once len(suffix) == N
|
||||
suffixOff int // offset to write into suffix
|
||||
skipped int64
|
||||
|
||||
// TODO(bradfitz): we could keep one large []byte and use part of it for
|
||||
// the prefix, reserve space for the '... Omitting N bytes ...' message,
|
||||
// then the ring buffer suffix, and just rearrange the ring buffer
|
||||
// suffix when Bytes() is called, but it doesn't seem worth it for
|
||||
// now just for error messages. It's only ~64KB anyway.
|
||||
}
|
||||
|
||||
func (w *prefixSuffixSaver) Write(p []byte) (n int, err error) {
|
||||
lenp := len(p)
|
||||
p = w.fill(&w.prefix, p)
|
||||
|
||||
// Only keep the last w.N bytes of suffix data.
|
||||
if overage := len(p) - w.N; overage > 0 {
|
||||
p = p[overage:]
|
||||
w.skipped += int64(overage)
|
||||
}
|
||||
p = w.fill(&w.suffix, p)
|
||||
|
||||
// w.suffix is full now if p is non-empty. Overwrite it in a circle.
|
||||
for len(p) > 0 { // 0, 1, or 2 iterations.
|
||||
n := copy(w.suffix[w.suffixOff:], p)
|
||||
p = p[n:]
|
||||
w.skipped += int64(n)
|
||||
w.suffixOff += n
|
||||
if w.suffixOff == w.N {
|
||||
w.suffixOff = 0
|
||||
}
|
||||
}
|
||||
return lenp, nil
|
||||
}
|
||||
|
||||
// fill appends up to len(p) bytes of p to *dst, such that *dst does not
|
||||
// grow larger than w.N. It returns the un-appended suffix of p.
|
||||
func (w *prefixSuffixSaver) fill(dst *[]byte, p []byte) (pRemain []byte) {
|
||||
if remain := w.N - len(*dst); remain > 0 {
|
||||
add := minInt(len(p), remain)
|
||||
*dst = append(*dst, p[:add]...)
|
||||
p = p[add:]
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (w *prefixSuffixSaver) Bytes() []byte {
|
||||
if w.suffix == nil {
|
||||
return w.prefix
|
||||
}
|
||||
if w.skipped == 0 {
|
||||
return append(w.prefix, w.suffix...)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(len(w.prefix) + len(w.suffix) + 50)
|
||||
buf.Write(w.prefix)
|
||||
buf.WriteString("\n... omitting ")
|
||||
buf.WriteString(strconv.FormatInt(w.skipped, 10))
|
||||
buf.WriteString(" bytes ...\n")
|
||||
buf.Write(w.suffix[w.suffixOff:])
|
||||
buf.Write(w.suffix[:w.suffixOff])
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// dedupEnv returns a copy of env with any duplicates removed, in favor of
|
||||
// later values.
|
||||
// Items not of the normal environment "key=value" form are preserved unchanged.
|
||||
func dedupEnv(env []string) []string {
|
||||
return dedupEnvCase(runtime.GOOS == "windows", env)
|
||||
}
|
||||
|
||||
// dedupEnvCase is dedupEnv with a case option for testing.
|
||||
// If caseInsensitive is true, the case of keys is ignored.
|
||||
func dedupEnvCase(caseInsensitive bool, env []string) []string {
|
||||
out := make([]string, 0, len(env))
|
||||
saw := make(map[string]int, len(env)) // key => index into out
|
||||
for _, kv := range env {
|
||||
eq := strings.Index(kv, "=")
|
||||
if eq < 0 {
|
||||
out = append(out, kv)
|
||||
continue
|
||||
}
|
||||
k := kv[:eq]
|
||||
if caseInsensitive {
|
||||
k = strings.ToLower(k)
|
||||
}
|
||||
if dupIdx, isDup := saw[k]; isDup {
|
||||
out[dupIdx] = kv
|
||||
continue
|
||||
}
|
||||
saw[k] = len(out)
|
||||
out = append(out, kv)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// addCriticalEnv adds any critical environment variables that are required
|
||||
// (or at least almost always required) on the operating system.
|
||||
// Currently this is only used for Windows.
|
||||
func addCriticalEnv(env []string) []string {
|
||||
if runtime.GOOS != "windows" {
|
||||
return env
|
||||
}
|
||||
for _, kv := range env {
|
||||
eq := strings.Index(kv, "=")
|
||||
if eq < 0 {
|
||||
continue
|
||||
}
|
||||
k := kv[:eq]
|
||||
if strings.EqualFold(k, "SYSTEMROOT") {
|
||||
// We already have it.
|
||||
return env
|
||||
}
|
||||
}
|
||||
return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !plan9,!windows
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func init() {
|
||||
skipStdinCopyError = func(err error) bool {
|
||||
// Ignore EPIPE errors copying to stdin if the program
|
||||
// completed successfully otherwise.
|
||||
// See Issue 9173.
|
||||
pe, ok := err.(*os.PathError)
|
||||
return ok &&
|
||||
pe.Op == "write" && pe.Path == "|1" &&
|
||||
pe.Err == syscall.EPIPE
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright 2017 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func init() {
|
||||
skipStdinCopyError = func(err error) bool {
|
||||
// Ignore ERROR_BROKEN_PIPE and ERROR_NO_DATA errors copying
|
||||
// to stdin if the program completed successfully otherwise.
|
||||
// See Issue 20445.
|
||||
const _ERROR_NO_DATA = syscall.Errno(0xe8)
|
||||
pe, ok := err.(*os.PathError)
|
||||
return ok &&
|
||||
pe.Op == "write" && pe.Path == "|1" &&
|
||||
(pe.Err == syscall.ERROR_BROKEN_PIPE || pe.Err == _ERROR_NO_DATA)
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrefixSuffixSaver(t *testing.T) {
|
||||
tests := []struct {
|
||||
N int
|
||||
writes []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
N: 2,
|
||||
writes: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
N: 2,
|
||||
writes: []string{"a"},
|
||||
want: "a",
|
||||
},
|
||||
{
|
||||
N: 2,
|
||||
writes: []string{"abc", "d"},
|
||||
want: "abcd",
|
||||
},
|
||||
{
|
||||
N: 2,
|
||||
writes: []string{"abc", "d", "e"},
|
||||
want: "ab\n... omitting 1 bytes ...\nde",
|
||||
},
|
||||
{
|
||||
N: 2,
|
||||
writes: []string{"ab______________________yz"},
|
||||
want: "ab\n... omitting 22 bytes ...\nyz",
|
||||
},
|
||||
{
|
||||
N: 2,
|
||||
writes: []string{"ab_______________________y", "z"},
|
||||
want: "ab\n... omitting 23 bytes ...\nyz",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
w := &prefixSuffixSaver{N: tt.N}
|
||||
for _, s := range tt.writes {
|
||||
n, err := io.WriteString(w, s)
|
||||
if err != nil || n != len(s) {
|
||||
t.Errorf("%d. WriteString(%q) = %v, %v; want %v, %v", i, s, n, err, len(s), nil)
|
||||
}
|
||||
}
|
||||
if got := string(w.Bytes()); got != tt.want {
|
||||
t.Errorf("%d. Bytes = %q; want %q", i, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build js,wasm
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrNotFound is the error resulting if a path search failed to find an executable file.
|
||||
var ErrNotFound = errors.New("executable file not found in $PATH")
|
||||
|
||||
// LookPath searches for an executable named file in the
|
||||
// directories named by the PATH environment variable.
|
||||
// If file contains a slash, it is tried directly and the PATH is not consulted.
|
||||
// The result may be an absolute path or a path relative to the current directory.
|
||||
func LookPath(file string) (string, error) {
|
||||
// Wasm can not execute processes, so act as if there are no executables at all.
|
||||
return "", &Error{file, ErrNotFound}
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrNotFound is the error resulting if a path search failed to find an executable file.
|
||||
var ErrNotFound = errors.New("executable file not found in $path")
|
||||
|
||||
func findExecutable(file string) error {
|
||||
d, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m := d.Mode(); !m.IsDir() && m&0111 != 0 {
|
||||
return nil
|
||||
}
|
||||
return os.ErrPermission
|
||||
}
|
||||
|
||||
// LookPath searches for an executable named file in the
|
||||
// directories named by the path environment variable.
|
||||
// If file begins with "/", "#", "./", or "../", it is tried
|
||||
// directly and the path is not consulted.
|
||||
// The result may be an absolute path or a path relative to the current directory.
|
||||
func LookPath(file string) (string, error) {
|
||||
// skip the path lookup for these prefixes
|
||||
skip := []string{"/", "#", "./", "../"}
|
||||
|
||||
for _, p := range skip {
|
||||
if strings.HasPrefix(file, p) {
|
||||
err := findExecutable(file)
|
||||
if err == nil {
|
||||
return file, nil
|
||||
}
|
||||
return "", &Error{file, err}
|
||||
}
|
||||
}
|
||||
|
||||
path := os.Getenv("path")
|
||||
for _, dir := range filepath.SplitList(path) {
|
||||
path := filepath.Join(dir, file)
|
||||
if err := findExecutable(path); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
return "", &Error{file, ErrNotFound}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var nonExistentPaths = []string{
|
||||
"some-non-existent-path",
|
||||
"non-existent-path/slashed",
|
||||
}
|
||||
|
||||
func TestLookPathNotFound(t *testing.T) {
|
||||
for _, name := range nonExistentPaths {
|
||||
path, err := LookPath(name)
|
||||
if err == nil {
|
||||
t.Fatalf("LookPath found %q in $PATH", name)
|
||||
}
|
||||
if path != "" {
|
||||
t.Fatalf("LookPath path == %q when err != nil", path)
|
||||
}
|
||||
perr, ok := err.(*Error)
|
||||
if !ok {
|
||||
t.Fatal("LookPath error is not an exec.Error")
|
||||
}
|
||||
if perr.Name != name {
|
||||
t.Fatalf("want Error name %q, got %q", name, perr.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrNotFound is the error resulting if a path search failed to find an executable file.
|
||||
var ErrNotFound = errors.New("executable file not found in $PATH")
|
||||
|
||||
func findExecutable(file string) error {
|
||||
d, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m := d.Mode(); !m.IsDir() && m&0111 != 0 {
|
||||
return nil
|
||||
}
|
||||
return os.ErrPermission
|
||||
}
|
||||
|
||||
// LookPath searches for an executable named file in the
|
||||
// directories named by the PATH environment variable.
|
||||
// If file contains a slash, it is tried directly and the PATH is not consulted.
|
||||
// The result may be an absolute path or a path relative to the current directory.
|
||||
func LookPath(file string) (string, error) {
|
||||
// NOTE(rsc): I wish we could use the Plan 9 behavior here
|
||||
// (only bypass the path if file begins with / or ./ or ../)
|
||||
// but that would not match all the Unix shells.
|
||||
|
||||
if strings.Contains(file, "/") {
|
||||
err := findExecutable(file)
|
||||
if err == nil {
|
||||
return file, nil
|
||||
}
|
||||
return "", &Error{file, err}
|
||||
}
|
||||
path := os.Getenv("PATH")
|
||||
for _, dir := range filepath.SplitList(path) {
|
||||
if dir == "" {
|
||||
// Unix shell semantics: path element "" means "."
|
||||
dir = "."
|
||||
}
|
||||
path := filepath.Join(dir, file)
|
||||
if err := findExecutable(path); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
return "", &Error{file, ErrNotFound}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLookPathUnixEmptyPath(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal("Getwd failed: ", err)
|
||||
}
|
||||
err = os.Chdir(tmp)
|
||||
if err != nil {
|
||||
t.Fatal("Chdir failed: ", err)
|
||||
}
|
||||
defer os.Chdir(wd)
|
||||
|
||||
f, err := os.OpenFile("exec_me", os.O_CREATE|os.O_EXCL, 0700)
|
||||
if err != nil {
|
||||
t.Fatal("OpenFile failed: ", err)
|
||||
}
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
t.Fatal("Close failed: ", err)
|
||||
}
|
||||
|
||||
pathenv := os.Getenv("PATH")
|
||||
defer os.Setenv("PATH", pathenv)
|
||||
|
||||
err = os.Setenv("PATH", "")
|
||||
if err != nil {
|
||||
t.Fatal("Setenv failed: ", err)
|
||||
}
|
||||
|
||||
path, err := LookPath("exec_me")
|
||||
if err == nil {
|
||||
t.Fatal("LookPath found exec_me in empty $PATH")
|
||||
}
|
||||
if path != "" {
|
||||
t.Fatalf("LookPath path == %q when err != nil", path)
|
||||
}
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package exec
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrNotFound is the error resulting if a path search failed to find an executable file.
|
||||
var ErrNotFound = errors.New("executable file not found in %PATH%")
|
||||
|
||||
func chkStat(file string) error {
|
||||
d, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return os.ErrPermission
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasExt(file string) bool {
|
||||
i := strings.LastIndex(file, ".")
|
||||
if i < 0 {
|
||||
return false
|
||||
}
|
||||
return strings.LastIndexAny(file, `:\/`) < i
|
||||
}
|
||||
|
||||
func findExecutable(file string, exts []string) (string, error) {
|
||||
if len(exts) == 0 {
|
||||
return file, chkStat(file)
|
||||
}
|
||||
if hasExt(file) {
|
||||
if chkStat(file) == nil {
|
||||
return file, nil
|
||||
}
|
||||
}
|
||||
for _, e := range exts {
|
||||
if f := file + e; chkStat(f) == nil {
|
||||
return f, nil
|
||||
}
|
||||
}
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
// LookPath searches for an executable named file in the
|
||||
// directories named by the PATH environment variable.
|
||||
// If file contains a slash, it is tried directly and the PATH is not consulted.
|
||||
// LookPath also uses PATHEXT environment variable to match
|
||||
// a suitable candidate.
|
||||
// The result may be an absolute path or a path relative to the current directory.
|
||||
func LookPath(file string) (string, error) {
|
||||
var exts []string
|
||||
x := os.Getenv(`PATHEXT`)
|
||||
if x != "" {
|
||||
for _, e := range strings.Split(strings.ToLower(x), `;`) {
|
||||
if e == "" {
|
||||
continue
|
||||
}
|
||||
if e[0] != '.' {
|
||||
e = "." + e
|
||||
}
|
||||
exts = append(exts, e)
|
||||
}
|
||||
} else {
|
||||
exts = []string{".com", ".exe", ".bat", ".cmd"}
|
||||
}
|
||||
|
||||
if strings.ContainsAny(file, `:\/`) {
|
||||
if f, err := findExecutable(file, exts); err == nil {
|
||||
return f, nil
|
||||
} else {
|
||||
return "", &Error{file, err}
|
||||
}
|
||||
}
|
||||
if f, err := findExecutable(filepath.Join(".", file), exts); err == nil {
|
||||
return f, nil
|
||||
}
|
||||
path := os.Getenv("path")
|
||||
for _, dir := range filepath.SplitList(path) {
|
||||
if f, err := findExecutable(filepath.Join(dir, file), exts); err == nil {
|
||||
return f, nil
|
||||
}
|
||||
}
|
||||
return "", &Error{file, ErrNotFound}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package registry
|
||||
|
||||
func (k Key) SetValue(name string, valtype uint32, data []byte) error {
|
||||
return k.setValue(name, valtype, data)
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package registry_test
|
||||
|
||||
// Tailscale's redo-based build system doesn't know how to skip running Go tests
|
||||
// in directories that don't contain files for the current OS.
|
||||
//
|
||||
// https://github.com/tailscale/corp/issues/293
|
||||
//
|
||||
// So this is a dummy file for now to make it happy.
|
||||
@@ -1,204 +0,0 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
// Package registry provides access to the Windows registry.
|
||||
//
|
||||
// Here is a simple example, opening a registry key and reading a string value from it.
|
||||
//
|
||||
// k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// defer k.Close()
|
||||
//
|
||||
// s, _, err := k.GetStringValue("SystemRoot")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// fmt.Printf("Windows system root is %q\n", s)
|
||||
//
|
||||
package registry
|
||||
|
||||
import (
|
||||
"io"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Registry key security and access rights.
|
||||
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms724878.aspx
|
||||
// for details.
|
||||
ALL_ACCESS = 0xf003f
|
||||
CREATE_LINK = 0x00020
|
||||
CREATE_SUB_KEY = 0x00004
|
||||
ENUMERATE_SUB_KEYS = 0x00008
|
||||
EXECUTE = 0x20019
|
||||
NOTIFY = 0x00010
|
||||
QUERY_VALUE = 0x00001
|
||||
READ = 0x20019
|
||||
SET_VALUE = 0x00002
|
||||
WOW64_32KEY = 0x00200
|
||||
WOW64_64KEY = 0x00100
|
||||
WRITE = 0x20006
|
||||
)
|
||||
|
||||
// Key is a handle to an open Windows registry key.
|
||||
// Keys can be obtained by calling OpenKey; there are
|
||||
// also some predefined root keys such as CURRENT_USER.
|
||||
// Keys can be used directly in the Windows API.
|
||||
type Key syscall.Handle
|
||||
|
||||
const (
|
||||
// Windows defines some predefined root keys that are always open.
|
||||
// An application can use these keys as entry points to the registry.
|
||||
// Normally these keys are used in OpenKey to open new keys,
|
||||
// but they can also be used anywhere a Key is required.
|
||||
CLASSES_ROOT = Key(syscall.HKEY_CLASSES_ROOT)
|
||||
CURRENT_USER = Key(syscall.HKEY_CURRENT_USER)
|
||||
LOCAL_MACHINE = Key(syscall.HKEY_LOCAL_MACHINE)
|
||||
USERS = Key(syscall.HKEY_USERS)
|
||||
CURRENT_CONFIG = Key(syscall.HKEY_CURRENT_CONFIG)
|
||||
PERFORMANCE_DATA = Key(syscall.HKEY_PERFORMANCE_DATA)
|
||||
)
|
||||
|
||||
// Close closes open key k.
|
||||
func (k Key) Close() error {
|
||||
return syscall.RegCloseKey(syscall.Handle(k))
|
||||
}
|
||||
|
||||
// WaitChange waits for k to change using RegNotifyChangeKeyValue.
|
||||
// The subtree parameter is whether subtrees should also be watched.
|
||||
func (k Key) WaitChange(subtree bool) error {
|
||||
return regNotifyChangeKeyValue(syscall.Handle(k), subtree, 0, 0, false)
|
||||
}
|
||||
|
||||
// OpenKey opens a new key with path name relative to key k.
|
||||
// It accepts any open key, including CURRENT_USER and others,
|
||||
// and returns the new key and an error.
|
||||
// The access parameter specifies desired access rights to the
|
||||
// key to be opened.
|
||||
func OpenKey(k Key, path string, access uint32) (Key, error) {
|
||||
p, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var subkey syscall.Handle
|
||||
err = syscall.RegOpenKeyEx(syscall.Handle(k), p, 0, access, &subkey)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return Key(subkey), nil
|
||||
}
|
||||
|
||||
// OpenRemoteKey opens a predefined registry key on another
|
||||
// computer pcname. The key to be opened is specified by k, but
|
||||
// can only be one of LOCAL_MACHINE, PERFORMANCE_DATA or USERS.
|
||||
// If pcname is "", OpenRemoteKey returns local computer key.
|
||||
func OpenRemoteKey(pcname string, k Key) (Key, error) {
|
||||
var err error
|
||||
var p *uint16
|
||||
if pcname != "" {
|
||||
p, err = syscall.UTF16PtrFromString(`\\` + pcname)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
var remoteKey syscall.Handle
|
||||
err = regConnectRegistry(p, syscall.Handle(k), &remoteKey)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return Key(remoteKey), nil
|
||||
}
|
||||
|
||||
// ReadSubKeyNames returns the names of subkeys of key k.
|
||||
// The parameter n controls the number of returned names,
|
||||
// analogous to the way os.File.Readdirnames works.
|
||||
func (k Key) ReadSubKeyNames(n int) ([]string, error) {
|
||||
names := make([]string, 0)
|
||||
// Registry key size limit is 255 bytes and described there:
|
||||
// https://msdn.microsoft.com/library/windows/desktop/ms724872.aspx
|
||||
buf := make([]uint16, 256) //plus extra room for terminating zero byte
|
||||
loopItems:
|
||||
for i := uint32(0); ; i++ {
|
||||
if n > 0 {
|
||||
if len(names) == n {
|
||||
return names, nil
|
||||
}
|
||||
}
|
||||
l := uint32(len(buf))
|
||||
for {
|
||||
err := syscall.RegEnumKeyEx(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err == syscall.ERROR_MORE_DATA {
|
||||
// Double buffer size and try again.
|
||||
l = uint32(2 * len(buf))
|
||||
buf = make([]uint16, l)
|
||||
continue
|
||||
}
|
||||
if err == _ERROR_NO_MORE_ITEMS {
|
||||
break loopItems
|
||||
}
|
||||
return names, err
|
||||
}
|
||||
names = append(names, syscall.UTF16ToString(buf[:l]))
|
||||
}
|
||||
if n > len(names) {
|
||||
return names, io.EOF
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// CreateKey creates a key named path under open key k.
|
||||
// CreateKey returns the new key and a boolean flag that reports
|
||||
// whether the key already existed.
|
||||
// The access parameter specifies the access rights for the key
|
||||
// to be created.
|
||||
func CreateKey(k Key, path string, access uint32) (newk Key, openedExisting bool, err error) {
|
||||
var h syscall.Handle
|
||||
var d uint32
|
||||
err = regCreateKeyEx(syscall.Handle(k), syscall.StringToUTF16Ptr(path),
|
||||
0, nil, _REG_OPTION_NON_VOLATILE, access, nil, &h, &d)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return Key(h), d == _REG_OPENED_EXISTING_KEY, nil
|
||||
}
|
||||
|
||||
// DeleteKey deletes the subkey path of key k and its values.
|
||||
func DeleteKey(k Key, path string) error {
|
||||
return regDeleteKey(syscall.Handle(k), syscall.StringToUTF16Ptr(path))
|
||||
}
|
||||
|
||||
// A KeyInfo describes the statistics of a key. It is returned by Stat.
|
||||
type KeyInfo struct {
|
||||
SubKeyCount uint32
|
||||
MaxSubKeyLen uint32 // size of the key's subkey with the longest name, in Unicode characters, not including the terminating zero byte
|
||||
ValueCount uint32
|
||||
MaxValueNameLen uint32 // size of the key's longest value name, in Unicode characters, not including the terminating zero byte
|
||||
MaxValueLen uint32 // longest data component among the key's values, in bytes
|
||||
lastWriteTime syscall.Filetime
|
||||
}
|
||||
|
||||
// ModTime returns the key's last write time.
|
||||
func (ki *KeyInfo) ModTime() time.Time {
|
||||
return time.Unix(0, ki.lastWriteTime.Nanoseconds())
|
||||
}
|
||||
|
||||
// Stat retrieves information about the open key k.
|
||||
func (k Key) Stat() (*KeyInfo, error) {
|
||||
var ki KeyInfo
|
||||
err := syscall.RegQueryInfoKey(syscall.Handle(k), nil, nil, nil,
|
||||
&ki.SubKeyCount, &ki.MaxSubKeyLen, nil, &ki.ValueCount,
|
||||
&ki.MaxValueNameLen, &ki.MaxValueLen, nil, &ki.lastWriteTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ki, nil
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build generate
|
||||
|
||||
package registry
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall.go
|
||||
@@ -1,701 +0,0 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package registry_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"tailscale.com/tempfork/registry"
|
||||
)
|
||||
|
||||
func randKeyName(prefix string) string {
|
||||
const numbers = "0123456789"
|
||||
buf := make([]byte, 10)
|
||||
rand.Read(buf)
|
||||
for i, b := range buf {
|
||||
buf[i] = numbers[b%byte(len(numbers))]
|
||||
}
|
||||
return prefix + string(buf)
|
||||
}
|
||||
|
||||
func TestReadSubKeyNames(t *testing.T) {
|
||||
k, err := registry.OpenKey(registry.CLASSES_ROOT, "TypeLib", registry.ENUMERATE_SUB_KEYS)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
names, err := k.ReadSubKeyNames(-1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var foundStdOle bool
|
||||
for _, name := range names {
|
||||
// Every PC has "stdole 2.0 OLE Automation" library installed.
|
||||
if name == "{00020430-0000-0000-C000-000000000046}" {
|
||||
foundStdOle = true
|
||||
}
|
||||
}
|
||||
if !foundStdOle {
|
||||
t.Fatal("could not find stdole 2.0 OLE Automation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOpenDeleteKey(t *testing.T) {
|
||||
k, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
testKName := randKeyName("TestCreateOpenDeleteKey_")
|
||||
|
||||
testK, exist, err := registry.CreateKey(k, testKName, registry.CREATE_SUB_KEY)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testK.Close()
|
||||
|
||||
if exist {
|
||||
t.Fatalf("key %q already exists", testKName)
|
||||
}
|
||||
|
||||
testKAgain, exist, err := registry.CreateKey(k, testKName, registry.CREATE_SUB_KEY)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testKAgain.Close()
|
||||
|
||||
if !exist {
|
||||
t.Fatalf("key %q should already exist", testKName)
|
||||
}
|
||||
|
||||
testKOpened, err := registry.OpenKey(k, testKName, registry.ENUMERATE_SUB_KEYS)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testKOpened.Close()
|
||||
|
||||
err = registry.DeleteKey(k, testKName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testKOpenedAgain, err := registry.OpenKey(k, testKName, registry.ENUMERATE_SUB_KEYS)
|
||||
if err == nil {
|
||||
defer testKOpenedAgain.Close()
|
||||
t.Fatalf("key %q should already been deleted", testKName)
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
t.Fatalf(`unexpected error ("not exist" expected): %v`, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatch(t *testing.T) {
|
||||
k, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE|registry.WRITE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
testKName := randKeyName("TestWatch_")
|
||||
testK, _, err := registry.CreateKey(k, testKName, registry.CREATE_SUB_KEY|registry.NOTIFY|registry.WRITE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testK.Close()
|
||||
|
||||
timer := time.AfterFunc(100*time.Millisecond, func() {
|
||||
err := registry.DeleteKey(k, testKName)
|
||||
t.Logf("DeleteKey: %v", err)
|
||||
})
|
||||
defer timer.Stop()
|
||||
t.Logf("pre-wait")
|
||||
t0 := time.Now()
|
||||
err = testK.WaitChange(true)
|
||||
t.Logf("WaitChange after %v: %v", time.Since(t0).Round(time.Millisecond), err)
|
||||
}
|
||||
|
||||
func equalStringSlice(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if a == nil {
|
||||
return true
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type ValueTest struct {
|
||||
Type uint32
|
||||
Name string
|
||||
Value interface{}
|
||||
WillFail bool
|
||||
}
|
||||
|
||||
var ValueTests = []ValueTest{
|
||||
{Type: registry.SZ, Name: "String1", Value: ""},
|
||||
{Type: registry.SZ, Name: "String2", Value: "\000", WillFail: true},
|
||||
{Type: registry.SZ, Name: "String3", Value: "Hello World"},
|
||||
{Type: registry.SZ, Name: "String4", Value: "Hello World\000", WillFail: true},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString1", Value: ""},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString2", Value: "\000", WillFail: true},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString3", Value: "Hello World"},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString4", Value: "Hello\000World", WillFail: true},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString5", Value: "%PATH%"},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString6", Value: "%NO_SUCH_VARIABLE%"},
|
||||
{Type: registry.EXPAND_SZ, Name: "ExpString7", Value: "%PATH%;."},
|
||||
{Type: registry.BINARY, Name: "Binary1", Value: []byte{}},
|
||||
{Type: registry.BINARY, Name: "Binary2", Value: []byte{1, 2, 3}},
|
||||
{Type: registry.BINARY, Name: "Binary3", Value: []byte{3, 2, 1, 0, 1, 2, 3}},
|
||||
{Type: registry.DWORD, Name: "Dword1", Value: uint64(0)},
|
||||
{Type: registry.DWORD, Name: "Dword2", Value: uint64(1)},
|
||||
{Type: registry.DWORD, Name: "Dword3", Value: uint64(0xff)},
|
||||
{Type: registry.DWORD, Name: "Dword4", Value: uint64(0xffff)},
|
||||
{Type: registry.QWORD, Name: "Qword1", Value: uint64(0)},
|
||||
{Type: registry.QWORD, Name: "Qword2", Value: uint64(1)},
|
||||
{Type: registry.QWORD, Name: "Qword3", Value: uint64(0xff)},
|
||||
{Type: registry.QWORD, Name: "Qword4", Value: uint64(0xffff)},
|
||||
{Type: registry.QWORD, Name: "Qword5", Value: uint64(0xffffff)},
|
||||
{Type: registry.QWORD, Name: "Qword6", Value: uint64(0xffffffff)},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString1", Value: []string{"a", "b", "c"}},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString2", Value: []string{"abc", "", "cba"}},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString3", Value: []string{""}},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString4", Value: []string{"abcdef"}},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString5", Value: []string{"\000"}, WillFail: true},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString6", Value: []string{"a\000b"}, WillFail: true},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString7", Value: []string{"ab", "\000", "cd"}, WillFail: true},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString8", Value: []string{"\000", "cd"}, WillFail: true},
|
||||
{Type: registry.MULTI_SZ, Name: "MultiString9", Value: []string{"ab", "\000"}, WillFail: true},
|
||||
}
|
||||
|
||||
func setValues(t *testing.T, k registry.Key) {
|
||||
for _, test := range ValueTests {
|
||||
var err error
|
||||
switch test.Type {
|
||||
case registry.SZ:
|
||||
err = k.SetStringValue(test.Name, test.Value.(string))
|
||||
case registry.EXPAND_SZ:
|
||||
err = k.SetExpandStringValue(test.Name, test.Value.(string))
|
||||
case registry.MULTI_SZ:
|
||||
err = k.SetStringsValue(test.Name, test.Value.([]string))
|
||||
case registry.BINARY:
|
||||
err = k.SetBinaryValue(test.Name, test.Value.([]byte))
|
||||
case registry.DWORD:
|
||||
err = k.SetDWordValue(test.Name, uint32(test.Value.(uint64)))
|
||||
case registry.QWORD:
|
||||
err = k.SetQWordValue(test.Name, test.Value.(uint64))
|
||||
default:
|
||||
t.Fatalf("unsupported type %d for %s value", test.Type, test.Name)
|
||||
}
|
||||
if test.WillFail {
|
||||
if err == nil {
|
||||
t.Fatalf("setting %s value %q should fail, but succeeded", test.Name, test.Value)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enumerateValues(t *testing.T, k registry.Key) {
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
haveNames := make(map[string]bool)
|
||||
for _, n := range names {
|
||||
haveNames[n] = false
|
||||
}
|
||||
for _, test := range ValueTests {
|
||||
wantFound := !test.WillFail
|
||||
_, haveFound := haveNames[test.Name]
|
||||
if wantFound && !haveFound {
|
||||
t.Errorf("value %s is not found while enumerating", test.Name)
|
||||
}
|
||||
if haveFound && !wantFound {
|
||||
t.Errorf("value %s is found while enumerating, but expected to fail", test.Name)
|
||||
}
|
||||
if haveFound {
|
||||
delete(haveNames, test.Name)
|
||||
}
|
||||
}
|
||||
for n, v := range haveNames {
|
||||
t.Errorf("value %s (%v) is found while enumerating, but has not been cretaed", n, v)
|
||||
}
|
||||
}
|
||||
|
||||
func testErrNotExist(t *testing.T, name string, err error) {
|
||||
if err == nil {
|
||||
t.Errorf("%s value should not exist", name)
|
||||
return
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
t.Errorf("reading %s value should return 'not exist' error, but got: %s", name, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testErrUnexpectedType(t *testing.T, test ValueTest, gottype uint32, err error) {
|
||||
if err == nil {
|
||||
t.Errorf("GetXValue(%q) should not succeed", test.Name)
|
||||
return
|
||||
}
|
||||
if err != registry.ErrUnexpectedType {
|
||||
t.Errorf("reading %s value should return 'unexpected key value type' error, but got: %s", test.Name, err)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testGetStringValue(t *testing.T, k registry.Key, test ValueTest) {
|
||||
got, gottype, err := k.GetStringValue(test.Name)
|
||||
if err != nil {
|
||||
t.Errorf("GetStringValue(%s) failed: %v", test.Name, err)
|
||||
return
|
||||
}
|
||||
if got != test.Value {
|
||||
t.Errorf("want %s value %q, got %q", test.Name, test.Value, got)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
if gottype == registry.EXPAND_SZ {
|
||||
_, err = registry.ExpandString(got)
|
||||
if err != nil {
|
||||
t.Errorf("ExpandString(%s) failed: %v", got, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testGetIntegerValue(t *testing.T, k registry.Key, test ValueTest) {
|
||||
got, gottype, err := k.GetIntegerValue(test.Name)
|
||||
if err != nil {
|
||||
t.Errorf("GetIntegerValue(%s) failed: %v", test.Name, err)
|
||||
return
|
||||
}
|
||||
if got != test.Value.(uint64) {
|
||||
t.Errorf("want %s value %v, got %v", test.Name, test.Value, got)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testGetBinaryValue(t *testing.T, k registry.Key, test ValueTest) {
|
||||
got, gottype, err := k.GetBinaryValue(test.Name)
|
||||
if err != nil {
|
||||
t.Errorf("GetBinaryValue(%s) failed: %v", test.Name, err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(got, test.Value.([]byte)) {
|
||||
t.Errorf("want %s value %v, got %v", test.Name, test.Value, got)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testGetStringsValue(t *testing.T, k registry.Key, test ValueTest) {
|
||||
got, gottype, err := k.GetStringsValue(test.Name)
|
||||
if err != nil {
|
||||
t.Errorf("GetStringsValue(%s) failed: %v", test.Name, err)
|
||||
return
|
||||
}
|
||||
if !equalStringSlice(got, test.Value.([]string)) {
|
||||
t.Errorf("want %s value %#v, got %#v", test.Name, test.Value, got)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testGetValue(t *testing.T, k registry.Key, test ValueTest, size int) {
|
||||
if size <= 0 {
|
||||
return
|
||||
}
|
||||
// read data with no buffer
|
||||
gotsize, gottype, err := k.GetValue(test.Name, nil)
|
||||
if err != nil {
|
||||
t.Errorf("GetValue(%s, [%d]byte) failed: %v", test.Name, size, err)
|
||||
return
|
||||
}
|
||||
if gotsize != size {
|
||||
t.Errorf("want %s value size of %d, got %v", test.Name, size, gotsize)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
// read data with short buffer
|
||||
gotsize, gottype, err = k.GetValue(test.Name, make([]byte, size-1))
|
||||
if err == nil {
|
||||
t.Errorf("GetValue(%s, [%d]byte) should fail, but succeeded", test.Name, size-1)
|
||||
return
|
||||
}
|
||||
if err != registry.ErrShortBuffer {
|
||||
t.Errorf("reading %s value should return 'short buffer' error, but got: %s", test.Name, err)
|
||||
return
|
||||
}
|
||||
if gotsize != size {
|
||||
t.Errorf("want %s value size of %d, got %v", test.Name, size, gotsize)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
// read full data
|
||||
gotsize, gottype, err = k.GetValue(test.Name, make([]byte, size))
|
||||
if err != nil {
|
||||
t.Errorf("GetValue(%s, [%d]byte) failed: %v", test.Name, size, err)
|
||||
return
|
||||
}
|
||||
if gotsize != size {
|
||||
t.Errorf("want %s value size of %d, got %v", test.Name, size, gotsize)
|
||||
return
|
||||
}
|
||||
if gottype != test.Type {
|
||||
t.Errorf("want %s value type %v, got %v", test.Name, test.Type, gottype)
|
||||
return
|
||||
}
|
||||
// check GetValue returns ErrNotExist as required
|
||||
_, _, err = k.GetValue(test.Name+"_not_there", make([]byte, size))
|
||||
if err == nil {
|
||||
t.Errorf("GetValue(%q) should not succeed", test.Name)
|
||||
return
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
t.Errorf("GetValue(%q) should return 'not exist' error, but got: %s", test.Name, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func testValues(t *testing.T, k registry.Key) {
|
||||
for _, test := range ValueTests {
|
||||
switch test.Type {
|
||||
case registry.SZ, registry.EXPAND_SZ:
|
||||
if test.WillFail {
|
||||
_, _, err := k.GetStringValue(test.Name)
|
||||
testErrNotExist(t, test.Name, err)
|
||||
} else {
|
||||
testGetStringValue(t, k, test)
|
||||
_, gottype, err := k.GetIntegerValue(test.Name)
|
||||
testErrUnexpectedType(t, test, gottype, err)
|
||||
// Size of utf16 string in bytes is not perfect,
|
||||
// but correct for current test values.
|
||||
// Size also includes terminating 0.
|
||||
testGetValue(t, k, test, (len(test.Value.(string))+1)*2)
|
||||
}
|
||||
_, _, err := k.GetStringValue(test.Name + "_string_not_created")
|
||||
testErrNotExist(t, test.Name+"_string_not_created", err)
|
||||
case registry.DWORD, registry.QWORD:
|
||||
testGetIntegerValue(t, k, test)
|
||||
_, gottype, err := k.GetBinaryValue(test.Name)
|
||||
testErrUnexpectedType(t, test, gottype, err)
|
||||
_, _, err = k.GetIntegerValue(test.Name + "_int_not_created")
|
||||
testErrNotExist(t, test.Name+"_int_not_created", err)
|
||||
size := 8
|
||||
if test.Type == registry.DWORD {
|
||||
size = 4
|
||||
}
|
||||
testGetValue(t, k, test, size)
|
||||
case registry.BINARY:
|
||||
testGetBinaryValue(t, k, test)
|
||||
_, gottype, err := k.GetStringsValue(test.Name)
|
||||
testErrUnexpectedType(t, test, gottype, err)
|
||||
_, _, err = k.GetBinaryValue(test.Name + "_byte_not_created")
|
||||
testErrNotExist(t, test.Name+"_byte_not_created", err)
|
||||
testGetValue(t, k, test, len(test.Value.([]byte)))
|
||||
case registry.MULTI_SZ:
|
||||
if test.WillFail {
|
||||
_, _, err := k.GetStringsValue(test.Name)
|
||||
testErrNotExist(t, test.Name, err)
|
||||
} else {
|
||||
testGetStringsValue(t, k, test)
|
||||
_, gottype, err := k.GetStringValue(test.Name)
|
||||
testErrUnexpectedType(t, test, gottype, err)
|
||||
size := 0
|
||||
for _, s := range test.Value.([]string) {
|
||||
size += len(s) + 1 // nil terminated
|
||||
}
|
||||
size += 1 // extra nil at the end
|
||||
size *= 2 // count bytes, not uint16
|
||||
testGetValue(t, k, test, size)
|
||||
}
|
||||
_, _, err := k.GetStringsValue(test.Name + "_strings_not_created")
|
||||
testErrNotExist(t, test.Name+"_strings_not_created", err)
|
||||
default:
|
||||
t.Errorf("unsupported type %d for %s value", test.Type, test.Name)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testStat(t *testing.T, k registry.Key) {
|
||||
subk, _, err := registry.CreateKey(k, "subkey", registry.CREATE_SUB_KEY)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer subk.Close()
|
||||
|
||||
defer registry.DeleteKey(k, "subkey")
|
||||
|
||||
ki, err := k.Stat()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if ki.SubKeyCount != 1 {
|
||||
t.Error("key must have 1 subkey")
|
||||
}
|
||||
if ki.MaxSubKeyLen != 6 {
|
||||
t.Error("key max subkey name length must be 6")
|
||||
}
|
||||
if ki.ValueCount != 24 {
|
||||
t.Errorf("key must have 24 values, but is %d", ki.ValueCount)
|
||||
}
|
||||
if ki.MaxValueNameLen != 12 {
|
||||
t.Errorf("key max value name length must be 10, but is %d", ki.MaxValueNameLen)
|
||||
}
|
||||
if ki.MaxValueLen != 38 {
|
||||
t.Errorf("key max value length must be 38, but is %d", ki.MaxValueLen)
|
||||
}
|
||||
if mt, ct := ki.ModTime(), time.Now(); ct.Sub(mt) > 100*time.Millisecond {
|
||||
t.Errorf("key mod time is not close to current time: mtime=%v current=%v delta=%v", mt, ct, ct.Sub(mt))
|
||||
}
|
||||
}
|
||||
|
||||
func deleteValues(t *testing.T, k registry.Key) {
|
||||
for _, test := range ValueTests {
|
||||
if test.WillFail {
|
||||
continue
|
||||
}
|
||||
err := k.DeleteValue(test.Name)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if len(names) != 0 {
|
||||
t.Errorf("some values remain after deletion: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
softwareK, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer softwareK.Close()
|
||||
|
||||
testKName := randKeyName("TestValues_")
|
||||
|
||||
k, exist, err := registry.CreateKey(softwareK, testKName, registry.CREATE_SUB_KEY|registry.QUERY_VALUE|registry.SET_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
if exist {
|
||||
t.Fatalf("key %q already exists", testKName)
|
||||
}
|
||||
|
||||
defer registry.DeleteKey(softwareK, testKName)
|
||||
|
||||
setValues(t, k)
|
||||
|
||||
enumerateValues(t, k)
|
||||
|
||||
testValues(t, k)
|
||||
|
||||
testStat(t, k)
|
||||
|
||||
deleteValues(t, k)
|
||||
}
|
||||
|
||||
func TestExpandString(t *testing.T) {
|
||||
got, err := registry.ExpandString("%PATH%")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := os.Getenv("PATH")
|
||||
if got != want {
|
||||
t.Errorf("want %q string expanded, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidValues(t *testing.T) {
|
||||
softwareK, err := registry.OpenKey(registry.CURRENT_USER, "Software", registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer softwareK.Close()
|
||||
|
||||
testKName := randKeyName("TestInvalidValues_")
|
||||
|
||||
k, exist, err := registry.CreateKey(softwareK, testKName, registry.CREATE_SUB_KEY|registry.QUERY_VALUE|registry.SET_VALUE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer k.Close()
|
||||
|
||||
if exist {
|
||||
t.Fatalf("key %q already exists", testKName)
|
||||
}
|
||||
|
||||
defer registry.DeleteKey(softwareK, testKName)
|
||||
|
||||
var tests = []struct {
|
||||
Type uint32
|
||||
Name string
|
||||
Data []byte
|
||||
}{
|
||||
{registry.DWORD, "Dword1", nil},
|
||||
{registry.DWORD, "Dword2", []byte{1, 2, 3}},
|
||||
{registry.QWORD, "Qword1", nil},
|
||||
{registry.QWORD, "Qword2", []byte{1, 2, 3}},
|
||||
{registry.QWORD, "Qword3", []byte{1, 2, 3, 4, 5, 6, 7}},
|
||||
{registry.MULTI_SZ, "MultiString1", nil},
|
||||
{registry.MULTI_SZ, "MultiString2", []byte{0}},
|
||||
{registry.MULTI_SZ, "MultiString3", []byte{'a', 'b', 0}},
|
||||
{registry.MULTI_SZ, "MultiString4", []byte{'a', 0, 0, 'b', 0}},
|
||||
{registry.MULTI_SZ, "MultiString5", []byte{'a', 0, 0}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
err := k.SetValue(test.Name, test.Type, test.Data)
|
||||
if err != nil {
|
||||
t.Fatalf("SetValue for %q failed: %v", test.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
switch test.Type {
|
||||
case registry.DWORD, registry.QWORD:
|
||||
value, valType, err := k.GetIntegerValue(test.Name)
|
||||
if err == nil {
|
||||
t.Errorf("GetIntegerValue(%q) succeeded. Returns type=%d value=%v", test.Name, valType, value)
|
||||
}
|
||||
case registry.MULTI_SZ:
|
||||
value, valType, err := k.GetStringsValue(test.Name)
|
||||
if err == nil {
|
||||
if len(value) != 0 {
|
||||
t.Errorf("GetStringsValue(%q) succeeded. Returns type=%d value=%v", test.Name, valType, value)
|
||||
}
|
||||
}
|
||||
default:
|
||||
t.Errorf("unsupported type %d for %s value", test.Type, test.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMUIStringValue(t *testing.T) {
|
||||
if err := registry.LoadRegLoadMUIString(); err != nil {
|
||||
t.Skip("regLoadMUIString not supported; skipping")
|
||||
}
|
||||
if err := procGetDynamicTimeZoneInformation.Find(); err != nil {
|
||||
t.Skipf("%s not supported; skipping", procGetDynamicTimeZoneInformation.Name)
|
||||
}
|
||||
var dtzi DynamicTimezoneinformation
|
||||
if _, err := GetDynamicTimeZoneInformation(&dtzi); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tzKeyName := syscall.UTF16ToString(dtzi.TimeZoneKeyName[:])
|
||||
timezoneK, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones\`+tzKeyName, registry.READ)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer timezoneK.Close()
|
||||
|
||||
type testType struct {
|
||||
name string
|
||||
want string
|
||||
}
|
||||
var tests = []testType{
|
||||
{"MUI_Std", syscall.UTF16ToString(dtzi.StandardName[:])},
|
||||
}
|
||||
if dtzi.DynamicDaylightTimeDisabled == 0 {
|
||||
tests = append(tests, testType{"MUI_Dlt", syscall.UTF16ToString(dtzi.DaylightName[:])})
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
got, err := timezoneK.GetMUIStringValue(test.name)
|
||||
if err != nil {
|
||||
t.Error("GetMUIStringValue:", err)
|
||||
}
|
||||
|
||||
if got != test.want {
|
||||
t.Errorf("GetMUIStringValue: %s: Got %q, want %q", test.name, got, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type DynamicTimezoneinformation struct {
|
||||
Bias int32
|
||||
StandardName [32]uint16
|
||||
StandardDate syscall.Systemtime
|
||||
StandardBias int32
|
||||
DaylightName [32]uint16
|
||||
DaylightDate syscall.Systemtime
|
||||
DaylightBias int32
|
||||
TimeZoneKeyName [128]uint16
|
||||
DynamicDaylightTimeDisabled uint8
|
||||
}
|
||||
|
||||
var (
|
||||
kernel32DLL = syscall.NewLazyDLL("kernel32")
|
||||
|
||||
procGetDynamicTimeZoneInformation = kernel32DLL.NewProc("GetDynamicTimeZoneInformation")
|
||||
)
|
||||
|
||||
func GetDynamicTimeZoneInformation(dtzi *DynamicTimezoneinformation) (rc uint32, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procGetDynamicTimeZoneInformation.Addr(), 1, uintptr(unsafe.Pointer(dtzi)), 0, 0)
|
||||
rc = uint32(r0)
|
||||
if rc == 0xffffffff {
|
||||
if e1 != 0 {
|
||||
err = error(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package registry
|
||||
|
||||
import "syscall"
|
||||
|
||||
const (
|
||||
_REG_OPTION_NON_VOLATILE = 0
|
||||
|
||||
_REG_CREATED_NEW_KEY = 1
|
||||
_REG_OPENED_EXISTING_KEY = 2
|
||||
|
||||
_ERROR_NO_MORE_ITEMS syscall.Errno = 259
|
||||
)
|
||||
|
||||
func LoadRegLoadMUIString() error {
|
||||
return procRegLoadMUIStringW.Find()
|
||||
}
|
||||
|
||||
//sys regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) = advapi32.RegCreateKeyExW
|
||||
//sys regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) = advapi32.RegDeleteKeyW
|
||||
//sys regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) = advapi32.RegSetValueExW
|
||||
//sys regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) = advapi32.RegEnumValueW
|
||||
//sys regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) = advapi32.RegDeleteValueW
|
||||
//sys regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) = advapi32.RegLoadMUIStringW
|
||||
//sys regConnectRegistry(machinename *uint16, key syscall.Handle, result *syscall.Handle) (regerrno error) = advapi32.RegConnectRegistryW
|
||||
//sys regNotifyChangeKeyValue(key syscall.Handle, watchSubtree bool, notifyFilter uint32, event syscall.Handle, async bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue
|
||||
|
||||
//sys expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) = kernel32.ExpandEnvironmentStringsW
|
||||
@@ -1,386 +0,0 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build windows
|
||||
|
||||
package registry
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"syscall"
|
||||
"unicode/utf16"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Registry value types.
|
||||
NONE = 0
|
||||
SZ = 1
|
||||
EXPAND_SZ = 2
|
||||
BINARY = 3
|
||||
DWORD = 4
|
||||
DWORD_BIG_ENDIAN = 5
|
||||
LINK = 6
|
||||
MULTI_SZ = 7
|
||||
RESOURCE_LIST = 8
|
||||
FULL_RESOURCE_DESCRIPTOR = 9
|
||||
RESOURCE_REQUIREMENTS_LIST = 10
|
||||
QWORD = 11
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrShortBuffer is returned when the buffer was too short for the operation.
|
||||
ErrShortBuffer = syscall.ERROR_MORE_DATA
|
||||
|
||||
// ErrNotExist is returned when a registry key or value does not exist.
|
||||
ErrNotExist = syscall.ERROR_FILE_NOT_FOUND
|
||||
|
||||
// ErrUnexpectedType is returned by Get*Value when the value's type was unexpected.
|
||||
ErrUnexpectedType = errors.New("unexpected key value type")
|
||||
)
|
||||
|
||||
// GetValue retrieves the type and data for the specified value associated
|
||||
// with an open key k. It fills up buffer buf and returns the retrieved
|
||||
// byte count n. If buf is too small to fit the stored value it returns
|
||||
// ErrShortBuffer error along with the required buffer size n.
|
||||
// If no buffer is provided, it returns true and actual buffer size n.
|
||||
// If no buffer is provided, GetValue returns the value's type only.
|
||||
// If the value does not exist, the error returned is ErrNotExist.
|
||||
//
|
||||
// GetValue is a low level function. If value's type is known, use the appropriate
|
||||
// Get*Value function instead.
|
||||
func (k Key) GetValue(name string, buf []byte) (n int, valtype uint32, err error) {
|
||||
pname, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
var pbuf *byte
|
||||
if len(buf) > 0 {
|
||||
pbuf = (*byte)(unsafe.Pointer(&buf[0]))
|
||||
}
|
||||
l := uint32(len(buf))
|
||||
err = syscall.RegQueryValueEx(syscall.Handle(k), pname, nil, &valtype, pbuf, &l)
|
||||
if err != nil {
|
||||
return int(l), valtype, err
|
||||
}
|
||||
return int(l), valtype, nil
|
||||
}
|
||||
|
||||
func (k Key) getValue(name string, buf []byte) (data []byte, valtype uint32, err error) {
|
||||
p, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var t uint32
|
||||
n := uint32(len(buf))
|
||||
for {
|
||||
err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n)
|
||||
if err == nil {
|
||||
return buf[:n], t, nil
|
||||
}
|
||||
if err != syscall.ERROR_MORE_DATA {
|
||||
return nil, 0, err
|
||||
}
|
||||
if n <= uint32(len(buf)) {
|
||||
return nil, 0, err
|
||||
}
|
||||
buf = make([]byte, n)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStringValue retrieves the string value for the specified
|
||||
// value name associated with an open key k. It also returns the value's type.
|
||||
// If value does not exist, GetStringValue returns ErrNotExist.
|
||||
// If value is not SZ or EXPAND_SZ, it will return the correct value
|
||||
// type and ErrUnexpectedType.
|
||||
func (k Key) GetStringValue(name string) (val string, valtype uint32, err error) {
|
||||
data, typ, err2 := k.getValue(name, make([]byte, 64))
|
||||
if err2 != nil {
|
||||
return "", typ, err2
|
||||
}
|
||||
switch typ {
|
||||
case SZ, EXPAND_SZ:
|
||||
default:
|
||||
return "", typ, ErrUnexpectedType
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return "", typ, nil
|
||||
}
|
||||
u := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[: len(data)/2 : len(data)/2]
|
||||
return syscall.UTF16ToString(u), typ, nil
|
||||
}
|
||||
|
||||
// GetMUIStringValue retrieves the localized string value for
|
||||
// the specified value name associated with an open key k.
|
||||
// If the value name doesn't exist or the localized string value
|
||||
// can't be resolved, GetMUIStringValue returns ErrNotExist.
|
||||
// GetMUIStringValue panics if the system doesn't support
|
||||
// regLoadMUIString; use LoadRegLoadMUIString to check if
|
||||
// regLoadMUIString is supported before calling this function.
|
||||
func (k Key) GetMUIStringValue(name string) (string, error) {
|
||||
pname, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
buf := make([]uint16, 1024)
|
||||
var buflen uint32
|
||||
var pdir *uint16
|
||||
|
||||
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
|
||||
if err == syscall.ERROR_FILE_NOT_FOUND { // Try fallback path
|
||||
|
||||
// Try to resolve the string value using the system directory as
|
||||
// a DLL search path; this assumes the string value is of the form
|
||||
// @[path]\dllname,-strID but with no path given, e.g. @tzres.dll,-320.
|
||||
|
||||
// This approach works with tzres.dll but may have to be revised
|
||||
// in the future to allow callers to provide custom search paths.
|
||||
|
||||
var s string
|
||||
s, err = ExpandString("%SystemRoot%\\system32\\")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pdir, err = syscall.UTF16PtrFromString(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
|
||||
}
|
||||
|
||||
for err == syscall.ERROR_MORE_DATA { // Grow buffer if needed
|
||||
if buflen <= uint32(len(buf)) {
|
||||
break // Buffer not growing, assume race; break
|
||||
}
|
||||
buf = make([]uint16, buflen)
|
||||
err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return syscall.UTF16ToString(buf), nil
|
||||
}
|
||||
|
||||
// ExpandString expands environment-variable strings and replaces
|
||||
// them with the values defined for the current user.
|
||||
// Use ExpandString to expand EXPAND_SZ strings.
|
||||
func ExpandString(value string) (string, error) {
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
p, err := syscall.UTF16PtrFromString(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
r := make([]uint16, 100)
|
||||
for {
|
||||
n, err := expandEnvironmentStrings(p, &r[0], uint32(len(r)))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n <= uint32(len(r)) {
|
||||
return syscall.UTF16ToString(r[:n]), nil
|
||||
}
|
||||
r = make([]uint16, n)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStringsValue retrieves the []string value for the specified
|
||||
// value name associated with an open key k. It also returns the value's type.
|
||||
// If value does not exist, GetStringsValue returns ErrNotExist.
|
||||
// If value is not MULTI_SZ, it will return the correct value
|
||||
// type and ErrUnexpectedType.
|
||||
func (k Key) GetStringsValue(name string) (val []string, valtype uint32, err error) {
|
||||
data, typ, err2 := k.getValue(name, make([]byte, 64))
|
||||
if err2 != nil {
|
||||
return nil, typ, err2
|
||||
}
|
||||
if typ != MULTI_SZ {
|
||||
return nil, typ, ErrUnexpectedType
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, typ, nil
|
||||
}
|
||||
p := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[: len(data)/2 : len(data)/2]
|
||||
if len(p) == 0 {
|
||||
return nil, typ, nil
|
||||
}
|
||||
if p[len(p)-1] == 0 {
|
||||
p = p[:len(p)-1] // remove terminating null
|
||||
}
|
||||
val = make([]string, 0, 5)
|
||||
from := 0
|
||||
for i, c := range p {
|
||||
if c == 0 {
|
||||
val = append(val, string(utf16.Decode(p[from:i])))
|
||||
from = i + 1
|
||||
}
|
||||
}
|
||||
return val, typ, nil
|
||||
}
|
||||
|
||||
// GetIntegerValue retrieves the integer value for the specified
|
||||
// value name associated with an open key k. It also returns the value's type.
|
||||
// If value does not exist, GetIntegerValue returns ErrNotExist.
|
||||
// If value is not DWORD or QWORD, it will return the correct value
|
||||
// type and ErrUnexpectedType.
|
||||
func (k Key) GetIntegerValue(name string) (val uint64, valtype uint32, err error) {
|
||||
data, typ, err2 := k.getValue(name, make([]byte, 8))
|
||||
if err2 != nil {
|
||||
return 0, typ, err2
|
||||
}
|
||||
switch typ {
|
||||
case DWORD:
|
||||
if len(data) != 4 {
|
||||
return 0, typ, errors.New("DWORD value is not 4 bytes long")
|
||||
}
|
||||
var val32 uint32
|
||||
copy((*[4]byte)(unsafe.Pointer(&val32))[:], data)
|
||||
return uint64(val32), DWORD, nil
|
||||
case QWORD:
|
||||
if len(data) != 8 {
|
||||
return 0, typ, errors.New("QWORD value is not 8 bytes long")
|
||||
}
|
||||
copy((*[8]byte)(unsafe.Pointer(&val))[:], data)
|
||||
return val, QWORD, nil
|
||||
default:
|
||||
return 0, typ, ErrUnexpectedType
|
||||
}
|
||||
}
|
||||
|
||||
// GetBinaryValue retrieves the binary value for the specified
|
||||
// value name associated with an open key k. It also returns the value's type.
|
||||
// If value does not exist, GetBinaryValue returns ErrNotExist.
|
||||
// If value is not BINARY, it will return the correct value
|
||||
// type and ErrUnexpectedType.
|
||||
func (k Key) GetBinaryValue(name string) (val []byte, valtype uint32, err error) {
|
||||
data, typ, err2 := k.getValue(name, make([]byte, 64))
|
||||
if err2 != nil {
|
||||
return nil, typ, err2
|
||||
}
|
||||
if typ != BINARY {
|
||||
return nil, typ, ErrUnexpectedType
|
||||
}
|
||||
return data, typ, nil
|
||||
}
|
||||
|
||||
func (k Key) setValue(name string, valtype uint32, data []byte) error {
|
||||
p, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return regSetValueEx(syscall.Handle(k), p, 0, valtype, nil, 0)
|
||||
}
|
||||
return regSetValueEx(syscall.Handle(k), p, 0, valtype, &data[0], uint32(len(data)))
|
||||
}
|
||||
|
||||
// SetDWordValue sets the data and type of a name value
|
||||
// under key k to value and DWORD.
|
||||
func (k Key) SetDWordValue(name string, value uint32) error {
|
||||
return k.setValue(name, DWORD, (*[4]byte)(unsafe.Pointer(&value))[:])
|
||||
}
|
||||
|
||||
// SetQWordValue sets the data and type of a name value
|
||||
// under key k to value and QWORD.
|
||||
func (k Key) SetQWordValue(name string, value uint64) error {
|
||||
return k.setValue(name, QWORD, (*[8]byte)(unsafe.Pointer(&value))[:])
|
||||
}
|
||||
|
||||
func (k Key) setStringValue(name string, valtype uint32, value string) error {
|
||||
v, err := syscall.UTF16FromString(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[: len(v)*2 : len(v)*2]
|
||||
return k.setValue(name, valtype, buf)
|
||||
}
|
||||
|
||||
// SetStringValue sets the data and type of a name value
|
||||
// under key k to value and SZ. The value must not contain a zero byte.
|
||||
func (k Key) SetStringValue(name, value string) error {
|
||||
return k.setStringValue(name, SZ, value)
|
||||
}
|
||||
|
||||
// SetExpandStringValue sets the data and type of a name value
|
||||
// under key k to value and EXPAND_SZ. The value must not contain a zero byte.
|
||||
func (k Key) SetExpandStringValue(name, value string) error {
|
||||
return k.setStringValue(name, EXPAND_SZ, value)
|
||||
}
|
||||
|
||||
// SetStringsValue sets the data and type of a name value
|
||||
// under key k to value and MULTI_SZ. The value strings
|
||||
// must not contain a zero byte.
|
||||
func (k Key) SetStringsValue(name string, value []string) error {
|
||||
ss := ""
|
||||
for _, s := range value {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == 0 {
|
||||
return errors.New("string cannot have 0 inside")
|
||||
}
|
||||
}
|
||||
ss += s + "\x00"
|
||||
}
|
||||
v := utf16.Encode([]rune(ss + "\x00"))
|
||||
buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[: len(v)*2 : len(v)*2]
|
||||
return k.setValue(name, MULTI_SZ, buf)
|
||||
}
|
||||
|
||||
// SetBinaryValue sets the data and type of a name value
|
||||
// under key k to value and BINARY.
|
||||
func (k Key) SetBinaryValue(name string, value []byte) error {
|
||||
return k.setValue(name, BINARY, value)
|
||||
}
|
||||
|
||||
// DeleteValue removes a named value from the key k.
|
||||
func (k Key) DeleteValue(name string) error {
|
||||
return regDeleteValue(syscall.Handle(k), syscall.StringToUTF16Ptr(name))
|
||||
}
|
||||
|
||||
// ReadValueNames returns the value names of key k.
|
||||
// The parameter n controls the number of returned names,
|
||||
// analogous to the way os.File.Readdirnames works.
|
||||
func (k Key) ReadValueNames(n int) ([]string, error) {
|
||||
ki, err := k.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
names := make([]string, 0, ki.ValueCount)
|
||||
buf := make([]uint16, ki.MaxValueNameLen+1) // extra room for terminating null character
|
||||
loopItems:
|
||||
for i := uint32(0); ; i++ {
|
||||
if n > 0 {
|
||||
if len(names) == n {
|
||||
return names, nil
|
||||
}
|
||||
}
|
||||
l := uint32(len(buf))
|
||||
for {
|
||||
err := regEnumValue(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err == syscall.ERROR_MORE_DATA {
|
||||
// Double buffer size and try again.
|
||||
l = uint32(2 * len(buf))
|
||||
buf = make([]uint16, l)
|
||||
continue
|
||||
}
|
||||
if err == _ERROR_NO_MORE_ITEMS {
|
||||
break loopItems
|
||||
}
|
||||
return names, err
|
||||
}
|
||||
names = append(names, syscall.UTF16ToString(buf[:l]))
|
||||
}
|
||||
if n > len(names) {
|
||||
return names, io.EOF
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package registry
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return nil
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
|
||||
procRegCreateKeyExW = modadvapi32.NewProc("RegCreateKeyExW")
|
||||
procRegDeleteKeyW = modadvapi32.NewProc("RegDeleteKeyW")
|
||||
procRegSetValueExW = modadvapi32.NewProc("RegSetValueExW")
|
||||
procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW")
|
||||
procRegDeleteValueW = modadvapi32.NewProc("RegDeleteValueW")
|
||||
procRegLoadMUIStringW = modadvapi32.NewProc("RegLoadMUIStringW")
|
||||
procRegConnectRegistryW = modadvapi32.NewProc("RegConnectRegistryW")
|
||||
procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue")
|
||||
procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW")
|
||||
)
|
||||
|
||||
func regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall9(procRegCreateKeyExW.Addr(), 9, uintptr(key), uintptr(unsafe.Pointer(subkey)), uintptr(reserved), uintptr(unsafe.Pointer(class)), uintptr(options), uintptr(desired), uintptr(unsafe.Pointer(sa)), uintptr(unsafe.Pointer(result)), uintptr(unsafe.Pointer(disposition)))
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall(procRegDeleteKeyW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(subkey)), 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall6(procRegSetValueExW.Addr(), 6, uintptr(key), uintptr(unsafe.Pointer(valueName)), uintptr(reserved), uintptr(vtype), uintptr(unsafe.Pointer(buf)), uintptr(bufsize))
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valtype)), uintptr(unsafe.Pointer(buf)), uintptr(unsafe.Pointer(buflen)), 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall(procRegDeleteValueW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(name)), 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall9(procRegLoadMUIStringW.Addr(), 7, uintptr(key), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(unsafe.Pointer(buflenCopied)), uintptr(flags), uintptr(unsafe.Pointer(dir)), 0, 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regConnectRegistry(machinename *uint16, key syscall.Handle, result *syscall.Handle) (regerrno error) {
|
||||
r0, _, _ := syscall.Syscall(procRegConnectRegistryW.Addr(), 3, uintptr(unsafe.Pointer(machinename)), uintptr(key), uintptr(unsafe.Pointer(result)))
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func regNotifyChangeKeyValue(key syscall.Handle, watchSubtree bool, notifyFilter uint32, event syscall.Handle, async bool) (regerrno error) {
|
||||
var _p0 uint32
|
||||
if watchSubtree {
|
||||
_p0 = 1
|
||||
} else {
|
||||
_p0 = 0
|
||||
}
|
||||
var _p1 uint32
|
||||
if async {
|
||||
_p1 = 1
|
||||
} else {
|
||||
_p1 = 0
|
||||
}
|
||||
r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0)
|
||||
if r0 != 0 {
|
||||
regerrno = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procExpandEnvironmentStringsW.Addr(), 3, uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(size))
|
||||
n = uint32(r0)
|
||||
if n == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
119
tstime/tstime.go
119
tstime/tstime.go
@@ -8,109 +8,119 @@ package tstime
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go4.org/mem"
|
||||
)
|
||||
|
||||
var memZ = mem.S("Z")
|
||||
|
||||
// zoneOf returns the RFC3339 zone suffix (either "Z" or like
|
||||
// "+08:30"), or the empty string if it's invalid or not something we
|
||||
// want to cache.
|
||||
func zoneOf(s string) string {
|
||||
if strings.HasSuffix(s, "Z") {
|
||||
return "Z"
|
||||
func zoneOf(s mem.RO) mem.RO {
|
||||
if mem.HasSuffix(s, memZ) {
|
||||
return memZ
|
||||
}
|
||||
if len(s) < len("2020-04-05T15:56:00+08:00") {
|
||||
if s.Len() < len("2020-04-05T15:56:00+08:00") {
|
||||
// Too short, invalid? Let time.Parse fail on it.
|
||||
return ""
|
||||
return mem.S("")
|
||||
}
|
||||
zone := s[len(s)-len("+08:00"):]
|
||||
if c := zone[0]; c == '+' || c == '-' {
|
||||
min := zone[len("+08:"):]
|
||||
switch min {
|
||||
case "00", "15", "30":
|
||||
zone := s.SliceFrom(s.Len() - len("+08:00"))
|
||||
if c := zone.At(0); c == '+' || c == '-' {
|
||||
min := zone.SliceFrom(len("+08:"))
|
||||
if min.EqualString("00") || min.EqualString("15") || min.EqualString("30") {
|
||||
return zone
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return mem.S("")
|
||||
}
|
||||
|
||||
// locCache maps from zone offset suffix string ("+08:00") =>
|
||||
// *time.Location (from FixedLocation).
|
||||
// locCache maps from hash of zone offset suffix string ("+08:00") =>
|
||||
// {zone string, *time.Location (from FixedLocation)}.
|
||||
var locCache sync.Map
|
||||
|
||||
func getLocation(zone, timeValue string) (*time.Location, error) {
|
||||
if zone == "Z" {
|
||||
type locCacheEntry struct {
|
||||
zone string
|
||||
loc *time.Location
|
||||
}
|
||||
|
||||
func getLocation(zone, timeValue mem.RO) (*time.Location, error) {
|
||||
if zone.EqualString("Z") {
|
||||
return time.UTC, nil
|
||||
}
|
||||
if loci, ok := locCache.Load(zone); ok {
|
||||
return loci.(*time.Location), nil
|
||||
key := zone.MapHash()
|
||||
if entry, ok := locCache.Load(key); ok {
|
||||
// We're keying only on a hash; double-check zone to ensure no spurious collisions.
|
||||
e := entry.(locCacheEntry)
|
||||
if zone.EqualString(e.zone) {
|
||||
return e.loc, nil
|
||||
}
|
||||
}
|
||||
// TODO(bradfitz): just parse it and call time.FixedLocation.
|
||||
// For now, just have time.Parse do it once:
|
||||
t, err := time.Parse(time.RFC3339Nano, timeValue)
|
||||
t, err := time.Parse(time.RFC3339Nano, timeValue.StringCopy())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loc := t.Location()
|
||||
locCache.LoadOrStore(zone, loc)
|
||||
locCache.LoadOrStore(key, locCacheEntry{zone: zone.StringCopy(), loc: loc})
|
||||
return loc, nil
|
||||
|
||||
}
|
||||
|
||||
// Parse3339 is a wrapper around time.Parse(time.RFC3339Nano, s) that caches
|
||||
// timezone Locations for future parses.
|
||||
func Parse3339(s string) (time.Time, error) {
|
||||
func parse3339m(s mem.RO) (time.Time, error) {
|
||||
zone := zoneOf(s)
|
||||
if zone == "" {
|
||||
if zone.Len() == 0 {
|
||||
// Invalid or weird timezone offset. Use slow path,
|
||||
// which'll probably return an error.
|
||||
return time.Parse(time.RFC3339Nano, s)
|
||||
return time.Parse(time.RFC3339Nano, s.StringCopy())
|
||||
}
|
||||
loc, err := getLocation(zone, s)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
s = s[:len(s)-len(zone)] // remove zone suffix
|
||||
s = s.SliceTo(s.Len() - zone.Len()) // remove zone suffix
|
||||
var year, mon, day, hr, min, sec, nsec int
|
||||
const baseLen = len("2020-04-05T15:56:00")
|
||||
if len(s) < baseLen ||
|
||||
!parseInt(s[:4], &year) ||
|
||||
s[4] != '-' ||
|
||||
!parseInt(s[5:7], &mon) ||
|
||||
s[7] != '-' ||
|
||||
!parseInt(s[8:10], &day) ||
|
||||
s[10] != 'T' ||
|
||||
!parseInt(s[11:13], &hr) ||
|
||||
s[13] != ':' ||
|
||||
!parseInt(s[14:16], &min) ||
|
||||
s[16] != ':' ||
|
||||
!parseInt(s[17:19], &sec) {
|
||||
if s.Len() < baseLen ||
|
||||
!parseInt(s.SliceTo(4), &year) ||
|
||||
s.At(4) != '-' ||
|
||||
!parseInt(s.Slice(5, 7), &mon) ||
|
||||
s.At(7) != '-' ||
|
||||
!parseInt(s.Slice(8, 10), &day) ||
|
||||
s.At(10) != 'T' ||
|
||||
!parseInt(s.Slice(11, 13), &hr) ||
|
||||
s.At(13) != ':' ||
|
||||
!parseInt(s.Slice(14, 16), &min) ||
|
||||
s.At(16) != ':' ||
|
||||
!parseInt(s.Slice(17, 19), &sec) {
|
||||
return time.Time{}, errors.New("invalid time")
|
||||
}
|
||||
nsStr := s[baseLen:]
|
||||
if nsStr != "" {
|
||||
if nsStr[0] != '.' {
|
||||
nsStr := s.SliceFrom(baseLen)
|
||||
if nsStr.Len() != 0 {
|
||||
if nsStr.At(0) != '.' {
|
||||
return time.Time{}, errors.New("invalid optional nanosecond prefix")
|
||||
}
|
||||
if !parseInt(nsStr[1:], &nsec) {
|
||||
return time.Time{}, fmt.Errorf("invalid optional nanosecond number %q", nsStr[1:])
|
||||
nsStr = nsStr.SliceFrom(1)
|
||||
if !parseInt(nsStr, &nsec) {
|
||||
return time.Time{}, fmt.Errorf("invalid optional nanosecond number %q", nsStr.StringCopy())
|
||||
}
|
||||
for i := 0; i < len("999999999")-(len(nsStr)-1); i++ {
|
||||
for i := 0; i < len("999999999")-nsStr.Len(); i++ {
|
||||
nsec *= 10
|
||||
}
|
||||
}
|
||||
return time.Date(year, time.Month(mon), day, hr, min, sec, nsec, loc), nil
|
||||
}
|
||||
|
||||
func parseInt(s string, dst *int) bool {
|
||||
if len(s) == 0 || len(s) > len("999999999") {
|
||||
func parseInt(s mem.RO, dst *int) bool {
|
||||
if s.Len() == 0 || s.Len() > len("999999999") {
|
||||
*dst = 0
|
||||
return false
|
||||
}
|
||||
n := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
d := s[i] - '0'
|
||||
for i := 0; i < s.Len(); i++ {
|
||||
d := s.At(i) - '0'
|
||||
if d > 9 {
|
||||
*dst = 0
|
||||
return false
|
||||
@@ -120,3 +130,14 @@ func parseInt(s string, dst *int) bool {
|
||||
*dst = n
|
||||
return true
|
||||
}
|
||||
|
||||
// Parse3339 is a wrapper around time.Parse(time.RFC3339Nano, s) that caches
|
||||
// timezone Locations for future parses.
|
||||
func Parse3339(s string) (time.Time, error) {
|
||||
return parse3339m(mem.S(s))
|
||||
}
|
||||
|
||||
// Parse3339B is Parse3339 but for byte slices.
|
||||
func Parse3339B(b []byte) (time.Time, error) {
|
||||
return parse3339m(mem.B(b))
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ package tstime
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go4.org/mem"
|
||||
)
|
||||
|
||||
func TestParse3339(t *testing.T) {
|
||||
@@ -70,8 +72,8 @@ func TestZoneOf(t *testing.T) {
|
||||
{"+08:00", ""}, // too short
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := zoneOf(tt.in); got != tt.want {
|
||||
t.Errorf("zoneOf(%q) = %q; want %q", tt.in, got, tt.want)
|
||||
if got := zoneOf(mem.S(tt.in)); !got.EqualString(tt.want) {
|
||||
t.Errorf("zoneOf(%q) = %q; want %q", tt.in, got.StringCopy(), tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,7 +95,7 @@ func TestParseInt(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
var got int
|
||||
gotRet := parseInt(tt.in, &got)
|
||||
gotRet := parseInt(mem.S(tt.in), &got)
|
||||
if gotRet != tt.ret || got != tt.want {
|
||||
t.Errorf("parseInt(%q) = %v, %d; want %v, %d", tt.in, gotRet, got, tt.ret, tt.want)
|
||||
}
|
||||
@@ -182,6 +184,6 @@ func BenchmarkParse3339(b *testing.B) {
|
||||
func BenchmarkParseInt(b *testing.B) {
|
||||
var out int
|
||||
for i := 0; i < b.N; i++ {
|
||||
parseInt("148487491", &out)
|
||||
parseInt(mem.S("148487491"), &out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type response struct {
|
||||
// JSONHandlerFunc is an HTTP ReturnHandler that writes JSON responses to the client.
|
||||
//
|
||||
// Return a HTTPError to show an error message, otherwise JSONHandlerFunc will
|
||||
// only report "internal server error" to the user.
|
||||
// only report "internal server error" to the user with status code 500.
|
||||
type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err error)
|
||||
|
||||
// ServeHTTPReturn implements the ReturnHandler interface.
|
||||
@@ -31,23 +31,12 @@ type JSONHandlerFunc func(r *http.Request) (status int, data interface{}, err er
|
||||
// return http.StatusBadRequest, nil, err
|
||||
// }
|
||||
//
|
||||
// See jsonhandler_text.go for examples.
|
||||
// See jsonhandler_test.go for examples.
|
||||
func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var resp *response
|
||||
status, data, err := fn(r)
|
||||
if status == 0 {
|
||||
status = http.StatusInternalServerError
|
||||
resp = &response{
|
||||
Status: "error",
|
||||
Error: "internal server error",
|
||||
}
|
||||
} else if err == nil {
|
||||
resp = &response{
|
||||
Status: "success",
|
||||
Data: data,
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
if werr, ok := err.(HTTPError); ok {
|
||||
resp = &response{
|
||||
Status: "error",
|
||||
@@ -61,12 +50,29 @@ func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request
|
||||
if werr.Msg != "" {
|
||||
err = fmt.Errorf("%s: %w", werr.Msg, err)
|
||||
}
|
||||
// take status from the HTTPError to encourage error handling in one location
|
||||
if status != 0 && status != werr.Code {
|
||||
err = fmt.Errorf("[unexpected] non-zero status that does not match HTTPError status, status: %d, HTTPError.code: %d: %w", status, werr.Code, err)
|
||||
}
|
||||
status = werr.Code
|
||||
} else {
|
||||
status = http.StatusInternalServerError
|
||||
resp = &response{
|
||||
Status: "error",
|
||||
Error: "internal server error",
|
||||
}
|
||||
}
|
||||
} else if status == 0 {
|
||||
status = http.StatusInternalServerError
|
||||
resp = &response{
|
||||
Status: "error",
|
||||
Error: "internal server error",
|
||||
}
|
||||
} else if err == nil {
|
||||
resp = &response{
|
||||
Status: "success",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
b, jerr := json.Marshal(resp)
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
type Data struct {
|
||||
@@ -40,11 +42,11 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
if d.Status == status {
|
||||
t.Logf("ok: %s", d.Status)
|
||||
} else {
|
||||
t.Fatalf("wrong status: %s %s", d.Status, status)
|
||||
t.Fatalf("wrong status: got: %s, want: %s", d.Status, status)
|
||||
}
|
||||
|
||||
if w.Code != code {
|
||||
t.Fatalf("wrong status code: %d %d", w.Code, code)
|
||||
t.Fatalf("wrong status code: got: %d, want: %d", w.Code, code)
|
||||
}
|
||||
|
||||
if w.Header().Get("Content-Type") != "application/json" {
|
||||
@@ -67,7 +69,7 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
|
||||
t.Run("403 HTTPError", func(t *testing.T) {
|
||||
h := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
return http.StatusForbidden, nil, fmt.Errorf("forbidden")
|
||||
return 0, nil, Error(http.StatusForbidden, "forbidden", nil)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@@ -90,11 +92,11 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
h31 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
body := new(Data)
|
||||
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||
return http.StatusBadRequest, nil, err
|
||||
return 0, nil, Error(http.StatusBadRequest, err.Error(), err)
|
||||
}
|
||||
|
||||
if body.Name == "" {
|
||||
return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "name is empty", nil)
|
||||
return 0, nil, Error(http.StatusBadRequest, "name is empty", nil)
|
||||
}
|
||||
|
||||
return http.StatusOK, nil, nil
|
||||
@@ -126,13 +128,13 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
h32 := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
body := new(Data)
|
||||
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
||||
return http.StatusBadRequest, nil, err
|
||||
return 0, nil, Error(http.StatusBadRequest, err.Error(), err)
|
||||
}
|
||||
if body.Name == "root" {
|
||||
return http.StatusInternalServerError, nil, fmt.Errorf("invalid name")
|
||||
return 0, nil, fmt.Errorf("invalid name")
|
||||
}
|
||||
if body.Price == 0 {
|
||||
return http.StatusBadRequest, nil, Error(http.StatusBadGateway, "price is empty", nil)
|
||||
return 0, nil, Error(http.StatusBadRequest, "price is empty", nil)
|
||||
}
|
||||
|
||||
return http.StatusOK, &Data{Price: body.Price * 2}, nil
|
||||
@@ -159,7 +161,7 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("500 internal server error", func(t *testing.T) {
|
||||
t.Run("500 internal server error (unspecified error, not of type HTTPError)", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "root"}`))
|
||||
h32.ServeHTTPReturn(w, r)
|
||||
@@ -189,4 +191,41 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
}).ServeHTTPReturn(w, r)
|
||||
checkStatus(w, "error", http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError agree", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", nil)
|
||||
JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
return http.StatusForbidden, nil, Error(http.StatusForbidden, "403 forbidden", nil)
|
||||
}).ServeHTTPReturn(w, r)
|
||||
want := &Response{
|
||||
Status: "error",
|
||||
Data: &Data{},
|
||||
Error: "403 forbidden",
|
||||
}
|
||||
got := checkStatus(w, "error", http.StatusForbidden)
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Fatalf(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("403 forbidden, status returned by JSONHandlerFunc and HTTPError do not agree", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", nil)
|
||||
err := JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
return http.StatusInternalServerError, nil, Error(http.StatusForbidden, "403 forbidden", nil)
|
||||
}).ServeHTTPReturn(w, r)
|
||||
if !strings.HasPrefix(err.Error(), "[unexpected]") {
|
||||
t.Fatalf("returned error should have `[unexpected]` to note the disagreeing status codes: %v", err)
|
||||
}
|
||||
want := &Response{
|
||||
Status: "error",
|
||||
Data: &Data{},
|
||||
Error: "403 forbidden",
|
||||
}
|
||||
got := checkStatus(w, "error", http.StatusForbidden)
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Fatalf("(-want,+got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func rusageMaxRSS() float64 {
|
||||
}
|
||||
|
||||
rss := float64(ru.Maxrss)
|
||||
if runtime.GOOS == "darwin" {
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
||||
rss /= 1 << 20 // ru_maxrss is bytes on darwin
|
||||
} else {
|
||||
// ru_maxrss is kilobytes elsewhere (linux, openbsd, etc)
|
||||
|
||||
15
util/endian/big.go
Normal file
15
util/endian/big.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build mips mips64 ppc64 s390x
|
||||
|
||||
package endian
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// Big is whether the current platform is big endian.
|
||||
const Big = true
|
||||
|
||||
// Native is the platform's native byte order.
|
||||
var Native = binary.BigEndian
|
||||
6
util/endian/endian.go
Normal file
6
util/endian/endian.go
Normal file
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package endian exports a constant about whether the machine is big endian.
|
||||
package endian
|
||||
15
util/endian/little.go
Normal file
15
util/endian/little.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build 386 amd64 arm arm64 mips64le mipsle ppc64le riscv64 wasm
|
||||
|
||||
package endian
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// Big is whether the current platform is big endian.
|
||||
const Big = false
|
||||
|
||||
// Native is the platform's native byte order.
|
||||
var Native = binary.LittleEndian
|
||||
17
util/jsonutil/types.go
Normal file
17
util/jsonutil/types.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package jsonutil
|
||||
|
||||
// Bytes is a byte slice in a json-encoded struct.
|
||||
// encoding/json assumes that []byte fields are hex-encoded.
|
||||
// Bytes are not hex-encoded; they are treated the same as strings.
|
||||
// This can avoid unnecessary allocations due to a round trip through strings.
|
||||
type Bytes []byte
|
||||
|
||||
func (b *Bytes) UnmarshalText(text []byte) error {
|
||||
// Copy the contexts of text.
|
||||
*b = append(*b, text...)
|
||||
return nil
|
||||
}
|
||||
90
util/jsonutil/unmarshal.go
Normal file
90
util/jsonutil/unmarshal.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package jsonutil provides utilities to improve JSON performance.
|
||||
// It includes an Unmarshal wrapper that amortizes allocated garbage over subsequent runs
|
||||
// and a Bytes type to reduce allocations when unmarshalling a non-hex-encoded string into a []byte.
|
||||
package jsonutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// decoder is a re-usable json decoder.
|
||||
type decoder struct {
|
||||
dec *json.Decoder
|
||||
r *bytes.Reader
|
||||
}
|
||||
|
||||
var readerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bytes.NewReader(nil)
|
||||
},
|
||||
}
|
||||
|
||||
var decoderPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
var d decoder
|
||||
d.r = readerPool.Get().(*bytes.Reader)
|
||||
d.dec = json.NewDecoder(d.r)
|
||||
return &d
|
||||
},
|
||||
}
|
||||
|
||||
// Unmarshal is similar to encoding/json.Unmarshal.
|
||||
// There are three major differences:
|
||||
//
|
||||
// On error, encoding/json.Unmarshal zeros v.
|
||||
// This Unmarshal may leave partial data in v.
|
||||
// Always check the error before using v!
|
||||
// (Future improvements may remove this bug.)
|
||||
//
|
||||
// The errors they return don't always match perfectly.
|
||||
// If you do error matching more precise than err != nil,
|
||||
// don't use this Unmarshal.
|
||||
//
|
||||
// This Unmarshal allocates considerably less memory.
|
||||
func Unmarshal(b []byte, v interface{}) error {
|
||||
d := decoderPool.Get().(*decoder)
|
||||
d.r.Reset(b)
|
||||
off := d.dec.InputOffset()
|
||||
err := d.dec.Decode(v)
|
||||
d.r.Reset(nil) // don't keep a reference to b
|
||||
// In case of error, report the offset in this byte slice,
|
||||
// instead of in the totality of all bytes this decoder has processed.
|
||||
// It is not possible to make all errors match json.Unmarshal exactly,
|
||||
// but we can at least try.
|
||||
switch jsonerr := err.(type) {
|
||||
case *json.SyntaxError:
|
||||
jsonerr.Offset -= off
|
||||
case *json.UnmarshalTypeError:
|
||||
jsonerr.Offset -= off
|
||||
case nil:
|
||||
// json.Unmarshal fails if there's any extra junk in the input.
|
||||
// json.Decoder does not; see https://github.com/golang/go/issues/36225.
|
||||
// We need to check for anything left over in the buffer.
|
||||
if d.dec.More() {
|
||||
// TODO: Provide a better error message.
|
||||
// Unfortunately, we can't set the msg field.
|
||||
// The offset doesn't perfectly match json:
|
||||
// Ours is at the end of the valid data,
|
||||
// and theirs is at the beginning of the extra data after whitespace.
|
||||
// Close enough, though.
|
||||
err = &json.SyntaxError{Offset: d.dec.InputOffset() - off}
|
||||
|
||||
// TODO: zero v. This is hard; see encoding/json.indirect.
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
decoderPool.Put(d)
|
||||
} else {
|
||||
// There might be junk left in the decoder's buffer.
|
||||
// There's no way to flush it, no Reset method.
|
||||
// Abandoned the decoder but reuse the reader.
|
||||
readerPool.Put(d.r)
|
||||
}
|
||||
return err
|
||||
}
|
||||
65
util/jsonutil/unmarshal_test.go
Normal file
65
util/jsonutil/unmarshal_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package jsonutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompareToStd(t *testing.T) {
|
||||
tests := []string{
|
||||
`{}`,
|
||||
`{"a": 1}`,
|
||||
`{]`,
|
||||
`"abc"`,
|
||||
`5`,
|
||||
`{"a": 1} `,
|
||||
`{"a": 1} {}`,
|
||||
`{} bad data`,
|
||||
`{"a": 1} "hello"`,
|
||||
`[]`,
|
||||
` {"x": {"t": [3,4,5]}}`,
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
b := []byte(test)
|
||||
var ourV, stdV interface{}
|
||||
ourErr := Unmarshal(b, &ourV)
|
||||
stdErr := json.Unmarshal(b, &stdV)
|
||||
if (ourErr == nil) != (stdErr == nil) {
|
||||
t.Errorf("Unmarshal(%q): our err = %#[2]v (%[2]T), std err = %#[3]v (%[3]T)", test, ourErr, stdErr)
|
||||
}
|
||||
// if !reflect.DeepEqual(ourErr, stdErr) {
|
||||
// t.Logf("Unmarshal(%q): our err = %#[2]v (%[2]T), std err = %#[3]v (%[3]T)", test, ourErr, stdErr)
|
||||
// }
|
||||
if ourErr != nil {
|
||||
// TODO: if we zero ourV on error, remove this continue.
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(ourV, stdV) {
|
||||
t.Errorf("Unmarshal(%q): our val = %v, std val = %v", test, ourV, stdV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmarshal(b *testing.B) {
|
||||
var m interface{}
|
||||
j := []byte("5")
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
Unmarshal(j, &m)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStdUnmarshal(b *testing.B) {
|
||||
var m interface{}
|
||||
j := []byte("5")
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
json.Unmarshal(j, &m)
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,11 @@ func File(name string, fn func(line []byte) error) error {
|
||||
return Reader(f, fn)
|
||||
}
|
||||
|
||||
// Reader calls fn for each line.
|
||||
// If fn returns an error, Reader stops reading and returns that error.
|
||||
// Reader may also return errors encountered reading and parsing from r.
|
||||
// To stop reading early, use a sentinel "stop" error value and ignore
|
||||
// it when returned from Reader.
|
||||
func Reader(r io.Reader, fn func(line []byte) error) error {
|
||||
bs := bufio.NewScanner(r)
|
||||
for bs.Scan() {
|
||||
|
||||
9
util/racebuild/off.go
Normal file
9
util/racebuild/off.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !race
|
||||
|
||||
package racebuild
|
||||
|
||||
const On = false
|
||||
9
util/racebuild/on.go
Normal file
9
util/racebuild/on.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build race
|
||||
|
||||
package racebuild
|
||||
|
||||
const On = true
|
||||
7
util/racebuild/racebuild.go
Normal file
7
util/racebuild/racebuild.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package racebuild exports a constant about whether the current binary
|
||||
// was built with the race detector.
|
||||
package racebuild
|
||||
@@ -70,12 +70,12 @@ func TestMkversion(t *testing.T) {
|
||||
VERSION_XCODE="101.15.129"
|
||||
VERSION_WINRES="1,15,129,0"`},
|
||||
{"abcdef", "", 1, 2, 0, 17, `
|
||||
VERSION_SHORT="0.0.0"
|
||||
VERSION_LONG="0.0.0-tabcdef"
|
||||
VERSION_SHORT="1.2.0"
|
||||
VERSION_LONG="1.2.0-17-tabcdef"
|
||||
VERSION_GIT_HASH="abcdef"
|
||||
VERSION_EXTRA_HASH=""
|
||||
VERSION_XCODE="100.0.0"
|
||||
VERSION_WINRES="0,0,0,0"`},
|
||||
VERSION_XCODE="101.2.0"
|
||||
VERSION_WINRES="1,2,0,0"`},
|
||||
{"abcdef", "defghi", 1, 15, 0, 129, `
|
||||
VERSION_SHORT="1.15.129"
|
||||
VERSION_LONG="1.15.129-tabcdef-gdefghi"
|
||||
|
||||
@@ -10,17 +10,16 @@ import "runtime"
|
||||
func IsMobile() bool {
|
||||
// Good enough heuristic for now, at least until Apple makes
|
||||
// ARM laptops...
|
||||
return runtime.GOOS == "android" ||
|
||||
(runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64"))
|
||||
return runtime.GOOS == "android" || isIOS
|
||||
}
|
||||
|
||||
// OS returns runtime.GOOS, except instead of returning "darwin" it
|
||||
// returns "iOS" or "macOS".
|
||||
func OS() string {
|
||||
if isIOS {
|
||||
return "iOS"
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
if IsMobile() {
|
||||
return "iOS"
|
||||
}
|
||||
return "macOS"
|
||||
}
|
||||
return runtime.GOOS
|
||||
|
||||
9
version/prop_ios.go
Normal file
9
version/prop_ios.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build ios
|
||||
|
||||
package version
|
||||
|
||||
const isIOS = true
|
||||
9
version/prop_notios.go
Normal file
9
version/prop_notios.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !ios
|
||||
|
||||
package version
|
||||
|
||||
const isIOS = false
|
||||
@@ -10,7 +10,7 @@ package version
|
||||
// Long is a full version number for this build, of the form
|
||||
// "x.y.z-commithash", or "date.yyyymmdd" if no actual version was
|
||||
// provided.
|
||||
const Long = "date.20200921"
|
||||
const Long = "date.20201107"
|
||||
|
||||
// Short is a short version number for this build, of the form
|
||||
// "x.y.z", or "date.yyyymmdd" if no actual version was provided.
|
||||
|
||||
@@ -1,22 +1,43 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -eu
|
||||
|
||||
# Return the commitid of the given ref in the given repo dir. If the worktree
|
||||
# or index is dirty, also appends -dirty.
|
||||
#
|
||||
# $ git_hash_dirty ../.. HEAD
|
||||
# 1be01ddc6e430ca3aa9beea3587d16750efb3241-dirty
|
||||
git_hash_dirty() {
|
||||
(
|
||||
cd "$1"
|
||||
x=$(git rev-parse HEAD)
|
||||
if ! git diff-index --quiet HEAD; then
|
||||
x="$x-dirty"
|
||||
fi
|
||||
echo "$x"
|
||||
)
|
||||
}
|
||||
|
||||
case $# in
|
||||
0|1)
|
||||
# extra_hash describes a git repository other than the current
|
||||
# one. It gets embedded as an additional commit hash in built
|
||||
# extra_hash_or_dir is either:
|
||||
# - a git commitid
|
||||
# or
|
||||
# - the path to a git repo from which to calculate the real hash.
|
||||
#
|
||||
# It gets embedded as an additional commit hash in built
|
||||
# binaries, to help us locate the exact set of tools and code
|
||||
# that were used.
|
||||
extra_hash="${1:-}"
|
||||
if [ -z "$extra_hash" ]; then
|
||||
extra_hash_or_dir="${1:-}"
|
||||
if [ -z "$extra_hash_or_dir" ]; then
|
||||
# Nothing, empty extra hash is fine.
|
||||
extra_hash=""
|
||||
elif [ -d "$extra_hash/.git" ]; then
|
||||
extra_hash=$(cd "$extra_hash" && git describe --always --dirty --exclude '*' --abbrev=200)
|
||||
elif ! expr "$extra_hash" : "^[0-9a-f]*$"; then
|
||||
echo "Invalid extra hash '$extra_hash', must be a git commit hash or path to a git repo" >&2
|
||||
elif [ -d "$extra_hash_or_dir/.git" ]; then
|
||||
extra_hash=$(git_hash_dirty "$extra_hash_or_dir" HEAD)
|
||||
elif ! expr "$extra_hash_or_dir" : "^[0-9a-f]*$"; then
|
||||
echo "Invalid extra hash '$extra_hash_or_dir', must be a git commit or path to a git repo" >&2
|
||||
exit 1
|
||||
else
|
||||
extra_hash="$extra_hash_or_dir"
|
||||
fi
|
||||
|
||||
# Load the base version and optional corresponding git hash
|
||||
@@ -25,15 +46,12 @@ case $# in
|
||||
version_file="$(dirname $0)/../VERSION.txt"
|
||||
IFS=".$IFS" read -r major minor patch base_git_hash <"$version_file"
|
||||
if [ -z "$base_git_hash" ]; then
|
||||
base_git_hash=$(git rev-list --max-count=1 HEAD -- $version_file)
|
||||
base_git_hash=$(git rev-list --max-count=1 HEAD -- "$version_file")
|
||||
fi
|
||||
|
||||
# The full git has we're currently building at. --abbrev=200 is an
|
||||
# arbitrary large number larger than all currently-known hashes, so
|
||||
# that git displays the full commit hash.
|
||||
git_hash=$(git describe --always --dirty --exclude '*' --abbrev=200)
|
||||
git_hash=$(git_hash_dirty . HEAD)
|
||||
# The number of extra commits between the release base to git_hash.
|
||||
change_count=$(git rev-list ${base_git_hash}..HEAD | wc -l)
|
||||
change_count=$(git rev-list --count HEAD "^$base_git_hash")
|
||||
;;
|
||||
6)
|
||||
# Test mode: rather than run git commands and whatnot, take in
|
||||
@@ -46,14 +64,14 @@ case $# in
|
||||
change_count=$6
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 [extra-git-hash-or-checkout]"
|
||||
echo "Usage: $0 [extra-git-commitid-or-dir]"
|
||||
exit 1
|
||||
esac
|
||||
|
||||
# Shortened versions of git hashes, so that they fit neatly into an
|
||||
# "elongated" but still human-readable version number.
|
||||
short_git_hash=$(echo $git_hash | cut -c-9)
|
||||
short_extra_hash=$(echo $extra_hash | cut -c-9)
|
||||
short_git_hash=$(echo "$git_hash" | cut -c1-9)
|
||||
short_extra_hash=$(echo "$extra_hash" | cut -c1-9)
|
||||
|
||||
# Convert major/minor/patch/change_count into an adjusted
|
||||
# major/minor/patch. This block is where all our policies on
|
||||
@@ -62,25 +80,28 @@ if expr "$minor" : "[0-9]*[13579]$" >/dev/null; then
|
||||
# Odd minor numbers are unstable builds.
|
||||
if [ "$patch" != "0" ]; then
|
||||
# This is a fatal error, because a non-zero patch number
|
||||
# indicates that we created an unstable git tag in violation
|
||||
# indicates that we created an unstable VERSION.txt in violation
|
||||
# of our versioning policy, and we want to blow up loudly to
|
||||
# get that fixed.
|
||||
echo "Unstable release $major.$minor.$patch has a non-zero patch number, which is not allowed" >&2
|
||||
exit 1
|
||||
fi
|
||||
patch="$change_count"
|
||||
change_suffix=""
|
||||
elif [ "$change_count" != "0" ]; then
|
||||
# Even minor numbers are stable builds, but stable builds are
|
||||
# supposed to have a zero change count. Therefore, we're currently
|
||||
# describing a commit that's on a release branch, but hasn't been
|
||||
# tagged as a patch release yet. We allow these commits to build
|
||||
# for testing purposes, but force their version number to 0.0.0,
|
||||
# to reflect that they're an unreleasable build. The git hashes
|
||||
# still completely describe the build commit, so we can still
|
||||
# figure out what this build is if it escapes into the wild.
|
||||
major="0"
|
||||
minor="0"
|
||||
patch="0"
|
||||
# tagged as a patch release yet.
|
||||
#
|
||||
# We used to change the version number to 0.0.0 in that case, but that
|
||||
# caused some features to get disabled due to the low version number.
|
||||
# Instead, add yet another suffix to the version number, with a change
|
||||
# count.
|
||||
change_suffix="-$change_count"
|
||||
else
|
||||
# Even minor number with no extra changes.
|
||||
change_suffix=""
|
||||
fi
|
||||
|
||||
# Hack for 1.1: add 1000 to the patch number. We switched from using
|
||||
@@ -95,15 +116,15 @@ fi
|
||||
# policies. All that remains is to output the various vars that other
|
||||
# code can use to embed version data.
|
||||
if [ -z "$extra_hash" ]; then
|
||||
long_version_suffix="-t$short_git_hash"
|
||||
long_version_suffix="$change_suffix-t$short_git_hash"
|
||||
else
|
||||
long_version_suffix="-t${short_git_hash}-g${short_extra_hash}"
|
||||
long_version_suffix="$change_suffix-t$short_git_hash-g$short_extra_hash"
|
||||
fi
|
||||
cat <<EOF
|
||||
VERSION_SHORT="${major}.${minor}.${patch}"
|
||||
VERSION_LONG="${major}.${minor}.${patch}${long_version_suffix}"
|
||||
VERSION_GIT_HASH="${git_hash}"
|
||||
VERSION_EXTRA_HASH="${extra_hash}"
|
||||
VERSION_XCODE="$((major + 100)).${minor}.${patch}"
|
||||
VERSION_WINRES="${major},${minor},${patch},0"
|
||||
VERSION_SHORT="$major.$minor.$patch"
|
||||
VERSION_LONG="$major.$minor.$patch$long_version_suffix"
|
||||
VERSION_GIT_HASH="$git_hash"
|
||||
VERSION_EXTRA_HASH="$extra_hash"
|
||||
VERSION_XCODE="$((major + 100)).$minor.$patch"
|
||||
VERSION_WINRES="$major,$minor,$patch,0"
|
||||
EOF
|
||||
|
||||
@@ -2,55 +2,80 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package filter contains a stateful packet filter.
|
||||
// Package filter is a stateful packet filter.
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/groupcache/lru"
|
||||
"golang.org/x/time/rate"
|
||||
"tailscale.com/tailcfg"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
type filterState struct {
|
||||
mu sync.Mutex
|
||||
lru *lru.Cache // of tuple
|
||||
}
|
||||
|
||||
// Filter is a stateful packet filter.
|
||||
type Filter struct {
|
||||
logf logger.Logf
|
||||
// localNets is the list of IP prefixes that we know to be "local"
|
||||
// to this node. All packets coming in over tailscale must have a
|
||||
// destination within localNets, regardless of the policy filter
|
||||
// below. A nil localNets rejects all incoming traffic.
|
||||
localNets []Net
|
||||
// matches is a list of match->action rules applied to all packets
|
||||
// arriving over tailscale tunnels. Matches are checked in order,
|
||||
// and processing stops at the first matching rule. The default
|
||||
// policy if no rules match is to drop the packet.
|
||||
matches Matches
|
||||
// local4 and local6 are the lists of IP prefixes that we know
|
||||
// to be "local" to this node. All packets coming in over
|
||||
// tailscale must have a destination within local4 or local6,
|
||||
// regardless of the policy filter below. Zero values reject
|
||||
// all incoming traffic.
|
||||
local4 []net4
|
||||
local6 []net6
|
||||
// matches4 and matches6 are lists of match->action rules
|
||||
// applied to all packets arriving over tailscale
|
||||
// tunnels. Matches are checked in order, and processing stops
|
||||
// at the first matching rule. The default policy if no rules
|
||||
// match is to drop the packet.
|
||||
matches4 matches4
|
||||
matches6 matches6
|
||||
// state is the connection tracking state attached to this
|
||||
// filter. It is used to allow incoming traffic that is a response
|
||||
// to an outbound connection that this node made, even if those
|
||||
// incoming packets don't get accepted by matches above.
|
||||
state *filterState
|
||||
state4 *filterState
|
||||
state6 *filterState
|
||||
}
|
||||
|
||||
// Response is a verdict: either a Drop, Accept, or noVerdict skip to
|
||||
// continue processing.
|
||||
// tuple4 is a 4-tuple of source and destination IPv4 and port. It's
|
||||
// used as a lookup key in filterState.
|
||||
type tuple4 struct {
|
||||
SrcIP packet.IP4
|
||||
DstIP packet.IP4
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// tuple6 is a 4-tuple of source and destination IPv6 and port. It's
|
||||
// used as a lookup key in filterState.
|
||||
type tuple6 struct {
|
||||
SrcIP packet.IP6
|
||||
DstIP packet.IP6
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// filterState is a state cache of past seen packets.
|
||||
type filterState struct {
|
||||
mu sync.Mutex
|
||||
lru *lru.Cache // of tuple4 or tuple6
|
||||
}
|
||||
|
||||
// lruMax is the size of the LRU cache in filterState.
|
||||
const lruMax = 512
|
||||
|
||||
// Response is a verdict from the packet filter.
|
||||
type Response int
|
||||
|
||||
const (
|
||||
Drop Response = iota
|
||||
Accept
|
||||
noVerdict // Returned from subfilters to continue processing.
|
||||
Drop Response = iota // do not continue processing packet.
|
||||
Accept // continue processing packet.
|
||||
noVerdict // no verdict yet, continue running filter
|
||||
)
|
||||
|
||||
func (r Response) String() string {
|
||||
@@ -70,30 +95,46 @@ func (r Response) String() string {
|
||||
type RunFlags int
|
||||
|
||||
const (
|
||||
LogDrops RunFlags = 1 << iota
|
||||
LogAccepts
|
||||
HexdumpDrops
|
||||
HexdumpAccepts
|
||||
LogDrops RunFlags = 1 << iota // write dropped packet info to logf
|
||||
LogAccepts // write accepted packet info to logf
|
||||
HexdumpDrops // print packet hexdump when logging drops
|
||||
HexdumpAccepts // print packet hexdump when logging accepts
|
||||
)
|
||||
|
||||
type tuple struct {
|
||||
SrcIP packet.IP
|
||||
DstIP packet.IP
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
// NewAllowAllForTest returns a packet filter that accepts
|
||||
// everything. Use in tests only, as it permits some kinds of spoofing
|
||||
// attacks to reach the OS network stack.
|
||||
func NewAllowAllForTest(logf logger.Logf) *Filter {
|
||||
any4 := netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}
|
||||
any6 := netaddr.IPPrefix{IP: netaddr.IPFrom16([16]byte{}), Bits: 0}
|
||||
ms := []Match{
|
||||
{
|
||||
Srcs: []netaddr.IPPrefix{any4},
|
||||
Dsts: []NetPortRange{
|
||||
{
|
||||
Net: any4,
|
||||
Ports: PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Srcs: []netaddr.IPPrefix{any6},
|
||||
Dsts: []NetPortRange{
|
||||
{
|
||||
Net: any6,
|
||||
Ports: PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const lruMax = 512 // max entries in UDP LRU cache
|
||||
|
||||
// MatchAllowAll matches all packets.
|
||||
var MatchAllowAll = Matches{
|
||||
Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}},
|
||||
}
|
||||
|
||||
// NewAllowAll returns a packet filter that accepts everything to and
|
||||
// from localNets.
|
||||
func NewAllowAll(localNets []Net, logf logger.Logf) *Filter {
|
||||
return New(MatchAllowAll, localNets, nil, logf)
|
||||
return New(ms, []netaddr.IPPrefix{any4, any6}, nil, logf)
|
||||
}
|
||||
|
||||
// NewAllowNone returns a packet filter that rejects everything.
|
||||
@@ -104,22 +145,29 @@ func NewAllowNone(logf logger.Logf) *Filter {
|
||||
// New creates a new packet filter. The filter enforces that incoming
|
||||
// packets must be destined to an IP in localNets, and must be allowed
|
||||
// by matches. If shareStateWith is non-nil, the returned filter
|
||||
// shares state with the previous one, to enable rules to be changed
|
||||
// at runtime without breaking existing flows.
|
||||
func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.Logf) *Filter {
|
||||
var state *filterState
|
||||
// shares state with the previous one, to enable changing rules at
|
||||
// runtime without breaking existing stateful flows.
|
||||
func New(matches []Match, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter {
|
||||
var state4, state6 *filterState
|
||||
if shareStateWith != nil {
|
||||
state = shareStateWith.state
|
||||
state4 = shareStateWith.state4
|
||||
state6 = shareStateWith.state6
|
||||
} else {
|
||||
state = &filterState{
|
||||
state4 = &filterState{
|
||||
lru: lru.New(lruMax),
|
||||
}
|
||||
state6 = &filterState{
|
||||
lru: lru.New(lruMax),
|
||||
}
|
||||
}
|
||||
f := &Filter{
|
||||
logf: logf,
|
||||
matches: matches,
|
||||
localNets: localNets,
|
||||
state: state,
|
||||
logf: logf,
|
||||
matches4: newMatches4(matches),
|
||||
matches6: newMatches6(matches),
|
||||
local4: nets4FromIPPrefixes(localNets),
|
||||
local6: nets6FromIPPrefixes(localNets),
|
||||
state4: state4,
|
||||
state6: state6,
|
||||
}
|
||||
return f
|
||||
}
|
||||
@@ -131,79 +179,6 @@ func maybeHexdump(flag RunFlags, b []byte) string {
|
||||
return packet.Hexdump(b) + "\n"
|
||||
}
|
||||
|
||||
// MatchesFromFilterRules parse a number of wire-format FilterRule values into
|
||||
// the Matches format.
|
||||
// If an error is returned, the Matches result is still valid, containing the rules that
|
||||
// were successfully converted.
|
||||
func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) {
|
||||
mm := make([]Match, 0, len(pf))
|
||||
var erracc error
|
||||
|
||||
for _, r := range pf {
|
||||
m := Match{}
|
||||
|
||||
for i, s := range r.SrcIPs {
|
||||
bits := 32
|
||||
if len(r.SrcBits) > i {
|
||||
bits = r.SrcBits[i]
|
||||
}
|
||||
net, err := parseIP(s, bits)
|
||||
if err != nil && erracc == nil {
|
||||
erracc = err
|
||||
continue
|
||||
}
|
||||
m.Srcs = append(m.Srcs, net)
|
||||
}
|
||||
|
||||
for _, d := range r.DstPorts {
|
||||
bits := 32
|
||||
if d.Bits != nil {
|
||||
bits = *d.Bits
|
||||
}
|
||||
net, err := parseIP(d.IP, bits)
|
||||
if err != nil && erracc == nil {
|
||||
erracc = err
|
||||
continue
|
||||
}
|
||||
m.Dsts = append(m.Dsts, NetPortRange{
|
||||
Net: net,
|
||||
Ports: PortRange{
|
||||
First: d.Ports.First,
|
||||
Last: d.Ports.Last,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
mm = append(mm, m)
|
||||
}
|
||||
return mm, erracc
|
||||
}
|
||||
|
||||
func parseIP(host string, defaultBits int) (Net, error) {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.IsUnspecified() {
|
||||
// For clarity, reject 0.0.0.0 as an input
|
||||
return NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
|
||||
} else if ip == nil && host == "*" {
|
||||
// User explicitly requested wildcard dst ip
|
||||
return NetAny, nil
|
||||
} else {
|
||||
if ip != nil {
|
||||
ip = ip.To4()
|
||||
}
|
||||
if ip == nil || len(ip) != 4 {
|
||||
return NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
|
||||
}
|
||||
if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) {
|
||||
return NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
|
||||
}
|
||||
return Net{
|
||||
IP: NewIP(ip),
|
||||
Mask: Netmask(defaultBits),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
|
||||
// Logging is a quick way to record every newly opened TCP connection, but
|
||||
// we have to be cautious about flooding the logs vs letting people use
|
||||
@@ -212,7 +187,7 @@ func parseIP(host string, defaultBits int) (Net, error) {
|
||||
var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3)
|
||||
var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10)
|
||||
|
||||
func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, dir direction, r Response, why string) {
|
||||
func (f *Filter) logRateLimit(runflags RunFlags, q *packet.Parsed, dir direction, r Response, why string) {
|
||||
var verdict string
|
||||
|
||||
if r == Drop && omitDropLogging(q, dir) {
|
||||
@@ -235,8 +210,45 @@ func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, dir dir
|
||||
}
|
||||
}
|
||||
|
||||
// RunIn determines whether this node is allowed to receive q from a Tailscale peer.
|
||||
func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
// dummyPacket is a 20-byte slice of garbage, to pass the filter
|
||||
// pre-check when evaluating synthesized packets.
|
||||
var dummyPacket = []byte{
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
}
|
||||
|
||||
// CheckTCP determines whether TCP traffic from srcIP to dstIP:dstPort
|
||||
// is allowed.
|
||||
func (f *Filter) CheckTCP(srcIP, dstIP netaddr.IP, dstPort uint16) Response {
|
||||
pkt := &packet.Parsed{}
|
||||
pkt.Decode(dummyPacket) // initialize private fields
|
||||
switch {
|
||||
case (srcIP.Is4() && dstIP.Is6()) || (srcIP.Is6() && srcIP.Is4()):
|
||||
// Mistmatched address families, no filters will
|
||||
// match.
|
||||
return Drop
|
||||
case srcIP.Is4():
|
||||
pkt.IPVersion = 4
|
||||
pkt.SrcIP4 = packet.IP4FromNetaddr(srcIP)
|
||||
pkt.DstIP4 = packet.IP4FromNetaddr(dstIP)
|
||||
case srcIP.Is6():
|
||||
pkt.IPVersion = 6
|
||||
pkt.SrcIP6 = packet.IP6FromNetaddr(srcIP)
|
||||
pkt.DstIP6 = packet.IP6FromNetaddr(dstIP)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
pkt.IPProto = packet.TCP
|
||||
pkt.TCPFlags = packet.TCPSyn
|
||||
pkt.SrcPort = 0
|
||||
pkt.DstPort = dstPort
|
||||
|
||||
return f.RunIn(pkt, 0)
|
||||
}
|
||||
|
||||
// RunIn determines whether this node is allowed to receive q from a
|
||||
// Tailscale peer.
|
||||
func (f *Filter) RunIn(q *packet.Parsed, rf RunFlags) Response {
|
||||
dir := in
|
||||
r := f.pre(q, rf, dir)
|
||||
if r == Accept || r == Drop {
|
||||
@@ -244,13 +256,22 @@ func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
return r
|
||||
}
|
||||
|
||||
r, why := f.runIn(q)
|
||||
var why string
|
||||
switch q.IPVersion {
|
||||
case 4:
|
||||
r, why = f.runIn4(q)
|
||||
case 6:
|
||||
r, why = f.runIn6(q)
|
||||
default:
|
||||
r, why = Drop, "not-ip"
|
||||
}
|
||||
f.logRateLimit(rf, q, dir, r, why)
|
||||
return r
|
||||
}
|
||||
|
||||
// RunOut determines whether this node is allowed to send q to a Tailscale peer.
|
||||
func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
// RunOut determines whether this node is allowed to send q to a
|
||||
// Tailscale peer.
|
||||
func (f *Filter) RunOut(q *packet.Parsed, rf RunFlags) Response {
|
||||
dir := out
|
||||
r := f.pre(q, rf, dir)
|
||||
if r == Drop || r == Accept {
|
||||
@@ -262,21 +283,16 @@ func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
return r
|
||||
}
|
||||
|
||||
func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
||||
func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
|
||||
// A compromised peer could try to send us packets for
|
||||
// destinations we didn't explicitly advertise. This check is to
|
||||
// prevent that.
|
||||
if !ipInList(q.DstIP, f.localNets) {
|
||||
if !ip4InList(q.DstIP4, f.local4) {
|
||||
return Drop, "destination not allowed"
|
||||
}
|
||||
|
||||
if q.IPVersion == 6 {
|
||||
// TODO: support IPv6.
|
||||
return Drop, "no rules matched"
|
||||
}
|
||||
|
||||
switch q.IPProto {
|
||||
case packet.ICMP:
|
||||
case packet.ICMPv4:
|
||||
if q.IsEchoResponse() || q.IsError() {
|
||||
// ICMP responses are allowed.
|
||||
// TODO(apenwarr): consider using conntrack state.
|
||||
@@ -284,7 +300,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
||||
// related to an existing ICMP-Echo, TCP, or UDP
|
||||
// session.
|
||||
return Accept, "icmp response ok"
|
||||
} else if matchIPWithoutPorts(f.matches, q) {
|
||||
} else if f.matches4.matchIPsOnly(q) {
|
||||
// If any port is open to an IP, allow ICMP to it.
|
||||
return Accept, "icmp ok"
|
||||
}
|
||||
@@ -300,20 +316,20 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
||||
if q.IPProto == packet.TCP && !q.IsTCPSyn() {
|
||||
return Accept, "tcp non-syn"
|
||||
}
|
||||
if matchIPPorts(f.matches, q) {
|
||||
if f.matches4.match(q) {
|
||||
return Accept, "tcp ok"
|
||||
}
|
||||
case packet.UDP:
|
||||
t := tuple{q.SrcIP, q.DstIP, q.SrcPort, q.DstPort}
|
||||
t := tuple4{q.SrcIP4, q.DstIP4, q.SrcPort, q.DstPort}
|
||||
|
||||
f.state.mu.Lock()
|
||||
_, ok := f.state.lru.Get(t)
|
||||
f.state.mu.Unlock()
|
||||
f.state4.mu.Lock()
|
||||
_, ok := f.state4.lru.Get(t)
|
||||
f.state4.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
return Accept, "udp cached"
|
||||
}
|
||||
if matchIPPorts(f.matches, q) {
|
||||
if f.matches4.match(q) {
|
||||
return Accept, "udp ok"
|
||||
}
|
||||
default:
|
||||
@@ -322,24 +338,91 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
||||
return Drop, "no rules matched"
|
||||
}
|
||||
|
||||
func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) {
|
||||
if q.IPProto == packet.UDP {
|
||||
t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort}
|
||||
var ti interface{} = t // allocate once, rather than twice inside mutex
|
||||
func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
|
||||
// A compromised peer could try to send us packets for
|
||||
// destinations we didn't explicitly advertise. This check is to
|
||||
// prevent that.
|
||||
if !ip6InList(q.DstIP6, f.local6) {
|
||||
return Drop, "destination not allowed"
|
||||
}
|
||||
|
||||
f.state.mu.Lock()
|
||||
f.state.lru.Add(ti, ti)
|
||||
f.state.mu.Unlock()
|
||||
switch q.IPProto {
|
||||
case packet.ICMPv6:
|
||||
if q.IsEchoResponse() || q.IsError() {
|
||||
// ICMP responses are allowed.
|
||||
// TODO(apenwarr): consider using conntrack state.
|
||||
// We could choose to reject all packets that aren't
|
||||
// related to an existing ICMP-Echo, TCP, or UDP
|
||||
// session.
|
||||
return Accept, "icmp response ok"
|
||||
} else if f.matches6.matchIPsOnly(q) {
|
||||
// If any port is open to an IP, allow ICMP to it.
|
||||
return Accept, "icmp ok"
|
||||
}
|
||||
case packet.TCP:
|
||||
// For TCP, we want to allow *outgoing* connections,
|
||||
// which means we want to allow return packets on those
|
||||
// connections. To make this restriction work, we need to
|
||||
// allow non-SYN packets (continuation of an existing session)
|
||||
// to arrive. This should be okay since a new incoming session
|
||||
// can't be initiated without first sending a SYN.
|
||||
// It happens to also be much faster.
|
||||
// TODO(apenwarr): Skip the rest of decoding in this path?
|
||||
if q.IPProto == packet.TCP && !q.IsTCPSyn() {
|
||||
return Accept, "tcp non-syn"
|
||||
}
|
||||
if f.matches6.match(q) {
|
||||
return Accept, "tcp ok"
|
||||
}
|
||||
case packet.UDP:
|
||||
t := tuple6{q.SrcIP6, q.DstIP6, q.SrcPort, q.DstPort}
|
||||
|
||||
f.state6.mu.Lock()
|
||||
_, ok := f.state6.lru.Get(t)
|
||||
f.state6.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
return Accept, "udp cached"
|
||||
}
|
||||
if f.matches6.match(q) {
|
||||
return Accept, "udp ok"
|
||||
}
|
||||
default:
|
||||
return Drop, "Unknown proto"
|
||||
}
|
||||
return Drop, "no rules matched"
|
||||
}
|
||||
|
||||
// runIn runs the output-specific part of the filter logic.
|
||||
func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) {
|
||||
if q.IPProto != packet.UDP {
|
||||
return Accept, "ok out"
|
||||
}
|
||||
|
||||
switch q.IPVersion {
|
||||
case 4:
|
||||
t := tuple4{q.DstIP4, q.SrcIP4, q.DstPort, q.SrcPort}
|
||||
var ti interface{} = t // allocate once, rather than twice inside mutex
|
||||
f.state4.mu.Lock()
|
||||
f.state4.lru.Add(ti, ti)
|
||||
f.state4.mu.Unlock()
|
||||
case 6:
|
||||
t := tuple6{q.DstIP6, q.SrcIP6, q.DstPort, q.SrcPort}
|
||||
var ti interface{} = t // allocate once, rather than twice inside mutex
|
||||
f.state6.mu.Lock()
|
||||
f.state6.lru.Add(ti, ti)
|
||||
f.state6.mu.Unlock()
|
||||
}
|
||||
return Accept, "ok out"
|
||||
}
|
||||
|
||||
// direction is whether a packet was flowing in to this machine, or flowing out.
|
||||
// direction is whether a packet was flowing in to this machine, or
|
||||
// flowing out.
|
||||
type direction int
|
||||
|
||||
const (
|
||||
in direction = iota
|
||||
out
|
||||
in direction = iota // from Tailscale peer to local machine
|
||||
out // from local machine to Tailscale peer
|
||||
)
|
||||
|
||||
func (d direction) String() string {
|
||||
@@ -353,7 +436,9 @@ func (d direction) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Response {
|
||||
// pre runs the direction-agnostic filter logic. dir is only used for
|
||||
// logging.
|
||||
func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response {
|
||||
if len(q.Buffer()) == 0 {
|
||||
// wireguard keepalive packet, always permit.
|
||||
return Accept
|
||||
@@ -363,17 +448,25 @@ func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Respons
|
||||
return Drop
|
||||
}
|
||||
|
||||
if q.IPVersion == 6 {
|
||||
f.logRateLimit(rf, q, dir, Drop, "ipv6")
|
||||
return Drop
|
||||
}
|
||||
if q.DstIP.IsMulticast() {
|
||||
f.logRateLimit(rf, q, dir, Drop, "multicast")
|
||||
return Drop
|
||||
}
|
||||
if q.DstIP.IsLinkLocalUnicast() {
|
||||
f.logRateLimit(rf, q, dir, Drop, "link-local-unicast")
|
||||
return Drop
|
||||
switch q.IPVersion {
|
||||
case 4:
|
||||
if q.DstIP4.IsMulticast() {
|
||||
f.logRateLimit(rf, q, dir, Drop, "multicast")
|
||||
return Drop
|
||||
}
|
||||
if q.DstIP4.IsMostLinkLocalUnicast() {
|
||||
f.logRateLimit(rf, q, dir, Drop, "link-local-unicast")
|
||||
return Drop
|
||||
}
|
||||
case 6:
|
||||
if q.DstIP6.IsMulticast() {
|
||||
f.logRateLimit(rf, q, dir, Drop, "multicast")
|
||||
return Drop
|
||||
}
|
||||
if q.DstIP6.IsLinkLocalUnicast() {
|
||||
f.logRateLimit(rf, q, dir, Drop, "link-local-unicast")
|
||||
return Drop
|
||||
}
|
||||
}
|
||||
|
||||
switch q.IPProto {
|
||||
@@ -383,7 +476,7 @@ func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Respons
|
||||
return Drop
|
||||
case packet.Fragment:
|
||||
// Fragments after the first always need to be passed through.
|
||||
// Very small fragments are considered Junk by ParsedPacket.
|
||||
// Very small fragments are considered Junk by Parsed.
|
||||
f.logRateLimit(rf, q, dir, Accept, "fragment")
|
||||
return Accept
|
||||
}
|
||||
@@ -391,61 +484,21 @@ func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Respons
|
||||
return noVerdict
|
||||
}
|
||||
|
||||
const (
|
||||
// ipv6AllRoutersLinkLocal is ff02::2 (All link-local routers)
|
||||
ipv6AllRoutersLinkLocal = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
|
||||
// ipv6AllMLDv2CapableRouters is ff02::16 (All MLDv2-capable routers)
|
||||
ipv6AllMLDv2CapableRouters = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x16"
|
||||
)
|
||||
|
||||
// omitDropLogging reports whether packet p, which has already been
|
||||
// deemded a packet to Drop, should bypass the [rate-limited] logging.
|
||||
// We don't want to log scary & spammy reject warnings for packets that
|
||||
// are totally normal, like IPv6 route announcements.
|
||||
func omitDropLogging(p *packet.ParsedPacket, dir direction) bool {
|
||||
b := p.Buffer()
|
||||
switch dir {
|
||||
case out:
|
||||
switch p.IPVersion {
|
||||
case 4:
|
||||
// ParsedPacket.Decode zeros out ParsedPacket.IPProtocol for protocols
|
||||
// it doesn't know about, so parse it out ourselves if needed.
|
||||
ipProto := p.IPProto
|
||||
if ipProto == 0 && len(b) > 8 {
|
||||
ipProto = packet.IPProto(b[9])
|
||||
}
|
||||
// Omit logging about outgoing IGMP.
|
||||
if ipProto == packet.IGMP {
|
||||
return true
|
||||
}
|
||||
if p.DstIP.IsMulticast() || p.DstIP.IsLinkLocalUnicast() {
|
||||
return true
|
||||
}
|
||||
case 6:
|
||||
if len(b) < 40 {
|
||||
return false
|
||||
}
|
||||
src, dst := b[8:8+16], b[24:24+16]
|
||||
// Omit logging for outgoing IPv6 ICMP-v6 queries to ff02::2,
|
||||
// as sent by the OS, looking for routers.
|
||||
if p.IPProto == packet.ICMPv6 {
|
||||
if isLinkLocalV6(src) && string(dst) == ipv6AllRoutersLinkLocal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if string(dst) == ipv6AllMLDv2CapableRouters {
|
||||
return true
|
||||
}
|
||||
// Actually, just catch all multicast.
|
||||
if dst[0] == 0xff {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// deemed a packet to Drop, should bypass the [rate-limited] logging.
|
||||
// We don't want to log scary & spammy reject warnings for packets
|
||||
// that are totally normal, like IPv6 route announcements.
|
||||
func omitDropLogging(p *packet.Parsed, dir direction) bool {
|
||||
if dir != out {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isLinkLocalV6 reports whether src is in fe80::/10.
|
||||
func isLinkLocalV6(src []byte) bool {
|
||||
return len(src) == 16 && src[0] == 0xfe && src[1]>>6 == 0x80>>6
|
||||
switch p.IPVersion {
|
||||
case 4:
|
||||
return p.DstIP4.IsMulticast() || p.DstIP4.IsMostLinkLocalUnicast() || p.IPProto == packet.IGMP
|
||||
case 6:
|
||||
return p.DstIP6.IsMulticast() || p.DstIP6.IsLinkLocalUnicast()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,153 +5,184 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
// Type aliases only in test code: (but ideally nowhere)
|
||||
type ParsedPacket = packet.ParsedPacket
|
||||
type IP = packet.IP
|
||||
|
||||
var Unknown = packet.Unknown
|
||||
var ICMP = packet.ICMP
|
||||
var TCP = packet.TCP
|
||||
var UDP = packet.UDP
|
||||
var Fragment = packet.Fragment
|
||||
|
||||
func nets(ips []IP) []Net {
|
||||
out := make([]Net, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
out = append(out, Net{ip, Netmask(32)})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ippr(ip IP, start, end uint16) []NetPortRange {
|
||||
return []NetPortRange{
|
||||
NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}},
|
||||
}
|
||||
}
|
||||
|
||||
func netpr(ip IP, bits int, start, end uint16) []NetPortRange {
|
||||
return []NetPortRange{
|
||||
NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}},
|
||||
}
|
||||
}
|
||||
|
||||
var matches = Matches{
|
||||
{Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{
|
||||
NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}},
|
||||
NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}},
|
||||
}},
|
||||
{Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)},
|
||||
{Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)},
|
||||
{Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
|
||||
{Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
|
||||
{Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
|
||||
}
|
||||
|
||||
func newFilter(logf logger.Logf) *Filter {
|
||||
matches := []Match{
|
||||
{Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")},
|
||||
{Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")},
|
||||
{Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")},
|
||||
{Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")},
|
||||
{Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")},
|
||||
{Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")},
|
||||
{Srcs: nets("::1", "::2"), Dsts: netports("2001::1:22", "2001::2:22")},
|
||||
{Srcs: nets("::/0"), Dsts: netports("::/0:443")},
|
||||
}
|
||||
|
||||
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
||||
// 102.102.102.102, 119.119.119.119, 8.1.0.0/16
|
||||
localNets := nets([]IP{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777})
|
||||
localNets = append(localNets, Net{IP(0x08010000), Netmask(16)})
|
||||
localNets := nets("100.122.98.50", "1.2.3.4", "5.6.7.8", "102.102.102.102", "119.119.119.119", "8.1.0.0/16", "2001::/16")
|
||||
|
||||
return New(matches, localNets, nil, logf)
|
||||
}
|
||||
|
||||
func TestMarshal(t *testing.T) {
|
||||
for _, ent := range []Matches{Matches{matches[0]}, matches} {
|
||||
b, err := json.Marshal(ent)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
|
||||
mm2 := Matches{}
|
||||
if err := json.Unmarshal(b, &mm2); err != nil {
|
||||
t.Fatalf("unmarshal: %v (%v)", err, string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
acl := newFilter(t.Logf)
|
||||
// check packet filtering based on the table
|
||||
|
||||
type InOut struct {
|
||||
want Response
|
||||
p ParsedPacket
|
||||
p packet.Parsed
|
||||
}
|
||||
tests := []InOut{
|
||||
// Basic
|
||||
{Accept, parsed(TCP, 0x08010101, 0x01020304, 999, 22)},
|
||||
{Accept, parsed(UDP, 0x08010101, 0x01020304, 999, 22)},
|
||||
{Accept, parsed(ICMP, 0x08010101, 0x01020304, 0, 0)},
|
||||
{Drop, parsed(TCP, 0x08010101, 0x01020304, 0, 0)},
|
||||
{Accept, parsed(TCP, 0x08010101, 0x01020304, 0, 22)},
|
||||
{Drop, parsed(TCP, 0x08010101, 0x01020304, 0, 21)},
|
||||
{Accept, parsed(TCP, 0x11223344, 0x08012233, 0, 443)},
|
||||
{Drop, parsed(TCP, 0x11223344, 0x08012233, 0, 444)},
|
||||
{Accept, parsed(TCP, 0x11223344, 0x647a6232, 0, 999)},
|
||||
{Accept, parsed(TCP, 0x11223344, 0x647a6232, 0, 0)},
|
||||
// allow 8.1.1.1 => 1.2.3.4:22
|
||||
{Accept, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22)},
|
||||
{Accept, parsed(packet.ICMPv4, "8.1.1.1", "1.2.3.4", 0, 0)},
|
||||
{Drop, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 0)},
|
||||
{Accept, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 22)},
|
||||
{Drop, parsed(packet.TCP, "8.1.1.1", "1.2.3.4", 0, 21)},
|
||||
// allow 8.2.2.2. => 1.2.3.4:22
|
||||
{Accept, parsed(packet.TCP, "8.2.2.2", "1.2.3.4", 0, 22)},
|
||||
{Drop, parsed(packet.TCP, "8.2.2.2", "1.2.3.4", 0, 23)},
|
||||
{Drop, parsed(packet.TCP, "8.3.3.3", "1.2.3.4", 0, 22)},
|
||||
// allow 8.1.1.1 => 5.6.7.8:23-24
|
||||
{Accept, parsed(packet.TCP, "8.1.1.1", "5.6.7.8", 0, 23)},
|
||||
{Accept, parsed(packet.TCP, "8.1.1.1", "5.6.7.8", 0, 24)},
|
||||
{Drop, parsed(packet.TCP, "8.1.1.3", "5.6.7.8", 0, 24)},
|
||||
{Drop, parsed(packet.TCP, "8.1.1.1", "5.6.7.8", 0, 22)},
|
||||
// allow * => *:443
|
||||
{Accept, parsed(packet.TCP, "17.34.51.68", "8.1.34.51", 0, 443)},
|
||||
{Drop, parsed(packet.TCP, "17.34.51.68", "8.1.34.51", 0, 444)},
|
||||
// allow * => 100.122.98.50:*
|
||||
{Accept, parsed(packet.TCP, "17.34.51.68", "100.122.98.50", 0, 999)},
|
||||
{Accept, parsed(packet.TCP, "17.34.51.68", "100.122.98.50", 0, 0)},
|
||||
|
||||
// allow ::1, ::2 => [2001::1]:22
|
||||
{Accept, parsed(packet.TCP, "::1", "2001::1", 0, 22)},
|
||||
{Accept, parsed(packet.ICMPv6, "::1", "2001::1", 0, 0)},
|
||||
{Accept, parsed(packet.TCP, "::2", "2001::1", 0, 22)},
|
||||
{Accept, parsed(packet.TCP, "::2", "2001::2", 0, 22)},
|
||||
{Drop, parsed(packet.TCP, "::1", "2001::1", 0, 23)},
|
||||
{Drop, parsed(packet.TCP, "::1", "2001::3", 0, 22)},
|
||||
{Drop, parsed(packet.TCP, "::3", "2001::1", 0, 22)},
|
||||
// allow * => *:443
|
||||
{Accept, parsed(packet.TCP, "::1", "2001::1", 0, 443)},
|
||||
{Drop, parsed(packet.TCP, "::1", "2001::1", 0, 444)},
|
||||
|
||||
// localNets prefilter - accepted by policy filter, but
|
||||
// unexpected dst IP.
|
||||
{Drop, parsed(TCP, 0x08010101, 0x10203040, 0, 443)},
|
||||
|
||||
// Stateful UDP. Note each packet is run through the input
|
||||
// filter, then the output filter (which sets conntrack
|
||||
// state).
|
||||
// Initially empty cache
|
||||
{Drop, parsed(UDP, 0x77777777, 0x66666666, 4242, 4343)},
|
||||
// Return packet from previous attempt is allowed
|
||||
{Accept, parsed(UDP, 0x66666666, 0x77777777, 4343, 4242)},
|
||||
// Because of the return above, initial attempt is allowed now
|
||||
{Accept, parsed(UDP, 0x77777777, 0x66666666, 4242, 4343)},
|
||||
{Drop, parsed(packet.TCP, "8.1.1.1", "16.32.48.64", 0, 443)},
|
||||
{Drop, parsed(packet.TCP, "1::", "2602::1", 0, 443)},
|
||||
}
|
||||
for i, test := range tests {
|
||||
if got, _ := acl.runIn(&test.p); test.want != got {
|
||||
t.Errorf("#%d got=%v want=%v packet:%v\n", i, got, test.want, test.p)
|
||||
aclFunc := acl.runIn4
|
||||
if test.p.IPVersion == 6 {
|
||||
aclFunc = acl.runIn6
|
||||
}
|
||||
if got, why := aclFunc(&test.p); test.want != got {
|
||||
t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
|
||||
}
|
||||
if test.p.IPProto == packet.TCP {
|
||||
var got Response
|
||||
if test.p.IPVersion == 4 {
|
||||
got = acl.CheckTCP(test.p.SrcIP4.Netaddr(), test.p.DstIP4.Netaddr(), test.p.DstPort)
|
||||
} else {
|
||||
got = acl.CheckTCP(test.p.SrcIP6.Netaddr(), test.p.DstIP6.Netaddr(), test.p.DstPort)
|
||||
}
|
||||
if test.want != got {
|
||||
t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p)
|
||||
}
|
||||
// TCP and UDP are treated equivalently in the filter - verify that.
|
||||
test.p.IPProto = packet.UDP
|
||||
if got, why := aclFunc(&test.p); test.want != got {
|
||||
t.Errorf("#%d runIn (UDP) got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
|
||||
}
|
||||
}
|
||||
// Update UDP state
|
||||
_, _ = acl.runOut(&test.p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPState(t *testing.T) {
|
||||
acl := newFilter(t.Logf)
|
||||
flags := LogDrops | LogAccepts
|
||||
|
||||
a4 := parsed(packet.UDP, "119.119.119.119", "102.102.102.102", 4242, 4343)
|
||||
b4 := parsed(packet.UDP, "102.102.102.102", "119.119.119.119", 4343, 4242)
|
||||
|
||||
// Unsollicited UDP traffic gets dropped
|
||||
if got := acl.RunIn(&a4, flags); got != Drop {
|
||||
t.Fatalf("incoming initial packet not dropped, got=%v: %v", got, a4)
|
||||
}
|
||||
// We talk to that peer
|
||||
if got := acl.RunOut(&b4, flags); got != Accept {
|
||||
t.Fatalf("outbound packet didn't egress, got=%v: %v", got, b4)
|
||||
}
|
||||
// Now, the same packet as before is allowed back.
|
||||
if got := acl.RunIn(&a4, flags); got != Accept {
|
||||
t.Fatalf("incoming response packet not accepted, got=%v: %v", got, a4)
|
||||
}
|
||||
|
||||
a6 := parsed(packet.UDP, "2001::2", "2001::1", 4242, 4343)
|
||||
b6 := parsed(packet.UDP, "2001::1", "2001::2", 4343, 4242)
|
||||
|
||||
// Unsollicited UDP traffic gets dropped
|
||||
if got := acl.RunIn(&a6, flags); got != Drop {
|
||||
t.Fatalf("incoming initial packet not dropped: %v", a4)
|
||||
}
|
||||
// We talk to that peer
|
||||
if got := acl.RunOut(&b6, flags); got != Accept {
|
||||
t.Fatalf("outbound packet didn't egress: %v", b4)
|
||||
}
|
||||
// Now, the same packet as before is allowed back.
|
||||
if got := acl.RunIn(&a6, flags); got != Accept {
|
||||
t.Fatalf("incoming response packet not accepted: %v", a4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoAllocs(t *testing.T) {
|
||||
acl := newFilter(t.Logf)
|
||||
|
||||
tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
tcp4Packet := raw4(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22, 0)
|
||||
udp4Packet := raw4(packet.UDP, "8.1.1.1", "1.2.3.4", 999, 22, 0)
|
||||
tcp6Packet := raw6(packet.TCP, "2001::1", "2001::2", 999, 22, 0)
|
||||
udp6Packet := raw6(packet.UDP, "2001::1", "2001::2", 999, 22, 0)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
dir direction
|
||||
want int
|
||||
packet []byte
|
||||
}{
|
||||
{"tcp_in", true, 0, tcpPacket},
|
||||
{"tcp_out", false, 0, tcpPacket},
|
||||
{"udp_in", true, 0, udpPacket},
|
||||
{"tcp4_in", in, 0, tcp4Packet},
|
||||
{"tcp6_in", in, 0, tcp6Packet},
|
||||
{"tcp4_out", out, 0, tcp4Packet},
|
||||
{"tcp6_out", out, 0, tcp6Packet},
|
||||
{"udp4_in", in, 0, udp4Packet},
|
||||
{"udp6_in", in, 0, udp6Packet},
|
||||
// One alloc is inevitable (an lru cache update)
|
||||
{"udp_out", false, 1, udpPacket},
|
||||
{"udp4_out", out, 1, udp4Packet},
|
||||
{"udp6_out", out, 1, udp6Packet},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := int(testing.AllocsPerRun(1000, func() {
|
||||
q := &ParsedPacket{}
|
||||
q := &packet.Parsed{}
|
||||
q.Decode(test.packet)
|
||||
if test.in {
|
||||
switch test.dir {
|
||||
case in:
|
||||
acl.RunIn(q, 0)
|
||||
} else {
|
||||
case out:
|
||||
acl.RunOut(q, 0)
|
||||
}
|
||||
}))
|
||||
@@ -167,15 +198,17 @@ func TestParseIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
host string
|
||||
bits int
|
||||
want Net
|
||||
want []netaddr.IPPrefix
|
||||
wantErr string
|
||||
}{
|
||||
{"8.8.8.8", 24, Net{IP: packet.NewIP(net.ParseIP("8.8.8.8")), Mask: packet.NewIP(net.ParseIP("255.255.255.0"))}, ""},
|
||||
{"8.8.8.8", 33, Net{}, `invalid CIDR size 33 for host "8.8.8.8"`},
|
||||
{"8.8.8.8", -1, Net{}, `invalid CIDR size -1 for host "8.8.8.8"`},
|
||||
{"0.0.0.0", 24, Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
|
||||
{"*", 24, NetAny, ""},
|
||||
{"fe80::1", 128, NetNone, `ports="fe80::1": invalid IPv4 address`},
|
||||
{"8.8.8.8", 24, pfx("8.8.8.8/24"), ""},
|
||||
{"2601:1234::", 64, pfx("2601:1234::/64"), ""},
|
||||
{"8.8.8.8", 33, nil, `invalid CIDR size 33 for host "8.8.8.8"`},
|
||||
{"8.8.8.8", -1, nil, `invalid CIDR size -1 for host "8.8.8.8"`},
|
||||
{"2601:1234::", 129, nil, `invalid CIDR size 129 for host "2601:1234::"`},
|
||||
{"0.0.0.0", 24, nil, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
|
||||
{"::", 64, nil, `ports="::": to allow all IP addresses, use *:port, not [::]:port`},
|
||||
{"*", 24, pfx("0.0.0.0/0", "::/0"), ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, err := parseIP(tt.host, tt.bits)
|
||||
@@ -185,45 +218,50 @@ func TestParseIP(t *testing.T) {
|
||||
}
|
||||
t.Errorf("parseIP(%q, %v) error: %v; want error %q", tt.host, tt.bits, err, tt.wantErr)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("parseIP(%q, %v) = %#v; want %#v", tt.host, tt.bits, got, tt.want)
|
||||
if diff := cmp.Diff(got, tt.want, cmp.Comparer(func(a, b netaddr.IP) bool { return a == b })); diff != "" {
|
||||
t.Errorf("parseIP(%q, %v) = %s; want %s", tt.host, tt.bits, got, tt.want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFilter(b *testing.B) {
|
||||
acl := newFilter(b.Logf)
|
||||
tcp4Packet := raw4(packet.TCP, "8.1.1.1", "1.2.3.4", 999, 22, 0)
|
||||
udp4Packet := raw4(packet.UDP, "8.1.1.1", "1.2.3.4", 999, 22, 0)
|
||||
icmp4Packet := raw4(packet.ICMPv4, "8.1.1.1", "1.2.3.4", 0, 0, 0)
|
||||
|
||||
tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
icmpPacket := rawpacket(ICMP, 0x08010101, 0x01020304, 0, 0, 0)
|
||||
|
||||
tcpSynPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
// TCP filtering is trivial (Accept) for non-SYN packets.
|
||||
tcpSynPacket[33] = packet.TCPSyn
|
||||
tcp6Packet := raw6(packet.TCP, "::1", "2001::1", 999, 22, 0)
|
||||
udp6Packet := raw6(packet.UDP, "::1", "2001::1", 999, 22, 0)
|
||||
icmp6Packet := raw6(packet.ICMPv6, "::1", "2001::1", 0, 0, 0)
|
||||
|
||||
benches := []struct {
|
||||
name string
|
||||
in bool
|
||||
dir direction
|
||||
packet []byte
|
||||
}{
|
||||
// Non-SYN TCP and ICMP have similar code paths in and out.
|
||||
{"icmp", true, icmpPacket},
|
||||
{"tcp", true, tcpPacket},
|
||||
{"tcp_syn_in", true, tcpSynPacket},
|
||||
{"tcp_syn_out", false, tcpSynPacket},
|
||||
{"udp_in", true, udpPacket},
|
||||
{"udp_out", false, udpPacket},
|
||||
{"icmp4", in, icmp4Packet},
|
||||
{"tcp4_syn_in", in, tcp4Packet},
|
||||
{"tcp4_syn_out", out, tcp4Packet},
|
||||
{"udp4_in", in, udp4Packet},
|
||||
{"udp4_out", out, udp4Packet},
|
||||
{"icmp6", in, icmp6Packet},
|
||||
{"tcp6_syn_in", in, tcp6Packet},
|
||||
{"tcp6_syn_out", out, tcp6Packet},
|
||||
{"udp6_in", in, udp6Packet},
|
||||
{"udp6_out", out, udp6Packet},
|
||||
}
|
||||
|
||||
for _, bench := range benches {
|
||||
b.Run(bench.name, func(b *testing.B) {
|
||||
acl := newFilter(b.Logf)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
q := &ParsedPacket{}
|
||||
q := &packet.Parsed{}
|
||||
q.Decode(bench.packet)
|
||||
// This branch seems to have no measurable impact on performance.
|
||||
if bench.in {
|
||||
if bench.dir == in {
|
||||
acl.RunIn(q, 0)
|
||||
} else {
|
||||
acl.RunOut(q, 0)
|
||||
@@ -241,15 +279,15 @@ func TestPreFilter(t *testing.T) {
|
||||
}{
|
||||
{"empty", Accept, []byte{}},
|
||||
{"short", Drop, []byte("short")},
|
||||
{"junk", Drop, rawdefault(Unknown, 10)},
|
||||
{"fragment", Accept, rawdefault(Fragment, 40)},
|
||||
{"tcp", noVerdict, rawdefault(TCP, 200)},
|
||||
{"udp", noVerdict, rawdefault(UDP, 200)},
|
||||
{"icmp", noVerdict, rawdefault(ICMP, 200)},
|
||||
{"junk", Drop, raw4default(packet.Unknown, 10)},
|
||||
{"fragment", Accept, raw4default(packet.Fragment, 40)},
|
||||
{"tcp", noVerdict, raw4default(packet.TCP, 0)},
|
||||
{"udp", noVerdict, raw4default(packet.UDP, 0)},
|
||||
{"icmp", noVerdict, raw4default(packet.ICMPv4, 0)},
|
||||
}
|
||||
f := NewAllowNone(t.Logf)
|
||||
for _, testPacket := range packets {
|
||||
p := &ParsedPacket{}
|
||||
p := &packet.Parsed{}
|
||||
p.Decode(testPacket.b)
|
||||
got := f.pre(p, LogDrops|LogAccepts, in)
|
||||
if got != testPacket.want {
|
||||
@@ -258,100 +296,16 @@ func TestPreFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func parsed(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) ParsedPacket {
|
||||
return ParsedPacket{
|
||||
IPProto: proto,
|
||||
SrcIP: src,
|
||||
DstIP: dst,
|
||||
SrcPort: sport,
|
||||
DstPort: dport,
|
||||
TCPFlags: packet.TCPSyn,
|
||||
}
|
||||
}
|
||||
|
||||
// rawpacket generates a packet with given source and destination ports and IPs
|
||||
// and resizes the header to trimLength if it is nonzero.
|
||||
func rawpacket(proto packet.IPProto, src, dst packet.IP, sport, dport uint16, trimLength int) []byte {
|
||||
var headerLength int
|
||||
|
||||
switch proto {
|
||||
case ICMP:
|
||||
headerLength = 24
|
||||
case TCP:
|
||||
headerLength = 40
|
||||
case UDP:
|
||||
headerLength = 28
|
||||
default:
|
||||
headerLength = 24
|
||||
}
|
||||
if trimLength > headerLength {
|
||||
headerLength = trimLength
|
||||
}
|
||||
if trimLength == 0 {
|
||||
trimLength = headerLength
|
||||
}
|
||||
|
||||
bin := binary.BigEndian
|
||||
hdr := make([]byte, headerLength)
|
||||
hdr[0] = 0x45
|
||||
bin.PutUint16(hdr[2:4], uint16(trimLength))
|
||||
hdr[8] = 64
|
||||
bin.PutUint32(hdr[12:16], uint32(src))
|
||||
bin.PutUint32(hdr[16:20], uint32(dst))
|
||||
// ports
|
||||
bin.PutUint16(hdr[20:22], sport)
|
||||
bin.PutUint16(hdr[22:24], dport)
|
||||
|
||||
switch proto {
|
||||
case ICMP:
|
||||
hdr[9] = 1
|
||||
case TCP:
|
||||
hdr[9] = 6
|
||||
case UDP:
|
||||
hdr[9] = 17
|
||||
case Fragment:
|
||||
hdr[9] = 6
|
||||
// flags + fragOff
|
||||
bin.PutUint16(hdr[6:8], (1<<13)|1234)
|
||||
case Unknown:
|
||||
default:
|
||||
panic("unknown protocol")
|
||||
}
|
||||
|
||||
// Trim the header if requested
|
||||
hdr = hdr[:trimLength]
|
||||
|
||||
return hdr
|
||||
}
|
||||
|
||||
// rawdefault calls rawpacket with default ports and IPs.
|
||||
func rawdefault(proto packet.IPProto, trimLength int) []byte {
|
||||
ip := IP(0x08080808) // 8.8.8.8
|
||||
port := uint16(53)
|
||||
return rawpacket(proto, ip, ip, port, port, trimLength)
|
||||
}
|
||||
|
||||
func parseHexPkt(t *testing.T, h string) *packet.ParsedPacket {
|
||||
t.Helper()
|
||||
b, err := hex.DecodeString(strings.ReplaceAll(h, " ", ""))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read hex %q: %v", h, err)
|
||||
}
|
||||
p := new(packet.ParsedPacket)
|
||||
p.Decode(b)
|
||||
return p
|
||||
}
|
||||
|
||||
func TestOmitDropLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pkt *packet.ParsedPacket
|
||||
pkt *packet.Parsed
|
||||
dir direction
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "v4_tcp_out",
|
||||
pkt: &packet.ParsedPacket{IPVersion: 4, IPProto: packet.TCP},
|
||||
pkt: &packet.Parsed{IPVersion: 4, IPProto: packet.TCP},
|
||||
dir: out,
|
||||
want: false,
|
||||
},
|
||||
@@ -381,19 +335,19 @@ func TestOmitDropLogging(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "v4_multicast_out_low",
|
||||
pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP(net.ParseIP("224.0.0.0"))},
|
||||
pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("224.0.0.0")},
|
||||
dir: out,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "v4_multicast_out_high",
|
||||
pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP(net.ParseIP("239.255.255.255"))},
|
||||
pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("239.255.255.255")},
|
||||
dir: out,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "v4_link_local_unicast",
|
||||
pkt: &packet.ParsedPacket{IPVersion: 4, DstIP: packet.NewIP(net.ParseIP("169.254.1.2"))},
|
||||
pkt: &packet.Parsed{IPVersion: 4, DstIP4: mustIP4("169.254.1.2")},
|
||||
dir: out,
|
||||
want: true,
|
||||
},
|
||||
@@ -408,3 +362,201 @@ func TestOmitDropLogging(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustIP(s string) netaddr.IP {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func parsed(proto packet.IPProto, src, dst string, sport, dport uint16) packet.Parsed {
|
||||
sip, dip := mustIP(src), mustIP(dst)
|
||||
|
||||
var ret packet.Parsed
|
||||
ret.Decode(dummyPacket)
|
||||
ret.IPProto = proto
|
||||
ret.SrcPort = sport
|
||||
ret.DstPort = dport
|
||||
ret.TCPFlags = packet.TCPSyn
|
||||
|
||||
if sip.Is4() {
|
||||
ret.IPVersion = 4
|
||||
ret.SrcIP4 = packet.IP4FromNetaddr(sip)
|
||||
ret.DstIP4 = packet.IP4FromNetaddr(dip)
|
||||
} else {
|
||||
ret.IPVersion = 6
|
||||
ret.SrcIP6 = packet.IP6FromNetaddr(sip)
|
||||
ret.DstIP6 = packet.IP6FromNetaddr(dip)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func raw6(proto packet.IPProto, src, dst string, sport, dport uint16, trimLen int) []byte {
|
||||
u := packet.UDP6Header{
|
||||
IP6Header: packet.IP6Header{
|
||||
SrcIP: packet.IP6FromNetaddr(mustIP(src)),
|
||||
DstIP: packet.IP6FromNetaddr(mustIP(dst)),
|
||||
},
|
||||
SrcPort: sport,
|
||||
DstPort: dport,
|
||||
}
|
||||
|
||||
payload := make([]byte, 12)
|
||||
// Set the right bit to look like a TCP SYN, if the packet ends up interpreted as TCP
|
||||
payload[5] = packet.TCPSyn
|
||||
|
||||
b := packet.Generate(&u, payload) // payload large enough to possibly be TCP
|
||||
|
||||
// UDP marshaling clobbers IPProto, so override it here.
|
||||
u.IP6Header.IPProto = proto
|
||||
if err := u.IP6Header.Marshal(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if trimLen > 0 {
|
||||
return b[:trimLen]
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func raw4(proto packet.IPProto, src, dst string, sport, dport uint16, trimLength int) []byte {
|
||||
u := packet.UDP4Header{
|
||||
IP4Header: packet.IP4Header{
|
||||
SrcIP: packet.IP4FromNetaddr(mustIP(src)),
|
||||
DstIP: packet.IP4FromNetaddr(mustIP(dst)),
|
||||
},
|
||||
SrcPort: sport,
|
||||
DstPort: dport,
|
||||
}
|
||||
|
||||
payload := make([]byte, 12)
|
||||
// Set the right bit to look like a TCP SYN, if the packet ends up interpreted as TCP
|
||||
payload[5] = packet.TCPSyn
|
||||
|
||||
b := packet.Generate(&u, payload) // payload large enough to possibly be TCP
|
||||
|
||||
// UDP marshaling clobbers IPProto, so override it here.
|
||||
switch proto {
|
||||
case packet.Unknown, packet.Fragment:
|
||||
default:
|
||||
u.IP4Header.IPProto = proto
|
||||
}
|
||||
if err := u.IP4Header.Marshal(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if proto == packet.Fragment {
|
||||
// Set some fragment offset. This makes the IP
|
||||
// checksum wrong, but we don't validate the checksum
|
||||
// when parsing.
|
||||
b[7] = 255
|
||||
}
|
||||
|
||||
if trimLength > 0 {
|
||||
return b[:trimLength]
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func raw4default(proto packet.IPProto, trimLength int) []byte {
|
||||
return raw4(proto, "8.8.8.8", "8.8.8.8", 53, 53, trimLength)
|
||||
}
|
||||
|
||||
func parseHexPkt(t *testing.T, h string) *packet.Parsed {
|
||||
t.Helper()
|
||||
b, err := hex.DecodeString(strings.ReplaceAll(h, " ", ""))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read hex %q: %v", h, err)
|
||||
}
|
||||
p := new(packet.Parsed)
|
||||
p.Decode(b)
|
||||
return p
|
||||
}
|
||||
|
||||
func mustIP4(s string) packet.IP4 {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return packet.IP4FromNetaddr(ip)
|
||||
}
|
||||
|
||||
func pfx(strs ...string) (ret []netaddr.IPPrefix) {
|
||||
for _, s := range strs {
|
||||
pfx, err := netaddr.ParseIPPrefix(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret = append(ret, pfx)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func nets(nets ...string) (ret []netaddr.IPPrefix) {
|
||||
for _, s := range nets {
|
||||
if i := strings.IndexByte(s, '/'); i == -1 {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bits := uint8(32)
|
||||
if ip.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
|
||||
} else {
|
||||
pfx, err := netaddr.ParseIPPrefix(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret = append(ret, pfx)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func ports(s string) PortRange {
|
||||
if s == "*" {
|
||||
return PortRange{First: 0, Last: 65535}
|
||||
}
|
||||
|
||||
var fs, ls string
|
||||
i := strings.IndexByte(s, '-')
|
||||
if i == -1 {
|
||||
fs = s
|
||||
ls = fs
|
||||
} else {
|
||||
fs = s[:i]
|
||||
ls = s[i+1:]
|
||||
}
|
||||
first, err := strconv.ParseInt(fs, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
last, err := strconv.ParseInt(ls, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
return PortRange{uint16(first), uint16(last)}
|
||||
}
|
||||
|
||||
func netports(netPorts ...string) (ret []NetPortRange) {
|
||||
for _, s := range netPorts {
|
||||
i := strings.LastIndexByte(s, ':')
|
||||
if i == -1 {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
|
||||
npr := NetPortRange{
|
||||
Net: nets(s[:i])[0],
|
||||
Ports: ports(s[i+1:]),
|
||||
}
|
||||
ret = append(ret, npr)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
@@ -6,55 +6,16 @@ package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/wgengine/packet"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
func NewIP(ip net.IP) packet.IP {
|
||||
return packet.NewIP(ip)
|
||||
}
|
||||
|
||||
type Net struct {
|
||||
IP packet.IP
|
||||
Mask packet.IP
|
||||
}
|
||||
|
||||
func (n Net) Includes(ip packet.IP) bool {
|
||||
return (n.IP & n.Mask) == (ip & n.Mask)
|
||||
}
|
||||
|
||||
func (n Net) Bits() int {
|
||||
return 32 - bits.TrailingZeros32(uint32(n.Mask))
|
||||
}
|
||||
|
||||
func (n Net) String() string {
|
||||
b := n.Bits()
|
||||
if b == 32 {
|
||||
return n.IP.String()
|
||||
} else if b == 0 {
|
||||
return "*"
|
||||
} else {
|
||||
return fmt.Sprintf("%s/%d", n.IP, b)
|
||||
}
|
||||
}
|
||||
|
||||
var NetAny = Net{0, 0}
|
||||
var NetNone = Net{^packet.IP(0), ^packet.IP(0)}
|
||||
|
||||
func Netmask(bits int) packet.IP {
|
||||
b := ^uint32((1 << (32 - bits)) - 1)
|
||||
return packet.IP(b)
|
||||
}
|
||||
|
||||
// PortRange is a range of TCP and UDP ports.
|
||||
type PortRange struct {
|
||||
First, Last uint16
|
||||
First, Last uint16 // inclusive
|
||||
}
|
||||
|
||||
var PortRangeAny = PortRange{0, 65535}
|
||||
|
||||
func (pr PortRange) String() string {
|
||||
if pr.First == 0 && pr.Last == 65535 {
|
||||
return "*"
|
||||
@@ -65,30 +26,26 @@ func (pr PortRange) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// contains returns whether port is in pr.
|
||||
func (pr PortRange) contains(port uint16) bool {
|
||||
return port >= pr.First && port <= pr.Last
|
||||
}
|
||||
|
||||
// NetPortRange combines an IP address prefix and PortRange.
|
||||
type NetPortRange struct {
|
||||
Net Net
|
||||
Net netaddr.IPPrefix
|
||||
Ports PortRange
|
||||
}
|
||||
|
||||
var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny}
|
||||
|
||||
func (ipr NetPortRange) String() string {
|
||||
return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports)
|
||||
func (npr NetPortRange) String() string {
|
||||
return fmt.Sprintf("%v:%v", npr.Net, npr.Ports)
|
||||
}
|
||||
|
||||
// Match matches packets from any IP address in Srcs to any ip:port in
|
||||
// Dsts.
|
||||
type Match struct {
|
||||
Dsts []NetPortRange
|
||||
Srcs []Net
|
||||
}
|
||||
|
||||
func (m Match) Clone() (res Match) {
|
||||
if m.Dsts != nil {
|
||||
res.Dsts = append([]NetPortRange{}, m.Dsts...)
|
||||
}
|
||||
if m.Srcs != nil {
|
||||
res.Srcs = append([]Net{}, m.Srcs...)
|
||||
}
|
||||
return res
|
||||
Srcs []netaddr.IPPrefix
|
||||
}
|
||||
|
||||
func (m Match) String() string {
|
||||
@@ -114,58 +71,3 @@ func (m Match) String() string {
|
||||
}
|
||||
return fmt.Sprintf("%v=>%v", ss, ds)
|
||||
}
|
||||
|
||||
type Matches []Match
|
||||
|
||||
func (m Matches) Clone() (res Matches) {
|
||||
for _, match := range m {
|
||||
res = append(res, match.Clone())
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func ipInList(ip packet.IP, netlist []Net) bool {
|
||||
for _, net := range netlist {
|
||||
if net.Includes(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchIPPorts(mm Matches, q *packet.ParsedPacket) bool {
|
||||
for _, acl := range mm {
|
||||
for _, dst := range acl.Dsts {
|
||||
if !dst.Net.Includes(q.DstIP) {
|
||||
continue
|
||||
}
|
||||
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
|
||||
continue
|
||||
}
|
||||
if !ipInList(q.SrcIP, acl.Srcs) {
|
||||
// Skip other dests in this acl, since
|
||||
// the src will never match.
|
||||
break
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchIPWithoutPorts(mm Matches, q *packet.ParsedPacket) bool {
|
||||
for _, acl := range mm {
|
||||
for _, dst := range acl.Dsts {
|
||||
if !dst.Net.Includes(q.DstIP) {
|
||||
continue
|
||||
}
|
||||
if !ipInList(q.SrcIP, acl.Srcs) {
|
||||
// Skip other dests in this acl, since
|
||||
// the src will never match.
|
||||
break
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
151
wgengine/filter/match4.go
Normal file
151
wgengine/filter/match4.go
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"strings"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
)
|
||||
|
||||
type net4 struct {
|
||||
ip packet.IP4
|
||||
mask packet.IP4
|
||||
}
|
||||
|
||||
func net4FromIPPrefix(pfx netaddr.IPPrefix) net4 {
|
||||
if !pfx.IP.Is4() {
|
||||
panic("net4FromIPPrefix given non-ipv4 prefix")
|
||||
}
|
||||
return net4{
|
||||
ip: packet.IP4FromNetaddr(pfx.IP),
|
||||
mask: netmask4(pfx.Bits),
|
||||
}
|
||||
}
|
||||
|
||||
func nets4FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net4) {
|
||||
for _, pfx := range pfxs {
|
||||
if pfx.IP.Is4() {
|
||||
ret = append(ret, net4FromIPPrefix(pfx))
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (n net4) Contains(ip packet.IP4) bool {
|
||||
return (n.ip & n.mask) == (ip & n.mask)
|
||||
}
|
||||
|
||||
func (n net4) Bits() int {
|
||||
return 32 - bits.TrailingZeros32(uint32(n.mask))
|
||||
}
|
||||
|
||||
func (n net4) String() string {
|
||||
b := n.Bits()
|
||||
if b == 32 {
|
||||
return n.ip.String()
|
||||
} else if b == 0 {
|
||||
return "*"
|
||||
} else {
|
||||
return fmt.Sprintf("%s/%d", n.ip, b)
|
||||
}
|
||||
}
|
||||
|
||||
type npr4 struct {
|
||||
net net4
|
||||
ports PortRange
|
||||
}
|
||||
|
||||
func (npr npr4) String() string {
|
||||
return fmt.Sprintf("%s:%s", npr.net, npr.ports)
|
||||
}
|
||||
|
||||
type match4 struct {
|
||||
srcs []net4
|
||||
dsts []npr4
|
||||
}
|
||||
|
||||
type matches4 []match4
|
||||
|
||||
func (ms matches4) String() string {
|
||||
var b strings.Builder
|
||||
for _, m := range ms {
|
||||
fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func newMatches4(ms []Match) (ret matches4) {
|
||||
for _, m := range ms {
|
||||
var m4 match4
|
||||
for _, src := range m.Srcs {
|
||||
if src.IP.Is4() {
|
||||
m4.srcs = append(m4.srcs, net4FromIPPrefix(src))
|
||||
}
|
||||
}
|
||||
for _, dst := range m.Dsts {
|
||||
if dst.Net.IP.Is4() {
|
||||
m4.dsts = append(m4.dsts, npr4{net4FromIPPrefix(dst.Net), dst.Ports})
|
||||
}
|
||||
}
|
||||
if len(m4.srcs) > 0 && len(m4.dsts) > 0 {
|
||||
ret = append(ret, m4)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// match returns whether q's source IP and destination IP:port match
|
||||
// any of ms.
|
||||
func (ms matches4) match(q *packet.Parsed) bool {
|
||||
for _, m := range ms {
|
||||
if !ip4InList(q.SrcIP4, m.srcs) {
|
||||
continue
|
||||
}
|
||||
for _, dst := range m.dsts {
|
||||
if !dst.net.Contains(q.DstIP4) {
|
||||
continue
|
||||
}
|
||||
if !dst.ports.contains(q.DstPort) {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchIPsOnly returns whether q's source and destination IP match
|
||||
// any of ms.
|
||||
func (ms matches4) matchIPsOnly(q *packet.Parsed) bool {
|
||||
for _, m := range ms {
|
||||
if !ip4InList(q.SrcIP4, m.srcs) {
|
||||
continue
|
||||
}
|
||||
for _, dst := range m.dsts {
|
||||
if dst.net.Contains(q.DstIP4) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func netmask4(bits uint8) packet.IP4 {
|
||||
b := ^uint32((1 << (32 - bits)) - 1)
|
||||
return packet.IP4(b)
|
||||
}
|
||||
|
||||
func ip4InList(ip packet.IP4, netlist []net4) bool {
|
||||
for _, net := range netlist {
|
||||
if net.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
171
wgengine/filter/match6.go
Normal file
171
wgengine/filter/match6.go
Normal file
@@ -0,0 +1,171 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"strings"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
)
|
||||
|
||||
type net6 struct {
|
||||
ip packet.IP6
|
||||
mask packet.IP6
|
||||
}
|
||||
|
||||
func net6FromIPPrefix(pfx netaddr.IPPrefix) net6 {
|
||||
if !pfx.IP.Is6() {
|
||||
panic("net6FromIPPrefix given non-ipv6 prefix")
|
||||
}
|
||||
var mask packet.IP6
|
||||
if pfx.Bits > 64 {
|
||||
mask.Hi = ^uint64(0)
|
||||
mask.Lo = (^uint64(0) << (128 - pfx.Bits))
|
||||
} else {
|
||||
mask.Hi = (^uint64(0) << (64 - pfx.Bits))
|
||||
}
|
||||
|
||||
return net6{
|
||||
ip: packet.IP6FromNetaddr(pfx.IP),
|
||||
mask: mask,
|
||||
}
|
||||
}
|
||||
|
||||
func nets6FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net6) {
|
||||
for _, pfx := range pfxs {
|
||||
if pfx.IP.Is6() {
|
||||
ret = append(ret, net6FromIPPrefix(pfx))
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (n net6) Contains(ip packet.IP6) bool {
|
||||
// This is equivalent to the more straightforward implementation:
|
||||
// ((n.ip.Hi & n.mask.Hi) == (ip.Hi & n.mask.Hi) &&
|
||||
// (n.ip.Lo & n.mask.Lo) == (ip.Lo & n.mask.Lo))
|
||||
//
|
||||
// This implementation runs significantly faster because it
|
||||
// eliminates branches and minimizes the required
|
||||
// bit-twiddling.
|
||||
a := (n.ip.Hi ^ ip.Hi) & n.mask.Hi
|
||||
b := (n.ip.Lo ^ ip.Lo) & n.mask.Lo
|
||||
return (a | b) == 0
|
||||
}
|
||||
|
||||
func (n net6) Bits() int {
|
||||
return 128 - bits.TrailingZeros64(n.mask.Hi) - bits.TrailingZeros64(n.mask.Lo)
|
||||
}
|
||||
|
||||
func (n net6) String() string {
|
||||
switch n.Bits() {
|
||||
case 128:
|
||||
return n.ip.String()
|
||||
case 0:
|
||||
return "*"
|
||||
default:
|
||||
return fmt.Sprintf("%s/%d", n.ip, n.Bits())
|
||||
}
|
||||
}
|
||||
|
||||
type npr6 struct {
|
||||
net net6
|
||||
ports PortRange
|
||||
}
|
||||
|
||||
func (npr npr6) String() string {
|
||||
return fmt.Sprintf("%s:%s", npr.net, npr.ports)
|
||||
}
|
||||
|
||||
type match6 struct {
|
||||
srcs []net6
|
||||
dsts []npr6
|
||||
}
|
||||
|
||||
type matches6 []match6
|
||||
|
||||
func (ms matches6) String() string {
|
||||
var b strings.Builder
|
||||
for _, m := range ms {
|
||||
fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func newMatches6(ms []Match) (ret matches6) {
|
||||
for _, m := range ms {
|
||||
var m6 match6
|
||||
for _, src := range m.Srcs {
|
||||
if src.IP.Is6() {
|
||||
m6.srcs = append(m6.srcs, net6FromIPPrefix(src))
|
||||
}
|
||||
}
|
||||
for _, dst := range m.Dsts {
|
||||
if dst.Net.IP.Is6() {
|
||||
m6.dsts = append(m6.dsts, npr6{net6FromIPPrefix(dst.Net), dst.Ports})
|
||||
}
|
||||
}
|
||||
if len(m6.srcs) > 0 && len(m6.dsts) > 0 {
|
||||
ret = append(ret, m6)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (ms matches6) match(q *packet.Parsed) bool {
|
||||
outer:
|
||||
for i := range ms {
|
||||
srcs := ms[i].srcs
|
||||
for j := range srcs {
|
||||
if srcs[j].Contains(q.SrcIP6) {
|
||||
dsts := ms[i].dsts
|
||||
for k := range dsts {
|
||||
if dsts[k].net.Contains(q.DstIP6) && dsts[k].ports.contains(q.DstPort) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// We hit on src, but missed on all
|
||||
// dsts. No need to try other srcs,
|
||||
// they'll never fully match.
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ms matches6) matchIPsOnly(q *packet.Parsed) bool {
|
||||
outer:
|
||||
for i := range ms {
|
||||
srcs := ms[i].srcs
|
||||
for j := range srcs {
|
||||
if srcs[j].Contains(q.SrcIP6) {
|
||||
dsts := ms[i].dsts
|
||||
for k := range dsts {
|
||||
if dsts[k].net.Contains(q.DstIP6) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// We hit on src, but missed on all
|
||||
// dsts. No need to try other srcs,
|
||||
// they'll never fully match.
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ip6InList(ip packet.IP6, netlist []net6) bool {
|
||||
for _, net := range netlist {
|
||||
if net.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
37
wgengine/filter/match6_test.go
Normal file
37
wgengine/filter/match6_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filter
|
||||
|
||||
import "testing"
|
||||
|
||||
// Verifies that the fast bit-twiddling implementation of Contains
|
||||
// works the same as the easy-to-read implementation. Since we can't
|
||||
// sensibly check it on 128 bits, the test runs over 4-bit
|
||||
// "IPs". Bit-twiddling is the same at any width, so this adequately
|
||||
// proves that the implementations are equivalent.
|
||||
func TestOptimizedContains(t *testing.T) {
|
||||
for ipHi := 0; ipHi < 0xf; ipHi++ {
|
||||
for ipLo := 0; ipLo < 0xf; ipLo++ {
|
||||
for nIPHi := 0; nIPHi < 0xf; nIPHi++ {
|
||||
for nIPLo := 0; nIPLo < 0xf; nIPLo++ {
|
||||
for maskHi := 0; maskHi < 0xf; maskHi++ {
|
||||
for maskLo := 0; maskLo < 0xf; maskLo++ {
|
||||
|
||||
a := (nIPHi ^ ipHi) & maskHi
|
||||
b := (nIPLo ^ ipLo) & maskLo
|
||||
got := (a | b) == 0
|
||||
|
||||
want := ((nIPHi&maskHi) == (ipHi&maskHi) && (nIPLo&maskLo) == (ipLo&maskLo))
|
||||
|
||||
if got != want {
|
||||
t.Errorf("mask %1x%1x/%1x%1x %1x%1x got=%v want=%v", nIPHi, nIPLo, maskHi, maskLo, ipHi, ipLo, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
99
wgengine/filter/tailcfg.go
Normal file
99
wgengine/filter/tailcfg.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// MatchesFromFilterRules converts tailcfg FilterRules into Matches.
|
||||
// If an error is returned, the Matches result is still valid,
|
||||
// containing the rules that were successfully converted.
|
||||
func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
|
||||
mm := make([]Match, 0, len(pf))
|
||||
var erracc error
|
||||
|
||||
for _, r := range pf {
|
||||
m := Match{}
|
||||
|
||||
for i, s := range r.SrcIPs {
|
||||
bits := 32
|
||||
if len(r.SrcBits) > i {
|
||||
bits = r.SrcBits[i]
|
||||
}
|
||||
nets, err := parseIP(s, bits)
|
||||
if err != nil && erracc == nil {
|
||||
erracc = err
|
||||
continue
|
||||
}
|
||||
m.Srcs = append(m.Srcs, nets...)
|
||||
}
|
||||
|
||||
for _, d := range r.DstPorts {
|
||||
bits := 32
|
||||
if d.Bits != nil {
|
||||
bits = *d.Bits
|
||||
}
|
||||
nets, err := parseIP(d.IP, bits)
|
||||
if err != nil && erracc == nil {
|
||||
erracc = err
|
||||
continue
|
||||
}
|
||||
for _, net := range nets {
|
||||
m.Dsts = append(m.Dsts, NetPortRange{
|
||||
Net: net,
|
||||
Ports: PortRange{
|
||||
First: d.Ports.First,
|
||||
Last: d.Ports.Last,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
mm = append(mm, m)
|
||||
}
|
||||
return mm, erracc
|
||||
}
|
||||
|
||||
var (
|
||||
zeroIP4 = netaddr.IPv4(0, 0, 0, 0)
|
||||
zeroIP6 = netaddr.IPFrom16([16]byte{})
|
||||
)
|
||||
|
||||
func parseIP(host string, defaultBits int) ([]netaddr.IPPrefix, error) {
|
||||
if host == "*" {
|
||||
// User explicitly requested wildcard dst ip.
|
||||
return []netaddr.IPPrefix{
|
||||
{IP: zeroIP4, Bits: 0},
|
||||
{IP: zeroIP6, Bits: 0},
|
||||
}, nil
|
||||
}
|
||||
|
||||
ip, err := netaddr.ParseIP(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ports=%#v: invalid IP address", host)
|
||||
}
|
||||
if ip == zeroIP4 {
|
||||
// For clarity, reject 0.0.0.0 as an input
|
||||
return nil, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
|
||||
}
|
||||
if ip == zeroIP6 {
|
||||
// For clarity, reject :: as an input
|
||||
return nil, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not [::]:port", host)
|
||||
}
|
||||
|
||||
if defaultBits < 0 || (ip.Is4() && defaultBits > 32) || (ip.Is6() && defaultBits > 128) {
|
||||
return nil, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
|
||||
}
|
||||
return []netaddr.IPPrefix{
|
||||
{
|
||||
IP: ip,
|
||||
Bits: uint8(defaultBits),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -2492,6 +2492,9 @@ func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error {
|
||||
host := ""
|
||||
if inTest() && !c.simulatedNetwork {
|
||||
host = "127.0.0.1"
|
||||
if which == "udp6" {
|
||||
host = "::1"
|
||||
}
|
||||
}
|
||||
var pc net.PacketConn
|
||||
var err error
|
||||
|
||||
@@ -158,7 +158,7 @@ func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, der
|
||||
|
||||
tun := tuntest.NewChannelTUN()
|
||||
tsTun := tstun.WrapTUN(logf, tun.TUN())
|
||||
tsTun.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf))
|
||||
tsTun.SetFilter(filter.NewAllowAllForTest(logf))
|
||||
|
||||
dev := device.NewDevice(tsTun, &device.DeviceOptions{
|
||||
Logger: &device.Logger{
|
||||
|
||||
@@ -7,6 +7,8 @@ package monitor
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -24,9 +26,15 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
iphlpapi = syscall.NewLazyDLL("iphlpapi.dll")
|
||||
notifyAddrChangeProc = iphlpapi.NewProc("NotifyAddrChange")
|
||||
notifyRouteChangeProc = iphlpapi.NewProc("NotifyRouteChange")
|
||||
iphlpapi = syscall.NewLazyDLL("iphlpapi.dll")
|
||||
notifyAddrChangeProc = iphlpapi.NewProc("NotifyAddrChange")
|
||||
notifyRouteChangeProc = iphlpapi.NewProc("NotifyRouteChange")
|
||||
cancelIPChangeNotifyProc = iphlpapi.NewProc("CancelIPChangeNotify")
|
||||
)
|
||||
|
||||
const (
|
||||
_STATUS_PENDING = 0x00000103 // 259
|
||||
_STATUS_WAIT_0 = 0
|
||||
)
|
||||
|
||||
type unspecifiedMessage struct{}
|
||||
@@ -43,27 +51,33 @@ type messageOrError struct {
|
||||
}
|
||||
|
||||
type winMon struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
messagec chan messageOrError
|
||||
logf logger.Logf
|
||||
pollTicker *time.Ticker
|
||||
lastState *interfaces.State
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
messagec chan messageOrError
|
||||
logf logger.Logf
|
||||
pollTicker *time.Ticker
|
||||
lastState *interfaces.State
|
||||
closeHandle windows.Handle // signaled upon close
|
||||
|
||||
mu sync.Mutex
|
||||
event windows.Handle
|
||||
lastNetChange time.Time
|
||||
inFastPoll bool // recent net change event made us go into fast polling mode (to detect proxy changes)
|
||||
}
|
||||
|
||||
func newOSMon(logf logger.Logf) (osMon, error) {
|
||||
closeHandle, err := windows.CreateEvent(nil, 1 /* manual reset */, 0 /* unsignaled */, nil /* no name */)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CreateEvent: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m := &winMon{
|
||||
logf: logf,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
messagec: make(chan messageOrError, 1),
|
||||
pollTicker: time.NewTicker(pollIntervalSlow),
|
||||
logf: logf,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
messagec: make(chan messageOrError, 1),
|
||||
pollTicker: time.NewTicker(pollIntervalSlow),
|
||||
closeHandle: closeHandle,
|
||||
}
|
||||
go m.awaitIPAndRouteChanges()
|
||||
return m, nil
|
||||
@@ -72,14 +86,7 @@ func newOSMon(logf logger.Logf) (osMon, error) {
|
||||
func (m *winMon) Close() error {
|
||||
m.cancel()
|
||||
m.pollTicker.Stop()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if h := m.event; h != 0 {
|
||||
// Wake up any reader blocked in Receive.
|
||||
windows.SetEvent(h)
|
||||
}
|
||||
|
||||
windows.SetEvent(m.closeHandle) // wakes up any reader blocked in Receive
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -136,52 +143,80 @@ func (m *winMon) getIPOrRouteChangeMessage() (message, error) {
|
||||
return nil, errClosed
|
||||
}
|
||||
|
||||
var o windows.Overlapped
|
||||
h, err := windows.CreateEvent(nil, 1 /* true*/, 0 /* unsignaled */, nil /* no name */)
|
||||
if err != nil {
|
||||
m.logf("CreateEvent: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer windows.CloseHandle(h)
|
||||
// TODO(bradfitz): locking ourselves to an OS thread here
|
||||
// likely isn't necessary, but also can't really hurt.
|
||||
// We'll be blocked in windows.WaitForMultipleObjects below
|
||||
// anyway, so might as well stay on this thread during the
|
||||
// notify calls and cancel funcs.
|
||||
// Given the past memory corruption from misuse of these APIs,
|
||||
// and my continued lack of understanding of Windows APIs,
|
||||
// I'll be paranoid. But perhaps we can remove this once
|
||||
// we understand more.
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
m.mu.Lock()
|
||||
m.event = h
|
||||
m.mu.Unlock()
|
||||
|
||||
o.HEvent = h
|
||||
|
||||
err = notifyAddrChange(&h, &o)
|
||||
addrHandle, oaddr, cancel, err := notifyAddrChange()
|
||||
if err != nil {
|
||||
m.logf("notifyAddrChange: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
err = notifyRouteChange(&h, &o)
|
||||
defer cancel()
|
||||
|
||||
routeHandle, oroute, cancel, err := notifyRouteChange()
|
||||
if err != nil {
|
||||
m.logf("notifyRouteChange: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
t0 := time.Now()
|
||||
_, err = windows.WaitForSingleObject(o.HEvent, windows.INFINITE)
|
||||
if m.ctx.Err() != nil {
|
||||
eventNum, err := windows.WaitForMultipleObjects([]windows.Handle{
|
||||
m.closeHandle, // eventNum 0
|
||||
addrHandle, // eventNum 1
|
||||
routeHandle, // eventNum 2
|
||||
}, false, windows.INFINITE)
|
||||
if m.ctx.Err() != nil || (err == nil && eventNum == 0) {
|
||||
return nil, errClosed
|
||||
}
|
||||
if err != nil {
|
||||
m.logf("waitForSingleObject: %v", err)
|
||||
m.logf("waitForMultipleObjects: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := time.Since(t0)
|
||||
m.logf("got windows change event after %v", d)
|
||||
var eventStr string
|
||||
|
||||
// notifyAddrChange and notifyRouteChange both seem to return the same
|
||||
// handle value. Determine which fired by looking at the "Internal" (sic)
|
||||
// field of the Ovelapped instead.
|
||||
// TODO(bradfitz): maybe clean this up; see TODO in callNotifyProc.
|
||||
if (eventNum == 1 || eventNum == 2) && addrHandle == routeHandle {
|
||||
if oaddr.Internal == _STATUS_WAIT_0 && oroute.Internal == _STATUS_PENDING {
|
||||
eventStr = "addr-o" // "-o" overlapped suffix to distinguish from "addr" below
|
||||
} else if oroute.Internal == _STATUS_WAIT_0 && oaddr.Internal == _STATUS_PENDING {
|
||||
eventStr = "route-o"
|
||||
} else {
|
||||
eventStr = fmt.Sprintf("[unexpected] addr.internal=%d; route.internal=%d", oaddr.Internal, oroute.Internal)
|
||||
}
|
||||
} else {
|
||||
switch eventNum {
|
||||
case 1:
|
||||
eventStr = "addr"
|
||||
case 2:
|
||||
eventStr = "route"
|
||||
default:
|
||||
eventStr = fmt.Sprintf("%d [unexpected]", eventNum)
|
||||
}
|
||||
}
|
||||
m.logf("got windows change event after %v: evt=%s", d, eventStr)
|
||||
|
||||
m.mu.Lock()
|
||||
{
|
||||
m.lastNetChange = time.Now()
|
||||
m.event = 0
|
||||
|
||||
// Something changed, so assume Windows is about to
|
||||
// discover its new proxy settings from WPAD, which
|
||||
// seems to take a bit. Poll heavily for awhile.
|
||||
m.logf("starting quick poll, waiting for WPAD change")
|
||||
m.inFastPoll = true
|
||||
m.pollTicker.Reset(pollIntervalFast)
|
||||
}
|
||||
@@ -190,23 +225,46 @@ func (m *winMon) getIPOrRouteChangeMessage() (message, error) {
|
||||
return unspecifiedMessage{}, nil
|
||||
}
|
||||
|
||||
func notifyAddrChange(h *windows.Handle, o *windows.Overlapped) error {
|
||||
return callNotifyProc(notifyAddrChangeProc, h, o)
|
||||
func notifyAddrChange() (h windows.Handle, o *windows.Overlapped, cancel func(), err error) {
|
||||
return callNotifyProc(notifyAddrChangeProc)
|
||||
}
|
||||
|
||||
func notifyRouteChange(h *windows.Handle, o *windows.Overlapped) error {
|
||||
return callNotifyProc(notifyRouteChangeProc, h, o)
|
||||
func notifyRouteChange() (h windows.Handle, o *windows.Overlapped, cancel func(), err error) {
|
||||
return callNotifyProc(notifyRouteChangeProc)
|
||||
}
|
||||
|
||||
func callNotifyProc(p *syscall.LazyProc, h *windows.Handle, o *windows.Overlapped) error {
|
||||
r1, _, e1 := p.Call(uintptr(unsafe.Pointer(h)), uintptr(unsafe.Pointer(o)))
|
||||
expect := uintptr(0)
|
||||
if h != nil || o != nil {
|
||||
const ERROR_IO_PENDING = 997
|
||||
expect = ERROR_IO_PENDING
|
||||
func callNotifyProc(p *syscall.LazyProc) (h windows.Handle, o *windows.Overlapped, cancel func(), err error) {
|
||||
o = new(windows.Overlapped)
|
||||
|
||||
// TODO(bradfitz): understand why this if-false code doesn't
|
||||
// work, even though the docs online suggest we should pass an
|
||||
// event in the overlapped.Hevent field.
|
||||
// The docs at
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-overlapped
|
||||
// says that o.HEvent can be zero, though, which seems to work.
|
||||
// Note that the returned windows.Handle returns the same value for both
|
||||
// notifyAddrChange and notifyRouteChange, which is why our caller needs
|
||||
// to look at the returned Overlapped's Internal field to see which case
|
||||
// fired. That's also worth understanding more.
|
||||
// See crawshaw's comment at https://github.com/tailscale/tailscale/pull/944#discussion_r526469186
|
||||
// too.
|
||||
if false {
|
||||
evt, err := windows.CreateEvent(nil, 0, 0, nil)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
o.HEvent = evt
|
||||
}
|
||||
if r1 == expect {
|
||||
return nil
|
||||
|
||||
r1, _, e1 := syscall.Syscall(p.Addr(), 2, uintptr(unsafe.Pointer(&h)), uintptr(unsafe.Pointer(o)), 0)
|
||||
|
||||
// We expect ERROR_IO_PENDING.
|
||||
if syscall.Errno(r1) != windows.ERROR_IO_PENDING {
|
||||
return 0, nil, nil, e1
|
||||
}
|
||||
return e1
|
||||
|
||||
cancel = func() {
|
||||
cancelIPChangeNotifyProc.Call(uintptr(unsafe.Pointer(o)))
|
||||
}
|
||||
return h, o, cancel, nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user