Compare commits
8 Commits
v1.48.2
...
bradfitz/g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b256c319c0 | ||
|
|
57da1f1501 | ||
|
|
37c0b9be63 | ||
|
|
18280ebf7d | ||
|
|
623d72c83b | ||
|
|
f101a75dce | ||
|
|
f75a36f9bc | ||
|
|
cf31b58ed1 |
@@ -1 +1 @@
|
||||
1.48.2
|
||||
1.49.0
|
||||
|
||||
@@ -10,6 +10,7 @@ type DNSConfig struct {
|
||||
Domains []string `json:"domains"`
|
||||
Nameservers []string `json:"nameservers"`
|
||||
Proxied bool `json:"proxied"`
|
||||
DNSFilterURL string `json:"DNSFilterURL"`
|
||||
}
|
||||
|
||||
type DNSResolver struct {
|
||||
|
||||
@@ -7,12 +7,19 @@ export default function App() {
|
||||
|
||||
return (
|
||||
<div className="py-14">
|
||||
<main className="container max-w-lg mx-auto mb-8 py-6 px-8 bg-white rounded-md shadow-2xl">
|
||||
<Header data={data} />
|
||||
<IP data={data} />
|
||||
<State data={data} />
|
||||
</main>
|
||||
<Footer data={data} />
|
||||
{!data ? (
|
||||
// TODO(sonia): add a loading view
|
||||
<div className="text-center">Loading...</div>
|
||||
) : (
|
||||
<>
|
||||
<main className="container max-w-lg mx-auto mb-8 py-6 px-8 bg-white rounded-md shadow-2xl">
|
||||
<Header data={data} />
|
||||
<IP data={data} />
|
||||
<State data={data} />
|
||||
</main>
|
||||
<Footer data={data} />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
export type UserProfile = {
|
||||
LoginName: string
|
||||
DisplayName: string
|
||||
ProfilePicURL: string
|
||||
}
|
||||
import { useEffect, useState } from "react"
|
||||
|
||||
export type NodeData = {
|
||||
Profile: UserProfile
|
||||
@@ -20,29 +16,22 @@ export type NodeData = {
|
||||
IPNVersion: string
|
||||
}
|
||||
|
||||
// testData is static set of nodedata used during development.
|
||||
// This can be removed once we have a real node data API.
|
||||
const testData: NodeData = {
|
||||
Profile: {
|
||||
LoginName: "amelie",
|
||||
DisplayName: "Amelie Pangolin",
|
||||
ProfilePicURL: "https://login.tailscale.com/logo192.png",
|
||||
},
|
||||
Status: "Running",
|
||||
DeviceName: "amelies-laptop",
|
||||
IP: "100.1.2.3",
|
||||
AdvertiseExitNode: false,
|
||||
AdvertiseRoutes: "",
|
||||
LicensesURL: "https://tailscale.com/licenses/tailscale",
|
||||
TUNMode: false,
|
||||
IsSynology: true,
|
||||
DSMVersion: 7,
|
||||
IsUnraid: false,
|
||||
UnraidToken: "",
|
||||
IPNVersion: "0.1.0",
|
||||
export type UserProfile = {
|
||||
LoginName: string
|
||||
DisplayName: string
|
||||
ProfilePicURL: string
|
||||
}
|
||||
|
||||
// useNodeData returns basic data about the current node.
|
||||
export default function useNodeData() {
|
||||
return testData
|
||||
const [data, setData] = useState<NodeData>()
|
||||
|
||||
useEffect(() => {
|
||||
fetch("/api/data")
|
||||
.then((response) => response.json())
|
||||
.then((json) => setData(json))
|
||||
.catch((error) => console.error(error))
|
||||
}, [])
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"tailscale.com/net/netutil"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/groupmember"
|
||||
"tailscale.com/util/httpm"
|
||||
"tailscale.com/version/distro"
|
||||
)
|
||||
|
||||
@@ -78,30 +79,6 @@ func init() {
|
||||
template.Must(tmpl.New("web.css").Parse(webCSS))
|
||||
}
|
||||
|
||||
type tmplData struct {
|
||||
Profile tailcfg.UserProfile
|
||||
SynologyUser string
|
||||
Status string
|
||||
DeviceName string
|
||||
IP string
|
||||
AdvertiseExitNode bool
|
||||
AdvertiseRoutes string
|
||||
LicensesURL string
|
||||
TUNMode bool
|
||||
IsSynology bool
|
||||
DSMVersion int // 6 or 7, if IsSynology=true
|
||||
IsUnraid bool
|
||||
UnraidToken string
|
||||
IPNVersion string
|
||||
}
|
||||
|
||||
type postedData struct {
|
||||
AdvertiseRoutes string
|
||||
AdvertiseExitNode bool
|
||||
Reauthenticate bool
|
||||
ForceLogout bool
|
||||
}
|
||||
|
||||
// authorize returns the name of the user accessing the web UI after verifying
|
||||
// whether the user has access to the web UI. The function will write the
|
||||
// error to the provided http.ResponseWriter.
|
||||
@@ -294,12 +271,26 @@ req.send(null);
|
||||
// ServeHTTP processes all requests for the Tailscale web client.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if s.devMode {
|
||||
if r.URL.Path == "/api/data" {
|
||||
user, err := authorize(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case httpm.GET:
|
||||
s.serveGetNodeDataJSON(w, r, user)
|
||||
case httpm.POST:
|
||||
s.servePostNodeUpdate(w, r)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
return
|
||||
}
|
||||
// When in dev mode, proxy to the Vite dev server.
|
||||
s.devProxy.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
if authRedirect(w, r) {
|
||||
return
|
||||
}
|
||||
@@ -309,80 +300,49 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/redirect" || r.URL.Path == "/redirect/" {
|
||||
switch {
|
||||
case r.URL.Path == "/redirect" || r.URL.Path == "/redirect/":
|
||||
io.WriteString(w, authenticationRedirectHTML)
|
||||
return
|
||||
case r.Method == "POST":
|
||||
s.servePostNodeUpdate(w, r)
|
||||
return
|
||||
default:
|
||||
s.serveGetNodeData(w, r, user)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type nodeData struct {
|
||||
Profile tailcfg.UserProfile
|
||||
SynologyUser string
|
||||
Status string
|
||||
DeviceName string
|
||||
IP string
|
||||
AdvertiseExitNode bool
|
||||
AdvertiseRoutes string
|
||||
LicensesURL string
|
||||
TUNMode bool
|
||||
IsSynology bool
|
||||
DSMVersion int // 6 or 7, if IsSynology=true
|
||||
IsUnraid bool
|
||||
UnraidToken string
|
||||
IPNVersion string
|
||||
}
|
||||
|
||||
func (s *Server) getNodeData(ctx context.Context, user string) (*nodeData, error) {
|
||||
st, err := s.lc.Status(ctx)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
prefs, err := s.lc.GetPrefs(ctx)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if r.Method == "POST" {
|
||||
defer r.Body.Close()
|
||||
var postData postedData
|
||||
type mi map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&postData); err != nil {
|
||||
w.WriteHeader(400)
|
||||
json.NewEncoder(w).Encode(mi{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := netutil.CalcAdvertiseRoutes(postData.AdvertiseRoutes, postData.AdvertiseExitNode)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(mi{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
mp := &ipn.MaskedPrefs{
|
||||
AdvertiseRoutesSet: true,
|
||||
WantRunningSet: true,
|
||||
}
|
||||
mp.Prefs.WantRunning = true
|
||||
mp.Prefs.AdvertiseRoutes = routes
|
||||
log.Printf("Doing edit: %v", mp.Pretty())
|
||||
|
||||
if _, err := s.lc.EditPrefs(ctx, mp); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(mi{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var reauth, logout bool
|
||||
if postData.Reauthenticate {
|
||||
reauth = true
|
||||
}
|
||||
if postData.ForceLogout {
|
||||
logout = true
|
||||
}
|
||||
log.Printf("tailscaleUp(reauth=%v, logout=%v) ...", reauth, logout)
|
||||
url, err := s.tailscaleUp(r.Context(), st, postData)
|
||||
log.Printf("tailscaleUp = (URL %v, %v)", url != "", err)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(mi{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if url != "" {
|
||||
json.NewEncoder(w).Encode(mi{"url": url})
|
||||
} else {
|
||||
io.WriteString(w, "{}")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
profile := st.User[st.Self.UserID]
|
||||
deviceName := strings.Split(st.Self.DNSName, ".")[0]
|
||||
versionShort := strings.Split(st.Version, "-")[0]
|
||||
data := tmplData{
|
||||
data := &nodeData{
|
||||
SynologyUser: user,
|
||||
Profile: profile,
|
||||
Status: st.BackendState,
|
||||
@@ -410,16 +370,106 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if len(st.TailscaleIPs) != 0 {
|
||||
data.IP = st.TailscaleIPs[0].String()
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request, user string) {
|
||||
data, err := s.getNodeData(r.Context(), user)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
if err := tmpl.Execute(buf, data); err != nil {
|
||||
if err := tmpl.Execute(buf, *data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) tailscaleUp(ctx context.Context, st *ipnstate.Status, postData postedData) (authURL string, retErr error) {
|
||||
func (s *Server) serveGetNodeDataJSON(w http.ResponseWriter, r *http.Request, user string) {
|
||||
data, err := s.getNodeData(r.Context(), user)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(*data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return
|
||||
}
|
||||
|
||||
type nodeUpdate struct {
|
||||
AdvertiseRoutes string
|
||||
AdvertiseExitNode bool
|
||||
Reauthenticate bool
|
||||
ForceLogout bool
|
||||
}
|
||||
|
||||
func (s *Server) servePostNodeUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
defer r.Body.Close()
|
||||
|
||||
st, err := s.lc.Status(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var postData nodeUpdate
|
||||
type mi map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&postData); err != nil {
|
||||
w.WriteHeader(400)
|
||||
json.NewEncoder(w).Encode(mi{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := netutil.CalcAdvertiseRoutes(postData.AdvertiseRoutes, postData.AdvertiseExitNode)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(mi{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
mp := &ipn.MaskedPrefs{
|
||||
AdvertiseRoutesSet: true,
|
||||
WantRunningSet: true,
|
||||
}
|
||||
mp.Prefs.WantRunning = true
|
||||
mp.Prefs.AdvertiseRoutes = routes
|
||||
log.Printf("Doing edit: %v", mp.Pretty())
|
||||
|
||||
if _, err := s.lc.EditPrefs(r.Context(), mp); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(mi{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var reauth, logout bool
|
||||
if postData.Reauthenticate {
|
||||
reauth = true
|
||||
}
|
||||
if postData.ForceLogout {
|
||||
logout = true
|
||||
}
|
||||
log.Printf("tailscaleUp(reauth=%v, logout=%v) ...", reauth, logout)
|
||||
url, err := s.tailscaleUp(r.Context(), st, postData)
|
||||
log.Printf("tailscaleUp = (URL %v, %v)", url != "", err)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(mi{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if url != "" {
|
||||
json.NewEncoder(w).Encode(mi{"url": url})
|
||||
} else {
|
||||
io.WriteString(w, "{}")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) tailscaleUp(ctx context.Context, st *ipnstate.Status, postData nodeUpdate) (authURL string, retErr error) {
|
||||
if postData.ForceLogout {
|
||||
if err := s.lc.Logout(ctx); err != nil {
|
||||
return "", fmt.Errorf("Logout error: %w", err)
|
||||
|
||||
@@ -28,7 +28,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"tailscale.com/hostinfo"
|
||||
"tailscale.com/net/tshttpproxy"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/winutil"
|
||||
@@ -185,8 +187,6 @@ func (up *updater) confirm(ver string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
const synoinfoConfPath = "/etc/synoinfo.conf"
|
||||
|
||||
func (up *updater) updateSynology() error {
|
||||
if up.Version != "" {
|
||||
return errors.New("installing a specific version on Synology is not supported")
|
||||
@@ -194,7 +194,7 @@ func (up *updater) updateSynology() error {
|
||||
|
||||
// Get the latest version and list of SPKs from pkgs.tailscale.com.
|
||||
osName := fmt.Sprintf("dsm%d", distro.DSMVersion())
|
||||
arch, err := synoArch(runtime.GOARCH, synoinfoConfPath)
|
||||
arch, err := synoArch(hostinfo.New())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -245,62 +245,51 @@ func (up *updater) updateSynology() error {
|
||||
|
||||
// synoArch returns the Synology CPU architecture matching one of the SPK
|
||||
// architectures served from pkgs.tailscale.com.
|
||||
func synoArch(goArch, synoinfoPath string) (string, error) {
|
||||
func synoArch(hinfo *tailcfg.Hostinfo) (string, error) {
|
||||
// Most Synology boxes just use a different arch name from GOARCH.
|
||||
arch := map[string]string{
|
||||
"amd64": "x86_64",
|
||||
"386": "i686",
|
||||
"arm64": "armv8",
|
||||
}[goArch]
|
||||
|
||||
}[hinfo.GoArch]
|
||||
// Here's the fun part, some older ARM boxes require you to use SPKs
|
||||
// specifically for their CPU.
|
||||
//
|
||||
// See https://github.com/SynoCommunity/spksrc/wiki/Synology-and-SynoCommunity-Package-Architectures
|
||||
// for a complete list. Here, we override GOARCH for those older boxes that
|
||||
// support at least DSM6.
|
||||
//
|
||||
// This is an artisanal hand-crafted list based on the wiki page. Some
|
||||
// values may be wrong, since we don't have all those devices to actually
|
||||
// test with.
|
||||
switch hinfo.DeviceModel {
|
||||
case "DS213air", "DS213", "DS413j",
|
||||
"DS112", "DS112+", "DS212", "DS212+", "RS212", "RS812", "DS212j", "DS112j",
|
||||
"DS111", "DS211", "DS211+", "DS411slim", "DS411", "RS411", "DS211j", "DS411j":
|
||||
arch = "88f6281"
|
||||
case "NVR1218", "NVR216", "VS960HD", "VS360HD":
|
||||
arch = "hi3535"
|
||||
case "DS1517", "DS1817", "DS416", "DS2015xs", "DS715", "DS1515", "DS215+":
|
||||
arch = "alpine"
|
||||
case "DS216se", "DS115j", "DS114", "DS214se", "DS414slim", "RS214", "DS14", "EDS14", "DS213j":
|
||||
arch = "armada370"
|
||||
case "DS115", "DS215j":
|
||||
arch = "armada375"
|
||||
case "DS419slim", "DS218j", "RS217", "DS116", "DS216j", "DS216", "DS416slim", "RS816", "DS416j":
|
||||
arch = "armada38x"
|
||||
case "RS815", "DS214", "DS214+", "DS414", "RS814":
|
||||
arch = "armadaxp"
|
||||
case "DS414j":
|
||||
arch = "comcerto2k"
|
||||
case "DS216play":
|
||||
arch = "monaco"
|
||||
}
|
||||
if arch == "" {
|
||||
// Here's the fun part, some older ARM boxes require you to use SPKs
|
||||
// specifically for their CPU. See
|
||||
// https://github.com/SynoCommunity/spksrc/wiki/Synology-and-SynoCommunity-Package-Architectures
|
||||
// for a complete list.
|
||||
//
|
||||
// Some CPUs will map to neither this list nor the goArch map above, and we
|
||||
// don't have SPKs for them.
|
||||
cpu, err := parseSynoinfo(synoinfoPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get CPU architecture: %w", err)
|
||||
}
|
||||
switch cpu {
|
||||
case "88f6281", "88f6282", "hi3535", "alpine", "armada370",
|
||||
"armada375", "armada38x", "armadaxp", "comcerto2k", "monaco":
|
||||
arch = cpu
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported Synology CPU architecture %q (Go arch %q), please report a bug at https://github.com/tailscale/tailscale/issues/new/choose", cpu, goArch)
|
||||
}
|
||||
return "", fmt.Errorf("cannot determine CPU architecture for Synology model %q (Go arch %q), please report a bug at https://github.com/tailscale/tailscale/issues/new/choose", hinfo.DeviceModel, hinfo.GoArch)
|
||||
}
|
||||
return arch, nil
|
||||
}
|
||||
|
||||
func parseSynoinfo(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Look for a line like:
|
||||
// unique="synology_88f6282_413j"
|
||||
// Extract the CPU in the middle (88f6282 in the above example).
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
l := s.Text()
|
||||
if !strings.HasPrefix(l, "unique=") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(l, "_", 3)
|
||||
if len(parts) != 3 {
|
||||
return "", fmt.Errorf(`malformed %q: found %q, expected format like 'unique="synology_$cpu_$model'`, path, l)
|
||||
}
|
||||
return parts[1], nil
|
||||
}
|
||||
return "", fmt.Errorf(`missing "unique=" field in %q`, path)
|
||||
}
|
||||
|
||||
func (up *updater) updateDebLike() error {
|
||||
ver, err := requestedTailscaleVersion(up.Version, up.track)
|
||||
if err != nil {
|
||||
@@ -599,12 +588,7 @@ func parseAlpinePackageVersion(out []byte) (string, error) {
|
||||
}
|
||||
|
||||
func (up *updater) updateMacSys() error {
|
||||
// use sparkle? do we have permissions from this context? does sudo help?
|
||||
// We can at least fail with a command they can run to update from the shell.
|
||||
// Like "tailscale update --macsys | sudo sh" or something.
|
||||
//
|
||||
// TODO(bradfitz,mihai): implement. But for now:
|
||||
return errors.ErrUnsupported
|
||||
return errors.New("NOTREACHED: On MacSys builds, `tailscale update` is handled in Swift to launch the GUI updater")
|
||||
}
|
||||
|
||||
func (up *updater) updateMacAppStore() error {
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestUpdateDebianAptSourcesListBytes(t *testing.T) {
|
||||
@@ -444,151 +446,29 @@ tailscale installed size:
|
||||
|
||||
func TestSynoArch(t *testing.T) {
|
||||
tests := []struct {
|
||||
goarch string
|
||||
synoinfoUnique string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{goarch: "amd64", synoinfoUnique: "synology_x86_224", want: "x86_64"},
|
||||
{goarch: "arm64", synoinfoUnique: "synology_armv8_124", want: "armv8"},
|
||||
{goarch: "386", synoinfoUnique: "synology_i686_415play", want: "i686"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_88f6281_213air", want: "88f6281"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_88f6282_413j", want: "88f6282"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_hi3535_NVR1218", want: "hi3535"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_alpine_1517", want: "alpine"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_armada370_216se", want: "armada370"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_armada375_115", want: "armada375"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_armada38x_419slim", want: "armada38x"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_armadaxp_RS815", want: "armadaxp"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_comcerto2k_414j", want: "comcerto2k"},
|
||||
{goarch: "arm", synoinfoUnique: "synology_monaco_216play", want: "monaco"},
|
||||
{goarch: "ppc64", synoinfoUnique: "synology_qoriq_413", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%s-%s", tt.goarch, tt.synoinfoUnique), func(t *testing.T) {
|
||||
synoinfoConfPath := filepath.Join(t.TempDir(), "synoinfo.conf")
|
||||
if err := os.WriteFile(
|
||||
synoinfoConfPath,
|
||||
[]byte(fmt.Sprintf("unique=%q\n", tt.synoinfoUnique)),
|
||||
0600,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, err := synoArch(tt.goarch, synoinfoConfPath)
|
||||
if err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Fatalf("got unexpected error %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
t.Fatalf("got %q, expected an error", got)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("got %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSynoinfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
content string
|
||||
goarch string
|
||||
model string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
desc: "double-quoted",
|
||||
content: `
|
||||
company_title="Synology"
|
||||
unique="synology_88f6281_213air"
|
||||
`,
|
||||
want: "88f6281",
|
||||
},
|
||||
{
|
||||
desc: "single-quoted",
|
||||
content: `
|
||||
company_title="Synology"
|
||||
unique='synology_88f6281_213air'
|
||||
`,
|
||||
want: "88f6281",
|
||||
},
|
||||
{
|
||||
desc: "unquoted",
|
||||
content: `
|
||||
company_title="Synology"
|
||||
unique=synology_88f6281_213air
|
||||
`,
|
||||
want: "88f6281",
|
||||
},
|
||||
{
|
||||
desc: "missing unique",
|
||||
content: `
|
||||
company_title="Synology"
|
||||
`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
desc: "empty unique",
|
||||
content: `
|
||||
company_title="Synology"
|
||||
unique=
|
||||
`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
desc: "empty unique double-quoted",
|
||||
content: `
|
||||
company_title="Synology"
|
||||
unique=""
|
||||
`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
desc: "empty unique single-quoted",
|
||||
content: `
|
||||
company_title="Synology"
|
||||
unique=''
|
||||
`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
desc: "malformed unique",
|
||||
content: `
|
||||
company_title="Synology"
|
||||
unique="synology_88f6281"
|
||||
`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
desc: "empty file",
|
||||
content: ``,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
desc: "empty lines and comments",
|
||||
content: `
|
||||
|
||||
# In a file named synoinfo? Shocking!
|
||||
company_title="Synology"
|
||||
|
||||
|
||||
# unique= is_a_field_that_follows
|
||||
unique="synology_88f6281_213air"
|
||||
|
||||
`,
|
||||
want: "88f6281",
|
||||
},
|
||||
{goarch: "amd64", model: "DS224+", want: "x86_64"},
|
||||
{goarch: "arm64", model: "DS124", want: "armv8"},
|
||||
{goarch: "386", model: "DS415play", want: "i686"},
|
||||
{goarch: "arm", model: "DS213air", want: "88f6281"},
|
||||
{goarch: "arm", model: "NVR1218", want: "hi3535"},
|
||||
{goarch: "arm", model: "DS1517", want: "alpine"},
|
||||
{goarch: "arm", model: "DS216se", want: "armada370"},
|
||||
{goarch: "arm", model: "DS115", want: "armada375"},
|
||||
{goarch: "arm", model: "DS419slim", want: "armada38x"},
|
||||
{goarch: "arm", model: "RS815", want: "armadaxp"},
|
||||
{goarch: "arm", model: "DS414j", want: "comcerto2k"},
|
||||
{goarch: "arm", model: "DS216play", want: "monaco"},
|
||||
{goarch: "riscv64", model: "DS999", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
synoinfoConfPath := filepath.Join(t.TempDir(), "synoinfo.conf")
|
||||
if err := os.WriteFile(synoinfoConfPath, []byte(tt.content), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, err := parseSynoinfo(synoinfoConfPath)
|
||||
t.Run(fmt.Sprintf("%s-%s", tt.goarch, tt.model), func(t *testing.T) {
|
||||
got, err := synoArch(&tailcfg.Hostinfo{GoArch: tt.goarch, DeviceModel: tt.model})
|
||||
if err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Fatalf("got unexpected error %v", err)
|
||||
|
||||
@@ -66,7 +66,7 @@ func runExitNodeList(ctx context.Context, args []string) error {
|
||||
var peers []*ipnstate.PeerStatus
|
||||
for _, ps := range st.Peer {
|
||||
if !ps.ExitNodeOption {
|
||||
// We only show location based exit nodes.
|
||||
// We only show exit nodes under the exit-node subcommand.
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
|
||||
tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/util/groupmember from tailscale.com/client/web
|
||||
tailscale.com/util/httpm from tailscale.com/client/tailscale
|
||||
tailscale.com/util/httpm from tailscale.com/client/tailscale+
|
||||
tailscale.com/util/lineread from tailscale.com/net/interfaces+
|
||||
L tailscale.com/util/linuxfw from tailscale.com/net/netns
|
||||
tailscale.com/util/mak from tailscale.com/net/netcheck+
|
||||
|
||||
@@ -1110,6 +1110,16 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
|
||||
}
|
||||
|
||||
nm := sess.netmapForResponse(&resp)
|
||||
|
||||
// Occasionally print the netmap header.
|
||||
// This is handy for debugging, and our logs processing
|
||||
// pipeline depends on it. (TODO: Remove this dependency.)
|
||||
// Code elsewhere prints netmap diffs every time they are received.
|
||||
now := c.clock.Now()
|
||||
if now.Sub(c.lastPrintMap) >= 5*time.Minute {
|
||||
c.lastPrintMap = now
|
||||
c.logf("[v1] new network map[%d]:\n%s", i, nm.VeryConcise())
|
||||
}
|
||||
if nm.SelfNode == nil {
|
||||
c.logf("MapResponse lacked node")
|
||||
return errors.New("MapResponse lacked node")
|
||||
@@ -1129,15 +1139,6 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
|
||||
nm.SelfNode.Capabilities = nil
|
||||
}
|
||||
|
||||
// Occasionally print the netmap header.
|
||||
// This is handy for debugging, and our logs processing
|
||||
// pipeline depends on it. (TODO: Remove this dependency.)
|
||||
// Code elsewhere prints netmap diffs every time they are received.
|
||||
now := c.clock.Now()
|
||||
if now.Sub(c.lastPrintMap) >= 5*time.Minute {
|
||||
c.lastPrintMap = now
|
||||
c.logf("[v1] new network map[%d]:\n%s", i, nm.VeryConcise())
|
||||
}
|
||||
newPersist := persist.AsStruct()
|
||||
newPersist.NodeID = nm.SelfNode.StableID
|
||||
newPersist.UserProfile = nm.UserProfiles[nm.User]
|
||||
|
||||
@@ -64,6 +64,7 @@ const (
|
||||
NotifyInitialState // if set, the first Notify message (sent immediately) will contain the current State + BrowseToURL
|
||||
NotifyInitialPrefs // if set, the first Notify message (sent immediately) will contain the current Prefs
|
||||
NotifyInitialNetMap // if set, the first Notify message (sent immediately) will contain the current NetMap
|
||||
NotifyGUINetMap // if set, only use the Notify.GUINetMap; Notify.Netmap will always be nil. Also impacts NotifyInitialNetMap.
|
||||
|
||||
NotifyNoPrivateKeys // if set, private keys that would normally be sent in updates are zeroed out
|
||||
)
|
||||
@@ -81,13 +82,14 @@ type Notify struct {
|
||||
// For State InUseOtherUser, ErrMessage is not critical and just contains the details.
|
||||
ErrMessage *string
|
||||
|
||||
LoginFinished *empty.Message // non-nil when/if the login process succeeded
|
||||
State *State // if non-nil, the new or current IPN state
|
||||
Prefs *PrefsView // if non-nil && Valid, the new or current preferences
|
||||
NetMap *netmap.NetworkMap // if non-nil, the new or current netmap
|
||||
Engine *EngineStatus // if non-nil, the new or current wireguard stats
|
||||
BrowseToURL *string // if non-nil, UI should open a browser right now
|
||||
BackendLogID *string // if non-nil, the public logtail ID used by backend
|
||||
LoginFinished *empty.Message // non-nil when/if the login process succeeded
|
||||
State *State // if non-nil, the new or current IPN state
|
||||
Prefs *PrefsView // if non-nil && Valid, the new or current preferences
|
||||
//NetMap *netmap.NetworkMap // if non-nil, the new or current netmap
|
||||
GUINetMap *netmap.GUINetworkMap // if non-nil, the new or current netmap
|
||||
Engine *EngineStatus // if non-nil, the new or current wireguard stats
|
||||
BrowseToURL *string // if non-nil, UI should open a browser right now
|
||||
BackendLogID *string // if non-nil, the public logtail ID used by backend
|
||||
|
||||
// FilesWaiting if non-nil means that files are buffered in
|
||||
// the Tailscale daemon and ready for local transfer to the
|
||||
@@ -133,9 +135,9 @@ func (n Notify) String() string {
|
||||
if n.Prefs != nil && n.Prefs.Valid() {
|
||||
fmt.Fprintf(&sb, "%v ", n.Prefs.Pretty())
|
||||
}
|
||||
if n.NetMap != nil {
|
||||
sb.WriteString("NetMap{...} ")
|
||||
}
|
||||
// if n.NetMap != nil {
|
||||
// sb.WriteString("NetMap{...} ")
|
||||
// }
|
||||
if n.Engine != nil {
|
||||
fmt.Fprintf(&sb, "wg=%v ", *n.Engine)
|
||||
}
|
||||
|
||||
@@ -29,8 +29,6 @@ type strideEntry[T any] struct {
|
||||
prefixIndex int
|
||||
// value is the value associated with the strideEntry, if any.
|
||||
value *T
|
||||
// child is the child strideTable associated with the strideEntry, if any.
|
||||
child *strideTable[T]
|
||||
}
|
||||
|
||||
// strideTable is a binary tree that implements an 8-bit routing table.
|
||||
@@ -50,12 +48,17 @@ type strideTable[T any] struct {
|
||||
// parent of the node at index i is located at index i>>1, and its children
|
||||
// at indices i<<1 and (i<<1)+1.
|
||||
//
|
||||
// A few consequences of this arrangement: host routes (/8) occupy the last
|
||||
// 256 entries in the table; the single default route /0 is at index 1, and
|
||||
// index 0 is unused (in the original paper, it's hijacked through sneaky C
|
||||
// memory trickery to store the refcount, but this is Go, where we don't
|
||||
// store random bits in pointers lest we confuse the GC)
|
||||
// A few consequences of this arrangement: host routes (/8) occupy
|
||||
// the last numChildren entries in the table; the single default
|
||||
// route /0 is at index 1, and index 0 is unused (in the original
|
||||
// paper, it's hijacked through sneaky C memory trickery to store
|
||||
// the refcount, but this is Go, where we don't store random bits
|
||||
// in pointers lest we confuse the GC)
|
||||
entries [lastHostIndex + 1]strideEntry[T]
|
||||
// children are the child tables of this table. Each child
|
||||
// represents the address space within one of this table's host
|
||||
// routes (/8).
|
||||
children [numChildren]*strideTable[T]
|
||||
// routeRefs is the number of route entries in this table.
|
||||
routeRefs uint16
|
||||
// childRefs is the number of child strideTables referenced by this table.
|
||||
@@ -67,63 +70,60 @@ const (
|
||||
firstHostIndex = 0b1_0000_0000
|
||||
// lastHostIndex is the array index of the last host route. This is hostIndex(0xFF/8).
|
||||
lastHostIndex = 0b1_1111_1111
|
||||
|
||||
// numChildren is the maximum number of child tables a strideTable can hold.
|
||||
numChildren = 256
|
||||
)
|
||||
|
||||
// getChild returns the child strideTable pointer for addr (if any), and an
|
||||
// internal array index that can be used with deleteChild.
|
||||
func (t *strideTable[T]) getChild(addr uint8) (child *strideTable[T], idx int) {
|
||||
idx = hostIndex(addr)
|
||||
return t.entries[idx].child, idx
|
||||
// getChild returns the child strideTable pointer for addr, or nil if none.
|
||||
func (t *strideTable[T]) getChild(addr uint8) *strideTable[T] {
|
||||
return t.children[addr]
|
||||
}
|
||||
|
||||
// deleteChild deletes the child strideTable at idx (if any). idx should be
|
||||
// obtained via a call to getChild.
|
||||
func (t *strideTable[T]) deleteChild(idx int) {
|
||||
t.entries[idx].child = nil
|
||||
t.childRefs--
|
||||
// deleteChild deletes the child strideTable at addr. It is valid to
|
||||
// delete a non-existent child.
|
||||
func (t *strideTable[T]) deleteChild(addr uint8) {
|
||||
if t.children[addr] != nil {
|
||||
t.childRefs--
|
||||
}
|
||||
t.children[addr] = nil
|
||||
}
|
||||
|
||||
// setChild replaces the child strideTable for addr (if any) with child.
|
||||
// setChild sets the child strideTable for addr to child.
|
||||
func (t *strideTable[T]) setChild(addr uint8, child *strideTable[T]) {
|
||||
t.setChildByIndex(hostIndex(addr), child)
|
||||
}
|
||||
|
||||
// setChildByIndex replaces the child strideTable at idx (if any) with
|
||||
// child. idx should be obtained via a call to getChild.
|
||||
func (t *strideTable[T]) setChildByIndex(idx int, child *strideTable[T]) {
|
||||
if t.entries[idx].child == nil {
|
||||
if t.children[addr] == nil {
|
||||
t.childRefs++
|
||||
}
|
||||
t.entries[idx].child = child
|
||||
t.children[addr] = child
|
||||
}
|
||||
|
||||
// getOrCreateChild returns the child strideTable for addr, creating it if
|
||||
// necessary.
|
||||
func (t *strideTable[T]) getOrCreateChild(addr uint8) (child *strideTable[T], created bool) {
|
||||
idx := hostIndex(addr)
|
||||
if t.entries[idx].child == nil {
|
||||
t.entries[idx].child = &strideTable[T]{
|
||||
ret := t.children[addr]
|
||||
if ret == nil {
|
||||
ret = &strideTable[T]{
|
||||
prefix: childPrefixOf(t.prefix, addr),
|
||||
}
|
||||
t.children[addr] = ret
|
||||
t.childRefs++
|
||||
return t.entries[idx].child, true
|
||||
return ret, true
|
||||
}
|
||||
return t.entries[idx].child, false
|
||||
return ret, false
|
||||
}
|
||||
|
||||
// getValAndChild returns both the prefix and child strideTable for
|
||||
// addr. Both returned values can be nil if no entry of that type
|
||||
// exists for addr.
|
||||
func (t *strideTable[T]) getValAndChild(addr uint8) (*T, *strideTable[T]) {
|
||||
idx := hostIndex(addr)
|
||||
return t.entries[idx].value, t.entries[idx].child
|
||||
return t.entries[hostIndex(addr)].value, t.children[addr]
|
||||
}
|
||||
|
||||
// findFirstChild returns the first child strideTable in t, or nil if
|
||||
// t has no children.
|
||||
func (t *strideTable[T]) findFirstChild() *strideTable[T] {
|
||||
for i := firstHostIndex; i <= lastHostIndex; i++ {
|
||||
if child := t.entries[i].child; child != nil {
|
||||
for _, child := range t.children {
|
||||
if child != nil {
|
||||
return child
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,7 +364,7 @@ func (t *Table[T]) Delete(pfx netip.Prefix) {
|
||||
// write to strideTables[N] and strideIndexes[N-1].
|
||||
strideIdx := 0
|
||||
strideTables := [16]*strideTable[T]{st}
|
||||
strideIndexes := [15]int{}
|
||||
strideIndexes := [15]uint8{}
|
||||
|
||||
// Similar to Insert, navigate down the tree of strideTables,
|
||||
// looking for the one that houses this prefix. This part is
|
||||
@@ -384,7 +384,7 @@ func (t *Table[T]) Delete(pfx netip.Prefix) {
|
||||
if debugDelete {
|
||||
fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix)
|
||||
}
|
||||
child, idx := st.getChild(bs[byteIdx])
|
||||
child := st.getChild(bs[byteIdx])
|
||||
if child == nil {
|
||||
// Prefix can't exist in the table, because one of the
|
||||
// necessary strideTables doesn't exist.
|
||||
@@ -393,7 +393,7 @@ func (t *Table[T]) Delete(pfx netip.Prefix) {
|
||||
}
|
||||
return
|
||||
}
|
||||
strideIndexes[strideIdx] = idx
|
||||
strideIndexes[strideIdx] = bs[byteIdx]
|
||||
strideTables[strideIdx+1] = child
|
||||
strideIdx++
|
||||
|
||||
@@ -475,7 +475,7 @@ func (t *Table[T]) Delete(pfx netip.Prefix) {
|
||||
if debugDelete {
|
||||
fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix)
|
||||
}
|
||||
strideTables[strideIdx-1].setChildByIndex(strideIndexes[strideIdx-1], child)
|
||||
strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child)
|
||||
return
|
||||
default:
|
||||
// This table has two or more children, so it's acting as a "fork in
|
||||
@@ -505,12 +505,12 @@ func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) {
|
||||
fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs)
|
||||
indent += 4
|
||||
st.treeDebugStringRec(w, 1, indent)
|
||||
for i := firstHostIndex; i <= lastHostIndex; i++ {
|
||||
if child := st.entries[i].child; child != nil {
|
||||
addr, len := inversePrefixIndex(i)
|
||||
fmt.Fprintf(w, "%s%d/%d (%02x/%d): ", strings.Repeat(" ", indent), addr, len, addr, len)
|
||||
strideSummary(w, child, indent)
|
||||
for addr, child := range st.children {
|
||||
if child == nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr)
|
||||
strideSummary(w, child, indent)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -607,7 +607,7 @@ func TestInsertCompare(t *testing.T) {
|
||||
seenVals4[fastVal] = true
|
||||
}
|
||||
if slowVal != fastVal {
|
||||
t.Errorf("get(%q) = %p, want %p", a, fastVal, slowVal)
|
||||
t.Fatalf("get(%q) = %p, want %p", a, fastVal, slowVal)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1092,11 +1092,12 @@ func (t *Table[T]) numStridesRec(seen map[*strideTable[T]]bool, st *strideTable[
|
||||
if st.childRefs == 0 {
|
||||
return ret
|
||||
}
|
||||
for i := firstHostIndex; i <= lastHostIndex; i++ {
|
||||
if c := st.entries[i].child; c != nil && !seen[c] {
|
||||
seen[c] = true
|
||||
ret += t.numStridesRec(seen, c)
|
||||
for _, c := range st.children {
|
||||
if c == nil || seen[c] {
|
||||
continue
|
||||
}
|
||||
seen[c] = true
|
||||
ret += t.numStridesRec(seen, c)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
@@ -106,13 +106,11 @@ type upnpClient interface {
|
||||
// It is not used for anything other than labelling.
|
||||
const tsPortMappingDesc = "tailscale-portmap"
|
||||
|
||||
// addAnyPortMapping abstracts over different UPnP client connections, calling
|
||||
// the available AddAnyPortMapping call if available for WAN IP connection v2,
|
||||
// otherwise picking either the previous port (if one is present) or a random
|
||||
// port and trying to obtain a mapping using AddPortMapping.
|
||||
//
|
||||
// It returns the new external port (which may not be identical to the external
|
||||
// port specified), or an error.
|
||||
// addAnyPortMapping abstracts over different UPnP client connections, calling the available
|
||||
// AddAnyPortMapping call if available for WAN IP connection v2, otherwise defaulting to the old
|
||||
// behavior of calling AddPortMapping with port = 0 to specify a wildcard port.
|
||||
// It returns the new external port (which may not be identical to the external port specified),
|
||||
// or an error.
|
||||
//
|
||||
// TODO(bradfitz): also returned the actual lease duration obtained. and check it regularly.
|
||||
func addAnyPortMapping(
|
||||
@@ -123,31 +121,6 @@ func addAnyPortMapping(
|
||||
internalClient string,
|
||||
leaseDuration time.Duration,
|
||||
) (newPort uint16, err error) {
|
||||
// Some devices don't let clients add a port mapping for privileged
|
||||
// ports (ports below 1024). Additionally, per section 2.3.18 of the
|
||||
// UPnP spec, regarding the ExternalPort field:
|
||||
//
|
||||
// If this value is specified as a wildcard (i.e. 0), connection
|
||||
// request on all external ports (that are not otherwise mapped)
|
||||
// will be forwarded to InternalClient. In the wildcard case, the
|
||||
// value(s) of InternalPort on InternalClient are ignored by the IGD
|
||||
// for those connections that are forwarded to InternalClient.
|
||||
// Obviously only one such entry can exist in the NAT at any time
|
||||
// and conflicts are handled with a “first write wins” behavior.
|
||||
//
|
||||
// We obviously do not want to open all ports on the user's device to
|
||||
// the internet, so we want to do this prior to calling either
|
||||
// AddAnyPortMapping or AddPortMapping.
|
||||
//
|
||||
// Pick an external port that's greater than 1024 by getting a random
|
||||
// number in [0, 65535 - 1024] and then adding 1024 to it, shifting the
|
||||
// range to [1024, 65535].
|
||||
if externalPort < 1024 {
|
||||
externalPort = uint16(rand.Intn(65535-1024) + 1024)
|
||||
}
|
||||
|
||||
// First off, try using AddAnyPortMapping; if there's a conflict, the
|
||||
// router will pick another port and return it.
|
||||
if upnp, ok := upnp.(*internetgateway2.WANIPConnection2); ok {
|
||||
return upnp.AddAnyPortMapping(
|
||||
ctx,
|
||||
@@ -162,8 +135,15 @@ func addAnyPortMapping(
|
||||
)
|
||||
}
|
||||
|
||||
// Fall back to using AddPortMapping, which requests a mapping to/from
|
||||
// a specific external port.
|
||||
// Some devices don't let clients add a port mapping for privileged
|
||||
// ports (ports below 1024).
|
||||
//
|
||||
// Pick an external port that's greater than 1024 by getting a random
|
||||
// number in [0, 65535 - 1024] and then adding 1024 to it, shifting the
|
||||
// range to [1024, 65535].
|
||||
if externalPort < 1024 {
|
||||
externalPort = uint16(rand.Intn(65535-1024) + 1024)
|
||||
}
|
||||
err = upnp.AddPortMapping(
|
||||
ctx,
|
||||
"",
|
||||
|
||||
@@ -734,12 +734,9 @@ type NetInfo struct {
|
||||
// the control plane.
|
||||
DERPLatency map[string]float64 `json:",omitempty"`
|
||||
|
||||
// FirewallMode encodes both which firewall mode was selected and why.
|
||||
// It is Linux-specific (at least as of 2023-08-19) and is meant to help
|
||||
// debug iptables-vs-nftables issues. The string is of the form
|
||||
// "{nft,ift}-REASON", like "nft-forced" or "ipt-default". Empty means
|
||||
// either not Linux or a configuration in which the host firewall rules
|
||||
// are not managed by tailscaled.
|
||||
// FirewallMode is the current firewall utility in use by router (iptables, nftables).
|
||||
// FirewallMode ipt means iptables, nft means nftables. When it's empty user is not using
|
||||
// our netfilter runners to manage firewall rules.
|
||||
FirewallMode string `json:",omitempty"`
|
||||
|
||||
// Update BasicallyEqual when adding fields.
|
||||
@@ -1406,6 +1403,8 @@ type DNSConfig struct {
|
||||
//
|
||||
// Matches are case insensitive.
|
||||
ExitNodeFilteredSet []string `json:",omitempty"`
|
||||
// DNSFilterURL contains a user inputed URL that should have a list of domains to be blocked
|
||||
DNSFilterURL string `json:",omitempty"`
|
||||
}
|
||||
|
||||
// DNSRecord is an extra DNS record to add to MagicDNS.
|
||||
|
||||
@@ -261,6 +261,7 @@ var _DNSConfigCloneNeedsRegeneration = DNSConfig(struct {
|
||||
CertDomains []string
|
||||
ExtraRecords []DNSRecord
|
||||
ExitNodeFilteredSet []string
|
||||
DNSFilterURL string
|
||||
}{})
|
||||
|
||||
// Clone makes a deep copy of RegisterResponse.
|
||||
|
||||
@@ -557,6 +557,7 @@ func (v DNSConfigView) ExtraRecords() views.Slice[DNSRecord] { return views.Slic
|
||||
func (v DNSConfigView) ExitNodeFilteredSet() views.Slice[string] {
|
||||
return views.SliceOf(v.ж.ExitNodeFilteredSet)
|
||||
}
|
||||
func (v DNSConfigView) DNSFilterURL() string { return v.ж.DNSFilterURL }
|
||||
|
||||
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
|
||||
var _DNSConfigViewNeedsRegeneration = DNSConfig(struct {
|
||||
@@ -569,6 +570,7 @@ var _DNSConfigViewNeedsRegeneration = DNSConfig(struct {
|
||||
CertDomains []string
|
||||
ExtraRecords []DNSRecord
|
||||
ExitNodeFilteredSet []string
|
||||
DNSFilterURL string
|
||||
}{})
|
||||
|
||||
// View returns a readonly view of RegisterResponse.
|
||||
|
||||
@@ -45,6 +45,12 @@ type AccessLogRecord struct {
|
||||
Bytes int `json:"bytes,omitempty"`
|
||||
// Error encountered during request processing.
|
||||
Err string `json:"err,omitempty"`
|
||||
// RequestID is a unique ID for this request. When a request fails due to an
|
||||
// error, the ID is generated and displayed to the client immediately after
|
||||
// the error text, as well as logged here. This makes it easier to correlate
|
||||
// support requests with server logs. If a RequestID generator is not
|
||||
// configured, RequestID will be empty.
|
||||
RequestID RequestID `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// String returns m as a JSON string.
|
||||
|
||||
@@ -169,7 +169,8 @@ type ReturnHandler interface {
|
||||
type HandlerOptions struct {
|
||||
QuietLoggingIfSuccessful bool // if set, do not log successfully handled HTTP requests (200 and 304 status codes)
|
||||
Logf logger.Logf
|
||||
Now func() time.Time // if nil, defaults to time.Now
|
||||
Now func() time.Time // if nil, defaults to time.Now
|
||||
GenerateRequestID func(*http.Request) RequestID // if nil, no request IDs are generated
|
||||
|
||||
// If non-nil, StatusCodeCounters maintains counters
|
||||
// of status codes for handled responses.
|
||||
@@ -266,6 +267,11 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
msg.Code = 499 // nginx convention: Client Closed Request
|
||||
msg.Err = context.Canceled.Error()
|
||||
case hErrOK:
|
||||
if hErr.RequestID == "" && h.opts.GenerateRequestID != nil {
|
||||
hErr.RequestID = h.opts.GenerateRequestID(r)
|
||||
}
|
||||
msg.RequestID = hErr.RequestID
|
||||
|
||||
// Handler asked us to send an error. Do so, if we haven't
|
||||
// already sent a response.
|
||||
msg.Err = hErr.Msg
|
||||
@@ -296,14 +302,24 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
lw.WriteHeader(msg.Code)
|
||||
fmt.Fprintln(lw, hErr.Msg)
|
||||
if hErr.RequestID != "" {
|
||||
fmt.Fprintln(lw, hErr.RequestID)
|
||||
}
|
||||
}
|
||||
case err != nil:
|
||||
const internalServerError = "internal server error"
|
||||
|
||||
errorMessage := internalServerError
|
||||
if h.opts.GenerateRequestID != nil {
|
||||
msg.RequestID = h.opts.GenerateRequestID(r)
|
||||
errorMessage = errorMessage + "\n" + string(msg.RequestID)
|
||||
}
|
||||
// Handler returned a generic error. Serve an internal server
|
||||
// error, if necessary.
|
||||
msg.Err = err.Error()
|
||||
if lw.code == 0 {
|
||||
msg.Code = http.StatusInternalServerError
|
||||
http.Error(lw, "internal server error", msg.Code)
|
||||
http.Error(lw, errorMessage, msg.Code)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,18 +414,44 @@ func (l loggingResponseWriter) Flush() {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// RequestID is an opaque identifier for a HTTP request, used to correlate
|
||||
// user-visible errors with backend server logs. If present in a HTTPError, the
|
||||
// RequestID will be printed alongside the message text and logged in the
|
||||
// AccessLogRecord. If an HTTPError has no RequestID (or a non-HTTPError error
|
||||
// is returned), but the StdHandler has a RequestID generator function, then a
|
||||
// RequestID will be generated before responding to the client and logging the
|
||||
// error.
|
||||
//
|
||||
// In the event that there is no ErrorHandlerFunc and a non-HTTPError is
|
||||
// returned to a StdHandler, the response body will be formatted like
|
||||
// "internal server error\n{RequestID}\n".
|
||||
//
|
||||
// There is no particular format required for a RequestID, but ideally it should
|
||||
// be obvious to an end-user that it is something to record for support
|
||||
// purposes. One possible example for a RequestID format is:
|
||||
// REQ-{server identifier}-{timestamp}-{random hex string}.
|
||||
type RequestID string
|
||||
|
||||
// HTTPError is an error with embedded HTTP response information.
|
||||
//
|
||||
// It is the error type to be (optionally) used by Handler.ServeHTTPReturn.
|
||||
type HTTPError struct {
|
||||
Code int // HTTP response code to send to client; 0 means 500
|
||||
Msg string // Response body to send to client
|
||||
Err error // Detailed error to log on the server
|
||||
Header http.Header // Optional set of HTTP headers to set in the response
|
||||
Code int // HTTP response code to send to client; 0 means 500
|
||||
Msg string // Response body to send to client
|
||||
Err error // Detailed error to log on the server
|
||||
RequestID RequestID // Optional identifier to connect client-visible errors with server logs
|
||||
Header http.Header // Optional set of HTTP headers to set in the response
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e HTTPError) Error() string { return fmt.Sprintf("httperror{%d, %q, %v}", e.Code, e.Msg, e.Err) }
|
||||
func (e HTTPError) Error() string {
|
||||
if e.RequestID != "" {
|
||||
return fmt.Sprintf("httperror{%d, %q, %v, RequestID=%q}", e.Code, e.Msg, e.Err, e.RequestID)
|
||||
} else {
|
||||
// Backwards compatibility
|
||||
return fmt.Sprintf("httperror{%d, %q, %v}", e.Code, e.Msg, e.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e HTTPError) Unwrap() error { return e.Err }
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ func (f handlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) err
|
||||
}
|
||||
|
||||
func TestStdHandler(t *testing.T) {
|
||||
const exampleRequestID = "example-request-id"
|
||||
var (
|
||||
handlerCode = func(code int) ReturnHandler {
|
||||
return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
@@ -66,16 +67,20 @@ func TestStdHandler(t *testing.T) {
|
||||
bgCtx = context.Background()
|
||||
// canceledCtx, cancel = context.WithCancel(bgCtx)
|
||||
startTime = time.Unix(1687870000, 1234)
|
||||
|
||||
setExampleRequestID = func(_ *http.Request) RequestID { return exampleRequestID }
|
||||
)
|
||||
// cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rh ReturnHandler
|
||||
r *http.Request
|
||||
errHandler ErrorHandlerFunc
|
||||
wantCode int
|
||||
wantLog AccessLogRecord
|
||||
name string
|
||||
rh ReturnHandler
|
||||
r *http.Request
|
||||
errHandler ErrorHandlerFunc
|
||||
generateRequestID func(*http.Request) RequestID
|
||||
wantCode int
|
||||
wantLog AccessLogRecord
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "handler returns 200",
|
||||
@@ -94,6 +99,24 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 200 with request ID",
|
||||
rh: handlerCode(200),
|
||||
r: req(bgCtx, "http://example.com/"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 200,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
TLS: false,
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
Code: 200,
|
||||
RequestURI: "/",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 404",
|
||||
rh: handlerCode(404),
|
||||
@@ -110,6 +133,23 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 404 with request ID",
|
||||
rh: handlerCode(404),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 404,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
RequestURI: "/foo",
|
||||
Code: 404,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 404 via HTTPError",
|
||||
rh: handlerErr(0, Error(404, "not found", testErr)),
|
||||
@@ -125,6 +165,27 @@ func TestStdHandler(t *testing.T) {
|
||||
Err: "not found: " + testErr.Error(),
|
||||
Code: 404,
|
||||
},
|
||||
wantBody: "not found\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 404 via HTTPError with request ID",
|
||||
rh: handlerErr(0, Error(404, "not found", testErr)),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 404,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
RequestURI: "/foo",
|
||||
Err: "not found: " + testErr.Error(),
|
||||
Code: 404,
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: "not found\n" + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
@@ -142,6 +203,27 @@ func TestStdHandler(t *testing.T) {
|
||||
Err: "not found",
|
||||
Code: 404,
|
||||
},
|
||||
wantBody: "not found\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns 404 with request ID and nil child error",
|
||||
rh: handlerErr(0, Error(404, "not found", nil)),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 404,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
RequestURI: "/foo",
|
||||
Err: "not found",
|
||||
Code: 404,
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: "not found\n" + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
@@ -159,6 +241,27 @@ func TestStdHandler(t *testing.T) {
|
||||
Err: "visible error",
|
||||
Code: 500,
|
||||
},
|
||||
wantBody: "visible error\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns user-visible error with request ID",
|
||||
rh: handlerErr(0, vizerror.New("visible error")),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 500,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
RequestURI: "/foo",
|
||||
Err: "visible error",
|
||||
Code: 500,
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: "visible error\n" + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
@@ -176,6 +279,27 @@ func TestStdHandler(t *testing.T) {
|
||||
Err: "visible error",
|
||||
Code: 500,
|
||||
},
|
||||
wantBody: "visible error\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns user-visible error wrapped by private error with request ID",
|
||||
rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 500,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
RequestURI: "/foo",
|
||||
Err: "visible error",
|
||||
Code: 500,
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: "visible error\n" + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
@@ -193,6 +317,27 @@ func TestStdHandler(t *testing.T) {
|
||||
Err: testErr.Error(),
|
||||
Code: 500,
|
||||
},
|
||||
wantBody: "internal server error\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns generic error with request ID",
|
||||
rh: handlerErr(0, testErr),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 500,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
RequestURI: "/foo",
|
||||
Err: testErr.Error(),
|
||||
Code: 500,
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: "internal server error\n" + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
@@ -212,6 +357,25 @@ func TestStdHandler(t *testing.T) {
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns error after writing response with request ID",
|
||||
rh: handlerErr(200, testErr),
|
||||
r: req(bgCtx, "http://example.com/foo"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 200,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
RequestURI: "/foo",
|
||||
Err: testErr.Error(),
|
||||
Code: 200,
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns HTTPError after writing response",
|
||||
rh: handlerErr(200, Error(404, "not found", testErr)),
|
||||
@@ -267,6 +431,7 @@ func TestStdHandler(t *testing.T) {
|
||||
Code: 101,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "error handler gets run",
|
||||
rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
|
||||
@@ -286,6 +451,62 @@ func TestStdHandler(t *testing.T) {
|
||||
Err: "not found",
|
||||
RequestURI: "/",
|
||||
},
|
||||
wantBody: "not found\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "error handler gets run with request ID",
|
||||
rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
|
||||
r: req(bgCtx, "http://example.com/"),
|
||||
generateRequestID: setExampleRequestID,
|
||||
wantCode: 200,
|
||||
errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
|
||||
http.Error(w, fmt.Sprintf("%s with request ID %s", e.Msg, e.RequestID), 200)
|
||||
},
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
TLS: false,
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
Code: 404,
|
||||
Err: "not found",
|
||||
RequestURI: "/",
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: "not found with request ID " + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "request ID can use information from request",
|
||||
rh: handlerErr(0, Error(400, "bad request", nil)),
|
||||
r: func() *http.Request {
|
||||
r := req(bgCtx, "http://example.com/")
|
||||
r.AddCookie(&http.Cookie{Name: "want_request_id", Value: "asdf1234"})
|
||||
return r
|
||||
}(),
|
||||
generateRequestID: func(r *http.Request) RequestID {
|
||||
c, _ := r.Cookie("want_request_id")
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return RequestID(c.Value)
|
||||
},
|
||||
wantCode: 400,
|
||||
wantLog: AccessLogRecord{
|
||||
When: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
TLS: false,
|
||||
Host: "example.com",
|
||||
RequestURI: "/",
|
||||
Method: "GET",
|
||||
Code: 400,
|
||||
Err: "bad request",
|
||||
RequestID: "asdf1234",
|
||||
},
|
||||
wantBody: "bad request\nasdf1234\n",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -305,7 +526,7 @@ func TestStdHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
rec := noopHijacker{httptest.NewRecorder(), false}
|
||||
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
|
||||
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, GenerateRequestID: test.generateRequestID, OnError: test.errHandler})
|
||||
h.ServeHTTP(&rec, test.r)
|
||||
res := rec.Result()
|
||||
if res.StatusCode != test.wantCode {
|
||||
@@ -324,6 +545,9 @@ func TestStdHandler(t *testing.T) {
|
||||
if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" {
|
||||
t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(rec.Body.String(), test.wantBody); diff != "" {
|
||||
t.Errorf("handler wrote incorrect body (-got+want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,34 @@ import (
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
type NetworkMapView struct {
|
||||
nm *NetworkMap
|
||||
}
|
||||
|
||||
type GUIPeerNode struct {
|
||||
ID tailcfg.StableNodeID
|
||||
BaseOrFQDNName string // no "." substring if in your tailnet, else FQDN
|
||||
Owner tailcfg.UserID // user or fake userid for tagged nodes
|
||||
IPv4 netip.Addr // may be be zero value (empty string in JSON)
|
||||
IPv6 netip.Addr // may be be zero value (empty string in JSON)
|
||||
MachineStatus tailcfg.MachineStatus
|
||||
Hostinfo GUIHostInfo
|
||||
|
||||
IsExitNode bool
|
||||
}
|
||||
|
||||
type GUIHostInfo struct {
|
||||
ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user
|
||||
}
|
||||
|
||||
type GUINetworkMap struct {
|
||||
SelfNode *tailcfg.Node // TODO: GUISelfNode ?
|
||||
Peers []*GUIPeerNode
|
||||
UserProfiles map[tailcfg.UserID]tailcfg.UserProfile
|
||||
|
||||
TKAEnabled bool
|
||||
}
|
||||
|
||||
// NetworkMap is the current state of the world.
|
||||
//
|
||||
// The fields should all be considered read-only. They might
|
||||
@@ -70,8 +98,6 @@ type NetworkMap struct {
|
||||
// hash of the latest update message to tick through TKA).
|
||||
TKAHead tka.AUMHash
|
||||
|
||||
// ACLs
|
||||
|
||||
User tailcfg.UserID
|
||||
|
||||
// Domain is the current Tailnet name.
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
@@ -27,16 +26,12 @@ const (
|
||||
chainNamePostrouting = "ts-postrouting"
|
||||
)
|
||||
|
||||
// chainTypeRegular is an nftables chain that does not apply to a hook.
|
||||
const chainTypeRegular = ""
|
||||
|
||||
type chainInfo struct {
|
||||
table *nftables.Table
|
||||
name string
|
||||
chainType nftables.ChainType
|
||||
chainHook *nftables.ChainHook
|
||||
chainPriority *nftables.ChainPriority
|
||||
chainPolicy *nftables.ChainPolicy
|
||||
}
|
||||
|
||||
type nftable struct {
|
||||
@@ -45,21 +40,6 @@ type nftable struct {
|
||||
Nat *nftables.Table
|
||||
}
|
||||
|
||||
// nftablesRunner implements a netfilterRunner using the netlink based nftables
|
||||
// library. As nftables allows for arbitrary tables and chains, there is a need
|
||||
// to follow conventions in order to integrate well with a surrounding
|
||||
// ecosystem. The rules installed by nftablesRunner have the following
|
||||
// properties:
|
||||
// - Install rules that intend to take precedence over rules installed by
|
||||
// other software. Tailscale provides packet filtering for tailnet traffic
|
||||
// inside the daemon based on the tailnet ACL rules.
|
||||
// - As nftables "accept" is not final, rules from high priority tables (low
|
||||
// numbers) will fall through to lower priority tables (high numbers). In
|
||||
// order to effectively be 'final', we install "jump" rules into conventional
|
||||
// tables and chains that will reach an accept verdict inside those tables.
|
||||
// - The table and chain conventions followed here are those used by
|
||||
// `iptables-nft` and `ufw`, so that those tools co-exist and do not
|
||||
// negatively affect Tailscale function.
|
||||
type nftablesRunner struct {
|
||||
conn *nftables.Conn
|
||||
nft4 *nftable
|
||||
@@ -136,11 +116,6 @@ func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Ch
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// isTSChain retruns true if the chain name starts with ts
|
||||
func isTSChain(name string) bool {
|
||||
return strings.HasPrefix(name, "ts-")
|
||||
}
|
||||
|
||||
// createChainIfNotExist creates a chain with the given name in the given table
|
||||
// if it does not exist.
|
||||
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
||||
@@ -148,11 +123,8 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
||||
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
|
||||
return fmt.Errorf("get chain: %w", err)
|
||||
} else if err == nil {
|
||||
// The chain already exists. If it is a TS chain, check the
|
||||
// type/hook/priority, but for "conventional chains" assume they're what
|
||||
// we expect (in case iptables-nft/ufw make minor behavior changes in
|
||||
// the future).
|
||||
if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) {
|
||||
// Chain already exists
|
||||
if chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority {
|
||||
return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
|
||||
}
|
||||
return nil
|
||||
@@ -164,7 +136,6 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
||||
Type: cinfo.chainType,
|
||||
Hooknum: cinfo.chainHook,
|
||||
Priority: cinfo.chainPriority,
|
||||
Policy: cinfo.chainPolicy,
|
||||
})
|
||||
|
||||
if err := c.Flush(); err != nil {
|
||||
@@ -257,10 +228,6 @@ ruleLoop:
|
||||
}
|
||||
|
||||
for i, e := range r.Exprs {
|
||||
// Skip counter expressions, as they will not match.
|
||||
if _, ok := e.(*expr.Counter); ok {
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(e, rule.Exprs[i]) {
|
||||
continue ruleLoop
|
||||
}
|
||||
@@ -421,49 +388,27 @@ func (n *nftablesRunner) getNATTables() []*nftable {
|
||||
// AddChains creates custom Tailscale chains in netfilter via nftables
|
||||
// if the ts-chain doesn't already exist.
|
||||
func (n *nftablesRunner) AddChains() error {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
for _, table := range n.getTables() {
|
||||
// Create the filter table if it doesn't exist, this table name is the same
|
||||
// as the name used by iptables-nft and ufw. We install rules into the
|
||||
// same conventional table so that `accept` verdicts from our jump
|
||||
// chains are conclusive.
|
||||
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
||||
filter, err := createTableIfNotExist(n.conn, table.Proto, "ts-filter")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create table: %w", err)
|
||||
}
|
||||
table.Filter = filter
|
||||
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityRef(-1)}); err != nil {
|
||||
return fmt.Errorf("create forward chain: %w", err)
|
||||
}
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
||||
return fmt.Errorf("create input chain: %w", err)
|
||||
}
|
||||
// Adding the tailscale chains that contain our rules.
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
return fmt.Errorf("create forward chain: %w", err)
|
||||
}
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityRef(-1)}); err != nil {
|
||||
return fmt.Errorf("create input chain: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, table := range n.getNATTables() {
|
||||
// Create the nat table if it doesn't exist, this table name is the same
|
||||
// as the name used by iptables-nft and ufw. We install rules into the
|
||||
// same conventional table so that `accept` verdicts from our jump
|
||||
// chains are conclusive.
|
||||
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
||||
nat, err := createTableIfNotExist(n.conn, table.Proto, "ts-nat")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create table: %w", err)
|
||||
}
|
||||
table.Nat = nat
|
||||
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
|
||||
return fmt.Errorf("create postrouting chain: %w", err)
|
||||
}
|
||||
// Adding the tailscale chain that contains our rules.
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
|
||||
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATDest}); err != nil {
|
||||
return fmt.Errorf("create postrouting chain: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -500,16 +445,19 @@ func (n *nftablesRunner) DelChains() error {
|
||||
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
|
||||
return fmt.Errorf("delete chain: %w", err)
|
||||
}
|
||||
n.conn.DelTable(table.Filter)
|
||||
}
|
||||
|
||||
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
|
||||
return fmt.Errorf("delete chain: %w", err)
|
||||
}
|
||||
n.conn.DelTable(n.nft4.Nat)
|
||||
|
||||
if n.v6NATAvailable {
|
||||
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
|
||||
return fmt.Errorf("delete chain: %w", err)
|
||||
}
|
||||
n.conn.DelTable(n.nft6.Nat)
|
||||
}
|
||||
|
||||
if err := n.conn.Flush(); err != nil {
|
||||
@@ -519,128 +467,15 @@ func (n *nftablesRunner) DelChains() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createHookRule creates a rule to jump from a hooked chain to a regular chain.
|
||||
func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule {
|
||||
exprs := []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: toChainName,
|
||||
},
|
||||
}
|
||||
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
Chain: fromChain,
|
||||
Exprs: exprs,
|
||||
}
|
||||
|
||||
return rule
|
||||
}
|
||||
|
||||
// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
|
||||
func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
|
||||
rule := createHookRule(table, fromChain, toChainName)
|
||||
_ = conn.InsertRule(rule)
|
||||
|
||||
if err := conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush add rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
|
||||
// in tables and jump from those chains to tailscale chains.
|
||||
// AddHooks is defined to satisfy the interface. NfTables does not require
|
||||
// AddHooks, since we don't have any default tables or chains in nftables.
|
||||
func (n *nftablesRunner) AddHooks() error {
|
||||
conn := n.conn
|
||||
|
||||
for _, table := range n.getTables() {
|
||||
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get INPUT chain: %w", err)
|
||||
}
|
||||
err = addHookRule(conn, table.Filter, inputChain, chainNameInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Addhook: %w", err)
|
||||
}
|
||||
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get FORWARD chain: %w", err)
|
||||
}
|
||||
err = addHookRule(conn, table.Filter, forwardChain, chainNameForward)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Addhook: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, table := range n.getNATTables() {
|
||||
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get INPUT chain: %w", err)
|
||||
}
|
||||
err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Addhook: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
|
||||
func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
|
||||
rule := createHookRule(table, fromChain, toChainName)
|
||||
existingRule, err := findRule(conn, rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to find hook rule: %w", err)
|
||||
}
|
||||
|
||||
if existingRule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = conn.DelRule(existingRule)
|
||||
|
||||
if err := conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush del hook rule: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
|
||||
// DelHooks is defined to satisfy the interface. NfTables does not require
|
||||
// DelHooks, since we don't have any default tables or chains in nftables.
|
||||
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
||||
conn := n.conn
|
||||
|
||||
for _, table := range n.getTables() {
|
||||
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get INPUT chain: %w", err)
|
||||
}
|
||||
err = delHookRule(conn, table.Filter, inputChain, chainNameInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delhook: %w", err)
|
||||
}
|
||||
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get FORWARD chain: %w", err)
|
||||
}
|
||||
err = delHookRule(conn, table.Filter, forwardChain, chainNameForward)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delhook: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, table := range n.getNATTables() {
|
||||
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get INPUT chain: %w", err)
|
||||
}
|
||||
err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delhook: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1118,62 +953,25 @@ func (n *nftablesRunner) DelSNATRule() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupChain removes a jump rule from hookChainName to tsChainName, and then
|
||||
// the entire chain tsChainName. Errors are logged, but attempts to remove both
|
||||
// the jump rule and chain continue even if one errors.
|
||||
func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
|
||||
// remove the jump first, before removing the jump destination.
|
||||
defaultChain, err := getChainFromTable(conn, table, hookChainName)
|
||||
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
||||
logf("cleanup: did not find default chain: %s", err)
|
||||
}
|
||||
if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
||||
// delete hook in convention chain
|
||||
_ = delHookRule(conn, table, defaultChain, tsChainName)
|
||||
}
|
||||
|
||||
tsChain, err := getChainFromTable(conn, table, tsChainName)
|
||||
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
|
||||
logf("cleanup: did not find ts-chain: %s", err)
|
||||
}
|
||||
|
||||
if tsChain != nil {
|
||||
// flush and delete ts-chain
|
||||
conn.FlushChain(tsChain)
|
||||
conn.DelChain(tsChain)
|
||||
err = conn.Flush()
|
||||
logf("cleanup: delete and flush chain %s: %s", tsChainName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// NfTablesCleanUp removes all Tailscale added nftables rules.
|
||||
// Any errors that occur are logged to the provided logf.
|
||||
func NfTablesCleanUp(logf logger.Logf) {
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
logf("cleanup: nftables connection: %s", err)
|
||||
logf("ERROR: nftables connection: %w", err)
|
||||
}
|
||||
|
||||
tables, err := conn.ListTables() // both v4 and v6
|
||||
if err != nil {
|
||||
logf("cleanup: list tables: %s", err)
|
||||
logf("ERROR: list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
// These table names were used briefly in 1.48.0.
|
||||
if table.Name == "ts-filter" || table.Name == "ts-nat" {
|
||||
conn.DelTable(table)
|
||||
if err := conn.Flush(); err != nil {
|
||||
logf("cleanup: flush delete table %s: %s", table.Name, err)
|
||||
logf("ERROR: flush table %s: %w", table.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
if table.Name == "filter" {
|
||||
cleanupChain(logf, conn, table, "INPUT", chainNameInput)
|
||||
cleanupChain(logf, conn, table, "FORWARD", chainNameForward)
|
||||
}
|
||||
if table.Name == "nat" {
|
||||
cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,48 +101,6 @@ func newTestConn(t *testing.T, want [][]byte) *nftables.Conn {
|
||||
return conn
|
||||
}
|
||||
|
||||
func TestInsertHookRule(t *testing.T) {
|
||||
proto := nftables.TableFamilyIPv4
|
||||
want := [][]byte{
|
||||
// batch begin
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
// nft add table ip ts-filter-test
|
||||
[]byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"),
|
||||
// nft add chain ip ts-filter-test ts-input-test { type filter hook input priority 0 \; }
|
||||
[]byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x03\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"),
|
||||
// nft add chain ip ts-filter-test ts-jumpto
|
||||
[]byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x0e\x00\x03\x00\x74\x73\x2d\x6a\x75\x6d\x70\x74\x6f\x00\x00\x00"),
|
||||
// nft add rule ip ts-filter-test ts-input-test counter jump ts-jumptp
|
||||
[]byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x02\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x70\x00\x04\x80\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x40\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x20\x00\x02\x80\x1c\x00\x02\x80\x08\x00\x01\x00\xff\xff\xff\xfd\x0e\x00\x02\x00\x74\x73\x2d\x6a\x75\x6d\x70\x74\x6f\x00\x00\x00"),
|
||||
// batch end
|
||||
[]byte("\x00\x00\x00\x0a"),
|
||||
}
|
||||
testConn := newTestConn(t, want)
|
||||
table := testConn.AddTable(&nftables.Table{
|
||||
Family: proto,
|
||||
Name: "ts-filter-test",
|
||||
})
|
||||
|
||||
fromchain := testConn.AddChain(&nftables.Chain{
|
||||
Name: "ts-input-test",
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
})
|
||||
|
||||
tochain := testConn.AddChain(&nftables.Chain{
|
||||
Name: "ts-jumpto",
|
||||
Table: table,
|
||||
})
|
||||
|
||||
err := addHookRule(testConn, table, fromchain, tochain.Name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestInsertLoopbackRule(t *testing.T) {
|
||||
proto := nftables.TableFamilyIPv4
|
||||
want := [][]byte{
|
||||
@@ -503,8 +461,8 @@ func TestAddAndDelNetfilterChains(t *testing.T) {
|
||||
t.Fatalf("list chains failed: %v", err)
|
||||
}
|
||||
|
||||
if len(chainsV4) != 6 {
|
||||
t.Fatalf("len(chainsV4) = %d, want 6", len(chainsV4))
|
||||
if len(chainsV4) != 3 {
|
||||
t.Fatalf("len(chainsV4) = %d, want 3", len(chainsV4))
|
||||
}
|
||||
|
||||
chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6)
|
||||
@@ -512,8 +470,8 @@ func TestAddAndDelNetfilterChains(t *testing.T) {
|
||||
t.Fatalf("list chains failed: %v", err)
|
||||
}
|
||||
|
||||
if len(chainsV6) != 6 {
|
||||
t.Fatalf("len(chainsV6) = %d, want 6", len(chainsV6))
|
||||
if len(chainsV6) != 3 {
|
||||
t.Fatalf("len(chainsV6) = %d, want 3", len(chainsV6))
|
||||
}
|
||||
|
||||
runner.DelChains()
|
||||
@@ -830,87 +788,3 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
|
||||
t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNFTAddAndDelHookRule(t *testing.T) {
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
|
||||
return
|
||||
}
|
||||
|
||||
conn := newSysConn(t)
|
||||
runner := newFakeNftablesRunner(t, conn)
|
||||
runner.AddChains()
|
||||
defer runner.DelChains()
|
||||
runner.AddHooks()
|
||||
|
||||
forwardChain, err := getChainFromTable(conn, runner.nft4.Filter, "FORWARD")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get forwardChain: %v", err)
|
||||
}
|
||||
|
||||
forwardChainRules, err := conn.GetRules(forwardChain.Table, forwardChain)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get rules: %v", err)
|
||||
}
|
||||
|
||||
if len(forwardChainRules) != 1 {
|
||||
t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules))
|
||||
}
|
||||
|
||||
inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get inputChain: %v", err)
|
||||
}
|
||||
|
||||
inputChainRules, err := conn.GetRules(inputChain.Table, inputChain)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get rules: %v", err)
|
||||
}
|
||||
|
||||
if len(inputChainRules) != 1 {
|
||||
t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules))
|
||||
}
|
||||
|
||||
postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get postroutingChain: %v", err)
|
||||
}
|
||||
|
||||
postroutingChainRules, err := conn.GetRules(postroutingChain.Table, postroutingChain)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get rules: %v", err)
|
||||
}
|
||||
|
||||
if len(postroutingChainRules) != 1 {
|
||||
t.Fatalf("expected 1 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
|
||||
}
|
||||
|
||||
runner.DelHooks(t.Logf)
|
||||
|
||||
forwardChainRules, err = conn.GetRules(forwardChain.Table, forwardChain)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get rules: %v", err)
|
||||
}
|
||||
|
||||
if len(forwardChainRules) != 0 {
|
||||
t.Fatalf("expected 0 rule in FORWARD chain, got %v", len(forwardChainRules))
|
||||
}
|
||||
|
||||
inputChainRules, err = conn.GetRules(inputChain.Table, inputChain)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get rules: %v", err)
|
||||
}
|
||||
|
||||
if len(inputChainRules) != 0 {
|
||||
t.Fatalf("expected 0 rule in INPUT chain, got %v", len(inputChainRules))
|
||||
}
|
||||
|
||||
postroutingChainRules, err = conn.GetRules(postroutingChain.Table, postroutingChain)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get rules: %v", err)
|
||||
}
|
||||
|
||||
if len(postroutingChainRules) != 0 {
|
||||
t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -665,7 +665,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
ep.noteRecvActivity(ipp)
|
||||
ep.noteRecvActivity()
|
||||
if stats := c.stats.Load(); stats != nil {
|
||||
stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n)
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ type endpoint struct {
|
||||
|
||||
heartBeatTimer *time.Timer // nil when idle
|
||||
lastSend mono.Time // last time there was outgoing packets sent to this peer (from wireguard-go)
|
||||
lastFullPing mono.Time // last time we pinged all disco or wireguard only endpoints
|
||||
lastFullPing mono.Time // last time we pinged all disco endpoints
|
||||
derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients)
|
||||
|
||||
bestAddr addrLatency // best non-DERP path; zero if none
|
||||
@@ -132,14 +132,6 @@ type endpointState struct {
|
||||
index int16 // index in nodecfg.Node.Endpoints; meaningless if lastGotPing non-zero
|
||||
}
|
||||
|
||||
// clear removes all derived / probed state from an endpointState.
|
||||
func (s *endpointState) clear() {
|
||||
*s = endpointState{
|
||||
index: s.index,
|
||||
lastGotPing: s.lastGotPing,
|
||||
}
|
||||
}
|
||||
|
||||
// pongHistoryCount is how many pongReply values we keep per endpointState
|
||||
const pongHistoryCount = 64
|
||||
|
||||
@@ -228,26 +220,14 @@ func (de *endpoint) initFakeUDPAddr() {
|
||||
|
||||
// noteRecvActivity records receive activity on de, and invokes
|
||||
// Conn.noteRecvActivity no more than once every 10s.
|
||||
func (de *endpoint) noteRecvActivity(ipp netip.AddrPort) {
|
||||
now := mono.Now()
|
||||
|
||||
// TODO(raggi): this probably applies relatively equally well to disco
|
||||
// managed endpoints, but that would be a less conservative change.
|
||||
if de.isWireguardOnly {
|
||||
de.mu.Lock()
|
||||
de.bestAddr.AddrPort = ipp
|
||||
de.bestAddrAt = now
|
||||
de.trustBestAddrUntil = now.Add(5 * time.Second)
|
||||
de.mu.Unlock()
|
||||
func (de *endpoint) noteRecvActivity() {
|
||||
if de.c.noteRecvActivity == nil {
|
||||
return
|
||||
}
|
||||
|
||||
now := mono.Now()
|
||||
elapsed := now.Sub(de.lastRecv.LoadAtomic())
|
||||
if elapsed > 10*time.Second {
|
||||
de.lastRecv.StoreAtomic(now)
|
||||
|
||||
if de.c.noteRecvActivity == nil {
|
||||
return
|
||||
}
|
||||
de.c.noteRecvActivity(de.publicKey)
|
||||
}
|
||||
}
|
||||
@@ -309,23 +289,11 @@ func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.Ad
|
||||
//
|
||||
// de.mu must be held.
|
||||
func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) {
|
||||
if len(de.endpointState) == 0 {
|
||||
de.c.logf("magicsock: addrForSendWireguardLocked: [unexpected] no candidates available for endpoint")
|
||||
return udpAddr, false
|
||||
}
|
||||
|
||||
// lowestLatency is a high duration initially, so we
|
||||
// can be sure we're going to have a duration lower than this
|
||||
// for the first latency retrieved.
|
||||
lowestLatency := time.Hour
|
||||
var oldestPing mono.Time
|
||||
for ipp, state := range de.endpointState {
|
||||
if oldestPing.IsZero() {
|
||||
oldestPing = state.lastPing
|
||||
} else if state.lastPing.Before(oldestPing) {
|
||||
oldestPing = state.lastPing
|
||||
}
|
||||
|
||||
if latency, ok := state.latencyLocked(); ok {
|
||||
if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() {
|
||||
// If we have the same latency,IPv6 is prioritized.
|
||||
@@ -336,25 +304,35 @@ func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.Add
|
||||
}
|
||||
}
|
||||
}
|
||||
needPing := len(de.endpointState) > 1 && now.Sub(oldestPing) > wireguardPingInterval
|
||||
|
||||
if !udpAddr.IsValid() {
|
||||
candidates := maps.Keys(de.endpointState)
|
||||
|
||||
// Randomly select an address to use until we retrieve latency information
|
||||
// and give it a short trustBestAddrUntil time so we avoid flapping between
|
||||
// addresses while waiting on latency information to be populated.
|
||||
udpAddr = candidates[rand.Intn(len(candidates))]
|
||||
if udpAddr.IsValid() {
|
||||
// Set trustBestAddrUntil to an hour, so we will
|
||||
// continue to use this address for a long period of time.
|
||||
de.bestAddr.AddrPort = udpAddr
|
||||
de.trustBestAddrUntil = now.Add(1 * time.Hour)
|
||||
return udpAddr, false
|
||||
}
|
||||
|
||||
candidates := maps.Keys(de.endpointState)
|
||||
if len(candidates) == 0 {
|
||||
de.c.logf("magicsock: addrForSendWireguardLocked: [unexpected] no candidates available for endpoint")
|
||||
return udpAddr, false
|
||||
}
|
||||
|
||||
// Randomly select an address to use until we retrieve latency information
|
||||
// and give it a short trustBestAddrUntil time so we avoid flapping between
|
||||
// addresses while waiting on latency information to be populated.
|
||||
udpAddr = candidates[rand.Intn(len(candidates))]
|
||||
de.bestAddr.AddrPort = udpAddr
|
||||
// Only extend trustBestAddrUntil by one second to avoid packet
|
||||
// reordering and/or CPU usage from random selection during the first
|
||||
// second. We should receive a response due to a WireGuard handshake in
|
||||
// less than one second in good cases, in which case this will be then
|
||||
// extended to 15 seconds.
|
||||
de.trustBestAddrUntil = now.Add(time.Second)
|
||||
return udpAddr, needPing
|
||||
if len(candidates) == 1 {
|
||||
// if we only have one address that we can send data too,
|
||||
// we should trust it for a longer period of time.
|
||||
de.trustBestAddrUntil = now.Add(1 * time.Hour)
|
||||
} else {
|
||||
de.trustBestAddrUntil = now.Add(15 * time.Second)
|
||||
}
|
||||
|
||||
return udpAddr, len(candidates) > 1
|
||||
}
|
||||
|
||||
// heartbeat is called every heartbeatInterval to keep the best UDP path alive,
|
||||
@@ -489,14 +467,6 @@ func (de *endpoint) send(buffs [][]byte) error {
|
||||
var err error
|
||||
if udpAddr.IsValid() {
|
||||
_, err = de.c.sendUDPBatch(udpAddr, buffs)
|
||||
|
||||
// If the error is known to indicate that the endpoint is no longer
|
||||
// usable, clear the endpoint statistics so that the next send will
|
||||
// re-evaluate the best endpoint.
|
||||
if err != nil && isBadEndpointErr(err) {
|
||||
de.noteBadEndpoint(udpAddr)
|
||||
}
|
||||
|
||||
// TODO(raggi): needs updating for accuracy, as in error conditions we may have partial sends.
|
||||
if stats := de.c.stats.Load(); err == nil && stats != nil {
|
||||
var txBytes int
|
||||
@@ -888,30 +858,6 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.T
|
||||
return false
|
||||
}
|
||||
|
||||
// clearBestAddrLocked clears the bestAddr and related fields such that future
|
||||
// packets will re-evaluate the best address to send to next.
|
||||
//
|
||||
// de.mu must be held.
|
||||
func (de *endpoint) clearBestAddrLocked() {
|
||||
de.bestAddr = addrLatency{}
|
||||
de.bestAddrAt = 0
|
||||
de.trustBestAddrUntil = 0
|
||||
}
|
||||
|
||||
// noteBadEndpoint marks ipp as a bad endpoint that would need to be
|
||||
// re-evaluated before future use, this should be called for example if a send
|
||||
// to ipp fails due to a host unreachable error or similar.
|
||||
func (de *endpoint) noteBadEndpoint(ipp netip.AddrPort) {
|
||||
de.mu.Lock()
|
||||
defer de.mu.Unlock()
|
||||
|
||||
de.clearBestAddrLocked()
|
||||
|
||||
if st, ok := de.endpointState[ipp]; ok {
|
||||
st.clear()
|
||||
}
|
||||
}
|
||||
|
||||
// noteConnectivityChange is called when connectivity changes enough
|
||||
// that we should question our earlier assumptions about which paths
|
||||
// work.
|
||||
@@ -919,11 +865,7 @@ func (de *endpoint) noteConnectivityChange() {
|
||||
de.mu.Lock()
|
||||
defer de.mu.Unlock()
|
||||
|
||||
de.clearBestAddrLocked()
|
||||
|
||||
for k := range de.endpointState {
|
||||
de.endpointState[k].clear()
|
||||
}
|
||||
de.trustBestAddrUntil = 0
|
||||
}
|
||||
|
||||
// handlePongConnLocked handles a Pong message (a reply to an earlier ping).
|
||||
@@ -1200,7 +1142,9 @@ func (de *endpoint) stopAndReset() {
|
||||
func (de *endpoint) resetLocked() {
|
||||
de.lastSend = 0
|
||||
de.lastFullPing = 0
|
||||
de.clearBestAddrLocked()
|
||||
de.bestAddr = addrLatency{}
|
||||
de.bestAddrAt = 0
|
||||
de.trustBestAddrUntil = 0
|
||||
for _, es := range de.endpointState {
|
||||
es.lastPing = 0
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !js && !wasm
|
||||
// +build !js,!wasm
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to
|
||||
// errors.Is while avoiding an allocation per call.
|
||||
var errHOSTUNREACH error = syscall.EHOSTUNREACH
|
||||
|
||||
// isBadEndpointErr checks if err is one which is known to report that an
|
||||
// endpoint can no longer be sent to. It is not exhaustive, and for unknown
|
||||
// errors always reports false.
|
||||
func isBadEndpointErr(err error) bool {
|
||||
return errors.Is(err, errHOSTUNREACH)
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build js || wasm
|
||||
// +build js wasm
|
||||
|
||||
package magicsock
|
||||
|
||||
// isBadEndpointErr checks if err is one which is known to report that an
|
||||
// endpoint can no longer be sent to. It is not exhaustive, but covers known
|
||||
// cases.
|
||||
func isBadEndpointErr(err error) bool {
|
||||
return false
|
||||
}
|
||||
@@ -1188,7 +1188,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache)
|
||||
cache.gen = de.numStopAndReset()
|
||||
ep = de
|
||||
}
|
||||
ep.noteRecvActivity(ipp)
|
||||
ep.noteRecvActivity()
|
||||
if stats := c.stats.Load(); stats != nil {
|
||||
stats.UpdateRxPhysical(ep.nodeAddr, ipp, len(b))
|
||||
}
|
||||
@@ -2607,11 +2607,6 @@ var (
|
||||
// resetting the counter, as the first pings likely didn't through
|
||||
// the firewall)
|
||||
discoPingInterval = 5 * time.Second
|
||||
|
||||
// wireguardPingInterval is the minimum time between pings to an endpoint.
|
||||
// Pings are only sent if we have not observed bidirectional traffic with an
|
||||
// endpoint in at least this duration.
|
||||
wireguardPingInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// indexSentinelDeleted is the temporary value that endpointState.index takes while
|
||||
|
||||
@@ -1210,11 +1210,11 @@ func Test32bitAlignment(t *testing.T) {
|
||||
t.Fatalf("endpoint.lastRecv is not 8-byte aligned")
|
||||
}
|
||||
|
||||
de.noteRecvActivity(netip.AddrPort{}) // verify this doesn't panic on 32-bit
|
||||
de.noteRecvActivity() // verify this doesn't panic on 32-bit
|
||||
if called != 1 {
|
||||
t.Fatal("expected call to noteRecvActivity")
|
||||
}
|
||||
de.noteRecvActivity(netip.AddrPort{})
|
||||
de.noteRecvActivity()
|
||||
if called != 1 {
|
||||
t.Error("expected no second call to noteRecvActivity")
|
||||
}
|
||||
@@ -2668,7 +2668,6 @@ func newPingResponder(t *testing.T) *pingResponder {
|
||||
|
||||
func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
|
||||
testTime := mono.Now()
|
||||
secondPingTime := testTime.Add(10 * time.Second)
|
||||
|
||||
type endpointDetails struct {
|
||||
addrPort netip.AddrPort
|
||||
@@ -2676,79 +2675,16 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
wgTests := []struct {
|
||||
name string
|
||||
sendInitialPing bool
|
||||
validAddr bool
|
||||
sendFollowUpPing bool
|
||||
pingTime mono.Time
|
||||
ep []endpointDetails
|
||||
want netip.AddrPort
|
||||
name string
|
||||
noV4 bool
|
||||
noV6 bool
|
||||
sendWGPing bool
|
||||
ep []endpointDetails
|
||||
want netip.AddrPort
|
||||
}{
|
||||
{
|
||||
name: "no endpoints",
|
||||
sendInitialPing: false,
|
||||
validAddr: false,
|
||||
sendFollowUpPing: false,
|
||||
pingTime: testTime,
|
||||
ep: []endpointDetails{},
|
||||
want: netip.AddrPort{},
|
||||
},
|
||||
{
|
||||
name: "singular endpoint does not request ping",
|
||||
sendInitialPing: false,
|
||||
validAddr: true,
|
||||
sendFollowUpPing: false,
|
||||
pingTime: testTime,
|
||||
ep: []endpointDetails{
|
||||
{
|
||||
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
|
||||
latency: 100 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
want: netip.MustParseAddrPort("1.1.1.1:111"),
|
||||
},
|
||||
{
|
||||
name: "ping sent within wireguardPingInterval should not request ping",
|
||||
sendInitialPing: true,
|
||||
validAddr: true,
|
||||
sendFollowUpPing: false,
|
||||
pingTime: testTime.Add(7 * time.Second),
|
||||
ep: []endpointDetails{
|
||||
{
|
||||
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
|
||||
latency: 100 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"),
|
||||
latency: 2000 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
want: netip.MustParseAddrPort("1.1.1.1:111"),
|
||||
},
|
||||
{
|
||||
name: "ping sent outside of wireguardPingInterval should request ping",
|
||||
sendInitialPing: true,
|
||||
validAddr: true,
|
||||
sendFollowUpPing: true,
|
||||
pingTime: testTime.Add(3 * time.Second),
|
||||
ep: []endpointDetails{
|
||||
{
|
||||
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
|
||||
latency: 100 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"),
|
||||
latency: 150 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
want: netip.MustParseAddrPort("1.1.1.1:111"),
|
||||
},
|
||||
{
|
||||
name: "choose lowest latency for useable IPv4 and IPv6",
|
||||
sendInitialPing: true,
|
||||
validAddr: true,
|
||||
sendFollowUpPing: false,
|
||||
pingTime: secondPingTime,
|
||||
name: "choose lowest latency for useable IPv4 and IPv6",
|
||||
sendWGPing: true,
|
||||
ep: []endpointDetails{
|
||||
{
|
||||
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
|
||||
@@ -2762,11 +2698,8 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
|
||||
want: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"),
|
||||
},
|
||||
{
|
||||
name: "choose IPv6 address when latency is the same for v4 and v6",
|
||||
sendInitialPing: true,
|
||||
validAddr: true,
|
||||
sendFollowUpPing: false,
|
||||
pingTime: secondPingTime,
|
||||
name: "choose IPv6 address when latency is the same for v4 and v6",
|
||||
sendWGPing: true,
|
||||
ep: []endpointDetails{
|
||||
{
|
||||
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
|
||||
@@ -2782,57 +2715,52 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range wgTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
endpoint := &endpoint{
|
||||
isWireguardOnly: true,
|
||||
endpointState: map[netip.AddrPort]*endpointState{},
|
||||
c: &Conn{
|
||||
logf: t.Logf,
|
||||
noV4: atomic.Bool{},
|
||||
noV6: atomic.Bool{},
|
||||
},
|
||||
endpoint := &endpoint{
|
||||
isWireguardOnly: true,
|
||||
endpointState: map[netip.AddrPort]*endpointState{},
|
||||
c: &Conn{
|
||||
noV4: atomic.Bool{},
|
||||
noV6: atomic.Bool{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, epd := range test.ep {
|
||||
endpoint.endpointState[epd.addrPort] = &endpointState{}
|
||||
}
|
||||
|
||||
udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime)
|
||||
if !udpAddr.IsValid() {
|
||||
t.Error("udpAddr returned is not valid")
|
||||
}
|
||||
if shouldPing != test.sendWGPing {
|
||||
t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendWGPing)
|
||||
}
|
||||
|
||||
for _, epd := range test.ep {
|
||||
state, ok := endpoint.endpointState[epd.addrPort]
|
||||
if !ok {
|
||||
t.Errorf("addr does not exist in endpoint state map")
|
||||
}
|
||||
|
||||
for _, epd := range test.ep {
|
||||
endpoint.endpointState[epd.addrPort] = &endpointState{}
|
||||
}
|
||||
udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime)
|
||||
if udpAddr.IsValid() != test.validAddr {
|
||||
t.Errorf("udpAddr validity is incorrect; got %v, want %v", udpAddr.IsValid(), test.validAddr)
|
||||
}
|
||||
if shouldPing != test.sendInitialPing {
|
||||
t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendInitialPing)
|
||||
latency, ok := state.latencyLocked()
|
||||
if ok {
|
||||
t.Errorf("latency was set for %v: %v", epd.addrPort, latency)
|
||||
}
|
||||
state.recentPongs = append(state.recentPongs, pongReply{
|
||||
latency: epd.latency,
|
||||
})
|
||||
state.recentPong = 0
|
||||
}
|
||||
|
||||
// Update the endpointState to simulate a ping having been
|
||||
// sent and a pong received.
|
||||
for _, epd := range test.ep {
|
||||
state, ok := endpoint.endpointState[epd.addrPort]
|
||||
if !ok {
|
||||
t.Errorf("addr does not exist in endpoint state map")
|
||||
}
|
||||
state.lastPing = test.pingTime
|
||||
|
||||
latency, ok := state.latencyLocked()
|
||||
if ok {
|
||||
t.Errorf("latency was set for %v: %v", epd.addrPort, latency)
|
||||
}
|
||||
state.recentPongs = append(state.recentPongs, pongReply{
|
||||
latency: epd.latency,
|
||||
})
|
||||
state.recentPong = 0
|
||||
}
|
||||
|
||||
udpAddr, _, shouldPing = endpoint.addrForSendLocked(secondPingTime)
|
||||
if udpAddr != test.want {
|
||||
t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want)
|
||||
}
|
||||
if shouldPing != test.sendFollowUpPing {
|
||||
t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendFollowUpPing)
|
||||
}
|
||||
if endpoint.bestAddr.AddrPort != test.want {
|
||||
t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want)
|
||||
}
|
||||
})
|
||||
udpAddr, _, shouldPing = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute))
|
||||
if udpAddr != test.want {
|
||||
t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want)
|
||||
}
|
||||
if shouldPing {
|
||||
t.Error("addrForSendLocked should not indicate ping is required")
|
||||
}
|
||||
if endpoint.bestAddr.AddrPort != test.want {
|
||||
t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,32 +85,41 @@ func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMod
|
||||
iptAva, nftAva := true, true
|
||||
iptRuleCount, err := det.iptDetect()
|
||||
if err != nil {
|
||||
logf("detect iptables rule: %v", err)
|
||||
logf("router: detect iptables rule: %v", err)
|
||||
iptAva = false
|
||||
}
|
||||
nftRuleCount, err := det.nftDetect()
|
||||
if err != nil {
|
||||
logf("detect nftables rule: %v", err)
|
||||
logf("router: detect nftables rule: %v", err)
|
||||
nftAva = false
|
||||
}
|
||||
logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount)
|
||||
logf("router: nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount)
|
||||
switch {
|
||||
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables":
|
||||
// TODO(KevinLiang10): Updates to a flag
|
||||
logf("router: envknob TS_DEBUG_FIREWALL_MODE=nftables set")
|
||||
hostinfo.SetFirewallMode("nft-forced")
|
||||
return linuxfw.FirewallModeNfTables
|
||||
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables":
|
||||
logf("router: envknob TS_DEBUG_FIREWALL_MODE=iptables set")
|
||||
hostinfo.SetFirewallMode("ipt-forced")
|
||||
return linuxfw.FirewallModeIPTables
|
||||
case nftRuleCount > 0 && iptRuleCount == 0:
|
||||
logf("nftables is currently in use")
|
||||
logf("router: nftables is currently in use")
|
||||
hostinfo.SetFirewallMode("nft-inuse")
|
||||
return linuxfw.FirewallModeNfTables
|
||||
case iptRuleCount > 0 && nftRuleCount == 0:
|
||||
logf("iptables is currently in use")
|
||||
logf("router: iptables is currently in use")
|
||||
hostinfo.SetFirewallMode("ipt-inuse")
|
||||
return linuxfw.FirewallModeIPTables
|
||||
case nftAva:
|
||||
// if both iptables and nftables are available but
|
||||
// neither/both are currently used, use nftables.
|
||||
logf("nftables is available")
|
||||
logf("router: nftables is available")
|
||||
hostinfo.SetFirewallMode("nft")
|
||||
return linuxfw.FirewallModeNfTables
|
||||
case iptAva:
|
||||
logf("iptables is available")
|
||||
logf("router: iptables is available")
|
||||
hostinfo.SetFirewallMode("ipt")
|
||||
return linuxfw.FirewallModeIPTables
|
||||
default:
|
||||
@@ -127,44 +136,18 @@ func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMod
|
||||
// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set.
|
||||
func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
|
||||
tableDetector := &linuxFWDetector{}
|
||||
var mode linuxfw.FirewallMode
|
||||
|
||||
// We now use iptables as default and have "auto" and "nftables" as
|
||||
// options for people to test further.
|
||||
switch {
|
||||
case distro.Get() == distro.Gokrazy:
|
||||
// Reduce startup logging on gokrazy. There's no way to do iptables on
|
||||
// gokrazy anyway.
|
||||
logf("GoKrazy should use nftables.")
|
||||
hostinfo.SetFirewallMode("nft-gokrazy")
|
||||
mode = linuxfw.FirewallModeNfTables
|
||||
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables":
|
||||
logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set")
|
||||
hostinfo.SetFirewallMode("nft-forced")
|
||||
mode = linuxfw.FirewallModeNfTables
|
||||
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto":
|
||||
mode = chooseFireWallMode(logf, tableDetector)
|
||||
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables":
|
||||
logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set")
|
||||
hostinfo.SetFirewallMode("ipt-forced")
|
||||
mode = linuxfw.FirewallModeIPTables
|
||||
default:
|
||||
logf("default choosing iptables")
|
||||
hostinfo.SetFirewallMode("ipt-default")
|
||||
mode = linuxfw.FirewallModeIPTables
|
||||
}
|
||||
|
||||
mode := chooseFireWallMode(logf, tableDetector)
|
||||
var nfr netfilterRunner
|
||||
var err error
|
||||
switch mode {
|
||||
case linuxfw.FirewallModeIPTables:
|
||||
logf("using iptables")
|
||||
logf("router: using iptables")
|
||||
nfr, err = linuxfw.NewIPTablesRunner(logf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case linuxfw.FirewallModeNfTables:
|
||||
logf("using nftables")
|
||||
logf("router: using nftables")
|
||||
nfr, err = linuxfw.NewNfTablesRunner(logf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user