Compare commits
35 Commits
crawshaw/e
...
josh/tsweb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b89db4dff | ||
|
|
1dc4151f8b | ||
|
|
8d6cf14456 | ||
|
|
b4947be0c8 | ||
|
|
01e8a152f7 | ||
|
|
2448c000b3 | ||
|
|
903988b392 | ||
|
|
8267ea0f80 | ||
|
|
8fe503057d | ||
|
|
5d9ab502f3 | ||
|
|
a19c110dd3 | ||
|
|
2db6cd1025 | ||
|
|
be9d564c29 | ||
|
|
3a94ece30c | ||
|
|
86a902b201 | ||
|
|
adda2d2a51 | ||
|
|
a80cef0c13 | ||
|
|
84046d6f7c | ||
|
|
ec62217f52 | ||
|
|
21358cf2f5 | ||
|
|
37e7a387ff | ||
|
|
15599323a1 | ||
|
|
60abeb027b | ||
|
|
b9c92b90db | ||
|
|
e206a3663f | ||
|
|
0173a50bf0 | ||
|
|
dbea8217ac | ||
|
|
82cd98609f | ||
|
|
39d173e5fc | ||
|
|
c8551c8a67 | ||
|
|
3a74f2d2d7 | ||
|
|
24c9dbd129 | ||
|
|
62db629227 | ||
|
|
3c481d6b18 | ||
|
|
b3d268c5a1 |
@@ -12,14 +12,11 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
dnsMu sync.Mutex
|
||||
dnsCache = map[string][]net.IP{}
|
||||
)
|
||||
var dnsCache atomic.Value // of []byte
|
||||
|
||||
var bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
|
||||
|
||||
@@ -37,6 +34,7 @@ func refreshBootstrapDNS() {
|
||||
if *bootstrapDNS == "" {
|
||||
return
|
||||
}
|
||||
dnsEntries := make(map[string][]net.IP)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
names := strings.Split(*bootstrapDNS, ",")
|
||||
@@ -47,23 +45,23 @@ func refreshBootstrapDNS() {
|
||||
log.Printf("bootstrap DNS lookup %q: %v", name, err)
|
||||
continue
|
||||
}
|
||||
dnsMu.Lock()
|
||||
dnsCache[name] = addrs
|
||||
dnsMu.Unlock()
|
||||
dnsEntries[name] = addrs
|
||||
}
|
||||
j, err := json.MarshalIndent(dnsEntries, "", "\t")
|
||||
if err != nil {
|
||||
// leave the old values in place
|
||||
return
|
||||
}
|
||||
dnsCache.Store(j)
|
||||
}
|
||||
|
||||
func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
|
||||
bootstrapDNSRequests.Add(1)
|
||||
dnsMu.Lock()
|
||||
j, err := json.MarshalIndent(dnsCache, "", "\t")
|
||||
dnsMu.Unlock()
|
||||
if err != nil {
|
||||
log.Printf("bootstrap DNS JSON: %v", err)
|
||||
http.Error(w, "JSON marshal error", 500)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
j, _ := dnsCache.Load().([]byte)
|
||||
// Bootstrap DNS requests occur cross-regions,
|
||||
// and are randomized per request,
|
||||
// so keeping a connection open is pointlessly expensive.
|
||||
w.Header().Set("Connection", "close")
|
||||
w.Write(j)
|
||||
}
|
||||
|
||||
35
cmd/derper/bootstrap_dns_test.go
Normal file
35
cmd/derper/bootstrap_dns_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkHandleBootstrapDNS(b *testing.B) {
|
||||
prev := *bootstrapDNS
|
||||
*bootstrapDNS = "log.tailscale.io,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com"
|
||||
defer func() {
|
||||
*bootstrapDNS = prev
|
||||
}()
|
||||
refreshBootstrapDNS()
|
||||
w := new(bitbucketResponseWriter)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(b *testing.PB) {
|
||||
for b.Next() {
|
||||
handleBootstrapDNS(w, nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type bitbucketResponseWriter struct{}
|
||||
|
||||
func (b *bitbucketResponseWriter) Header() http.Header { return make(http.Header) }
|
||||
|
||||
func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
|
||||
|
||||
func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
"tailscale.com/atomicfile"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
@@ -49,6 +51,9 @@ var (
|
||||
meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list")
|
||||
bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns")
|
||||
verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.")
|
||||
|
||||
acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection")
|
||||
acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection")
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -296,7 +301,7 @@ func main() {
|
||||
}
|
||||
}()
|
||||
}
|
||||
err = httpsrv.ListenAndServeTLS("", "")
|
||||
err = rateLimitedListenAndServeTLS(httpsrv)
|
||||
} else {
|
||||
log.Printf("derper: serving on %s", *addr)
|
||||
err = httpsrv.ListenAndServe()
|
||||
@@ -390,3 +395,63 @@ func defaultMeshPSKFile() string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func rateLimitedListenAndServeTLS(srv *http.Server) error {
|
||||
addr := srv.Addr
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
}
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rln := newRateLimitedListener(ln, rate.Limit(*acceptConnLimit), *acceptConnBurst)
|
||||
expvar.Publish("tls_listener", rln.ExpVar())
|
||||
defer rln.Close()
|
||||
return srv.ServeTLS(rln, "", "")
|
||||
}
|
||||
|
||||
type rateLimitedListener struct {
|
||||
// These are at the start of the struct to ensure 64-bit alignment
|
||||
// on 32-bit architecture regardless of what other fields may exist
|
||||
// in this package.
|
||||
numAccepts expvar.Int // does not include number of rejects
|
||||
numRejects expvar.Int
|
||||
|
||||
net.Listener
|
||||
|
||||
lim *rate.Limiter
|
||||
}
|
||||
|
||||
func newRateLimitedListener(ln net.Listener, limit rate.Limit, burst int) *rateLimitedListener {
|
||||
return &rateLimitedListener{Listener: ln, lim: rate.NewLimiter(limit, burst)}
|
||||
}
|
||||
|
||||
func (l *rateLimitedListener) ExpVar() expvar.Var {
|
||||
m := new(metrics.Set)
|
||||
m.Set("counter_accepted_connections", &l.numAccepts)
|
||||
m.Set("counter_rejected_connections", &l.numRejects)
|
||||
return m
|
||||
}
|
||||
|
||||
var errLimitedConn = errors.New("cannot accept connection; rate limited")
|
||||
|
||||
func (l *rateLimitedListener) Accept() (net.Conn, error) {
|
||||
// Even under a rate limited situation, we accept the connection immediately
|
||||
// and close it, rather than being slow at accepting new connections.
|
||||
// This provides two benefits: 1) it signals to the client that something
|
||||
// is going on on the server, and 2) it prevents new connections from
|
||||
// piling up and occupying resources in the OS kernel.
|
||||
// The client will retry as needing (with backoffs in place).
|
||||
cn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !l.lim.Allow() {
|
||||
l.numRejects.Add(1)
|
||||
cn.Close()
|
||||
return nil, errLimitedConn
|
||||
}
|
||||
l.numAccepts.Add(1)
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"tailscale.com/paths"
|
||||
"tailscale.com/safesocket"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/version/distro"
|
||||
)
|
||||
|
||||
var Stderr io.Writer = os.Stderr
|
||||
@@ -155,6 +156,9 @@ change in the future.
|
||||
if strSliceContains(args, "debug") {
|
||||
rootCmd.Subcommands = append(rootCmd.Subcommands, debugCmd)
|
||||
}
|
||||
if runtime.GOOS == "linux" && distro.Get() == distro.Synology {
|
||||
rootCmd.Subcommands = append(rootCmd.Subcommands, configureHostCmd)
|
||||
}
|
||||
|
||||
if err := rootCmd.Parse(args); err != nil {
|
||||
if errors.Is(err, flag.ErrHelp) {
|
||||
|
||||
85
cmd/tailscale/cli/configure-host.go
Normal file
85
cmd/tailscale/cli/configure-host.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/peterbourgon/ff/v3/ffcli"
|
||||
"tailscale.com/hostinfo"
|
||||
"tailscale.com/version/distro"
|
||||
)
|
||||
|
||||
var configureHostCmd = &ffcli.Command{
|
||||
Name: "configure-host",
|
||||
Exec: runConfigureHost,
|
||||
ShortHelp: "Configure Synology to enable more Tailscale features",
|
||||
LongHelp: strings.TrimSpace(`
|
||||
The 'configure-host' command is intended to run at boot as root
|
||||
to create the /dev/net/tun device and give the tailscaled binary
|
||||
permission to use it.
|
||||
|
||||
See: https://tailscale.com/kb/1152/synology-outbound/
|
||||
`),
|
||||
FlagSet: (func() *flag.FlagSet {
|
||||
fs := newFlagSet("configure-host")
|
||||
return fs
|
||||
})(),
|
||||
}
|
||||
|
||||
var configureHostArgs struct{}
|
||||
|
||||
func runConfigureHost(ctx context.Context, args []string) error {
|
||||
if len(args) > 0 {
|
||||
return errors.New("unknown arguments")
|
||||
}
|
||||
if runtime.GOOS != "linux" || distro.Get() != distro.Synology {
|
||||
return errors.New("only implemented on Synology")
|
||||
}
|
||||
if uid := os.Getuid(); uid != 0 {
|
||||
return fmt.Errorf("must be run as root, not %q (%v)", os.Getenv("USER"), uid)
|
||||
}
|
||||
osVer := hostinfo.GetOSVersion()
|
||||
isDSM6 := strings.HasPrefix(osVer, "Synology 6")
|
||||
isDSM7 := strings.HasPrefix(osVer, "Synology 7")
|
||||
if !isDSM6 && !isDSM7 {
|
||||
return fmt.Errorf("unsupported DSM version %q", osVer)
|
||||
}
|
||||
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll("/dev/net", 0755); err != nil {
|
||||
return fmt.Errorf("creating /dev/net: %v", err)
|
||||
}
|
||||
if out, err := exec.Command("/bin/mknod", "/dev/net/tun", "c", "10", "200").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("creating /dev/net/tun: %v, %s", err, out)
|
||||
}
|
||||
}
|
||||
if err := os.Chmod("/dev/net/tun", 0666); err != nil {
|
||||
return err
|
||||
}
|
||||
if isDSM6 {
|
||||
fmt.Printf("/dev/net/tun exists and has permissions 0666. Skipping setcap on DSM6.\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
const daemonBin = "/var/packages/Tailscale/target/bin/tailscaled"
|
||||
if _, err := os.Stat(daemonBin); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("tailscaled binary not found at %s. Is the Tailscale *.spk package installed?", daemonBin)
|
||||
}
|
||||
return err
|
||||
}
|
||||
if out, err := exec.Command("/bin/setcap", "cap_net_admin+eip", daemonBin).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("setcap: %v, %s", err, out)
|
||||
}
|
||||
fmt.Printf("Done. To restart Tailscale to use the new permissions, run:\n\n sudo synosystemctl restart pkgctl-Tailscale.service\n\n")
|
||||
return nil
|
||||
}
|
||||
@@ -39,7 +39,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
L tailscale.com/derp/wsconn from tailscale.com/derp/derphttp
|
||||
tailscale.com/disco from tailscale.com/derp
|
||||
tailscale.com/envknob from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/hostinfo from tailscale.com/net/interfaces
|
||||
tailscale.com/hostinfo from tailscale.com/net/interfaces+
|
||||
tailscale.com/ipn from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/ipn/ipnstate from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/kube from tailscale.com/ipn
|
||||
|
||||
@@ -203,6 +203,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/net/netknob from tailscale.com/logpolicy+
|
||||
tailscale.com/net/netns from tailscale.com/cmd/tailscaled+
|
||||
💣 tailscale.com/net/netstat from tailscale.com/ipn/ipnserver
|
||||
tailscale.com/net/netutil from tailscale.com/ipn/ipnlocal+
|
||||
tailscale.com/net/packet from tailscale.com/net/tstun+
|
||||
tailscale.com/net/portmapper from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/net/proxymux from tailscale.com/cmd/tailscaled
|
||||
@@ -352,6 +353,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
hash/fnv from gvisor.dev/gvisor/pkg/tcpip/network/ipv6+
|
||||
hash/maphash from go4.org/mem
|
||||
html from net/http/pprof+
|
||||
html/template from tailscale.com/tsweb
|
||||
io from bufio+
|
||||
io/fs from crypto/rand+
|
||||
io/ioutil from github.com/aws/aws-sdk-go-v2/aws/protocol/query+
|
||||
@@ -390,6 +392,8 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
sync/atomic from context+
|
||||
syscall from crypto/rand+
|
||||
text/tabwriter from runtime/pprof
|
||||
text/template from html/template
|
||||
text/template/parse from html/template+
|
||||
time from compress/gzip+
|
||||
unicode from bytes+
|
||||
unicode/utf16 from crypto/x509+
|
||||
|
||||
@@ -74,7 +74,7 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch
|
||||
changes <- svc.Status{State: svc.StartPending}
|
||||
|
||||
svcAccepts := svc.AcceptStop
|
||||
if winutil.GetRegInteger("FlushDNSOnSessionUnlock", 0) != 0 {
|
||||
if winutil.GetPolicyInteger("FlushDNSOnSessionUnlock", 0) != 0 {
|
||||
svcAccepts |= svc.AcceptSessionChange
|
||||
}
|
||||
|
||||
|
||||
@@ -266,9 +266,9 @@ func (c *Auto) authRoutine() {
|
||||
goal := c.loginGoal
|
||||
ctx := c.authCtx
|
||||
if goal != nil {
|
||||
c.logf("authRoutine: %s; wantLoggedIn=%v", c.state, goal.wantLoggedIn)
|
||||
c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, goal.wantLoggedIn)
|
||||
} else {
|
||||
c.logf("authRoutine: %s; goal=nil paused=%v", c.state, c.paused)
|
||||
c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -414,7 +414,7 @@ func (c *Auto) mapRoutine() {
|
||||
}
|
||||
continue
|
||||
}
|
||||
c.logf("mapRoutine: %s", c.state)
|
||||
c.logf("[v1] mapRoutine: %s", c.state)
|
||||
loggedIn := c.loggedIn
|
||||
ctx := c.mapCtx
|
||||
c.mu.Unlock()
|
||||
@@ -445,9 +445,9 @@ func (c *Auto) mapRoutine() {
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logf("mapRoutine: context done.")
|
||||
c.logf("[v1] mapRoutine: context done.")
|
||||
case <-c.newMapCh:
|
||||
c.logf("mapRoutine: new map needed while idle.")
|
||||
c.logf("[v1] mapRoutine: new map needed while idle.")
|
||||
}
|
||||
} else {
|
||||
// Be sure this is false when we're not inside
|
||||
|
||||
@@ -168,6 +168,10 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
tr.DialContext = dnscache.Dialer(dialer.DialContext, dnsCache)
|
||||
tr.DialTLSContext = dnscache.TLSDialer(dialer.DialContext, dnsCache, tr.TLSClientConfig)
|
||||
tr.ForceAttemptHTTP2 = true
|
||||
// Disable implicit gzip compression; the various
|
||||
// handlers (register, map, set-dns, etc) do their own
|
||||
// zstd compression per naclbox.
|
||||
tr.DisableCompression = true
|
||||
httpc = &http.Client{Transport: tr}
|
||||
}
|
||||
|
||||
@@ -210,7 +214,7 @@ func (c *Direct) SetHostinfo(hi *tailcfg.Hostinfo) bool {
|
||||
}
|
||||
c.hostinfo = hi.Clone()
|
||||
j, _ := json.Marshal(c.hostinfo)
|
||||
c.logf("HostInfo: %s", j)
|
||||
c.logf("[v1] HostInfo: %s", j)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -241,10 +245,10 @@ func (c *Direct) GetPersist() persist.Persist {
|
||||
}
|
||||
|
||||
func (c *Direct) TryLogout(ctx context.Context) error {
|
||||
c.logf("direct.TryLogout()")
|
||||
c.logf("[v1] direct.TryLogout()")
|
||||
|
||||
mustRegen, newURL, err := c.doLogin(ctx, loginOpt{Logout: true})
|
||||
c.logf("TryLogout control response: mustRegen=%v, newURL=%v, err=%v", mustRegen, newURL, err)
|
||||
c.logf("[v1] TryLogout control response: mustRegen=%v, newURL=%v, err=%v", mustRegen, newURL, err)
|
||||
|
||||
c.mu.Lock()
|
||||
c.persist = persist.Persist{}
|
||||
@@ -254,7 +258,7 @@ func (c *Direct) TryLogout(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c *Direct) TryLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags LoginFlags) (url string, err error) {
|
||||
c.logf("direct.TryLogin(token=%v, flags=%v)", t != nil, flags)
|
||||
c.logf("[v1] direct.TryLogin(token=%v, flags=%v)", t != nil, flags)
|
||||
return c.doLoginOrRegen(ctx, loginOpt{Token: t, Flags: flags})
|
||||
}
|
||||
|
||||
@@ -262,7 +266,7 @@ func (c *Direct) TryLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Log
|
||||
//
|
||||
// On success, newURL and err will both be nil.
|
||||
func (c *Direct) WaitLoginURL(ctx context.Context, url string) (newURL string, err error) {
|
||||
c.logf("direct.WaitLoginURL")
|
||||
c.logf("[v1] direct.WaitLoginURL")
|
||||
return c.doLoginOrRegen(ctx, loginOpt{URL: url})
|
||||
}
|
||||
|
||||
@@ -465,7 +469,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
|
||||
if resp.AuthURL != "" {
|
||||
c.logf("AuthURL is %v", resp.AuthURL)
|
||||
} else {
|
||||
c.logf("No AuthURL")
|
||||
c.logf("[v1] No AuthURL")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -516,7 +520,7 @@ func (c *Direct) newEndpoints(localPort uint16, endpoints []tailcfg.Endpoint) (c
|
||||
for _, ep := range endpoints {
|
||||
epStrs = append(epStrs, ep.Addr.String())
|
||||
}
|
||||
c.logf("client.newEndpoints(%v, %v)", localPort, epStrs)
|
||||
c.logf("[v2] client.newEndpoints(%v, %v)", localPort, epStrs)
|
||||
c.localPort = localPort
|
||||
c.endpoints = append(c.endpoints[:0], endpoints...)
|
||||
if len(endpoints) > 0 {
|
||||
@@ -821,10 +825,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
|
||||
|
||||
if Debug.StripEndpoints {
|
||||
for _, p := range resp.Peers {
|
||||
// We need at least one endpoint here for now else
|
||||
// other code doesn't even create the discoEndpoint.
|
||||
// TODO(bradfitz): fix that and then just nil this out.
|
||||
p.Endpoints = []string{"127.9.9.9:456"}
|
||||
p.Endpoints = nil
|
||||
}
|
||||
}
|
||||
if Debug.StripCaps {
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/certstore"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -73,23 +74,46 @@ func isSubjectInChain(subject string, chain []*x509.Certificate) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func selectIdentityFromSlice(subject string, ids []certstore.Identity) (certstore.Identity, []*x509.Certificate) {
|
||||
func selectIdentityFromSlice(subject string, ids []certstore.Identity, now time.Time) (certstore.Identity, []*x509.Certificate) {
|
||||
var bestCandidate struct {
|
||||
id certstore.Identity
|
||||
chain []*x509.Certificate
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
chain, err := id.CertificateChain()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(chain) < 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !isSupportedCertificate(chain[0]) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isSubjectInChain(subject, chain) {
|
||||
return id, chain
|
||||
if now.Before(chain[0].NotBefore) || now.After(chain[0].NotAfter) {
|
||||
// Certificate is not valid at this time
|
||||
continue
|
||||
}
|
||||
|
||||
if !isSubjectInChain(subject, chain) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Select the most recently issued certificate. If there is a tie, pick
|
||||
// one arbitrarily.
|
||||
if len(bestCandidate.chain) > 0 && bestCandidate.chain[0].NotBefore.After(chain[0].NotBefore) {
|
||||
continue
|
||||
}
|
||||
|
||||
bestCandidate.id = id
|
||||
bestCandidate.chain = chain
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return bestCandidate.id, bestCandidate.chain
|
||||
}
|
||||
|
||||
// findIdentity locates an identity from the Windows or Darwin certificate
|
||||
@@ -105,7 +129,7 @@ func findIdentity(subject string, st certstore.Store) (certstore.Identity, []*x5
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
selected, chain := selectIdentityFromSlice(subject, ids)
|
||||
selected, chain := selectIdentityFromSlice(subject, ids, time.Now())
|
||||
|
||||
for _, id := range ids {
|
||||
if id != selected {
|
||||
|
||||
238
control/controlclient/sign_supported_test.go
Normal file
238
control/controlclient/sign_supported_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows && cgo
|
||||
// +build windows,cgo
|
||||
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/certstore"
|
||||
)
|
||||
|
||||
const (
|
||||
testRootCommonName = "testroot"
|
||||
testRootSubject = "CN=testroot"
|
||||
)
|
||||
|
||||
type testIdentity struct {
|
||||
chain []*x509.Certificate
|
||||
}
|
||||
|
||||
func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
|
||||
return []*x509.Certificate{
|
||||
{
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
PublicKeyAlgorithm: x509.RSA,
|
||||
},
|
||||
{
|
||||
Subject: pkix.Name{
|
||||
CommonName: rootCommonName,
|
||||
},
|
||||
PublicKeyAlgorithm: x509.RSA,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *testIdentity) Certificate() (*x509.Certificate, error) {
|
||||
return t.chain[0], nil
|
||||
}
|
||||
|
||||
func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
|
||||
return t.chain, nil
|
||||
}
|
||||
|
||||
func (t *testIdentity) Signer() (crypto.Signer, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (t *testIdentity) Delete() error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (t *testIdentity) Close() {}
|
||||
|
||||
func TestSelectIdentityFromSlice(t *testing.T) {
|
||||
var times []time.Time
|
||||
for _, ts := range []string{
|
||||
"2000-01-01T00:00:00Z",
|
||||
"2001-01-01T00:00:00Z",
|
||||
"2002-01-01T00:00:00Z",
|
||||
"2003-01-01T00:00:00Z",
|
||||
} {
|
||||
tm, err := time.Parse(time.RFC3339, ts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
times = append(times, tm)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
subject string
|
||||
ids []certstore.Identity
|
||||
now time.Time
|
||||
// wantIndex is an index into ids, or -1 for nil.
|
||||
wantIndex int
|
||||
}{
|
||||
{
|
||||
name: "single unexpired identity",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[2]),
|
||||
},
|
||||
},
|
||||
now: times[1],
|
||||
wantIndex: 0,
|
||||
},
|
||||
{
|
||||
name: "single expired identity",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: -1,
|
||||
},
|
||||
{
|
||||
name: "unrelated ids",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain("something", times[0], times[2]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[2]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain("else", times[0], times[2]),
|
||||
},
|
||||
},
|
||||
now: times[1],
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "expired with unrelated ids",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain("something", times[0], times[3]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain("else", times[0], times[3]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: -1,
|
||||
},
|
||||
{
|
||||
name: "one expired",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "two certs both unexpired",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[3]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "two unexpired one expired",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[3]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: 1,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
|
||||
|
||||
if gotId == nil && gotChain != nil {
|
||||
t.Error("id is nil: got non-nil chain, want nil chain")
|
||||
return
|
||||
}
|
||||
if gotId != nil && gotChain == nil {
|
||||
t.Error("id is not nil: got nil chain, want non-nil chain")
|
||||
return
|
||||
}
|
||||
if tt.wantIndex == -1 {
|
||||
if gotId != nil {
|
||||
t.Error("got non-nil id, want nil id")
|
||||
}
|
||||
return
|
||||
}
|
||||
if gotId == nil {
|
||||
t.Error("got nil id, want non-nil id")
|
||||
return
|
||||
}
|
||||
if gotId != tt.ids[tt.wantIndex] {
|
||||
found := -1
|
||||
for i := range tt.ids {
|
||||
if tt.ids[i] == gotId {
|
||||
found = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if found == -1 {
|
||||
t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
|
||||
} else {
|
||||
t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
|
||||
}
|
||||
}
|
||||
|
||||
tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
|
||||
if !ok {
|
||||
t.Error("got non-testIdentity, want testIdentity")
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tid.chain, gotChain) {
|
||||
t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -20,15 +20,51 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/types/opt"
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
set = map[string]string{}
|
||||
list []string
|
||||
)
|
||||
|
||||
func noteEnv(k, v string) {
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if _, ok := set[v]; !ok {
|
||||
list = append(list, k)
|
||||
}
|
||||
set[k] = v
|
||||
}
|
||||
|
||||
// logf is logger.Logf, but logger depends on envknob, so for circular
|
||||
// dependency reasons, make a type alias (so it's still assignable,
|
||||
// but has nice docs here).
|
||||
type logf = func(format string, args ...interface{})
|
||||
|
||||
// LogCurrent logs the currently set environment knobs.
|
||||
func LogCurrent(logf logf) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for _, k := range list {
|
||||
logf("envknob: %s=%q", k, set[k])
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the named environment variable, using os.Getenv.
|
||||
//
|
||||
// In the future it will also track usage for reporting on debug pages.
|
||||
// If the variable is non-empty, it's also tracked & logged as being
|
||||
// an in-use knob.
|
||||
func String(envVar string) string {
|
||||
return os.Getenv(envVar)
|
||||
v := os.Getenv(envVar)
|
||||
noteEnv(envVar, v)
|
||||
return v
|
||||
}
|
||||
|
||||
// Bool returns the boolean value of the named environment variable.
|
||||
@@ -51,9 +87,10 @@ func boolOr(envVar string, implicitValue bool) bool {
|
||||
}
|
||||
b, err := strconv.ParseBool(val)
|
||||
if err == nil {
|
||||
noteEnv(envVar, strconv.FormatBool(b)) // canonicalize
|
||||
return b
|
||||
}
|
||||
log.Fatalf("invalid environment variable %s value %q: %v", envVar, val, err)
|
||||
log.Fatalf("invalid boolean environment variable %s value %q", envVar, val)
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
@@ -69,7 +106,7 @@ func LookupBool(envVar string) (v bool, ok bool) {
|
||||
if err == nil {
|
||||
return b, true
|
||||
}
|
||||
log.Fatalf("invalid environment variable %s value %q: %v", envVar, val, err)
|
||||
log.Fatalf("invalid boolean environment variable %s value %q", envVar, val)
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
@@ -95,9 +132,10 @@ func LookupInt(envVar string) (v int, ok bool) {
|
||||
}
|
||||
v, err := strconv.Atoi(val)
|
||||
if err == nil {
|
||||
noteEnv(envVar, val)
|
||||
return v, true
|
||||
}
|
||||
log.Fatalf("invalid environment variable %s value %q: %v", envVar, val, err)
|
||||
log.Fatalf("invalid integer environment variable %s: %v", envVar, val)
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
|
||||
@@ -170,6 +170,10 @@ func NewLocalBackend(logf logger.Logf, logid string, store ipn.StateStore, diale
|
||||
if e == nil {
|
||||
panic("ipn.NewLocalBackend: engine must not be nil")
|
||||
}
|
||||
|
||||
hi := hostinfo.New()
|
||||
logf("Host: %s/%s, %s", hi.OS, hi.GoArch, hi.OSVersion)
|
||||
envknob.LogCurrent(logf)
|
||||
if dialer == nil {
|
||||
dialer = new(tsdial.Dialer)
|
||||
}
|
||||
@@ -607,7 +611,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) {
|
||||
if strings.TrimSpace(diff) == "" {
|
||||
b.logf("[v1] netmap diff: (none)")
|
||||
} else {
|
||||
b.logf("netmap diff:\n%v", diff)
|
||||
b.logf("[v1] netmap diff:\n%v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -906,7 +910,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
|
||||
timer := time.NewTimer(time.Second)
|
||||
select {
|
||||
case <-b.gotPortPollRes:
|
||||
b.logf("got initial portlist info in %v", time.Since(t0).Round(time.Millisecond))
|
||||
b.logf("[v1] got initial portlist info in %v", time.Since(t0).Round(time.Millisecond))
|
||||
timer.Stop()
|
||||
case <-timer.C:
|
||||
b.logf("timeout waiting for initial portlist")
|
||||
@@ -1054,17 +1058,17 @@ func (b *LocalBackend) updateFilter(netMap *netmap.NetworkMap, prefs *ipn.Prefs)
|
||||
}
|
||||
|
||||
if !haveNetmap {
|
||||
b.logf("netmap packet filter: (not ready yet)")
|
||||
b.logf("[v1] netmap packet filter: (not ready yet)")
|
||||
b.setFilter(filter.NewAllowNone(b.logf, logNets))
|
||||
return
|
||||
}
|
||||
|
||||
oldFilter := b.e.GetFilter()
|
||||
if shieldsUp {
|
||||
b.logf("netmap packet filter: (shields up)")
|
||||
b.logf("[v1] netmap packet filter: (shields up)")
|
||||
b.setFilter(filter.NewShieldsUpFilter(localNets, logNets, oldFilter, b.logf))
|
||||
} else {
|
||||
b.logf("netmap packet filter: %v filters", len(packetFilter))
|
||||
b.logf("[v1] netmap packet filter: %v filters", len(packetFilter))
|
||||
b.setFilter(filter.New(packetFilter, localNets, logNets, oldFilter, b.logf))
|
||||
}
|
||||
}
|
||||
@@ -1503,19 +1507,19 @@ func (b *LocalBackend) loadStateLocked(key ipn.StateKey, prefs *ipn.Prefs) (err
|
||||
}
|
||||
}
|
||||
|
||||
b.logf("using backend prefs")
|
||||
bs, err := b.store.ReadState(key)
|
||||
switch {
|
||||
case errors.Is(err, ipn.ErrStateNotExist):
|
||||
b.prefs = ipn.NewPrefs()
|
||||
b.prefs.WantRunning = false
|
||||
b.logf("created empty state for %q: %s", key, b.prefs.Pretty())
|
||||
b.logf("using backend prefs; created empty state for %q: %s", key, b.prefs.Pretty())
|
||||
return nil
|
||||
case err != nil:
|
||||
return fmt.Errorf("store.ReadState(%q): %v", key, err)
|
||||
return fmt.Errorf("backend prefs: store.ReadState(%q): %v", key, err)
|
||||
}
|
||||
b.prefs, err = ipn.PrefsFromBytes(bs, false)
|
||||
if err != nil {
|
||||
b.logf("using backend prefs for %q", key)
|
||||
return fmt.Errorf("PrefsFromBytes: %v", err)
|
||||
}
|
||||
|
||||
@@ -1538,7 +1542,7 @@ func (b *LocalBackend) loadStateLocked(key ipn.StateKey, prefs *ipn.Prefs) (err
|
||||
}
|
||||
}
|
||||
|
||||
b.logf("backend prefs for %q: %s", key, b.prefs.Pretty())
|
||||
b.logf("using backend prefs for %q: %s", key, b.prefs.Pretty())
|
||||
|
||||
b.sshAtomicBool.Set(b.prefs != nil && b.prefs.RunSSH)
|
||||
|
||||
@@ -1916,15 +1920,15 @@ func (b *LocalBackend) authReconfig() {
|
||||
b.mu.Unlock()
|
||||
|
||||
if blocked {
|
||||
b.logf("authReconfig: blocked, skipping.")
|
||||
b.logf("[v1] authReconfig: blocked, skipping.")
|
||||
return
|
||||
}
|
||||
if nm == nil {
|
||||
b.logf("authReconfig: netmap not yet valid. Skipping.")
|
||||
b.logf("[v1] authReconfig: netmap not yet valid. Skipping.")
|
||||
return
|
||||
}
|
||||
if !prefs.WantRunning {
|
||||
b.logf("authReconfig: skipping because !WantRunning.")
|
||||
b.logf("[v1] authReconfig: skipping because !WantRunning.")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ import (
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/dns/resolver"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/netutil"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/clientmetric"
|
||||
@@ -506,27 +507,9 @@ func (pln *peerAPIListener) ServeConn(src netaddr.IPPort, c net.Conn) {
|
||||
if addH2C != nil {
|
||||
addH2C(httpServer)
|
||||
}
|
||||
go httpServer.Serve(&oneConnListener{Listener: pln.ln, conn: c})
|
||||
go httpServer.Serve(netutil.NewOneConnListenerFrom(c, pln.ln))
|
||||
}
|
||||
|
||||
type oneConnListener struct {
|
||||
net.Listener
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Accept() (c net.Conn, err error) {
|
||||
c = l.conn
|
||||
if c == nil {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
l.conn = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Close() error { return nil }
|
||||
|
||||
// peerAPIHandler serves the Peer API for a source specific client.
|
||||
type peerAPIHandler struct {
|
||||
ps *peerAPIServer
|
||||
|
||||
@@ -39,6 +39,7 @@ import (
|
||||
"tailscale.com/ipn/store/aws"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/netstat"
|
||||
"tailscale.com/net/netutil"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/paths"
|
||||
"tailscale.com/safesocket"
|
||||
@@ -308,7 +309,7 @@ func (s *Server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) {
|
||||
ErrorLog: logger.StdLogger(logf),
|
||||
Handler: s.localhostHandler(ci),
|
||||
}
|
||||
httpServer.Serve(&oneConnListener{&protoSwitchConn{s: s, br: br, Conn: c}})
|
||||
httpServer.Serve(netutil.NewOneConnListener(&protoSwitchConn{s: s, br: br, Conn: c}))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1061,29 +1062,6 @@ func getEngineUntilItWorksWrapper(getEngine func() (wgengine.Engine, error)) fun
|
||||
}
|
||||
}
|
||||
|
||||
type dummyAddr string
|
||||
type oneConnListener struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Accept() (c net.Conn, err error) {
|
||||
c = l.conn
|
||||
if c == nil {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
l.conn = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Close() error { return nil }
|
||||
|
||||
func (l *oneConnListener) Addr() net.Addr { return dummyAddr("unused-address") }
|
||||
|
||||
func (a dummyAddr) Network() string { return string(a) }
|
||||
func (a dummyAddr) String() string { return string(a) }
|
||||
|
||||
// protoSwitchConn is a net.Conn that's we want to speak HTTP to but
|
||||
// it's already had a few bytes read from it to determine that it's
|
||||
// HTTP. So we Read from its bufio.Reader. On Close, we we tell the
|
||||
|
||||
@@ -571,6 +571,16 @@ func New(collection string) *Policy {
|
||||
}
|
||||
}
|
||||
|
||||
// dialLog is used by NewLogtailTransport to log the happy path of its
|
||||
// own dialing.
|
||||
//
|
||||
// By default it goes nowhere and is only enabled when
|
||||
// tailscaled's in verbose mode.
|
||||
//
|
||||
// log.Printf isn't used so its own logs don't loop back into logtail
|
||||
// in the happy path, thus generating more logs.
|
||||
var dialLog = log.New(io.Discard, "logtail: ", log.LstdFlags|log.Lmsgprefix)
|
||||
|
||||
// SetVerbosityLevel controls the verbosity level that should be
|
||||
// written to stderr. 0 is the default (not verbose). Levels 1 or higher
|
||||
// are increasingly verbose.
|
||||
@@ -578,6 +588,9 @@ func New(collection string) *Policy {
|
||||
// It should not be changed concurrently with log writes.
|
||||
func (p *Policy) SetVerbosityLevel(level int) {
|
||||
p.Logtail.SetVerbosityLevel(level)
|
||||
if level > 0 {
|
||||
dialLog.SetOutput(os.Stderr)
|
||||
}
|
||||
}
|
||||
|
||||
// Close immediately shuts down the logger.
|
||||
@@ -624,7 +637,7 @@ func NewLogtailTransport(host string) *http.Transport {
|
||||
c, err := nd.DialContext(ctx, netw, addr)
|
||||
d := time.Since(t0).Round(time.Millisecond)
|
||||
if err == nil {
|
||||
log.Printf("logtail: dialed %q in %v", addr, d)
|
||||
dialLog.Printf("dialed %q in %v", addr, d)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -637,10 +650,10 @@ func NewLogtailTransport(host string) *http.Transport {
|
||||
err = errors.New(res.Status)
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("logtail: CONNECT response from tailscaled: %v", err)
|
||||
log.Printf("logtail: CONNECT response error from tailscaled: %v", err)
|
||||
c.Close()
|
||||
} else {
|
||||
log.Printf("logtail: connected via tailscaled")
|
||||
dialLog.Printf("connected via tailscaled")
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ func (b *Backoff) BackOff(ctx context.Context, err error) {
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
if d >= b.LogLongerThan {
|
||||
b.logf("%s: backoff: %d msec", b.name, d.Milliseconds())
|
||||
b.logf("%s: [v1] backoff: %d msec", b.name, d.Milliseconds())
|
||||
}
|
||||
t := b.NewTimer(d)
|
||||
select {
|
||||
|
||||
@@ -113,6 +113,16 @@ func ParsePublicID(s string) (PublicID, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// MustParsePublicID calls ParsePublicID and panics in case of an error.
|
||||
// It is intended for use with constant strings, typically in tests.
|
||||
func MustParsePublicID(s string) PublicID {
|
||||
id, err := ParsePublicID(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (id PublicID) MarshalText() ([]byte, error) {
|
||||
b := make([]byte, hex.EncodedLen(len(id)))
|
||||
if i := hex.Encode(b, id[:]); i != len(b) {
|
||||
|
||||
@@ -255,7 +255,7 @@ func (l *Logger) drainPending(scratch []byte) (res []byte) {
|
||||
l.explainedRaw = true
|
||||
}
|
||||
fmt.Fprintf(l.stderr, "RAW-STDERR: %s", b)
|
||||
b = l.encodeText(b, true)
|
||||
b = l.encodeText(b, true, 0)
|
||||
}
|
||||
|
||||
if entries > 0 {
|
||||
@@ -418,7 +418,7 @@ func (l *Logger) send(jsonBlob []byte) (int, error) {
|
||||
|
||||
// TODO: instead of allocating, this should probably just append
|
||||
// directly into the output log buffer.
|
||||
func (l *Logger) encodeText(buf []byte, skipClientTime bool) []byte {
|
||||
func (l *Logger) encodeText(buf []byte, skipClientTime bool, level int) []byte {
|
||||
now := l.timeNow()
|
||||
|
||||
// Factor in JSON encoding overhead to try to only do one alloc
|
||||
@@ -463,6 +463,14 @@ func (l *Logger) encodeText(buf []byte, skipClientTime bool) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
// Add the log level, if non-zero. Note that we only use log
|
||||
// levels 1 and 2 currently. It's unlikely we'll ever make it
|
||||
// past 9.
|
||||
if level > 0 && level < 10 {
|
||||
b = append(b, `"v":`...)
|
||||
b = append(b, '0'+byte(level))
|
||||
b = append(b, ',')
|
||||
}
|
||||
b = append(b, "\"text\": \""...)
|
||||
for _, c := range buf {
|
||||
switch c {
|
||||
@@ -493,9 +501,9 @@ func (l *Logger) encodeText(buf []byte, skipClientTime bool) []byte {
|
||||
return b
|
||||
}
|
||||
|
||||
func (l *Logger) encode(buf []byte) []byte {
|
||||
func (l *Logger) encode(buf []byte, level int) []byte {
|
||||
if buf[0] != '{' {
|
||||
return l.encodeText(buf, l.skipClientTime) // text fast-path
|
||||
return l.encodeText(buf, l.skipClientTime, level) // text fast-path
|
||||
}
|
||||
|
||||
now := l.timeNow()
|
||||
@@ -560,7 +568,7 @@ func (l *Logger) Write(buf []byte) (int, error) {
|
||||
l.stderr.Write(withNL)
|
||||
}
|
||||
}
|
||||
b := l.encode(buf)
|
||||
b := l.encode(buf, level)
|
||||
_, err := l.send(b)
|
||||
return len(buf), err
|
||||
}
|
||||
|
||||
@@ -217,7 +217,7 @@ func TestLoggerEncodeTextAllocs(t *testing.T) {
|
||||
lg := &Logger{timeNow: time.Now}
|
||||
inBuf := []byte("some text to encode")
|
||||
err := tstest.MinAllocsPerRun(t, 1, func() {
|
||||
sink = lg.encodeText(inBuf, false)
|
||||
sink = lg.encodeText(inBuf, false, 0)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -328,10 +328,60 @@ func unmarshalOne(t *testing.T, body []byte) map[string]interface{} {
|
||||
func TestEncodeTextTruncation(t *testing.T) {
|
||||
lg := &Logger{timeNow: time.Now, lowMem: true}
|
||||
in := bytes.Repeat([]byte("a"), 300)
|
||||
b := lg.encodeText(in, true)
|
||||
b := lg.encodeText(in, true, 0)
|
||||
got := string(b)
|
||||
want := `{"text": "` + strings.Repeat("a", 255) + `…+45"}` + "\n"
|
||||
if got != want {
|
||||
t.Errorf("got:\n%qwant:\n%q\n", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
type simpleMemBuf struct {
|
||||
Buffer
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *simpleMemBuf) Write(p []byte) (n int, err error) { return b.buf.Write(p) }
|
||||
|
||||
func TestEncode(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"normal",
|
||||
`{"logtail": {"client_time": "1970-01-01T00:02:03.000000456Z"}, "text": "normal"}` + "\n",
|
||||
},
|
||||
{
|
||||
"and a [v1] level one",
|
||||
`{"logtail": {"client_time": "1970-01-01T00:02:03.000000456Z"}, "v":1,"text": "and a level one"}` + "\n",
|
||||
},
|
||||
{
|
||||
"[v2] some verbose two",
|
||||
`{"logtail": {"client_time": "1970-01-01T00:02:03.000000456Z"}, "v":2,"text": "some verbose two"}` + "\n",
|
||||
},
|
||||
{
|
||||
"{}",
|
||||
`{"logtail":{"client_time":"1970-01-01T00:02:03.000000456Z"}}` + "\n",
|
||||
},
|
||||
{
|
||||
`{"foo":"bar"}`,
|
||||
`{"foo":"bar","logtail":{"client_time":"1970-01-01T00:02:03.000000456Z"}}` + "\n",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
buf := new(simpleMemBuf)
|
||||
lg := &Logger{
|
||||
timeNow: func() time.Time { return time.Unix(123, 456).UTC() },
|
||||
buffer: buf,
|
||||
}
|
||||
io.WriteString(lg, tt.in)
|
||||
got := buf.buf.String()
|
||||
if got != tt.want {
|
||||
t.Errorf("for %q,\n got: %#q\nwant: %#q\n", tt.in, got, tt.want)
|
||||
}
|
||||
if err := json.Compact(new(bytes.Buffer), buf.buf.Bytes()); err != nil {
|
||||
t.Errorf("invalid output JSON for %q: %s", tt.in, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,18 @@ func dnsMode(logf logger.Logf, env newOSConfigEnv) (ret string, err error) {
|
||||
logf("dns: %v", debug)
|
||||
}()
|
||||
|
||||
// Before we read /etc/resolv.conf (which might be in a broken
|
||||
// or symlink-dangling state), try to ping the D-Bus service
|
||||
// for systemd-resolved. If it's active on the machine, this
|
||||
// will make it start up and write the /etc/resolv.conf file
|
||||
// before it replies to the ping. (see how systemd's
|
||||
// src/resolve/resolved.c calls manager_write_resolv_conf
|
||||
// before the sd_event_loop starts)
|
||||
resolvedUp := env.dbusPing("org.freedesktop.resolve1", "/org/freedesktop/resolve1") == nil
|
||||
if resolvedUp {
|
||||
dbg("resolved-ping", "yes")
|
||||
}
|
||||
|
||||
bs, err := env.fs.ReadFile(resolvConf)
|
||||
if os.IsNotExist(err) {
|
||||
dbg("rc", "missing")
|
||||
@@ -99,7 +111,7 @@ func dnsMode(logf logger.Logf, env newOSConfigEnv) (ret string, err error) {
|
||||
dbg("resolved", "not-in-use")
|
||||
return "direct", nil
|
||||
}
|
||||
if err := env.dbusPing("org.freedesktop.resolve1", "/org/freedesktop/resolve1"); err != nil {
|
||||
if !resolvedUp {
|
||||
dbg("resolved", "no")
|
||||
return "direct", nil
|
||||
}
|
||||
@@ -309,15 +321,15 @@ func resolvedIsActuallyResolver(bs []byte) error {
|
||||
}
|
||||
|
||||
func dbusPing(name, objectPath string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
// DBus probably not running.
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
obj := conn.Object(name, dbus.ObjectPath(objectPath))
|
||||
call := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0)
|
||||
return call.Err
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
env: env(
|
||||
resolvDotConf("# Managed by systemd-resolved", "nameserver 127.0.0.53"),
|
||||
resolvedRunning()),
|
||||
wantLog: "dns: [rc=resolved nm=no ret=systemd-resolved]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=resolved nm=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
{
|
||||
@@ -88,7 +88,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
resolvDotConf("# Managed by systemd-resolved", "nameserver 127.0.0.53"),
|
||||
resolvedRunning(),
|
||||
nmRunning("1.2.3", false)),
|
||||
wantLog: "dns: [rc=resolved nm=yes nm-resolved=no ret=systemd-resolved]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=resolved nm=yes nm-resolved=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
{
|
||||
@@ -97,7 +97,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
resolvDotConf("# Managed by systemd-resolved", "nameserver 127.0.0.53"),
|
||||
resolvedRunning(),
|
||||
nmRunning("1.26.2", true)),
|
||||
wantLog: "dns: [rc=resolved nm=yes nm-resolved=yes nm-safe=yes ret=network-manager]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=resolved nm=yes nm-resolved=yes nm-safe=yes ret=network-manager]",
|
||||
want: "network-manager",
|
||||
},
|
||||
{
|
||||
@@ -106,7 +106,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
resolvDotConf("# Managed by systemd-resolved", "nameserver 127.0.0.53"),
|
||||
resolvedRunning(),
|
||||
nmRunning("1.27.0", true)),
|
||||
wantLog: "dns: [rc=resolved nm=yes nm-resolved=yes nm-safe=no ret=systemd-resolved]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=resolved nm=yes nm-resolved=yes nm-safe=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
{
|
||||
@@ -115,7 +115,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
resolvDotConf("# Managed by systemd-resolved", "nameserver 127.0.0.53"),
|
||||
resolvedRunning(),
|
||||
nmRunning("1.22.0", true)),
|
||||
wantLog: "dns: [rc=resolved nm=yes nm-resolved=yes nm-safe=no ret=systemd-resolved]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=resolved nm=yes nm-resolved=yes nm-safe=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
// Regression tests for extreme corner cases below.
|
||||
@@ -141,7 +141,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
"nameserver 127.0.0.53",
|
||||
"nameserver 127.0.0.53"),
|
||||
resolvedRunning()),
|
||||
wantLog: "dns: [rc=resolved nm=no ret=systemd-resolved]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=resolved nm=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
{
|
||||
@@ -156,7 +156,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
"# run \"systemd-resolve --status\" to see details about the actual nameservers.",
|
||||
"nameserver 127.0.0.53"),
|
||||
resolvedRunning()),
|
||||
wantLog: "dns: [rc=resolved nm=no ret=systemd-resolved]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=resolved nm=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
{
|
||||
@@ -183,7 +183,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
"options edns0 trust-ad"),
|
||||
resolvedRunning(),
|
||||
nmRunning("1.32.12", true)),
|
||||
wantLog: "dns: [rc=nm nm-resolved=yes nm-safe=no ret=systemd-resolved]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=nm nm-resolved=yes nm-safe=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
{
|
||||
@@ -206,7 +206,7 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
"options edns0 trust-ad"),
|
||||
resolvedRunning(),
|
||||
nmRunning("1.26.3", true)),
|
||||
wantLog: "dns: [rc=nm nm-resolved=yes nm-safe=yes ret=network-manager]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=nm nm-resolved=yes nm-safe=yes ret=network-manager]",
|
||||
want: "network-manager",
|
||||
},
|
||||
{
|
||||
@@ -217,7 +217,27 @@ func TestLinuxDNSMode(t *testing.T) {
|
||||
"nameserver 127.0.0.53",
|
||||
"options edns0 trust-ad"),
|
||||
resolvedRunning()),
|
||||
wantLog: "dns: [rc=nm nm-resolved=yes nm=no ret=systemd-resolved]",
|
||||
wantLog: "dns: [resolved-ping=yes rc=nm nm-resolved=yes nm=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
{
|
||||
// regression test for https://github.com/tailscale/tailscale/issues/3531
|
||||
name: "networkmanager_but_systemd-resolved_with_search_domain",
|
||||
env: env(resolvDotConf(
|
||||
"# Generated by NetworkManager",
|
||||
"search lan",
|
||||
"nameserver 127.0.0.53"),
|
||||
resolvedRunning()),
|
||||
wantLog: "dns: [resolved-ping=yes rc=nm nm-resolved=yes nm=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
{
|
||||
// Make sure that we ping systemd-resolved to let it start up and write its resolv.conf
|
||||
// before we read its file.
|
||||
env: env(resolvedStartOnPingAndThen(
|
||||
resolvDotConf("# Managed by systemd-resolved", "nameserver 127.0.0.53"),
|
||||
)),
|
||||
wantLog: "dns: [resolved-ping=yes rc=resolved nm=no ret=systemd-resolved]",
|
||||
want: "systemd-resolved",
|
||||
},
|
||||
}
|
||||
@@ -281,9 +301,14 @@ func (m memFS) WriteFile(name string, contents []byte, perm os.FileMode) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbusService struct {
|
||||
name, path string
|
||||
hook func() // if non-nil, run on ping
|
||||
}
|
||||
|
||||
type envBuilder struct {
|
||||
fs memFS
|
||||
dbus []struct{ name, path string }
|
||||
dbus []dbusService
|
||||
nmUsingResolved bool
|
||||
nmVersion string
|
||||
resolvconfStyle string
|
||||
@@ -312,6 +337,9 @@ func env(opts ...envOption) newOSConfigEnv {
|
||||
dbusPing: func(name, path string) error {
|
||||
for _, svc := range b.dbus {
|
||||
if svc.name == name && svc.path == path {
|
||||
if svc.hook != nil {
|
||||
svc.hook()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -337,9 +365,25 @@ func resolvDotConf(ss ...string) envOption {
|
||||
})
|
||||
}
|
||||
|
||||
// resolvedRunning returns an option that makes resolved reply to a dbusPing.
|
||||
func resolvedRunning() envOption {
|
||||
return resolvedStartOnPingAndThen( /* nothing */ )
|
||||
}
|
||||
|
||||
// resolvedStartOnPingAndThen returns an option that makes resolved be
|
||||
// active but not yet running. On a dbus ping, it then applies the
|
||||
// provided options.
|
||||
func resolvedStartOnPingAndThen(opts ...envOption) envOption {
|
||||
return envOpt(func(b *envBuilder) {
|
||||
b.dbus = append(b.dbus, struct{ name, path string }{"org.freedesktop.resolve1", "/org/freedesktop/resolve1"})
|
||||
b.dbus = append(b.dbus, dbusService{
|
||||
name: "org.freedesktop.resolve1",
|
||||
path: "/org/freedesktop/resolve1",
|
||||
hook: func() {
|
||||
for _, opt := range opts {
|
||||
opt.apply(b)
|
||||
}
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -347,7 +391,7 @@ func nmRunning(version string, usingResolved bool) envOption {
|
||||
return envOpt(func(b *envBuilder) {
|
||||
b.nmUsingResolved = usingResolved
|
||||
b.nmVersion = version
|
||||
b.dbus = append(b.dbus, struct{ name, path string }{"org.freedesktop.NetworkManager", "/org/freedesktop/NetworkManager/DnsManager"})
|
||||
b.dbus = append(b.dbus, dbusService{name: "org.freedesktop.NetworkManager", path: "/org/freedesktop/NetworkManager/DnsManager"})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -74,38 +74,6 @@ type resolvedLinkDomain struct {
|
||||
RoutingOnly bool
|
||||
}
|
||||
|
||||
// isResolvedActive determines if resolved is currently managing system DNS settings.
|
||||
func isResolvedActive() bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), reconfigTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
// Probably no DBus on the system, or we're not allowed to use
|
||||
// it. Cannot control resolved.
|
||||
return false
|
||||
}
|
||||
|
||||
rd := conn.Object("org.freedesktop.resolve1", dbus.ObjectPath("/org/freedesktop/resolve1"))
|
||||
call := rd.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0)
|
||||
if call.Err != nil {
|
||||
// Can't talk to resolved.
|
||||
return false
|
||||
}
|
||||
|
||||
config, err := newDirectManager(logger.Discard).readResolvFile(resolvConf)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// The sole nameserver must be the systemd-resolved stub.
|
||||
if len(config.Nameservers) == 1 && config.Nameservers[0] == resolvedListenAddr {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// resolvedManager is an OSConfigurator which uses the systemd-resolved DBus API.
|
||||
type resolvedManager struct {
|
||||
logf logger.Logf
|
||||
|
||||
@@ -278,53 +278,160 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
|
||||
|
||||
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
|
||||
func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
|
||||
return func(ctx context.Context, network, address string) (retConn net.Conn, ret error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// Bogus. But just let the real dialer return an error rather than
|
||||
// inventing a similar one.
|
||||
return fwd(ctx, network, address)
|
||||
}
|
||||
defer func() {
|
||||
// On any failure, assume our DNS is wrong and try our fallback, if any.
|
||||
if ret == nil || dnsCache.LookupIPFallback == nil {
|
||||
return
|
||||
}
|
||||
ips, err := dnsCache.LookupIPFallback(ctx, host)
|
||||
if err != nil {
|
||||
// Return with original error
|
||||
return
|
||||
}
|
||||
if c, err := raceDial(ctx, fwd, network, ips, port); err == nil {
|
||||
retConn = c
|
||||
ret = nil
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
ip, ip6, allIPs, err := dnsCache.LookupIP(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
|
||||
}
|
||||
i4s := v4addrs(allIPs)
|
||||
if len(i4s) < 2 {
|
||||
dst := net.JoinHostPort(ip.String(), port)
|
||||
if debug {
|
||||
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
|
||||
}
|
||||
c, err := fwd(ctx, network, dst)
|
||||
if err == nil || ctx.Err() != nil || ip6 == nil {
|
||||
return c, err
|
||||
}
|
||||
// Fall back to trying IPv6.
|
||||
dst = net.JoinHostPort(ip6.String(), port)
|
||||
return fwd(ctx, network, dst)
|
||||
}
|
||||
|
||||
// Multiple IPv4 candidates, and 0+ IPv6.
|
||||
ipsToTry := append(i4s, v6addrs(allIPs)...)
|
||||
return raceDial(ctx, fwd, network, ipsToTry, port)
|
||||
d := &dialer{
|
||||
fwd: fwd,
|
||||
dnsCache: dnsCache,
|
||||
pastConnect: map[netaddr.IP]time.Time{},
|
||||
}
|
||||
return d.DialContext
|
||||
}
|
||||
|
||||
// dialer is the config and accumulated state for a dial func returned by Dialer.
|
||||
type dialer struct {
|
||||
fwd DialContextFunc
|
||||
dnsCache *Resolver
|
||||
|
||||
mu sync.Mutex
|
||||
pastConnect map[netaddr.IP]time.Time
|
||||
}
|
||||
|
||||
func (d *dialer) DialContext(ctx context.Context, network, address string) (retConn net.Conn, ret error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// Bogus. But just let the real dialer return an error rather than
|
||||
// inventing a similar one.
|
||||
return d.fwd(ctx, network, address)
|
||||
}
|
||||
dc := &dialCall{
|
||||
d: d,
|
||||
network: network,
|
||||
address: address,
|
||||
host: host,
|
||||
port: port,
|
||||
}
|
||||
defer func() {
|
||||
// On failure, consider that our DNS might be wrong and ask the DNS fallback mechanism for
|
||||
// some other IPs to try.
|
||||
if ret == nil || d.dnsCache.LookupIPFallback == nil || dc.dnsWasTrustworthy() {
|
||||
return
|
||||
}
|
||||
ips, err := d.dnsCache.LookupIPFallback(ctx, host)
|
||||
if err != nil {
|
||||
// Return with original error
|
||||
return
|
||||
}
|
||||
if c, err := dc.raceDial(ctx, ips); err == nil {
|
||||
retConn = c
|
||||
ret = nil
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
ip, ip6, allIPs, err := d.dnsCache.LookupIP(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
|
||||
}
|
||||
i4s := v4addrs(allIPs)
|
||||
if len(i4s) < 2 {
|
||||
if debug {
|
||||
log.Printf("dnscache: dialing %s, %s for %s", network, ip, address)
|
||||
}
|
||||
ipNA, ok := netaddr.FromStdIP(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP %q", ip)
|
||||
}
|
||||
c, err := dc.dialOne(ctx, ipNA)
|
||||
if err == nil || ctx.Err() != nil {
|
||||
return c, err
|
||||
}
|
||||
// Fall back to trying IPv6, if any.
|
||||
ip6NA, ok := netaddr.FromStdIP(ip6)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
return dc.dialOne(ctx, ip6NA)
|
||||
}
|
||||
|
||||
// Multiple IPv4 candidates, and 0+ IPv6.
|
||||
ipsToTry := append(i4s, v6addrs(allIPs)...)
|
||||
return dc.raceDial(ctx, ipsToTry)
|
||||
}
|
||||
|
||||
// dialCall is the state around a single call to dial.
|
||||
type dialCall struct {
|
||||
d *dialer
|
||||
network, address, host, port string
|
||||
|
||||
mu sync.Mutex // lock ordering: dialer.mu, then dialCall.mu
|
||||
fails map[netaddr.IP]error // set of IPs that failed to dial thus far
|
||||
}
|
||||
|
||||
// dnsWasTrustworthy reports whether we think the IP address(es) we
|
||||
// tried (and failed) to dial were probably the correct IPs. Currently
|
||||
// the heuristic is whether they ever worked previously.
|
||||
func (dc *dialCall) dnsWasTrustworthy() bool {
|
||||
dc.d.mu.Lock()
|
||||
defer dc.d.mu.Unlock()
|
||||
dc.mu.Lock()
|
||||
defer dc.mu.Unlock()
|
||||
|
||||
if len(dc.fails) == 0 {
|
||||
// No information.
|
||||
return false
|
||||
}
|
||||
|
||||
// If any of the IPs we failed to dial worked previously in
|
||||
// this dialer, assume the DNS is fine.
|
||||
for ip := range dc.fails {
|
||||
if _, ok := dc.d.pastConnect[ip]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (dc *dialCall) dialOne(ctx context.Context, ip netaddr.IP) (net.Conn, error) {
|
||||
c, err := dc.d.fwd(ctx, dc.network, net.JoinHostPort(ip.String(), dc.port))
|
||||
dc.noteDialResult(ip, err)
|
||||
return c, err
|
||||
}
|
||||
|
||||
// noteDialResult records that a dial to ip either succeeded or
|
||||
// failed.
|
||||
func (dc *dialCall) noteDialResult(ip netaddr.IP, err error) {
|
||||
if err == nil {
|
||||
d := dc.d
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.pastConnect[ip] = time.Now()
|
||||
return
|
||||
}
|
||||
dc.mu.Lock()
|
||||
defer dc.mu.Unlock()
|
||||
if dc.fails == nil {
|
||||
dc.fails = map[netaddr.IP]error{}
|
||||
}
|
||||
dc.fails[ip] = err
|
||||
}
|
||||
|
||||
// uniqueIPs returns a possibly-mutated subslice of ips, filtering out
|
||||
// dups and ones that have already failed previously.
|
||||
func (dc *dialCall) uniqueIPs(ips []netaddr.IP) (ret []netaddr.IP) {
|
||||
dc.mu.Lock()
|
||||
defer dc.mu.Unlock()
|
||||
seen := map[netaddr.IP]bool{}
|
||||
ret = ips[:0]
|
||||
for _, ip := range ips {
|
||||
if seen[ip] {
|
||||
continue
|
||||
}
|
||||
seen[ip] = true
|
||||
if dc.fails[ip] != nil {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, ip)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// fallbackDelay is how long to wait between trying subsequent
|
||||
@@ -334,7 +441,7 @@ const fallbackDelay = 300 * time.Millisecond
|
||||
|
||||
// raceDial tries to dial port on each ip in ips, starting a new race
|
||||
// dial every fallbackDelay apart, returning whichever completes first.
|
||||
func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []netaddr.IP, port string) (net.Conn, error) {
|
||||
func (dc *dialCall) raceDial(ctx context.Context, ips []netaddr.IP) (net.Conn, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -345,6 +452,14 @@ func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []ne
|
||||
resc := make(chan res) // must be unbuffered
|
||||
failBoost := make(chan struct{}) // best effort send on dial failure
|
||||
|
||||
// Remove IPs that we tried & failed to dial previously
|
||||
// (such as when we're being called after a dnsfallback lookup and get
|
||||
// the same results)
|
||||
ips = dc.uniqueIPs(ips)
|
||||
if len(ips) == 0 {
|
||||
return nil, errors.New("no IPs")
|
||||
}
|
||||
|
||||
go func() {
|
||||
for i, ip := range ips {
|
||||
if i != 0 {
|
||||
@@ -359,7 +474,7 @@ func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []ne
|
||||
}
|
||||
}
|
||||
go func(ip netaddr.IP) {
|
||||
c, err := fwd(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
c, err := dc.dialOne(ctx, ip)
|
||||
if err != nil {
|
||||
// Best effort wake-up a pending dial.
|
||||
// e.g. IPv4 dials failing quickly on an IPv6-only system.
|
||||
|
||||
@@ -6,10 +6,14 @@ package dnscache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
|
||||
@@ -31,3 +35,78 @@ func TestDialer(t *testing.T) {
|
||||
t.Logf("dialed in %v", time.Since(t0))
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func TestDialCall_DNSWasTrustworthy(t *testing.T) {
|
||||
type step struct {
|
||||
ip netaddr.IP // IP we pretended to dial
|
||||
err error // the dial error or nil for success
|
||||
}
|
||||
mustIP := netaddr.MustParseIP
|
||||
errFail := errors.New("some connect failure")
|
||||
tests := []struct {
|
||||
name string
|
||||
steps []step
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no-info",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "previous-dial",
|
||||
steps: []step{
|
||||
{mustIP("2003::1"), nil},
|
||||
{mustIP("2003::1"), errFail},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no-previous-dial",
|
||||
steps: []step{
|
||||
{mustIP("2003::1"), errFail},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &dialer{
|
||||
pastConnect: map[netaddr.IP]time.Time{},
|
||||
}
|
||||
dc := &dialCall{
|
||||
d: d,
|
||||
}
|
||||
for _, st := range tt.steps {
|
||||
dc.noteDialResult(st.ip, st.err)
|
||||
}
|
||||
got := dc.dnsWasTrustworthy()
|
||||
if got != tt.want {
|
||||
t.Errorf("got %v; want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialCall_uniqueIPs(t *testing.T) {
|
||||
dc := &dialCall{}
|
||||
mustIP := netaddr.MustParseIP
|
||||
errFail := errors.New("some connect failure")
|
||||
dc.noteDialResult(mustIP("2003::1"), errFail)
|
||||
dc.noteDialResult(mustIP("2003::2"), errFail)
|
||||
got := dc.uniqueIPs([]netaddr.IP{
|
||||
mustIP("2003::1"),
|
||||
mustIP("2003::2"),
|
||||
mustIP("2003::2"),
|
||||
mustIP("2003::3"),
|
||||
mustIP("2003::3"),
|
||||
mustIP("2003::4"),
|
||||
mustIP("2003::4"),
|
||||
})
|
||||
want := []netaddr.IP{
|
||||
mustIP("2003::3"),
|
||||
mustIP("2003::4"),
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
50
net/netutil/netutil.go
Normal file
50
net/netutil/netutil.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package netutil contains misc shared networking code & types.
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// NewOneConnListener returns a net.Listener that returns c on its first
|
||||
// Accept and EOF thereafter. The Listener's Addr is a dummy address.
|
||||
func NewOneConnListener(c net.Conn) net.Listener {
|
||||
return NewOneConnListenerFrom(c, dummyListener{})
|
||||
}
|
||||
|
||||
// NewOneConnListenerFrom returns a net.Listener wrapping ln where
|
||||
// its Accept returns c on the first call and io.EOF thereafter.
|
||||
func NewOneConnListenerFrom(c net.Conn, ln net.Listener) net.Listener {
|
||||
return &oneConnListener{c, ln}
|
||||
}
|
||||
|
||||
type oneConnListener struct {
|
||||
conn net.Conn
|
||||
net.Listener
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Accept() (c net.Conn, err error) {
|
||||
c = l.conn
|
||||
if c == nil {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
l.conn = nil
|
||||
return
|
||||
}
|
||||
|
||||
type dummyListener struct{}
|
||||
|
||||
func (dummyListener) Close() error { return nil }
|
||||
func (dummyListener) Addr() net.Addr { return dummyAddr("unused-address") }
|
||||
func (dummyListener) Accept() (c net.Conn, err error) { return nil, io.EOF }
|
||||
|
||||
type dummyAddr string
|
||||
|
||||
func (a dummyAddr) Network() string { return string(a) }
|
||||
func (a dummyAddr) String() string { return string(a) }
|
||||
@@ -431,7 +431,6 @@ func (t *Wrapper) filterOut(p *packet.Parsed) filter.Response {
|
||||
return filter.DropSilently // don't pass on to OS; already handled
|
||||
}
|
||||
}
|
||||
// TODO(bradfitz): support pinging TailscaleServiceIPv6 too.
|
||||
|
||||
// Issue 1526 workaround: if we sent disco packets over
|
||||
// Tailscale from ourselves, then drop them, as that shouldn't
|
||||
|
||||
@@ -407,7 +407,8 @@ main() {
|
||||
fi
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
if ! type gpg >/dev/null; then
|
||||
apt-get install -y gnupg
|
||||
$SUDO apt-get update
|
||||
$SUDO apt-get install -y gnupg
|
||||
fi
|
||||
|
||||
set -x
|
||||
@@ -467,7 +468,7 @@ main() {
|
||||
;;
|
||||
xbps)
|
||||
set -x
|
||||
$SUDO xbps-install tailscale
|
||||
$SUDO xbps-install tailscale -y
|
||||
set +x
|
||||
;;
|
||||
emerge)
|
||||
|
||||
167
tsweb/debug.go
167
tsweb/debug.go
@@ -6,14 +6,20 @@ package tsweb
|
||||
|
||||
import (
|
||||
"expvar"
|
||||
"flag"
|
||||
"fmt"
|
||||
"html"
|
||||
"html/template"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/version"
|
||||
)
|
||||
@@ -33,6 +39,9 @@ type DebugHandler struct {
|
||||
kvs []func(io.Writer) // output one <li>...</li> each, see KV()
|
||||
urls []string // one <li>...</li> block with link each
|
||||
sections []func(io.Writer, *http.Request) // invoked in registration order prior to outputting </body>
|
||||
flagmu sync.Mutex // flagmu protects access to flagset and flagc
|
||||
flagset *flag.FlagSet // runtime-modifiable flags, may be nil
|
||||
flagc chan map[string]interface{} // DebugHandler sends new flag values on flagc when the flags have been modified
|
||||
}
|
||||
|
||||
// Debugger returns the DebugHandler registered on mux at /debug/,
|
||||
@@ -139,3 +148,161 @@ func gcHandler(w http.ResponseWriter, r *http.Request) {
|
||||
runtime.GC()
|
||||
w.Write([]byte("Done.\n"))
|
||||
}
|
||||
|
||||
// FlagSet returns a FlagSet that can be used to add runtime-modifiable flags to d.
|
||||
// Calling code should add flags to fs, but not retain the values directly.
|
||||
// Modifications to fs will be delivered via c.
|
||||
// Maps sent to c will be keyed on the flag name, and contain the new value.
|
||||
// Only modified values will be sent on c.
|
||||
//
|
||||
// Sample usage:
|
||||
// flagset, flagc := debug.FlagSet()
|
||||
// flagset.Int("max", 0, "maximum number of bars")
|
||||
// flagset.String("s", "qux", "default name for new foos")
|
||||
// go func() {
|
||||
// for change := range flagc {
|
||||
// // TODO: handle change, which will contain values for keys "max" and/or "s"
|
||||
// }
|
||||
// }()
|
||||
func (d *DebugHandler) FlagSet() (fs *flag.FlagSet, c chan map[string]interface{}) {
|
||||
d.flagmu.Lock()
|
||||
defer d.flagmu.Unlock()
|
||||
if d.flagset == nil {
|
||||
d.flagset = flag.NewFlagSet("debug", flag.ContinueOnError)
|
||||
d.flagc = make(chan map[string]interface{})
|
||||
d.Handle("flags", "Runtime flags", http.HandlerFunc(d.handleFlags))
|
||||
}
|
||||
return d.flagset, d.flagc
|
||||
}
|
||||
|
||||
type copiedFlag struct {
|
||||
Name string
|
||||
Value string
|
||||
Usage string
|
||||
}
|
||||
|
||||
func copyFlags(fs *flag.FlagSet) []copiedFlag {
|
||||
var all []copiedFlag
|
||||
fs.VisitAll(func(f *flag.Flag) {
|
||||
all = append(all, copiedFlag{Name: f.Name, Value: f.Value.String(), Usage: f.Usage})
|
||||
})
|
||||
return all
|
||||
}
|
||||
|
||||
func (d *DebugHandler) handleFlags(w http.ResponseWriter, r *http.Request) {
|
||||
d.flagmu.Lock()
|
||||
defer d.flagmu.Unlock()
|
||||
|
||||
var userError string
|
||||
var modified string
|
||||
if r.Method == http.MethodPost {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// Make a copy of existing values, in case we need to roll back.
|
||||
all := copyFlags(d.flagset)
|
||||
// Set inbound values.
|
||||
changed := make(map[string][2]string)
|
||||
rollback := false
|
||||
for k, v := range r.PostForm {
|
||||
if len(v) != 1 {
|
||||
userError = fmt.Sprintf("multiple values for name %q: %q", k, v)
|
||||
rollback = true
|
||||
break
|
||||
}
|
||||
f := d.flagset.Lookup(k)
|
||||
if f == nil {
|
||||
userError = fmt.Sprintf("unknown name %q", k)
|
||||
rollback = true
|
||||
break
|
||||
}
|
||||
prev := f.Value.String()
|
||||
new := strings.TrimSpace(v[0])
|
||||
if prev == new {
|
||||
continue
|
||||
}
|
||||
err := d.flagset.Set(k, new)
|
||||
if err != nil {
|
||||
userError = fmt.Sprintf("parsing value %q for name %q: %v", new, k, err)
|
||||
rollback = true
|
||||
break
|
||||
}
|
||||
changed[k] = [2]string{prev, new}
|
||||
}
|
||||
if rollback {
|
||||
for _, f := range all {
|
||||
d.flagset.Set(f.Name, f.Value)
|
||||
}
|
||||
} else {
|
||||
// Generate description of modifications.
|
||||
var names []string
|
||||
for k := range changed {
|
||||
names = append(names, k)
|
||||
}
|
||||
sort.Strings(names)
|
||||
buf := new(strings.Builder)
|
||||
for i, k := range names {
|
||||
if i != 0 {
|
||||
buf.WriteString("; ")
|
||||
}
|
||||
pn := changed[k]
|
||||
fmt.Fprintf(buf, "%q: %v → %v", k, pn[0], pn[1])
|
||||
}
|
||||
modified = buf.String()
|
||||
vals := make(map[string]interface{})
|
||||
for _, k := range names {
|
||||
vals[k] = d.flagset.Lookup(k).Value.(flag.Getter).Get()
|
||||
}
|
||||
|
||||
d.flagc <- vals
|
||||
// TODO: post modifications to Slack, along with attribution
|
||||
}
|
||||
}
|
||||
|
||||
dot := &struct {
|
||||
Error string
|
||||
Modified string
|
||||
Flags []copiedFlag
|
||||
}{
|
||||
Error: userError,
|
||||
Modified: modified,
|
||||
Flags: copyFlags(d.flagset),
|
||||
}
|
||||
err := flagsTemplate.Execute(w, dot)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
flagsTemplate = template.Must(template.New("flags").Parse(`
|
||||
<html>
|
||||
<body>
|
||||
|
||||
{{if .Error}}
|
||||
<h2>Error: <mark>{{.Error}}</mark></h2>
|
||||
{{end}}
|
||||
|
||||
{{if .Modified}}
|
||||
<h3>Modified: <mark>{{.Modified}}</mark></h3>
|
||||
{{end}}
|
||||
|
||||
<h3>Modifiable runtime flags</h3>
|
||||
|
||||
<p>Warning! Modifying these values will take effect immediately and impact the running service</p>
|
||||
|
||||
<form method="POST">
|
||||
<table>
|
||||
<tr> <th>Name</th> <th>Value</th> <th>Usage</th> </tr>
|
||||
{{range .Flags}}
|
||||
<tr> <td>{{.Name}}</td> <td><input type="text" value="{{.Value}}" name="{{.Name}}"/></td> <td>{{.Usage}}</td> </tr>
|
||||
{{end}}
|
||||
</table>
|
||||
<input type="submit"></input>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
`))
|
||||
)
|
||||
|
||||
@@ -100,6 +100,7 @@ func (fn JSONHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(encb)))
|
||||
w.WriteHeader(status)
|
||||
w.Write(encb)
|
||||
} else {
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(b)))
|
||||
|
||||
@@ -182,6 +182,23 @@ func TestNewJSONHandler(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("gzipped_400", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`))
|
||||
r.Header.Set("Accept-Encoding", "gzip")
|
||||
value := []string{"foo", "foo", "foo"}
|
||||
JSONHandlerFunc(func(r *http.Request) (int, interface{}, error) {
|
||||
return 400, value, nil
|
||||
}).ServeHTTPReturn(w, r)
|
||||
res := w.Result()
|
||||
if ct := res.Header.Get("Content-Encoding"); ct != "gzip" {
|
||||
t.Fatalf("encoding = %q; want gzip", ct)
|
||||
}
|
||||
if res.StatusCode != 400 {
|
||||
t.Errorf("Status = %v; want 400", res.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("400 post data error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`))
|
||||
|
||||
@@ -9,7 +9,31 @@ package winutil
|
||||
// are stored. This constant is a non-empty string only when GOOS=windows.
|
||||
const RegBase = regBase
|
||||
|
||||
// GetRegString looks up a registry path in our local machine path, or returns
|
||||
// GetPolicyString looks up a registry value in the local machine's path for
|
||||
// system policies, or returns the given default if it can't.
|
||||
// Use this function to read values that may be set by sysadmins via the MSI
|
||||
// installer or via GPO. For registry settings that you do *not* want to be
|
||||
// visible to sysadmin tools, use GetRegString instead.
|
||||
//
|
||||
// This function will only work on GOOS=windows. Trying to run it on any other
|
||||
// OS will always return the default value.
|
||||
func GetPolicyString(name, defval string) string {
|
||||
return getPolicyString(name, defval)
|
||||
}
|
||||
|
||||
// GetPolicyInteger looks up a registry value in the local machine's path for
|
||||
// system policies, or returns the given default if it can't.
|
||||
// Use this function to read values that may be set by sysadmins via the MSI
|
||||
// installer or via GPO. For registry settings that you do *not* want to be
|
||||
// visible to sysadmin tools, use GetRegInteger instead.
|
||||
//
|
||||
// This function will only work on GOOS=windows. Trying to run it on any other
|
||||
// OS will always return the default value.
|
||||
func GetPolicyInteger(name string, defval uint64) uint64 {
|
||||
return getPolicyInteger(name, defval)
|
||||
}
|
||||
|
||||
// GetRegString looks up a registry path in the local machine path, or returns
|
||||
// the given default if it can't.
|
||||
//
|
||||
// This function will only work on GOOS=windows. Trying to run it on any other
|
||||
@@ -18,7 +42,7 @@ func GetRegString(name, defval string) string {
|
||||
return getRegString(name, defval)
|
||||
}
|
||||
|
||||
// GetRegInteger looks up a registry path in our local machine path, or returns
|
||||
// GetRegInteger looks up a registry path in the local machine path, or returns
|
||||
// the given default if it can't.
|
||||
//
|
||||
// This function will only work on GOOS=windows. Trying to run it on any other
|
||||
|
||||
@@ -9,6 +9,10 @@ package winutil
|
||||
|
||||
const regBase = ``
|
||||
|
||||
func getPolicyString(name, defval string) string { return defval }
|
||||
|
||||
func getPolicyInteger(name string, defval uint64) uint64 { return defval }
|
||||
|
||||
func getRegString(name, defval string) string { return defval }
|
||||
|
||||
func getRegInteger(name string, defval uint64) uint64 { return defval }
|
||||
|
||||
@@ -5,33 +5,83 @@
|
||||
package winutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
const regBase = `SOFTWARE\Tailscale IPN`
|
||||
const (
|
||||
regBase = `SOFTWARE\Tailscale IPN`
|
||||
regPolicyBase = `SOFTWARE\Policies\Tailscale`
|
||||
)
|
||||
|
||||
// ErrNoShell is returned when the shell process is not found.
|
||||
var ErrNoShell = errors.New("no Shell process is present")
|
||||
|
||||
// GetDesktopPID searches the PID of the process that's running the
|
||||
// currently active desktop and whether it was found.
|
||||
// currently active desktop. Returns ErrNoShell if the shell is not present.
|
||||
// Usually the PID will be for explorer.exe.
|
||||
func GetDesktopPID() (pid uint32, ok bool) {
|
||||
func GetDesktopPID() (uint32, error) {
|
||||
hwnd := windows.GetShellWindow()
|
||||
if hwnd == 0 {
|
||||
return 0, false
|
||||
return 0, ErrNoShell
|
||||
}
|
||||
|
||||
var pid uint32
|
||||
windows.GetWindowThreadProcessId(hwnd, &pid)
|
||||
return pid, pid != 0
|
||||
if pid == 0 {
|
||||
return 0, fmt.Errorf("invalid PID for HWND %v", hwnd)
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
func getPolicyString(name, defval string) string {
|
||||
s, err := getRegStringInternal(regPolicyBase, name)
|
||||
if err != nil {
|
||||
// Fall back to the legacy path
|
||||
return getRegString(name, defval)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func getPolicyInteger(name string, defval uint64) uint64 {
|
||||
i, err := getRegIntegerInternal(regPolicyBase, name)
|
||||
if err != nil {
|
||||
// Fall back to the legacy path
|
||||
return getRegInteger(name, defval)
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getRegString(name, defval string) string {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, RegBase, registry.READ)
|
||||
s, err := getRegStringInternal(regBase, name)
|
||||
if err != nil {
|
||||
log.Printf("registry.OpenKey(%v): %v", RegBase, err)
|
||||
return defval
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func getRegInteger(name string, defval uint64) uint64 {
|
||||
i, err := getRegIntegerInternal(regBase, name)
|
||||
if err != nil {
|
||||
return defval
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func getRegStringInternal(subKey, name string) (string, error) {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
|
||||
if err != nil {
|
||||
log.Printf("registry.OpenKey(%v): %v", subKey, err)
|
||||
return "", err
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
val, _, err := key.GetStringValue(name)
|
||||
@@ -39,16 +89,16 @@ func getRegString(name, defval string) string {
|
||||
if err != registry.ErrNotExist {
|
||||
log.Printf("registry.GetStringValue(%v): %v", name, err)
|
||||
}
|
||||
return defval
|
||||
return "", err
|
||||
}
|
||||
return val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func getRegInteger(name string, defval uint64) uint64 {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, RegBase, registry.READ)
|
||||
func getRegIntegerInternal(subKey, name string) (uint64, error) {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
|
||||
if err != nil {
|
||||
log.Printf("registry.OpenKey(%v): %v", RegBase, err)
|
||||
return defval
|
||||
log.Printf("registry.OpenKey(%v): %v", subKey, err)
|
||||
return 0, err
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
@@ -57,9 +107,9 @@ func getRegInteger(name string, defval uint64) uint64 {
|
||||
if err != registry.ErrNotExist {
|
||||
log.Printf("registry.GetIntegerValue(%v): %v", name, err)
|
||||
}
|
||||
return defval
|
||||
return 0, err
|
||||
}
|
||||
return val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -93,3 +143,114 @@ func isSIDValidPrincipal(uid string) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// EnableCurrentThreadPrivilege enables the named privilege
|
||||
// in the current thread access token.
|
||||
func EnableCurrentThreadPrivilege(name string) error {
|
||||
var t windows.Token
|
||||
err := windows.OpenThreadToken(windows.CurrentThread(),
|
||||
windows.TOKEN_QUERY|windows.TOKEN_ADJUST_PRIVILEGES, false, &t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer t.Close()
|
||||
|
||||
var tp windows.Tokenprivileges
|
||||
|
||||
privStr, err := syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = windows.LookupPrivilegeValue(nil, privStr, &tp.Privileges[0].Luid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tp.PrivilegeCount = 1
|
||||
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
|
||||
return windows.AdjustTokenPrivileges(t, false, &tp, 0, nil, nil)
|
||||
}
|
||||
|
||||
// StartProcessAsChild starts exePath process as a child of parentPID.
|
||||
// StartProcessAsChild copies parentPID's environment variables into
|
||||
// the new process, along with any optional environment variables in extraEnv.
|
||||
func StartProcessAsChild(parentPID uint32, exePath string, extraEnv []string) error {
|
||||
// The rest of this function requires SeDebugPrivilege to be held.
|
||||
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
err := windows.ImpersonateSelf(windows.SecurityImpersonation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer windows.RevertToSelf()
|
||||
|
||||
// According to https://docs.microsoft.com/en-us/windows/win32/procthread/process-security-and-access-rights
|
||||
//
|
||||
// ... To open a handle to another process and obtain full access rights,
|
||||
// you must enable the SeDebugPrivilege privilege. ...
|
||||
//
|
||||
// But we only need PROCESS_CREATE_PROCESS. So perhaps SeDebugPrivilege is too much.
|
||||
//
|
||||
// https://devblogs.microsoft.com/oldnewthing/20080314-00/?p=23113
|
||||
//
|
||||
// TODO: try look for something less than SeDebugPrivilege
|
||||
|
||||
err = EnableCurrentThreadPrivilege("SeDebugPrivilege")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ph, err := windows.OpenProcess(
|
||||
windows.PROCESS_CREATE_PROCESS|windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_DUP_HANDLE,
|
||||
false, parentPID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer windows.CloseHandle(ph)
|
||||
|
||||
var pt windows.Token
|
||||
err = windows.OpenProcessToken(ph, windows.TOKEN_QUERY, &pt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer pt.Close()
|
||||
|
||||
env, err := pt.Environ(false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
}
|
||||
env = append(env, extraEnv...)
|
||||
|
||||
sys := &syscall.SysProcAttr{ParentProcess: syscall.Handle(ph)}
|
||||
|
||||
cmd := exec.Command(exePath)
|
||||
cmd.Env = env
|
||||
cmd.SysProcAttr = sys
|
||||
|
||||
return cmd.Start()
|
||||
}
|
||||
|
||||
// StartProcessAsCurrentGUIUser is like StartProcessAsChild, but if finds
|
||||
// current logged in user desktop process (normally explorer.exe),
|
||||
// and passes found PID to StartProcessAsChild.
|
||||
func StartProcessAsCurrentGUIUser(exePath string, extraEnv []string) error {
|
||||
// as described in https://devblogs.microsoft.com/oldnewthing/20190425-00/?p=102443
|
||||
desktop, err := GetDesktopPID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find desktop: %v", err)
|
||||
}
|
||||
err = StartProcessAsChild(desktop, exePath, extraEnv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start executable: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateAppMutex creates a named Windows mutex, returning nil if the mutex
|
||||
// is created successfully or an error if the mutex already exists or could not
|
||||
// be created for some other reason.
|
||||
func CreateAppMutex(name string) (windows.Handle, error) {
|
||||
return windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(name))
|
||||
}
|
||||
|
||||
@@ -2111,12 +2111,12 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netaddr.IPPort, de *endpoint) {
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.lastEndpointsTime.After(time.Now().Add(-endpointsFreshEnoughDuration)) {
|
||||
c.logf("magicsock: want call-me-maybe but endpoints stale; restunning")
|
||||
c.logf("[v1] magicsock: want call-me-maybe but endpoints stale; restunning")
|
||||
if c.onEndpointRefreshed == nil {
|
||||
c.onEndpointRefreshed = map[*endpoint]func(){}
|
||||
}
|
||||
c.onEndpointRefreshed[de] = func() {
|
||||
c.logf("magicsock: STUN done; sending call-me-maybe to %v %v", de.discoShort, de.publicKey.ShortString())
|
||||
c.logf("[v1] magicsock: STUN done; sending call-me-maybe to %v %v", de.discoShort, de.publicKey.ShortString())
|
||||
c.enqueueCallMeMaybe(derpAddr, de)
|
||||
}
|
||||
// TODO(bradfitz): make a new 'reSTUNQuickly' method
|
||||
|
||||
@@ -101,9 +101,21 @@ func (c *nlConn) Receive() (message, error) {
|
||||
dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength)
|
||||
gw := netaddrIP(rmsg.Attributes.Gateway)
|
||||
|
||||
if msg.Header.Type == unix.RTM_NEWROUTE &&
|
||||
(rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) &&
|
||||
(dst.IP().IsMulticast() || dst.IP().IsLinkLocalUnicast()) {
|
||||
// Normal Linux route changes on new interface coming up; don't log or react.
|
||||
return ignoreMessage{}, nil
|
||||
}
|
||||
|
||||
if rmsg.Table == tsTable && dst.IsSingleIP() {
|
||||
// Don't log. Spammy and normal to see a bunch of these on start-up,
|
||||
// which we make ourselves.
|
||||
} else if tsaddr.IsTailscaleIP(dst.IP()) {
|
||||
// Verbose only.
|
||||
c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr,
|
||||
condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw),
|
||||
rmsg.Attributes.OutIface, rmsg.Attributes.Table)
|
||||
} else {
|
||||
c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr,
|
||||
condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw),
|
||||
|
||||
@@ -132,7 +132,7 @@ feist
|
||||
spitz
|
||||
squirrel
|
||||
gerbil
|
||||
hampster
|
||||
hamster
|
||||
panda
|
||||
gibbon
|
||||
flyingfox
|
||||
@@ -210,3 +210,5 @@ tyrannosaurus
|
||||
velociraptor
|
||||
siren
|
||||
mudpuppy
|
||||
ferret
|
||||
roborovski
|
||||
|
||||
Reference in New Issue
Block a user