Compare commits
1 Commits
awly/cli-j
...
nickkhyl/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cab0e1a6f7 |
@@ -10,7 +10,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw
|
||||
W 💣 github.com/dblohm7/wingoes from tailscale.com/util/winutil
|
||||
github.com/fxamacker/cbor/v2 from tailscale.com/tka
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt+
|
||||
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+
|
||||
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+
|
||||
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+
|
||||
@@ -146,9 +146,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
tailscale.com/util/cloudenv from tailscale.com/hostinfo+
|
||||
W tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy
|
||||
tailscale.com/util/ctxkey from tailscale.com/tsweb+
|
||||
💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting
|
||||
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
|
||||
tailscale.com/util/dnsname from tailscale.com/hostinfo+
|
||||
tailscale.com/util/fastuuid from tailscale.com/tsweb
|
||||
💣 tailscale.com/util/hashx from tailscale.com/util/deephash
|
||||
tailscale.com/util/httpm from tailscale.com/client/tailscale
|
||||
tailscale.com/util/lineread from tailscale.com/hostinfo+
|
||||
L tailscale.com/util/linuxfw from tailscale.com/net/netns
|
||||
@@ -159,8 +161,17 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
tailscale.com/util/singleflight from tailscale.com/net/dnscache
|
||||
tailscale.com/util/slicesx from tailscale.com/cmd/derper+
|
||||
tailscale.com/util/syspolicy from tailscale.com/ipn
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/testenv from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/version from tailscale.com/derp+
|
||||
tailscale.com/version/distro from tailscale.com/envknob+
|
||||
@@ -180,6 +191,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box
|
||||
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
|
||||
W golang.org/x/exp/constraints from tailscale.com/util/winutil
|
||||
golang.org/x/exp/maps from tailscale.com/util/syspolicy/internal/metrics+
|
||||
L golang.org/x/net/bpf from github.com/mdlayher/netlink+
|
||||
golang.org/x/net/dns/dnsmessage from net+
|
||||
golang.org/x/net/http/httpguts from net/http
|
||||
@@ -240,7 +252,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
encoding/pem from crypto/tls+
|
||||
errors from bufio+
|
||||
expvar from github.com/prometheus/client_golang/prometheus+
|
||||
flag from tailscale.com/cmd/derper
|
||||
flag from tailscale.com/cmd/derper+
|
||||
fmt from compress/flate+
|
||||
go/token from google.golang.org/protobuf/internal/strs
|
||||
hash from crypto+
|
||||
@@ -273,7 +285,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
os from crypto/rand+
|
||||
os/exec from github.com/coreos/go-iptables/iptables+
|
||||
os/signal from tailscale.com/cmd/derper
|
||||
W os/user from tailscale.com/util/winutil
|
||||
W os/user from tailscale.com/util/winutil+
|
||||
path from github.com/prometheus/client_golang/prometheus/internal+
|
||||
path/filepath from crypto/x509+
|
||||
reflect from crypto/x509+
|
||||
|
||||
@@ -96,7 +96,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
💣 github.com/fsnotify/fsnotify from sigs.k8s.io/controller-runtime/pkg/certwatcher
|
||||
github.com/fxamacker/cbor/v2 from tailscale.com/tka
|
||||
github.com/gaissmai/bart from tailscale.com/net/ipset+
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt+
|
||||
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+
|
||||
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+
|
||||
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json/jsontext+
|
||||
@@ -803,6 +803,13 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/util/singleflight from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/slicesx from tailscale.com/appc+
|
||||
tailscale.com/util/syspolicy from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/util/systemd from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/testenv from tailscale.com/control/controlclient+
|
||||
@@ -811,7 +818,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
|
||||
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+
|
||||
|
||||
@@ -9,7 +9,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/pe+
|
||||
W 💣 github.com/dblohm7/wingoes/pe from tailscale.com/util/winutil/authenticode
|
||||
github.com/fxamacker/cbor/v2 from tailscale.com/tka
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt+
|
||||
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+
|
||||
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+
|
||||
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+
|
||||
@@ -152,9 +152,11 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
tailscale.com/util/cloudenv from tailscale.com/net/dnscache+
|
||||
tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+
|
||||
tailscale.com/util/ctxkey from tailscale.com/types/logger
|
||||
💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting
|
||||
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
|
||||
tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/util/groupmember from tailscale.com/client/web
|
||||
💣 tailscale.com/util/hashx from tailscale.com/util/deephash
|
||||
tailscale.com/util/httpm from tailscale.com/client/tailscale+
|
||||
tailscale.com/util/lineread from tailscale.com/hostinfo+
|
||||
L tailscale.com/util/linuxfw from tailscale.com/net/netns
|
||||
@@ -167,11 +169,19 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
tailscale.com/util/singleflight from tailscale.com/net/dnscache+
|
||||
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
|
||||
tailscale.com/util/syspolicy from tailscale.com/ipn
|
||||
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/version from tailscale.com/client/web+
|
||||
tailscale.com/version/distro from tailscale.com/client/web+
|
||||
@@ -191,7 +201,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12
|
||||
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
|
||||
W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+
|
||||
golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli
|
||||
golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli+
|
||||
golang.org/x/net/bpf from github.com/mdlayher/netlink+
|
||||
golang.org/x/net/dns/dnsmessage from net+
|
||||
golang.org/x/net/http/httpguts from net/http+
|
||||
|
||||
@@ -90,7 +90,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
💣 github.com/djherbis/times from tailscale.com/drive/driveimpl
|
||||
github.com/fxamacker/cbor/v2 from tailscale.com/tka
|
||||
github.com/gaissmai/bart from tailscale.com/net/tstun+
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt+
|
||||
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+
|
||||
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+
|
||||
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json/jsontext+
|
||||
@@ -395,6 +395,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/util/singleflight from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
|
||||
tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/util/systemd from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+
|
||||
@@ -403,7 +410,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
|
||||
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+
|
||||
|
||||
@@ -52,6 +52,8 @@ import (
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/wgcfg"
|
||||
@@ -2546,6 +2548,14 @@ func TestPreferencePolicyInfo(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
definitions := make([]*setting.Definition, 0, len(preferencePolicies)+1)
|
||||
definitions = append(definitions, must.Get(syspolicy.WellKnownSettingDefinition(syspolicy.ControlURL)))
|
||||
for _, pp := range preferencePolicies {
|
||||
definitions = append(definitions, must.Get(syspolicy.WellKnownSettingDefinition(pp.key)))
|
||||
}
|
||||
if err := setting.SetDefinitionsForTest(t, definitions...); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
for _, pp := range preferencePolicies {
|
||||
t.Run(string(pp.key), func(t *testing.T) {
|
||||
var h syspolicy.Handler
|
||||
@@ -2572,7 +2582,7 @@ func TestPreferencePolicyInfo(t *testing.T) {
|
||||
msh.stringPolicies[pp.key] = &tt.policyValue
|
||||
h = msh
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, h)
|
||||
rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, syspolicy.WrapHandler(h))
|
||||
|
||||
prefs := defaultPrefs.AsStruct()
|
||||
pp.set(prefs, tt.initialValue)
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
|
||||
// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
|
||||
// otherwise the actual error is returned and the next read for that key will retry using the handler.
|
||||
type CachingHandler struct {
|
||||
mu sync.Mutex
|
||||
strings map[string]string
|
||||
uint64s map[string]uint64
|
||||
bools map[string]bool
|
||||
strArrs map[string][]string
|
||||
notFound map[string]bool
|
||||
handler Handler
|
||||
}
|
||||
|
||||
// NewCachingHandler creates a CachingHandler given a handler.
|
||||
func NewCachingHandler(handler Handler) *CachingHandler {
|
||||
return &CachingHandler{
|
||||
handler: handler,
|
||||
strings: make(map[string]string),
|
||||
uint64s: make(map[string]uint64),
|
||||
bools: make(map[string]bool),
|
||||
strArrs: make(map[string][]string),
|
||||
notFound: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadString reads the policy settings value string given the key.
|
||||
// ReadString first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadString(key string) (string, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.strings[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return "", ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadString(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return "", err
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ch.strings[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadUInt64 reads the policy settings uint64 value given the key.
|
||||
// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.uint64s[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return 0, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadUInt64(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return 0, err
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ch.uint64s[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadBoolean reads the policy settings boolean value given the key.
|
||||
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.bools[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return false, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadBoolean(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return false, err
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
ch.bools[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadBoolean reads the policy settings boolean value given the key.
|
||||
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.strArrs[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadStringArray(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.strArrs[key] = val
|
||||
return val, nil
|
||||
}
|
||||
@@ -1,262 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandlerReadString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue string
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue string
|
||||
wantErr error
|
||||
strings map[string]string
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
strings: map[string]string{"test": "foo"},
|
||||
wantValue: "foo",
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: "foo",
|
||||
wantValue: "foo",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
s: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.strings != nil {
|
||||
cache.strings = tt.strings
|
||||
}
|
||||
got, err := cache.ReadString(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadString(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerReadUint64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue uint64
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue uint64
|
||||
wantErr error
|
||||
uint64s map[string]uint64
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
uint64s: map[string]uint64{"test": 1},
|
||||
wantValue: 1,
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: 1,
|
||||
wantValue: 1,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
u64: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.uint64s != nil {
|
||||
cache.uint64s = tt.uint64s
|
||||
}
|
||||
got, err := cache.ReadUInt64(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadUInt64(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestHandlerReadBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue bool
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue bool
|
||||
wantErr error
|
||||
bools map[string]bool
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
bools: map[string]bool{"test": true},
|
||||
wantValue: true,
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: true,
|
||||
wantValue: true,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
b: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.bools != nil {
|
||||
cache.bools = tt.bools
|
||||
}
|
||||
got, err := cache.ReadBoolean(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadBoolean(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,16 +4,15 @@
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
handlerUsed atomic.Bool
|
||||
handler Handler = defaultHandler{}
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
// Handler reads system policies from OS-specific storage.
|
||||
//
|
||||
// Deprecated: implementing a [Store] should be preferred.
|
||||
type Handler interface {
|
||||
// ReadString reads the policy setting's string value for the given key.
|
||||
// It should return ErrNoSuchKey if the key does not have a value set.
|
||||
@@ -29,55 +28,81 @@ type Handler interface {
|
||||
ReadStringArray(key string) ([]string, error)
|
||||
}
|
||||
|
||||
// ErrNoSuchKey is returned by a Handler when the specified key does not have a
|
||||
// value set.
|
||||
var ErrNoSuchKey = errors.New("no such key")
|
||||
|
||||
// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple.
|
||||
type defaultHandler struct{}
|
||||
|
||||
func (defaultHandler) ReadString(_ string) (string, error) {
|
||||
return "", ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadUInt64(_ string) (uint64, error) {
|
||||
return 0, ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadBoolean(_ string) (bool, error) {
|
||||
return false, ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadStringArray(_ string) ([]string, error) {
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
|
||||
// markHandlerInUse is called before handler methods are called.
|
||||
func markHandlerInUse() {
|
||||
handlerUsed.Store(true)
|
||||
}
|
||||
|
||||
// RegisterHandler initializes the policy handler and ensures registration will happen once.
|
||||
// RegisterHandler wraps and registers the specified handler as the device's
|
||||
// policy [Store] for the program's lifetime.
|
||||
//
|
||||
// Deprecated: using [RegisterStore] should be preferred.
|
||||
func RegisterHandler(h Handler) {
|
||||
// Technically this assignment is not concurrency safe, but in the
|
||||
// event that there was any risk of a data race, we will panic due to
|
||||
// the CompareAndSwap failing.
|
||||
handler = h
|
||||
if !handlerUsed.CompareAndSwap(false, true) {
|
||||
panic("handler was already used before registration")
|
||||
}
|
||||
rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h))
|
||||
}
|
||||
|
||||
// TB is a subset of testing.TB that we use to set up test helpers.
|
||||
// It's defined here to avoid pulling in the testing package.
|
||||
type TB interface {
|
||||
Helper()
|
||||
Cleanup(func())
|
||||
type TB = internal.TB
|
||||
|
||||
// SetHandlerForTest wraps and sets the specified handler as the device's policy
|
||||
// [Store] for the duration of tb.
|
||||
//
|
||||
// Deprecated: using [resultant.RegisterStoreForTest] should be preferred.
|
||||
func SetHandlerForTest(tb TB, h Handler) {
|
||||
if err := setWellKnownSettingsForTest(tb); err != nil {
|
||||
tb.Fatalf("setWellKnownSettingsForTest failed: %v", err)
|
||||
}
|
||||
rsop.RegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.CurrentScope(), WrapHandler(h))
|
||||
}
|
||||
|
||||
func SetHandlerForTest(tb TB, h Handler) {
|
||||
tb.Helper()
|
||||
oldHandler := handler
|
||||
handler = h
|
||||
tb.Cleanup(func() { handler = oldHandler })
|
||||
var _ source.Store = (*handlerStore)(nil)
|
||||
|
||||
// handlerStore is a [source.Store] that calls the underlying [Handler].
|
||||
// TODO(nickkhyl): remove it when the corp and android repos are updated.
|
||||
type handlerStore struct {
|
||||
h Handler
|
||||
}
|
||||
|
||||
// WrapHandler returns a [source.Store] that wraps the specified [Handler].
|
||||
func WrapHandler(h Handler) source.Store {
|
||||
return handlerStore{h}
|
||||
}
|
||||
|
||||
func (s handlerStore) Lock() error {
|
||||
if lockable, ok := s.h.(source.Lockable); ok {
|
||||
return lockable.Lock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s handlerStore) Unlock() {
|
||||
if lockable, ok := s.h.(source.Lockable); ok {
|
||||
lockable.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
|
||||
if lockable, ok := s.h.(source.Changeable); ok {
|
||||
return lockable.RegisterChangeCallback(callback)
|
||||
}
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
func (s handlerStore) ReadString(key setting.Key) (string, error) {
|
||||
return s.h.ReadString(string(key))
|
||||
}
|
||||
|
||||
func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) {
|
||||
return s.h.ReadUInt64(string(key))
|
||||
}
|
||||
|
||||
func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) {
|
||||
return s.h.ReadBoolean(string(key))
|
||||
}
|
||||
|
||||
func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) {
|
||||
return s.h.ReadStringArray(string(key))
|
||||
}
|
||||
|
||||
func (s handlerStore) Done() <-chan struct{} {
|
||||
if expirable, ok := s.h.(source.Expirable); ok {
|
||||
return expirable.Done()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultHandlerReadValues(t *testing.T) {
|
||||
var h defaultHandler
|
||||
|
||||
got, err := h.ReadString(string(AdminConsoleVisibility))
|
||||
if got != "" || err != ErrNoSuchKey {
|
||||
t.Fatalf("got %v err %v", got, err)
|
||||
}
|
||||
result, err := h.ReadUInt64(string(LogSCMInteractions))
|
||||
if result != 0 || err != ErrNoSuchKey {
|
||||
t.Fatalf("got %v err %v", result, err)
|
||||
}
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/winutil"
|
||||
)
|
||||
|
||||
var (
|
||||
windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors")
|
||||
windowsAny = clientmetric.NewGauge("windows_syspolicy_any")
|
||||
)
|
||||
|
||||
type windowsHandler struct{}
|
||||
|
||||
func init() {
|
||||
RegisterHandler(NewCachingHandler(windowsHandler{}))
|
||||
|
||||
keyList := []struct {
|
||||
isSet func(Key) bool
|
||||
keys []Key
|
||||
}{
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadString(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: stringKeys,
|
||||
},
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadBoolean(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: boolKeys,
|
||||
},
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadUInt64(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: uint64Keys,
|
||||
},
|
||||
}
|
||||
|
||||
var anySet bool
|
||||
for _, l := range keyList {
|
||||
for _, k := range l.keys {
|
||||
if !l.isSet(k) {
|
||||
continue
|
||||
}
|
||||
clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1)
|
||||
anySet = true
|
||||
}
|
||||
}
|
||||
if anySet {
|
||||
windowsAny.Set(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadString(key string) (string, error) {
|
||||
s, err := winutil.GetPolicyString(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
|
||||
return s, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadUInt64(key string) (uint64, error) {
|
||||
value, err := winutil.GetPolicyInteger(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadBoolean(key string) (bool, error) {
|
||||
value, err := winutil.GetPolicyInteger(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value != 0, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadStringArray(key string) ([]string, error) {
|
||||
value, err := winutil.GetPolicyStringArray(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
63
util/syspolicy/internal/internal.go
Normal file
63
util/syspolicy/internal/internal.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package internal contains miscellaneous functions and types
|
||||
// that are internal to the syspolicy packages.
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/go-json-experiment/json/jsontext"
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
// OSForTesting is the operating system override used for testing.
|
||||
// It follows the same naming convention as [version.OS].
|
||||
var OSForTesting lazy.SyncValue[string]
|
||||
|
||||
// OS is like [version.OS], but supports a test hook.
|
||||
func OS() string {
|
||||
return OSForTesting.Get(version.OS)
|
||||
}
|
||||
|
||||
// TB is a subset of testing.TB that we use to set up test helpers.
|
||||
// It's defined here to avoid pulling in the testing package.
|
||||
type TB interface {
|
||||
Helper()
|
||||
Cleanup(func())
|
||||
Logf(format string, args ...any)
|
||||
Error(args ...any)
|
||||
Errorf(format string, args ...any)
|
||||
Fatal(args ...any)
|
||||
Fatalf(format string, args ...any)
|
||||
}
|
||||
|
||||
// EqualJSONForTest compares the JSON in j1 and j2 for semantic equality.
|
||||
// It returns "", "", true if j1 and j2 are equal. Otherwise, it returns
|
||||
// indented versions of j1 and j2 and false.
|
||||
func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) {
|
||||
tb.Helper()
|
||||
j1 = j1.Clone()
|
||||
j2 = j2.Clone()
|
||||
// Canonicalize JSON values for comparison.
|
||||
if err := j1.Canonicalize(); err != nil {
|
||||
tb.Error(err)
|
||||
}
|
||||
if err := j2.Canonicalize(); err != nil {
|
||||
tb.Error(err)
|
||||
}
|
||||
// Check and return true if the two values are structurally equal.
|
||||
if bytes.Equal(j1, j2) {
|
||||
return "", "", true
|
||||
}
|
||||
// Otherwise, format the values for display and return false.
|
||||
if err := j1.Indent("", "\t"); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := j2.Indent("", "\t"); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return j1.String(), j2.String(), false
|
||||
}
|
||||
84
util/syspolicy/internal/lazyinit/lazyinit.go
Normal file
84
util/syspolicy/internal/lazyinit/lazyinit.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// The lazyinit package facilitates deferred package initialization.
|
||||
package lazyinit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var packageInit deferredOnce
|
||||
|
||||
// Defer defers the specified action until [Do] is called.
|
||||
// It returns a boolean indicating whether [Do] has already been called.
|
||||
func Defer(action func() error) bool {
|
||||
return packageInit.Defer(action)
|
||||
}
|
||||
|
||||
// DeferWithCleanup is like [Defer], but the action function returns a cleanup
|
||||
// function to be called in case of an error.
|
||||
func DeferWithCleanup(action func() (cleanup func(), err error)) bool {
|
||||
return packageInit.DeferWithCleanup(action)
|
||||
}
|
||||
|
||||
// Do runs all deferred functions and returns an error if any of them fail.
|
||||
func Do() error {
|
||||
return packageInit.Do()
|
||||
}
|
||||
|
||||
type deferredOnce struct {
|
||||
done atomic.Uint32
|
||||
err error
|
||||
m sync.Mutex
|
||||
funcs []func() (cleanup func(), err error)
|
||||
}
|
||||
|
||||
func (o *deferredOnce) Defer(action func() error) bool {
|
||||
return o.DeferWithCleanup(func() (cleanup func(), err error) {
|
||||
return nil, action()
|
||||
})
|
||||
}
|
||||
|
||||
func (o *deferredOnce) DeferWithCleanup(action func() (cleanup func(), err error)) bool {
|
||||
o.m.Lock()
|
||||
defer o.m.Unlock()
|
||||
if o.done.Load() != 0 {
|
||||
return false
|
||||
}
|
||||
o.funcs = append(o.funcs, action)
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *deferredOnce) Do() error {
|
||||
if o.done.Load() == 0 {
|
||||
o.doSlow()
|
||||
}
|
||||
return o.err
|
||||
}
|
||||
|
||||
func (o *deferredOnce) doSlow() (err error) {
|
||||
o.m.Lock()
|
||||
defer o.m.Unlock()
|
||||
if o.done.Load() == 0 {
|
||||
defer func() {
|
||||
o.done.Store(1)
|
||||
o.err = err
|
||||
}()
|
||||
for _, f := range o.funcs {
|
||||
cleanup, err := f()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cleanup != nil {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cleanup()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
return o.err
|
||||
}
|
||||
46
util/syspolicy/internal/loggerx/logger.go
Normal file
46
util/syspolicy/internal/loggerx/logger.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package loggerx provides logging functions to the rest of the syspolicy packages.
|
||||
package loggerx
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
const (
|
||||
errorPrefix = "syspolicy: "
|
||||
verbosePrefix = "syspolicy: [v2] "
|
||||
)
|
||||
|
||||
var (
|
||||
lazyErrorf lazy.SyncValue[logger.Logf]
|
||||
lazyVerbosef lazy.SyncValue[logger.Logf]
|
||||
)
|
||||
|
||||
// Errorf formats and writes an error message to the log.
|
||||
func Errorf(format string, args ...any) {
|
||||
errorf := lazyErrorf.Get(func() logger.Logf {
|
||||
return logger.WithPrefix(log.Printf, errorPrefix)
|
||||
})
|
||||
errorf(format, args...)
|
||||
}
|
||||
|
||||
// Verbosef formats and writes an optional, verbose message to the log.
|
||||
func Verbosef(format string, args ...any) {
|
||||
verbosef := lazyVerbosef.Get(func() logger.Logf {
|
||||
return logger.WithPrefix(log.Printf, verbosePrefix)
|
||||
})
|
||||
verbosef(format, args...)
|
||||
}
|
||||
|
||||
// SetForTest sets the specified errorf and verbosef functions for the duration
|
||||
// of tb and its subtests.
|
||||
func SetForTest(tb internal.TB, errorf, verbosef logger.Logf) {
|
||||
lazyErrorf.SetForTest(tb, errorf, nil)
|
||||
lazyVerbosef.SetForTest(tb, verbosef, nil)
|
||||
}
|
||||
315
util/syspolicy/internal/metrics/metrics.go
Normal file
315
util/syspolicy/internal/metrics/metrics.go
Normal file
@@ -0,0 +1,315 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package metrics provides logging and reporting for policy settings and scopes.
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
|
||||
var lazyReportMetrics lazy.SyncValue[bool] // used as a test hook
|
||||
|
||||
// ShouldReport reports whether metrics should be reported on the current environment.
|
||||
func ShouldReport() bool {
|
||||
return lazyReportMetrics.Get(func() bool {
|
||||
// macOS, iOS and tvOS create their own metrics,
|
||||
// and we don't have syspolicy on any other platforms.
|
||||
return setting.PlatformList{"android", "windows"}.HasCurrent()
|
||||
})
|
||||
}
|
||||
|
||||
// Reset metrics for the specified policy origin.
|
||||
func Reset(origin *setting.Origin) {
|
||||
scopeMetrics(origin).Reset()
|
||||
}
|
||||
|
||||
// ReportConfigured updates metrics and logs that the specified setting is
|
||||
// configured with the given value in the origin.
|
||||
func ReportConfigured(origin *setting.Origin, setting *setting.Definition, value any) {
|
||||
settingMetricsFor(setting).ReportValue(origin, value)
|
||||
}
|
||||
|
||||
// ReportError updates metrics and logs that the specified setting has an error
|
||||
// in the origin.
|
||||
func ReportError(origin *setting.Origin, setting *setting.Definition, err error) {
|
||||
settingMetricsFor(setting).ReportError(origin, err)
|
||||
}
|
||||
|
||||
// ReportNotConfigured updates metrics and logs that the specified setting is
|
||||
// not configured in the origin.
|
||||
func ReportNotConfigured(origin *setting.Origin, setting *setting.Definition) {
|
||||
settingMetricsFor(setting).Reset(origin)
|
||||
}
|
||||
|
||||
// metric is an interface implemented by [clientmetric.Metric] and [funcMetric].
|
||||
type metric interface {
|
||||
Add(v int64)
|
||||
Set(v int64)
|
||||
}
|
||||
|
||||
// policyScopeMetrics are metrics that apply to an entire policy scope rather
|
||||
// than a specific policy setting.
|
||||
type policyScopeMetrics struct {
|
||||
hasAny metric
|
||||
numErrored metric
|
||||
}
|
||||
|
||||
func newScopeMetrics(scope setting.Scope) *policyScopeMetrics {
|
||||
prefix := metricScopeName(scope)
|
||||
if prefix != "" {
|
||||
prefix += "_"
|
||||
}
|
||||
// {os}_syspolicy_{scope_unless_device}_any
|
||||
// Example: windows_syspolicy_any or windows_syspolicy_user_any.
|
||||
hasAny := newMetric(prefix+"any", clientmetric.TypeGauge)
|
||||
// {os}_syspolicy_{scope_unless_device}_errors
|
||||
// Example: windows_syspolicy_errors or windows_syspolicy_user_errors.
|
||||
//
|
||||
// TODO(nickkhyl): maybe make the `{os}_syspolicy_errors` metric a gauge rather than a counter?
|
||||
// It was a counter prior to https://github.com/tailscale/tailscale/issues/12687, so I kept it as such.
|
||||
// But I think a gauge makes more sense: syspolicy errors indicate a mismatch between the expected
|
||||
// policy value type or format and the actual value read from the underlying store (like the Windows Registry).
|
||||
// We'll encounter the same error every time we re-read the policy setting from the backing store
|
||||
// until the policy value is corrected by the user, or until we fix the bug in the code or ADMX.
|
||||
// There's probably no reason to count and accumulate them over time.
|
||||
numErrored := newMetric(prefix+"errors", clientmetric.TypeCounter)
|
||||
return &policyScopeMetrics{hasAny, numErrored}
|
||||
}
|
||||
|
||||
// ReportHasSettings is called when there's any configured policy setting in the scope.
|
||||
func (m *policyScopeMetrics) ReportHasSettings() {
|
||||
if m != nil {
|
||||
m.hasAny.Set(1)
|
||||
}
|
||||
}
|
||||
|
||||
// ReportError is called when there's any errored policy setting in the scope.
|
||||
func (m *policyScopeMetrics) ReportError() {
|
||||
if m != nil {
|
||||
m.numErrored.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset is called to reset the policy scope metrics, such as when the policy scope
|
||||
// is about to be reloaded.
|
||||
func (m *policyScopeMetrics) Reset() {
|
||||
if m != nil {
|
||||
m.hasAny.Set(0)
|
||||
// numErrored is a counter and cannot be (re-)set.
|
||||
}
|
||||
}
|
||||
|
||||
// settingMetrics are metrics for a single policy setting in one or more scopes.
|
||||
type settingMetrics struct {
|
||||
definition *setting.Definition
|
||||
isSet []metric // by scope
|
||||
hasErrors []metric // by scope
|
||||
}
|
||||
|
||||
// ReportValue is called when the policy setting is found to be configured in the specified source.
|
||||
func (m *settingMetrics) ReportValue(origin *setting.Origin, v any) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) {
|
||||
m.isSet[scope].Set(1)
|
||||
m.hasErrors[scope].Set(0)
|
||||
}
|
||||
scopeMetrics(origin).ReportHasSettings()
|
||||
loggerx.Verbosef("%v(%q) = %v\n", origin, m.definition.Key(), v)
|
||||
}
|
||||
|
||||
// ReportError is called when there's an error with the policy setting in the specified source.
|
||||
func (m *settingMetrics) ReportError(origin *setting.Origin, err error) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); int(scope) < len(m.hasErrors) {
|
||||
m.isSet[scope].Set(0)
|
||||
m.hasErrors[scope].Set(1)
|
||||
}
|
||||
scopeMetrics(origin).ReportError()
|
||||
loggerx.Errorf("%v(%q): %v\n", origin, m.definition.Key(), err)
|
||||
}
|
||||
|
||||
// Reset is called to reset the policy setting's metrics, such as when
|
||||
// the policy setting does not exist or the source containing the policy
|
||||
// is about to be reloaded.
|
||||
func (m *settingMetrics) Reset(origin *setting.Origin) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) {
|
||||
m.isSet[scope].Set(0)
|
||||
m.hasErrors[scope].Set(0)
|
||||
}
|
||||
}
|
||||
|
||||
// metricFn is a function that adds or sets a metric value.
|
||||
type metricFn = func(name string, typ clientmetric.Type, v int64)
|
||||
|
||||
// funcMetric implements [metric] by calling the specified add and set functions.
|
||||
// Used for testing, and with nil functions on platforms that do not support
|
||||
// syspolicy, and on platforms that report policy metrics from the GUI.
|
||||
type funcMetric struct {
|
||||
name string
|
||||
typ clientmetric.Type
|
||||
add, set metricFn
|
||||
}
|
||||
|
||||
func (m funcMetric) Add(v int64) {
|
||||
if m.add != nil {
|
||||
m.add(m.name, m.typ, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (m funcMetric) Set(v int64) {
|
||||
if m.set != nil {
|
||||
m.set(m.name, m.typ, v)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
lazyDeviceMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
lazyProfileMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
lazyUserMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
)
|
||||
|
||||
func scopeMetrics(origin *setting.Origin) *policyScopeMetrics {
|
||||
switch origin.Scope().Kind() {
|
||||
case setting.DeviceSetting:
|
||||
return lazyDeviceMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.DeviceSetting)
|
||||
})
|
||||
case setting.ProfileSetting:
|
||||
return lazyProfileMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.ProfileSetting)
|
||||
})
|
||||
case setting.UserSetting:
|
||||
return lazyUserMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.UserSetting)
|
||||
})
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
settingMetricsMu sync.RWMutex
|
||||
settingMetricsMap map[setting.Key]*settingMetrics
|
||||
)
|
||||
|
||||
func settingMetricsFor(setting *setting.Definition) *settingMetrics {
|
||||
settingMetricsMu.RLock()
|
||||
if metrics, ok := settingMetricsMap[setting.Key()]; ok {
|
||||
settingMetricsMu.RUnlock()
|
||||
return metrics
|
||||
}
|
||||
settingMetricsMu.RUnlock()
|
||||
return settingMetricsForSlow(setting)
|
||||
}
|
||||
|
||||
func settingMetricsForSlow(d *setting.Definition) *settingMetrics {
|
||||
settingMetricsMu.Lock()
|
||||
defer settingMetricsMu.Unlock()
|
||||
if metrics, ok := settingMetricsMap[d.Key()]; ok {
|
||||
return metrics
|
||||
}
|
||||
|
||||
isSet := make([]metric, d.Scope()+1)
|
||||
hasErrors := make([]metric, d.Scope()+1)
|
||||
for i := range isSet {
|
||||
scope := setting.Scope(i)
|
||||
// {os}_syspolicy_{key}_{scope_unless_device}
|
||||
// Example: windows_syspolicy_AdminConsole or windows_syspolicy_AdminConsole_user.
|
||||
isSet[i] = newSettingMetric(d.Key(), scope, "", clientmetric.TypeGauge)
|
||||
// {os}_syspolicy_{key}_{scope_unless_device}_error
|
||||
// Example: windows_syspolicy_AdminConsole_error or windows_syspolicy_TestSetting01_user_error.
|
||||
hasErrors[i] = newSettingMetric(d.Key(), scope, "error", clientmetric.TypeGauge)
|
||||
}
|
||||
metrics := &settingMetrics{d, isSet, hasErrors}
|
||||
mak.Set(&settingMetricsMap, d.Key(), metrics)
|
||||
return metrics
|
||||
}
|
||||
|
||||
// hooks for testing
|
||||
var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn]
|
||||
|
||||
// SetHooksForTest sets the specified addMetric and setMetric functions
|
||||
// as the metric functions for the duration of tb and all its subtests.
|
||||
func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) {
|
||||
oldAddMetric := addMetricTestHook.Swap(addMetric)
|
||||
oldSetMetric := setMetricTestHook.Swap(setMetric)
|
||||
tb.Cleanup(func() {
|
||||
addMetricTestHook.Store(oldAddMetric)
|
||||
setMetricTestHook.Store(oldSetMetric)
|
||||
})
|
||||
|
||||
settingMetricsMu.Lock()
|
||||
oldSettingMetricsMap := xmaps.Clone(settingMetricsMap)
|
||||
clear(settingMetricsMap)
|
||||
settingMetricsMu.Unlock()
|
||||
tb.Cleanup(func() {
|
||||
settingMetricsMu.Lock()
|
||||
settingMetricsMap = oldSettingMetricsMap
|
||||
settingMetricsMu.Unlock()
|
||||
})
|
||||
|
||||
// (re-)set the scope metrics to use the test hooks for the duration of tb.
|
||||
lazyDeviceMetrics.SetForTest(tb, newScopeMetrics(setting.DeviceSetting), nil)
|
||||
lazyProfileMetrics.SetForTest(tb, newScopeMetrics(setting.ProfileSetting), nil)
|
||||
lazyUserMetrics.SetForTest(tb, newScopeMetrics(setting.UserSetting), nil)
|
||||
}
|
||||
|
||||
func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric {
|
||||
name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_")
|
||||
if tag := metricScopeName(scope); tag != "" {
|
||||
name += "_" + tag
|
||||
}
|
||||
if suffix != "" {
|
||||
name += "_" + suffix
|
||||
}
|
||||
return newMetric(name, typ)
|
||||
}
|
||||
|
||||
func newMetric(name string, typ clientmetric.Type) metric {
|
||||
name = internal.OS() + "_syspolicy_" + name
|
||||
switch {
|
||||
case !ShouldReport():
|
||||
return &funcMetric{name: name, typ: typ}
|
||||
case testenv.InTest():
|
||||
return &funcMetric{name, typ, addMetricTestHook.Load(), setMetricTestHook.Load()}
|
||||
case typ == clientmetric.TypeCounter:
|
||||
return clientmetric.NewCounter(name)
|
||||
case typ == clientmetric.TypeGauge:
|
||||
return clientmetric.NewGauge(name)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func metricScopeName(scope setting.Scope) string {
|
||||
switch scope {
|
||||
case setting.DeviceSetting:
|
||||
return ""
|
||||
case setting.ProfileSetting:
|
||||
return "profile"
|
||||
case setting.UserSetting:
|
||||
return "user"
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
423
util/syspolicy/internal/metrics/metrics_test.go
Normal file
423
util/syspolicy/internal/metrics/metrics_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestSettingMetricNames(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key setting.Key
|
||||
scope setting.Scope
|
||||
suffix string
|
||||
typ clientmetric.Type
|
||||
osOverride string
|
||||
wantMetricName string
|
||||
}{
|
||||
{
|
||||
name: "windows-device-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole",
|
||||
},
|
||||
{
|
||||
name: "windows-user-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.UserSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_user",
|
||||
},
|
||||
{
|
||||
name: "windows-profile-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.ProfileSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_profile",
|
||||
},
|
||||
{
|
||||
name: "windows-profile-err",
|
||||
key: "AdminConsole",
|
||||
scope: setting.ProfileSetting,
|
||||
suffix: "error",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_profile_error",
|
||||
},
|
||||
{
|
||||
name: "android-device-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "android",
|
||||
wantMetricName: "android_syspolicy_AdminConsole",
|
||||
},
|
||||
{
|
||||
name: "key-path",
|
||||
key: "category/subcategory/setting",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "fakeos",
|
||||
wantMetricName: "fakeos_syspolicy_category_subcategory_setting",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
metric, ok := newSettingMetric(tt.key, tt.scope, tt.suffix, tt.typ).(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("metric is not a funcMetric")
|
||||
}
|
||||
if metric.name != tt.wantMetricName {
|
||||
t.Errorf("got %q, want %q", metric.name, tt.wantMetricName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopeMetrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope setting.Scope
|
||||
osOverride string
|
||||
wantHasAnyName string
|
||||
wantNumErroredName string
|
||||
wantHasAnyType clientmetric.Type
|
||||
wantNumErroredType clientmetric.Type
|
||||
}{
|
||||
{
|
||||
name: "windows-device",
|
||||
scope: setting.DeviceSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "windows-profile",
|
||||
scope: setting.ProfileSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_profile_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_profile_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "windows-user",
|
||||
scope: setting.UserSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_user_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_user_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "android-device",
|
||||
scope: setting.DeviceSetting,
|
||||
osOverride: "android",
|
||||
wantHasAnyName: "android_syspolicy_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "android_syspolicy_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
metrics := newScopeMetrics(tt.scope)
|
||||
hasAny, ok := metrics.hasAny.(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("hasAny is not a funcMetric")
|
||||
}
|
||||
numErrored, ok := metrics.numErrored.(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("numErrored is not a funcMetric")
|
||||
}
|
||||
if hasAny.name != tt.wantHasAnyName {
|
||||
t.Errorf("hasAny.Name: got %q, want %q", hasAny.name, tt.wantHasAnyName)
|
||||
}
|
||||
if hasAny.typ != tt.wantHasAnyType {
|
||||
t.Errorf("hasAny.Type: got %q, want %q", hasAny.typ, tt.wantHasAnyType)
|
||||
}
|
||||
if numErrored.name != tt.wantNumErroredName {
|
||||
t.Errorf("numErrored.Name: got %q, want %q", numErrored.name, tt.wantNumErroredName)
|
||||
}
|
||||
if numErrored.typ != tt.wantNumErroredType {
|
||||
t.Errorf("hasAny.Type: got %q, want %q", numErrored.typ, tt.wantNumErroredType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testSettingDetails struct {
|
||||
definition *setting.Definition
|
||||
origin *setting.Origin
|
||||
value any
|
||||
err error
|
||||
}
|
||||
|
||||
func TestReportMetrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
osOverride string
|
||||
useMetrics bool
|
||||
settings []testSettingDetails
|
||||
wantMetrics []TestState
|
||||
wantResetMetrics []TestState
|
||||
}{
|
||||
{
|
||||
name: "none",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{},
|
||||
wantMetrics: []TestState{},
|
||||
},
|
||||
{
|
||||
name: "single-value",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-error",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "value-and-error",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two-values",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 17,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two-errors",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 2},
|
||||
{"windows_syspolicy_TestSetting01_error", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 2},
|
||||
{"windows_syspolicy_TestSetting01_error", 0},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-scope",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.ProfileSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.ProfileSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.CurrentProfileScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting03", setting.UserSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.CurrentUserScope),
|
||||
value: 17,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_profile_errors", 1},
|
||||
{"windows_syspolicy_user_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02_profile_error", 1},
|
||||
{"windows_syspolicy_TestSetting03_user", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_profile_errors", 1},
|
||||
{"windows_syspolicy_user_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02_profile_error", 0},
|
||||
{"windows_syspolicy_TestSetting03_user", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "report-metrics-on-android",
|
||||
osOverride: "android",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"android_syspolicy_any", 1},
|
||||
{"android_syspolicy_TestSetting01", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"android_syspolicy_any", 0},
|
||||
{"android_syspolicy_TestSetting01", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "do-not-report-metrics-on-macos",
|
||||
osOverride: "macos",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{}, // none reported
|
||||
},
|
||||
{
|
||||
name: "do-not-report-metrics-on-ios",
|
||||
osOverride: "ios",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{}, // none reported
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset the lazy value so it'll be re-evaluated with the osOverride.
|
||||
lazyReportMetrics = lazy.SyncValue[bool]{}
|
||||
t.Cleanup(func() {
|
||||
// Also reset it during the cleanup.
|
||||
lazyReportMetrics = lazy.SyncValue[bool]{}
|
||||
})
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
|
||||
h := NewTestHandler(t)
|
||||
SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
for _, s := range tt.settings {
|
||||
if s.err != nil {
|
||||
ReportError(s.origin, s.definition, s.err)
|
||||
} else {
|
||||
ReportConfigured(s.origin, s.definition, s.value)
|
||||
}
|
||||
}
|
||||
h.MustEqual(tt.wantMetrics...)
|
||||
|
||||
for _, s := range tt.settings {
|
||||
Reset(s.origin)
|
||||
ReportNotConfigured(s.origin, s.definition)
|
||||
}
|
||||
h.MustEqual(tt.wantResetMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
88
util/syspolicy/internal/metrics/test_handler.go
Normal file
88
util/syspolicy/internal/metrics/test_handler.go
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
// TestState represents a metric name and its expected value.
|
||||
type TestState struct {
|
||||
Name string // `$os` in the name will be replaced by the actual operating system name.`
|
||||
Value int64
|
||||
}
|
||||
|
||||
// TestHandler facilitates testing of the code that uses metrics.
|
||||
type TestHandler struct {
|
||||
t internal.TB
|
||||
|
||||
m map[string]int64
|
||||
}
|
||||
|
||||
// NewTestHandler returns a new TestHandler.
|
||||
func NewTestHandler(t internal.TB) *TestHandler {
|
||||
return &TestHandler{t, make(map[string]int64)}
|
||||
}
|
||||
|
||||
// AddMetric increments the metric with the specified name and type by delta d.
|
||||
func (h *TestHandler) AddMetric(name string, typ clientmetric.Type, d int64) {
|
||||
h.t.Helper()
|
||||
if typ == clientmetric.TypeCounter && d < 0 {
|
||||
h.t.Fatalf("an attempt was made to decrement a counter metric %q", name)
|
||||
}
|
||||
if v, ok := h.m[name]; ok || d != 0 {
|
||||
h.m[name] = v + d
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetric sets the metric with the specified name and type to the value v.
|
||||
func (h *TestHandler) SetMetric(name string, typ clientmetric.Type, v int64) {
|
||||
h.t.Helper()
|
||||
if typ == clientmetric.TypeCounter {
|
||||
h.t.Fatalf("an attempt was made to set a counter metric %q", name)
|
||||
}
|
||||
if _, ok := h.m[name]; ok || v != 0 {
|
||||
h.m[name] = v
|
||||
}
|
||||
}
|
||||
|
||||
// MustEqual fails the test if the actual metric state differs from the specified state.
|
||||
func (h *TestHandler) MustEqual(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
h.MustContain(metrics...)
|
||||
h.mustNoExtra(metrics...)
|
||||
}
|
||||
|
||||
// MustContain fails the test if the specified metrics are not set or have
|
||||
// different values than specified. It permits other metrics to be set in
|
||||
// addition to the ones being tested.
|
||||
func (h *TestHandler) MustContain(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
for _, m := range metrics {
|
||||
name := strings.ReplaceAll(m.Name, "$os", internal.OS())
|
||||
v, ok := h.m[name]
|
||||
if !ok {
|
||||
h.t.Errorf("%q: got (none), want %v", name, m.Value)
|
||||
} else if v != m.Value {
|
||||
h.t.Fatalf("%q: got %v, want %v", name, v, m.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TestHandler) mustNoExtra(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
s := make(set.Set[string])
|
||||
for i := range metrics {
|
||||
s.Add(strings.ReplaceAll(metrics[i].Name, "$os", internal.OS()))
|
||||
}
|
||||
for n, v := range h.m {
|
||||
if !s.Contains(n) {
|
||||
h.t.Errorf("%q: got %v, want (none)", n, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,21 @@
|
||||
|
||||
package syspolicy
|
||||
|
||||
type Key string
|
||||
import (
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
|
||||
type Key = setting.Key
|
||||
|
||||
// The const block below lists known policy keys.
|
||||
// When adding a key to this list, remember to add a corresponding
|
||||
// [setting.Definition] to [implicitDefinitions] below.
|
||||
// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder.
|
||||
// Preferably, use a strongly typed policy hierarchy, such as [Policy],
|
||||
// instead of adding each key to the list below.
|
||||
|
||||
const (
|
||||
// Keys with a string value
|
||||
@@ -96,3 +110,83 @@ const (
|
||||
// AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes.
|
||||
AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes"
|
||||
)
|
||||
|
||||
// implicitDefinitions is a list of [setting.Definition] that will be registered
|
||||
// automatically by [settingDefinitions] as soon as the package needs to ready a policy.
|
||||
var implicitDefinitions = []*setting.Definition{
|
||||
// Device policy settings
|
||||
setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue),
|
||||
|
||||
// User policy settings
|
||||
setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
}
|
||||
|
||||
func init() {
|
||||
lazyinit.Defer(func() error {
|
||||
// Avoid implicit [SettingDefinition] registration during tests.
|
||||
// Each test should control which policy settings to register.
|
||||
// Use [setting.SetDefinitionsForTest] to specify necessary definitions,
|
||||
// or [setWellKnownSettingsForTest] to set implicit definitions for the test duration.
|
||||
if testenv.InTest() {
|
||||
return nil
|
||||
}
|
||||
for _, d := range implicitDefinitions {
|
||||
setting.RegisterDefinition(d)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap]
|
||||
|
||||
// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key,
|
||||
// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist
|
||||
// among implicit policy definitions.
|
||||
func WellKnownSettingDefinition(k Key) (*setting.Definition, error) {
|
||||
m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) {
|
||||
return setting.DefinitionMapOf(implicitDefinitions)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d, ok := m[k]; ok {
|
||||
return d, nil
|
||||
}
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
|
||||
// setWellKnownSettingsForTest registers all implicit setting definitions
|
||||
// for the duration of the test.
|
||||
func setWellKnownSettingsForTest(tb lazy.TB) error {
|
||||
return setting.SetDefinitionsForTest(tb, implicitDefinitions...)
|
||||
}
|
||||
|
||||
95
util/syspolicy/policy_keys_test.go
Normal file
95
util/syspolicy/policy_keys_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestKnownKeysRegistered(t *testing.T) {
|
||||
keyConsts, err := listStringConsts[Key]("policy_keys.go")
|
||||
if err != nil {
|
||||
t.Fatalf("listStringConsts failed: %v", err)
|
||||
}
|
||||
|
||||
m, err := setting.DefinitionMapOf(implicitDefinitions)
|
||||
if err != nil {
|
||||
t.Fatalf("definitionMapOf failed: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range keyConsts {
|
||||
t.Run(string(key), func(t *testing.T) {
|
||||
d := m[key]
|
||||
if d == nil {
|
||||
t.Fatalf("%q was not registered", key)
|
||||
}
|
||||
if d.Key() != key {
|
||||
t.Fatalf("d.Key got: %s, want %s", d.Key(), key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotAWellKnownSetting(t *testing.T) {
|
||||
d, err := WellKnownSettingDefinition("TestSettingDoesNotExist")
|
||||
if d != nil || err == nil {
|
||||
t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey)
|
||||
}
|
||||
}
|
||||
|
||||
func listStringConsts[T ~string](filename string) (map[string]T, error) {
|
||||
fset := token.NewFileSet()
|
||||
src, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := parser.ParseFile(fset, filename, src, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
consts := make(map[string]T)
|
||||
typeName := reflect.TypeFor[T]().Name()
|
||||
for _, d := range f.Decls {
|
||||
g, ok := d.(*ast.GenDecl)
|
||||
if !ok || g.Tok != token.CONST {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, s := range g.Specs {
|
||||
vs, ok := s.(*ast.ValueSpec)
|
||||
if !ok || len(vs.Names) != len(vs.Values) {
|
||||
continue
|
||||
}
|
||||
if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, n := range vs.Names {
|
||||
lit, ok := vs.Values[i].(*ast.BasicLit)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i]))
|
||||
}
|
||||
val, err := strconv.Unquote(lit.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value)
|
||||
}
|
||||
consts[n.Name] = T(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return consts, nil
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
var stringKeys = []Key{
|
||||
ControlURL,
|
||||
LogTarget,
|
||||
Tailnet,
|
||||
ExitNodeID,
|
||||
ExitNodeIP,
|
||||
EnableIncomingConnections,
|
||||
EnableServerMode,
|
||||
ExitNodeAllowLANAccess,
|
||||
EnableTailscaleDNS,
|
||||
EnableTailscaleSubnets,
|
||||
AdminConsoleVisibility,
|
||||
NetworkDevicesVisibility,
|
||||
TestMenuVisibility,
|
||||
UpdateMenuVisibility,
|
||||
RunExitNodeVisibility,
|
||||
PreferencesMenuVisibility,
|
||||
ExitNodeMenuVisibility,
|
||||
AutoUpdateVisibility,
|
||||
ResetToDefaultsVisibility,
|
||||
KeyExpirationNoticeTime,
|
||||
PostureChecking,
|
||||
ManagedByOrganizationName,
|
||||
ManagedByCaption,
|
||||
ManagedByURL,
|
||||
}
|
||||
|
||||
var boolKeys = []Key{
|
||||
LogSCMInteractions,
|
||||
FlushDNSOnSessionUnlock,
|
||||
}
|
||||
|
||||
var uint64Keys = []Key{}
|
||||
109
util/syspolicy/rsop/change_callbacks.go
Normal file
109
util/syspolicy/rsop/change_callbacks.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsop
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// Change represents a change from the Old to the New value of type T.
|
||||
type Change[T any] struct {
|
||||
New, Old T
|
||||
}
|
||||
|
||||
// PolicyChangeCallback is a function called whenever a policy changes.
|
||||
type PolicyChangeCallback func(*PolicyChange)
|
||||
|
||||
// PolicyChange describes a policy change.
|
||||
type PolicyChange struct {
|
||||
snapshots Change[*setting.Snapshot]
|
||||
}
|
||||
|
||||
// New returns the [setting.Snapshot] after the change.
|
||||
func (c PolicyChange) New() *setting.Snapshot {
|
||||
return c.snapshots.New
|
||||
}
|
||||
|
||||
// Old returns the [setting.Snapshot] before the change.
|
||||
func (c PolicyChange) Old() *setting.Snapshot {
|
||||
return c.snapshots.Old
|
||||
}
|
||||
|
||||
// HasChanged reports whether a policy setting with the specified [setting.Key], has changed.
|
||||
func (c PolicyChange) HasChanged(key setting.Key) bool {
|
||||
new, newErr := c.snapshots.New.GetErr(key)
|
||||
old, oldErr := c.snapshots.Old.GetErr(key)
|
||||
if newErr != nil && oldErr != nil {
|
||||
return false
|
||||
}
|
||||
if newErr != nil || oldErr != nil {
|
||||
return true
|
||||
}
|
||||
switch newVal := new.(type) {
|
||||
case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration:
|
||||
return newVal != old
|
||||
case []string:
|
||||
if oldVal, ok := old.([]string); ok {
|
||||
return slices.Equal(newVal, oldVal)
|
||||
}
|
||||
return false
|
||||
default:
|
||||
loggerx.Errorf("%q has an unsupported value type: %T", newVal)
|
||||
return reflect.DeepEqual(new, old)
|
||||
}
|
||||
}
|
||||
|
||||
// policyChangeCallbacks are the callbacks to invoke when the resultant policy changes.
|
||||
// It is safe for concurrent use.
|
||||
type policyChangeCallbacks struct {
|
||||
mu sync.RWMutex
|
||||
cbs set.HandleSet[PolicyChangeCallback]
|
||||
}
|
||||
|
||||
// Register adds the specified callback to be invoked whenever the policy changes.
|
||||
func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) {
|
||||
c.mu.Lock()
|
||||
handle := c.cbs.Add(callback)
|
||||
c.mu.Unlock()
|
||||
return func() {
|
||||
c.mu.Lock()
|
||||
delete(c.cbs, handle)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke calls the registered callback functions with the specified policy change info.
|
||||
func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) {
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
wg.Add(len(c.cbs))
|
||||
change := &PolicyChange{snapshots: snapshots}
|
||||
for _, cb := range c.cbs {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cb(change)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Close awaits the completion of active callbacks and prevents any further invocations.
|
||||
func (c *policyChangeCallbacks) Close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cbs != nil {
|
||||
clear(c.cbs)
|
||||
c.cbs = nil
|
||||
}
|
||||
}
|
||||
698
util/syspolicy/rsop/resultant_policy.go
Normal file
698
util/syspolicy/rsop/resultant_policy.go
Normal file
@@ -0,0 +1,698 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package rsop facilitates [source.Store] registration via [RegisterStore]
|
||||
// and provides access to the resultant policy merged from all registered sources
|
||||
// via [PolicyFor].
|
||||
package rsop
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/slicesx"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
var errResultantPolicyClosed = errors.New("resultant policy closed")
|
||||
|
||||
// The minimum and maximum wait times after detecting a policy change
|
||||
// before reloading the policy.
|
||||
// Policy changes occurring within [policyReloadMinDelay] of each other
|
||||
// will be batched together, resulting in a single policy reload
|
||||
// no later than [policyReloadMaxDelay] after the first detected change.
|
||||
// In other words, the resultant policy will be reloaded no more often than once
|
||||
// every 5 seconds, but at most 15 seconds after an underlying [source.Store]
|
||||
// has issued a policy change callback.
|
||||
// See [Policy.watchReload].
|
||||
const (
|
||||
defaultPolicyReloadMinDelay = 5 * time.Second
|
||||
defaultPolicyReloadMaxDelay = 15 * time.Second
|
||||
)
|
||||
|
||||
// policyReloadMinDelay and policyReloadMaxDelay are test hooks.
|
||||
// Their values default to [defaultPolicyReloadMinDelay] and [defaultPolicyReloadMaxDelay].
|
||||
var (
|
||||
policyReloadMinDelay, policyReloadMaxDelay lazy.SyncValue[time.Duration]
|
||||
)
|
||||
|
||||
// Policy provides access to the current resultant [setting.Snapshot] for a given
|
||||
// scope and allows to reload it from the underlying [source.Store]s. It also allows to
|
||||
// subscribe and receive a callback whenever the resultant [setting.Snapshot] is
|
||||
// changed. It is safe for concurrent use.
|
||||
type Policy struct {
|
||||
scope setting.PolicyScope
|
||||
|
||||
reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required
|
||||
changeSourceCh chan sourceChangeRequest // written to to add a new or remove an existing source
|
||||
closeCh chan struct{} // closed to signal that the Policy is being closed
|
||||
doneCh chan struct{} // closed by closeInternal when watchReload exits
|
||||
|
||||
// resultant is the most recent version of the [setting.Snapshot] containing policy settings
|
||||
// merged from all applicable sources.
|
||||
resultant atomic.Pointer[setting.Snapshot]
|
||||
|
||||
changeCallbacks policyChangeCallbacks
|
||||
|
||||
mu sync.RWMutex
|
||||
sources source.ReadableSources
|
||||
closing bool // Close was called (even if we're still closing)
|
||||
}
|
||||
|
||||
// newPolicy returns a new [Policy] for the specified [setting.PolicyScope]
|
||||
// that tracks changes and merges policy settings read from the specified sources.
|
||||
func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (p *Policy, err error) {
|
||||
readableSources := source.ReadableSources(make([]source.ReadableSource, len(sources)))
|
||||
for i, s := range sources {
|
||||
reader, err := s.Reader()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get a store reader: %v", err)
|
||||
}
|
||||
session, err := reader.OpenSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open a reading session: %v", err)
|
||||
}
|
||||
|
||||
readableSource := source.ReadableSource{
|
||||
Source: s,
|
||||
ReadingSession: session,
|
||||
}
|
||||
readableSources[i] = readableSource
|
||||
defer func() {
|
||||
if err != nil {
|
||||
readableSource.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Sort policy sources by their precedence from lower to higher.
|
||||
// For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}.
|
||||
readableSources.StableSort()
|
||||
|
||||
p = &Policy{
|
||||
scope: scope,
|
||||
sources: readableSources,
|
||||
reloadCh: make(chan reloadRequest, 1),
|
||||
changeSourceCh: make(chan sourceChangeRequest),
|
||||
closeCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
if err := p.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// IsValid reports whether p is in a valid state and has not been closed.
|
||||
func (p *Policy) IsValid() bool {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Scope returns the [setting.PolicyScope] that this resultant policy applies to.
|
||||
func (p *Policy) Scope() setting.PolicyScope {
|
||||
return p.scope
|
||||
}
|
||||
|
||||
// Get returns the most recent resultant [setting.Snapshot].
|
||||
func (p *Policy) Get() *setting.Snapshot {
|
||||
return p.resultant.Load()
|
||||
}
|
||||
|
||||
// RegisterChangeCallback adds a function to be called whenever the resultant
|
||||
// policy changes. The returned function can be used to unregister the callback.
|
||||
func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) {
|
||||
return p.changeCallbacks.Register(callback)
|
||||
}
|
||||
|
||||
// Reload synchronously re-reads policy settings from the underlying policy
|
||||
// [source.Store], constructing a new merged [setting.Snapshot] even if the policy remains
|
||||
// unchanged. In most scenarios, there's no need to re-read the policy manually.
|
||||
// Instead, it is recommended to register a policy change callback, or to use
|
||||
// the most recent [setting.Snapshot] returned by the [Policy.Get] method.
|
||||
func (p *Policy) Reload() (*setting.Snapshot, error) {
|
||||
return p.reload(true)
|
||||
}
|
||||
|
||||
// reload is like Reload, but allows to specify whether to re-read policy settings
|
||||
// from unchanged policy sources.
|
||||
func (p *Policy) reload(force bool) (*setting.Snapshot, error) {
|
||||
respCh := make(chan reloadResponse, 1)
|
||||
select {
|
||||
case p.reloadCh <- reloadRequest{force: force, respCh: respCh}:
|
||||
// continue
|
||||
case <-p.closeCh:
|
||||
return nil, errResultantPolicyClosed
|
||||
}
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
return resp.policy, resp.err
|
||||
case <-p.closeCh:
|
||||
return nil, errResultantPolicyClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the [Policy] is closed.
|
||||
func (p *Policy) Done() <-chan struct{} {
|
||||
return p.doneCh
|
||||
}
|
||||
|
||||
func (p *Policy) start() error {
|
||||
if _, err := p.reloadNow(false); err != nil {
|
||||
return err
|
||||
}
|
||||
go p.watchPolicyChanges()
|
||||
go p.watchReload()
|
||||
return nil
|
||||
}
|
||||
|
||||
// readAndMerge reads and merges policy settings from the underlying sources,
|
||||
// returning a [setting.Snapshot] with the merged result.
|
||||
// If the force parameter is true, it re-reads policy settings from each store
|
||||
// even if no policy change was observed, and returns an error if the read
|
||||
// operation fails.
|
||||
func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
// Start with an empty policy in the target scope.
|
||||
resultant := setting.NewSnapshot(nil, setting.SummaryWith(p.scope))
|
||||
// Then merge policy settings from all sources.
|
||||
// Policy sources with the highest precedence (e.g., the device policy) are merged last,
|
||||
// overriding any conflicting policy settings with lower precedence.
|
||||
for _, s := range p.sources {
|
||||
var policy *setting.Snapshot
|
||||
if force {
|
||||
var err error
|
||||
if policy, err = s.ReadSettings(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
policy = s.GetSettings()
|
||||
}
|
||||
resultant = setting.MergeSnapshots(resultant, policy)
|
||||
}
|
||||
return resultant, nil
|
||||
}
|
||||
|
||||
// reloadAsync requests an asynchronous background policy reload.
|
||||
// The policy will be reloaded no later than in [policyReloadMaxDelay].
|
||||
func (p *Policy) reloadAsync() {
|
||||
select {
|
||||
case p.reloadCh <- reloadRequest{}:
|
||||
// Sent.
|
||||
default:
|
||||
// A reload request is already en route.
|
||||
}
|
||||
}
|
||||
|
||||
// reloadNow loads and merges policies from all sources, updating the resultant policy.
|
||||
// If the force parameter is true, it forcibly reloads policies
|
||||
// from the underlying policy store, even if no policy changes were detected.
|
||||
//
|
||||
// Except for the initial policy reload during the [Policy] creation,
|
||||
// this method should only be called from the [Policy.watchReload] goroutine.
|
||||
func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) {
|
||||
new, err := p.readAndMerge(force)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
old := p.resultant.Swap(new)
|
||||
// A nil old value indicates the initial policy load rather than a policy change.
|
||||
// Additionally, we should not invoke the policy change callbacks unless the
|
||||
// policy items have actually changed.
|
||||
if old != nil && !old.EqualItems(new) {
|
||||
snapshots := Change[*setting.Snapshot]{New: new, Old: old}
|
||||
p.changeCallbacks.Invoke(snapshots)
|
||||
}
|
||||
return new, nil
|
||||
}
|
||||
|
||||
// AddSource adds the specified source to the list of sources used by p,
|
||||
// and triggers a synchronous policy refresh. It returns an error
|
||||
// if the source is not a valid source for this resultant policy,
|
||||
// or if the resultant policy is being closed,
|
||||
// or if policy refresh fails with an error.
|
||||
func (p *Policy) AddSource(source *source.Source) error {
|
||||
return p.changeSource(source, nil)
|
||||
}
|
||||
|
||||
// RemoveSource removes the specified source from the list of sources used by p,
|
||||
// and triggers a synchronous policy refresh. It returns an error if the
|
||||
// resultant policy is being closed, or if policy refresh fails with an error.
|
||||
func (p *Policy) RemoveSource(source *source.Source) error {
|
||||
return p.changeSource(nil, source)
|
||||
}
|
||||
|
||||
// ReplaceSource replaces the old source with the new source atomically,
|
||||
// and triggers a synchronous policy refresh. It returns an error
|
||||
// if the source is not a valid source for this resultant policy,
|
||||
// or if the resultant policy is being closed,
|
||||
// or if policy refresh fails with an error.
|
||||
func (p *Policy) ReplaceSource(old, new *source.Source) error {
|
||||
return p.changeSource(new, old)
|
||||
}
|
||||
|
||||
func (p *Policy) changeSource(toAdd, toRemove *source.Source) error {
|
||||
if toAdd == toRemove {
|
||||
return nil
|
||||
}
|
||||
if toAdd != nil && !p.scope.IsWithinOf(toAdd.Scope()) {
|
||||
return errors.New("scope mismatch")
|
||||
}
|
||||
respCh := make(chan error, 1)
|
||||
req := sourceChangeRequest{toAdd, toRemove, respCh}
|
||||
select {
|
||||
case p.changeSourceCh <- req:
|
||||
return <-respCh
|
||||
case <-p.closeCh:
|
||||
return errResultantPolicyClosed
|
||||
}
|
||||
}
|
||||
|
||||
// watchPolicyChanges awaits a policy change notification from any of the sources
|
||||
// and calls reloadAsync whenever a notification is received.
|
||||
func (p *Policy) watchPolicyChanges() {
|
||||
const (
|
||||
closeIdx = iota
|
||||
changeSourceIdx
|
||||
policyChangedOffset
|
||||
)
|
||||
|
||||
// The cases are Close, ChangeSource, PolicyChanged[0],...,PolicyChanged[N-1].
|
||||
p.mu.RLock()
|
||||
cases := make([]reflect.SelectCase, len(p.sources)+policyChangedOffset)
|
||||
// Add the PolicyChanged[N] cases.
|
||||
for i, source := range p.sources {
|
||||
cases[i+policyChangedOffset] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(source.PolicyChanged())}
|
||||
}
|
||||
// Add the Close case.
|
||||
cases[closeIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.closeCh)}
|
||||
// Add the ChangeSource case.
|
||||
cases[changeSourceIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.changeSourceCh)}
|
||||
p.mu.RUnlock()
|
||||
|
||||
for {
|
||||
switch chosen, recv, ok := reflect.Select(cases); chosen {
|
||||
case closeIdx: // Close
|
||||
// Exit the watch as the closeCh was closed, indicating that
|
||||
// the [Policy] is being closed.
|
||||
return
|
||||
case changeSourceIdx: // ChangeSource
|
||||
// We've received a source change request from one of the AddSource,
|
||||
// RemoveSource, or ReplaceSource methods, meaning that we need to:
|
||||
// - Open a reader session if a new source is being added;
|
||||
// - Update the p.sources slice;
|
||||
// - Update the cases slice;
|
||||
// - Trigger a synchronous policy reload;
|
||||
// - Report an error, if any, back to the caller.
|
||||
req := recv.Interface().(sourceChangeRequest)
|
||||
needClose, err := func() (close bool, err error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if req.toAdd != nil {
|
||||
if !p.sources.Contains(req.toAdd) {
|
||||
reader, err := req.toAdd.Reader()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get a store reader: %v", err)
|
||||
}
|
||||
session, err := reader.OpenSession()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to open a reading session: %v", err)
|
||||
}
|
||||
|
||||
addAt := p.sources.InsertionIndexOf(req.toAdd)
|
||||
toAdd := source.ReadableSource{
|
||||
Source: req.toAdd,
|
||||
ReadingSession: session,
|
||||
}
|
||||
p.sources = slices.Insert(p.sources, addAt, toAdd)
|
||||
newCase := reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(toAdd.PolicyChanged())}
|
||||
caseIndex := addAt + policyChangedOffset
|
||||
cases = slices.Insert(cases, caseIndex, newCase)
|
||||
}
|
||||
}
|
||||
if req.toDelete != nil {
|
||||
if deleteAt := p.sources.IndexOf(req.toDelete); deleteAt != -1 {
|
||||
p.sources.DeleteAt(deleteAt)
|
||||
caseIndex := deleteAt + policyChangedOffset
|
||||
cases = slices.Delete(cases, caseIndex, caseIndex+1)
|
||||
}
|
||||
}
|
||||
return len(p.sources) == 0, nil
|
||||
}()
|
||||
if err == nil {
|
||||
if needClose {
|
||||
// Close the resultant policy if the last policy source was deleted.
|
||||
p.Close()
|
||||
} else {
|
||||
// Otherwise, reload the policy synchronously.
|
||||
_, err = p.reload(false)
|
||||
}
|
||||
}
|
||||
req.respCh <- err
|
||||
default: // PolicyChanged[N]
|
||||
if !ok {
|
||||
// One of the PolicyChanged channels was closed, indicating that
|
||||
// the corresponding [source.Source] is no longer valid.
|
||||
// We can no longer keep this [Policy] up to date
|
||||
// and should close it.
|
||||
p.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// One of the PolicyChanged channels was signaled.
|
||||
// We should request an asynchronous policy reload.
|
||||
p.reloadAsync()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// watchReload processes incoming synchronous and asynchronous policy reload requests.
|
||||
// Synchronous requests (with a non-nil respCh) are served immediately.
|
||||
// Asynchronous requests are debounced and throttled: they are executed at least
|
||||
// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay]
|
||||
// after the first request in a batch.
|
||||
func (p *Policy) watchReload() {
|
||||
force := false // whether a forced refresh was requested
|
||||
var delayCh, timeoutCh <-chan time.Time
|
||||
reload := func(respCh chan<- reloadResponse) {
|
||||
delayCh, timeoutCh = nil, nil
|
||||
policy, err := p.reloadNow(force)
|
||||
if err != nil {
|
||||
loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err)
|
||||
}
|
||||
if respCh != nil {
|
||||
respCh <- reloadResponse{policy: policy, err: err}
|
||||
}
|
||||
force = false
|
||||
}
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case req := <-p.reloadCh:
|
||||
if req.force {
|
||||
force = true
|
||||
}
|
||||
if req.respCh != nil {
|
||||
reload(req.respCh)
|
||||
continue
|
||||
}
|
||||
if delayCh == nil {
|
||||
timeoutCh = time.After(policyReloadMaxDelay.Get(func() time.Duration { return defaultPolicyReloadMaxDelay }))
|
||||
}
|
||||
delayCh = time.After(policyReloadMinDelay.Get(func() time.Duration { return defaultPolicyReloadMinDelay }))
|
||||
case <-delayCh:
|
||||
reload(nil)
|
||||
case <-timeoutCh:
|
||||
reload(nil)
|
||||
case <-p.closeCh:
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
p.closeInternal()
|
||||
}
|
||||
|
||||
func (p *Policy) closeInternal() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.sources.Close()
|
||||
p.changeCallbacks.Close()
|
||||
close(p.doneCh)
|
||||
}
|
||||
|
||||
// Close initiates the closing of the resultant policy.
|
||||
// The actual closing is performed by closeInternal when watchReload exits,
|
||||
// and the Done() channel is closed when closeInternal finishes.
|
||||
func (p *Policy) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.closing {
|
||||
return
|
||||
}
|
||||
p.closing = true
|
||||
close(p.closeCh)
|
||||
}
|
||||
|
||||
// sourceChangeRequest is a request to add and/or remove source from a [Policy].
|
||||
type sourceChangeRequest struct {
|
||||
toAdd, toDelete *source.Source
|
||||
respCh chan<- error
|
||||
}
|
||||
|
||||
// reloadRequest describes a policy reload request.
|
||||
type reloadRequest struct {
|
||||
// force triggers an immediate synchronous policy reload,
|
||||
// reloading the policy regardless of whether a policy change was detected.
|
||||
force bool
|
||||
// respCh is an optional channel. If non-nil, it makes the reload request
|
||||
// synchronous and receives the result.
|
||||
respCh chan<- reloadResponse
|
||||
}
|
||||
|
||||
type reloadResponse struct {
|
||||
policy *setting.Snapshot
|
||||
err error
|
||||
}
|
||||
|
||||
var (
|
||||
policyMu sync.RWMutex
|
||||
policySources []*source.Source
|
||||
resultantPolicies []*Policy
|
||||
|
||||
resultantPolicyLRU [setting.MaxSettingScope + 1]syncs.AtomicValue[*Policy] // by [Scope.Kind]
|
||||
)
|
||||
|
||||
// registerSource registers the specified [source.Source] to be used by the package.
|
||||
// It updates existing [Policy]s returned by [PolicyFor] to use this source if
|
||||
// they are within the source's [setting.PolicyScope].
|
||||
func registerSource(source *source.Source) error {
|
||||
policyMu.Lock()
|
||||
defer policyMu.Unlock()
|
||||
if slices.Contains(policySources, source) {
|
||||
return nil
|
||||
}
|
||||
policySources = append(policySources, source)
|
||||
return forEachResultantPolicyLocked(func(policy *Policy) error {
|
||||
if !policy.Scope().IsWithinOf(source.Scope()) {
|
||||
return nil
|
||||
}
|
||||
return policy.AddSource(source)
|
||||
})
|
||||
}
|
||||
|
||||
// replaceSource is like [unregisterSource](old) followed by [registerSource](new),
|
||||
// but is atomic from the perspective of each [Policy].
|
||||
func replaceSource(old, new *source.Source) error {
|
||||
policyMu.Lock()
|
||||
defer policyMu.Unlock()
|
||||
oldIndex := slices.Index(policySources, old)
|
||||
if oldIndex == -1 {
|
||||
return fmt.Errorf("the source is not registered: %v", old)
|
||||
}
|
||||
policySources[oldIndex] = new
|
||||
return forEachResultantPolicyLocked(func(policy *Policy) error {
|
||||
if policy.Scope().IsWithinOf(old.Scope()) || policy.Scope().IsWithinOf(new.Scope()) {
|
||||
return nil
|
||||
}
|
||||
return policy.ReplaceSource(old, new)
|
||||
})
|
||||
}
|
||||
|
||||
// unregisterSource unregisters the specified [source.Source],
|
||||
// so that it won't be used by any new or existing [Policy].
|
||||
func unregisterSource(source *source.Source) error {
|
||||
policyMu.Lock()
|
||||
defer policyMu.Unlock()
|
||||
index := slices.Index(policySources, source)
|
||||
if index == -1 {
|
||||
return nil
|
||||
}
|
||||
policySources = slices.Delete(policySources, index, index+1)
|
||||
return forEachResultantPolicyLocked(func(policy *Policy) error {
|
||||
if !policy.Scope().IsWithinOf(source.Scope()) {
|
||||
return nil
|
||||
}
|
||||
return policy.RemoveSource(source)
|
||||
})
|
||||
}
|
||||
|
||||
// forEachResultantPolicyLocked calls fn for every [Policy] in [resultantPolicies].
|
||||
// It accumulates the returned errors, except for [errResultantPolicyClosed],
|
||||
// and returns an error that wraps all errors returned by fn.
|
||||
// The [policyMu] mutex must be held while this function is executed.
|
||||
func forEachResultantPolicyLocked(fn func(p *Policy) error) error {
|
||||
var errs []error
|
||||
for _, policy := range resultantPolicies {
|
||||
err := fn(policy)
|
||||
if err != nil && !errors.Is(err, errResultantPolicyClosed) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// PolicyFor returns the [Policy] for the specified scope,
|
||||
// creating one from the registered [source.Store]s if it does not exist.
|
||||
func PolicyFor(scope setting.PolicyScope) (*Policy, error) {
|
||||
if err := lazyinit.Do(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy := resultantPolicyLRU[scope.Kind()].Load(); policy != nil && policy.Scope() == scope && policy.IsValid() {
|
||||
return policy, nil
|
||||
}
|
||||
return policyForSlow(scope)
|
||||
}
|
||||
|
||||
func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) {
|
||||
defer func() {
|
||||
if policy != nil {
|
||||
resultantPolicyLRU[scope.Kind()].Store(policy)
|
||||
}
|
||||
}()
|
||||
|
||||
policyMu.RLock()
|
||||
if policy, ok := findPolicyByScopeLocked(scope); ok {
|
||||
policyMu.RUnlock()
|
||||
return policy, nil
|
||||
}
|
||||
policyMu.RUnlock()
|
||||
|
||||
policyMu.Lock()
|
||||
defer policyMu.Unlock()
|
||||
if policy, ok := findPolicyByScopeLocked(scope); ok {
|
||||
return policy, nil
|
||||
}
|
||||
sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool {
|
||||
return scope.IsWithinOf(source.Scope())
|
||||
})
|
||||
policy, err = newPolicy(scope, sources...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resultantPolicies = append(resultantPolicies, policy)
|
||||
go func() {
|
||||
<-policy.Done()
|
||||
deletePolicy(policy)
|
||||
}()
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// findPolicyByScopeLocked returns a policy with the specified scope and true if
|
||||
// one exists, otherwise it returns nil, false.
|
||||
// [policyMu] must be held.
|
||||
func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) {
|
||||
for _, policy := range resultantPolicies {
|
||||
if policy.Scope() == target && policy.IsValid() {
|
||||
return policy, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// deletePolicy deletes the specified resultant policy from the [resultantPolicies] list.
|
||||
func deletePolicy(policy *Policy) {
|
||||
policyMu.Lock()
|
||||
if i := slices.Index(resultantPolicies, policy); i != -1 {
|
||||
resultantPolicies = slices.Delete(resultantPolicies, i, i+1)
|
||||
}
|
||||
resultantPolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil)
|
||||
policyMu.Unlock()
|
||||
}
|
||||
|
||||
// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore]
|
||||
// or [StoreRegistration.Unregister] is called more than once.
|
||||
var ErrAlreadyConsumed = errors.New("the store registration is no longer valid")
|
||||
|
||||
// StoreRegistration is a [source.Store] registered for use in the specified scope.
|
||||
// It can be used to unregister the store, or replace it with another one.
|
||||
type StoreRegistration struct {
|
||||
source *source.Source
|
||||
consumed atomic.Uint32
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
|
||||
func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
|
||||
return newStoreRegistration(name, scope, store)
|
||||
}
|
||||
|
||||
// RegisterStoreForTest is like [RegisterStore], but unregisters the store when
|
||||
// tb and all its subtests complete.
|
||||
func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
|
||||
reg, err := RegisterStore(name, scope, store)
|
||||
if err == nil {
|
||||
tb.Cleanup(func() {
|
||||
if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) {
|
||||
tb.Fatalf("Unregister failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
return reg, err // may be nil or non-nil
|
||||
}
|
||||
|
||||
func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
|
||||
source := source.NewSource(name, scope, store)
|
||||
if err := registerSource(source); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StoreRegistration{source: source}, nil
|
||||
}
|
||||
|
||||
// ReplaceStore replaces the registered store with the new one,
|
||||
// returning a new [StoreRegistration] or an error.
|
||||
func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) {
|
||||
var res *StoreRegistration
|
||||
err := r.consume(func() error {
|
||||
newSource := source.NewSource(r.source.Name(), r.source.Scope(), new)
|
||||
if err := replaceSource(r.source, newSource); err != nil {
|
||||
return err
|
||||
}
|
||||
res = &StoreRegistration{source: newSource}
|
||||
return nil
|
||||
})
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Unregister reverts the registration.
|
||||
func (r *StoreRegistration) Unregister() error {
|
||||
return r.consume(func() error { return unregisterSource(r.source) })
|
||||
}
|
||||
|
||||
// consume invokes fn, consuming r if no error is returned.
|
||||
// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call.
|
||||
func (r *StoreRegistration) consume(fn func() error) (err error) {
|
||||
if r.consumed.Load() != 0 {
|
||||
return ErrAlreadyConsumed
|
||||
}
|
||||
return r.consumeSlow(fn)
|
||||
}
|
||||
|
||||
func (r *StoreRegistration) consumeSlow(fn func() error) (err error) {
|
||||
r.m.Lock()
|
||||
defer r.m.Unlock()
|
||||
if r.consumed.Load() != 0 {
|
||||
return ErrAlreadyConsumed
|
||||
}
|
||||
if err = fn(); err == nil {
|
||||
r.consumed.Store(1)
|
||||
}
|
||||
return err // may be nil or non-nil
|
||||
}
|
||||
368
util/syspolicy/rsop/resultant_policy_test.go
Normal file
368
util/syspolicy/rsop/resultant_policy_test.go
Normal file
@@ -0,0 +1,368 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsop
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
func TestRegisterSourceAndGetResultantPolicy(t *testing.T) {
|
||||
type sourceConfig struct {
|
||||
name string
|
||||
scope setting.PolicyScope
|
||||
settingKey setting.Key
|
||||
settingValue string
|
||||
wantEffective bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
scope setting.PolicyScope
|
||||
initialSources []sourceConfig
|
||||
additionalSources []sourceConfig
|
||||
wantSnapshot *setting.Snapshot
|
||||
}{
|
||||
{
|
||||
name: "DevicePolicy/NoSources",
|
||||
scope: setting.DeviceScope,
|
||||
wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope),
|
||||
},
|
||||
{
|
||||
name: "UserScope/NoSources",
|
||||
scope: setting.CurrentUserScope,
|
||||
wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/OneInitialSource",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/OneAdditionalSource",
|
||||
scope: setting.DeviceScope,
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/ManyInitialSources/NoConflicts",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceB",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "TestValueB",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceC",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyC",
|
||||
settingValue: "TestValueC",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
|
||||
"TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
|
||||
}, setting.DeviceScope),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/ManyInitialSources/Conflicts",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceB",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "TestValueB",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceC",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueC",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
|
||||
}, setting.DeviceScope),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/MixedSources/Conflicts",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceB",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "TestValueB",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceC",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueC",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceD",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueD",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceE",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyC",
|
||||
settingValue: "TestValueE",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceF",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueF",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
|
||||
"TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)),
|
||||
}, setting.DeviceScope),
|
||||
},
|
||||
{
|
||||
name: "UserScope/Init-DeviceSource",
|
||||
scope: setting.CurrentUserScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceDevice",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "DeviceValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
}, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "UserScope/Init-DeviceSource/Add-UserSource",
|
||||
scope: setting.CurrentUserScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceDevice",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "DeviceValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceUser",
|
||||
scope: setting.CurrentUserScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "UserValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)),
|
||||
}, setting.CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource",
|
||||
scope: setting.CurrentUserScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceDevice",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "DeviceValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceProfile",
|
||||
scope: setting.CurrentProfileScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "ProfileValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceUser",
|
||||
scope: setting.CurrentUserScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "UserValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)),
|
||||
}, setting.CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/User-Source-does-not-apply",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceDevice",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "DeviceValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceUser",
|
||||
scope: setting.CurrentUserScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "UserValue",
|
||||
wantEffective: false, // Registering a user source should have no impact on the device policy.
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Register all settings that we use in this test.
|
||||
var definitions []*setting.Definition
|
||||
for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) {
|
||||
definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue))
|
||||
}
|
||||
if err := setting.SetDefinitionsForTest(t, definitions...); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
|
||||
// Add the initial policy sources.
|
||||
var wantSources []*source.Source
|
||||
for _, s := range tt.initialSources {
|
||||
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
|
||||
source := source.NewSource(s.name, s.scope, store)
|
||||
if err := registerSource(source); err != nil {
|
||||
t.Fatalf("failed to register policy source: %v", source)
|
||||
}
|
||||
if s.wantEffective {
|
||||
wantSources = append(wantSources, source)
|
||||
}
|
||||
t.Cleanup(func() { unregisterSource(source) })
|
||||
}
|
||||
|
||||
// Retrieve the resultant policy.
|
||||
policy, err := resultantPolicyForTest(t, tt.scope)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get resultant policy for %v", tt.scope)
|
||||
}
|
||||
|
||||
// Add additional setting sources one by one, and check the policy settings at each step.
|
||||
for _, s := range tt.additionalSources {
|
||||
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
|
||||
source := source.NewSource(s.name, s.scope, store)
|
||||
if err := registerSource(source); err != nil {
|
||||
t.Fatalf("failed to register additional policy source: %v", source)
|
||||
}
|
||||
if s.wantEffective {
|
||||
wantSources = append(wantSources, source)
|
||||
}
|
||||
t.Cleanup(func() { unregisterSource(source) })
|
||||
}
|
||||
|
||||
sort.SliceStable(wantSources, func(i, j int) bool {
|
||||
return wantSources[i].Compare(wantSources[j]) < 0
|
||||
})
|
||||
gotSources := make([]*source.Source, len(policy.sources))
|
||||
for i, s := range policy.sources {
|
||||
gotSources[i] = s.Source
|
||||
}
|
||||
if !slices.Equal(gotSources, wantSources) {
|
||||
t.Errorf("Sources: got %v; want %v", gotSources, wantSources)
|
||||
}
|
||||
|
||||
// Verify the final resultant settings snapshots.
|
||||
if got := policy.Get(); !got.Equal(tt.wantSnapshot) {
|
||||
t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// resultantPolicyForTest is like [resultantPolicyFor], but it deletes the policy
|
||||
// when tb and all its subtests complete.
|
||||
func resultantPolicyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) {
|
||||
policy, err := PolicyFor(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
policy.Close()
|
||||
<-policy.Done()
|
||||
deletePolicy(policy)
|
||||
})
|
||||
return policy, nil
|
||||
}
|
||||
60
util/syspolicy/setting/errors.go
Normal file
60
util/syspolicy/setting/errors.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrNotConfigured is returned when the requested policy setting is not configured.
|
||||
ErrNotConfigured = errors.New("not configured")
|
||||
// ErrTypeMismatch is returned when there's a type mismatch between the actual type
|
||||
// of the setting value and the expected type.
|
||||
ErrTypeMismatch = errors.New("type mismatch")
|
||||
// ErrNoSuchKey is returned by [DefinitionOf] when no policy setting
|
||||
// has been registered with the specified key.
|
||||
//
|
||||
// Until 2024-08-02, this error was also returned by a [Handler] when the specified
|
||||
// key did not have a value set. While the package maintains compatibility with this
|
||||
// usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
|
||||
// [source.Store] implementations.
|
||||
ErrNoSuchKey = errors.New("no such key")
|
||||
)
|
||||
|
||||
// Error is an error when reading or parsing a policy setting.
|
||||
type Error struct {
|
||||
text string
|
||||
}
|
||||
|
||||
// NewError returns a [Error] with the specified error message.
|
||||
func NewError(text string) *Error {
|
||||
return &Error{text}
|
||||
}
|
||||
|
||||
// WrapError returns an [Error] with the text of the specified error,
|
||||
// or nil if err is nil, [ErrNotConfigured], or [ErrNoSuchKey].
|
||||
func WrapError(err error) *Error {
|
||||
if err == nil || errors.Is(err, ErrNotConfigured) || errors.Is(err, ErrNoSuchKey) {
|
||||
return nil
|
||||
}
|
||||
if err, ok := err.(*Error); ok {
|
||||
return err
|
||||
}
|
||||
return &Error{err.Error()}
|
||||
}
|
||||
|
||||
// Error implements error.
|
||||
func (e Error) Error() string {
|
||||
return e.text
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (e Error) MarshalText() (text []byte, err error) {
|
||||
return []byte(e.Error()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
func (e *Error) UnmarshalText(text []byte) error {
|
||||
e.text = string(text)
|
||||
return nil
|
||||
}
|
||||
13
util/syspolicy/setting/key.go
Normal file
13
util/syspolicy/setting/key.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
// Key is a string that uniquely identifies a policy and must remain unchanged
|
||||
// once established and documented for a given policy setting. It may contain
|
||||
// alphanumeric characters and zero or more [KeyPathSeparator]s to group
|
||||
// individual policy settings into categories.
|
||||
type Key string
|
||||
|
||||
// KeyPathSeparator allows logical grouping of policy settings into categories.
|
||||
const KeyPathSeparator = "/"
|
||||
71
util/syspolicy/setting/origin.go
Normal file
71
util/syspolicy/setting/origin.go
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
"github.com/go-json-experiment/json/jsontext"
|
||||
)
|
||||
|
||||
// Origin describes where a policy or a policy setting is configured.
|
||||
type Origin struct {
|
||||
data settingOrigin
|
||||
}
|
||||
|
||||
// settingOrigin is the marshallable data of a [Origin].
|
||||
type settingOrigin struct {
|
||||
Name string `json:",omitzero"`
|
||||
Scope PolicyScope
|
||||
}
|
||||
|
||||
// NewOrigin returns a new [Origin] with the specified scope.
|
||||
func NewOrigin(scope PolicyScope) *Origin {
|
||||
return NewNamedOrigin("", scope)
|
||||
}
|
||||
|
||||
// NewNamedOrigin returns a new [Origin] with the specified scope and name.
|
||||
func NewNamedOrigin(name string, scope PolicyScope) *Origin {
|
||||
return &Origin{settingOrigin{name, scope}}
|
||||
}
|
||||
|
||||
// Scope reports the policy [PolicyScope] where the setting is configured.
|
||||
func (s Origin) Scope() PolicyScope {
|
||||
return s.data.Scope
|
||||
}
|
||||
|
||||
// Name returns the name of the policy source where the setting is configured,
|
||||
// or "" if not available.
|
||||
func (s Origin) Name() string {
|
||||
return s.data.Name
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (s Origin) String() string {
|
||||
if s.Name() != "" {
|
||||
return fmt.Sprintf("%s (%v)", s.Name(), s.Scope())
|
||||
}
|
||||
return s.Scope().String()
|
||||
}
|
||||
|
||||
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
|
||||
func (s Origin) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
|
||||
return jsonv2.MarshalEncode(out, &s.data, opts)
|
||||
}
|
||||
|
||||
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
|
||||
func (s *Origin) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
|
||||
return jsonv2.UnmarshalDecode(in, &s.data, opts)
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler].
|
||||
func (s Origin) MarshalJSON() ([]byte, error) {
|
||||
return jsonv2.Marshal(s) // uses MarshalJSONV2
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler].
|
||||
func (s *Origin) UnmarshalJSON(b []byte) error {
|
||||
return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
|
||||
}
|
||||
195
util/syspolicy/setting/policy_scope.go
Normal file
195
util/syspolicy/setting/policy_scope.go
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
)
|
||||
|
||||
var (
|
||||
lazyCurrentScope lazy.SyncValue[PolicyScope]
|
||||
|
||||
// DeviceScope indicates a scope containing device-global policies.
|
||||
DeviceScope = PolicyScope{kind: DeviceSetting}
|
||||
// CurrentProfileScope indicates a scope containing policies that apply to the
|
||||
// currently active Tailscale profile.
|
||||
CurrentProfileScope = PolicyScope{kind: ProfileSetting}
|
||||
// CurrentUserScope indicates a scope containing policies that apply to the
|
||||
// current user, for whatever that means on the current platform and
|
||||
// in the current application context.
|
||||
CurrentUserScope = PolicyScope{kind: UserSetting}
|
||||
)
|
||||
|
||||
// PolicyScope is a management scope.
|
||||
type PolicyScope struct {
|
||||
kind Scope
|
||||
userID string
|
||||
profileID string
|
||||
}
|
||||
|
||||
// CurrentScope returns the default [PolicyScope] that the package will use to return
|
||||
// the policy settings for unless a different scope is explicitly requested.
|
||||
// This defaults to [DeviceScope], unless the process runs as a user (rather than LocalSystem)
|
||||
// on Windows, in which case it returns the [CurrentUserScope].
|
||||
func CurrentScope() PolicyScope {
|
||||
// Allow deferred package init functions to override the default scope.
|
||||
lazyinit.Do()
|
||||
return lazyCurrentScope.Get(func() PolicyScope { return DeviceScope })
|
||||
}
|
||||
|
||||
// SetCurrentScope attempts to set the specified scope as the current scope,
|
||||
// and reports whether it succeeds.
|
||||
// It can be called only once and must be during lazy package initialization.
|
||||
func SetCurrentScope(scope PolicyScope) bool {
|
||||
return lazyCurrentScope.Set(scope)
|
||||
}
|
||||
|
||||
// UserScopeOf returns a policy [PolicyScope] of the specified user.
|
||||
func UserScopeOf(uid string) PolicyScope {
|
||||
return PolicyScope{kind: UserSetting, userID: uid}
|
||||
}
|
||||
|
||||
// Kind reports the base [Scope] of s.
|
||||
func (s PolicyScope) Kind() Scope {
|
||||
return s.kind
|
||||
}
|
||||
|
||||
// IsApplicableSetting reports whether the specified setting applies to
|
||||
// and can be retrieved for this scope. Policy settings are applicable
|
||||
// to their own scopes as well as more specific scopes. For example,
|
||||
// device settings are applicable to device, profile and user scopes,
|
||||
// but user settings are only applicable to user scopes.
|
||||
// For instance, a menu visibility setting is inherently a user setting
|
||||
// and only makes sense in the context of a specific user.
|
||||
func (s PolicyScope) IsApplicableSetting(setting *Definition) bool {
|
||||
return setting != nil && setting.Scope() <= s.Kind()
|
||||
}
|
||||
|
||||
// IsConfigurableSetting reports whether the specified setting can be configured
|
||||
// by a policy at this scope. Policy settings are configurable at their own scopes
|
||||
// as well as broader scopes. For example, [UserSetting]s are configurable in
|
||||
// user, profile, and device scopes, but [DeviceSetting]s are only configurable
|
||||
// in the [DeviceScope]. For instance, the InstallUpdates policy setting
|
||||
// can only be configured in the device scope, as it controls whether updates
|
||||
// will be installed automatically on the device, rather than for specific users.
|
||||
func (s PolicyScope) IsConfigurableSetting(setting *Definition) bool {
|
||||
return setting != nil && setting.Scope() >= s.Kind()
|
||||
}
|
||||
|
||||
// IsWithinOf reports whether policy settings that apply to s2 also apply to s.
|
||||
// For example, policy settings that apply to the [DeviceScope] also apply to
|
||||
// the [CurrentUserScope].
|
||||
func (s PolicyScope) IsWithinOf(s2 PolicyScope) bool {
|
||||
if s2.Kind() > s.Kind() {
|
||||
return false
|
||||
}
|
||||
switch s2.Kind() {
|
||||
case DeviceSetting:
|
||||
return true
|
||||
case ProfileSetting:
|
||||
return s.profileID == s2.profileID
|
||||
case UserSetting:
|
||||
return s.userID == s2.userID
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// IsStrictlyWithinOf is like [IsWithinOf], except it returns false
|
||||
// when s and s2 is the same scope.
|
||||
func (s PolicyScope) IsStrictlyWithinOf(s2 PolicyScope) bool {
|
||||
return s != s2 && s.IsWithinOf(s2)
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (s PolicyScope) String() string {
|
||||
if s.profileID == "" && s.userID == "" {
|
||||
return s.kind.String()
|
||||
}
|
||||
return s.stringSlow()
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (s PolicyScope) MarshalText() ([]byte, error) {
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextUnmarshaler].
|
||||
func (s *PolicyScope) UnmarshalText(b []byte) error {
|
||||
*s = PolicyScope{}
|
||||
parts := strings.SplitN(string(b), "/", 2)
|
||||
if len(parts) == 0 {
|
||||
return fmt.Errorf("%s is not a valid scope", b)
|
||||
}
|
||||
for i, part := range parts {
|
||||
kind, id, err := parseScopeAndID(part)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if i > 0 && kind <= s.kind {
|
||||
return fmt.Errorf("invalid scope hierarchy: %s", b)
|
||||
}
|
||||
s.kind = kind
|
||||
switch kind {
|
||||
case DeviceSetting:
|
||||
if id != "" {
|
||||
return fmt.Errorf("the device scope must not have an ID: %s", b)
|
||||
}
|
||||
case ProfileSetting:
|
||||
s.profileID = id
|
||||
case UserSetting:
|
||||
s.userID = id
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s PolicyScope) stringSlow() string {
|
||||
var sb strings.Builder
|
||||
writeScopeWithID := func(s Scope, id string) {
|
||||
sb.WriteString(s.String())
|
||||
if id != "" {
|
||||
sb.WriteRune('(')
|
||||
sb.WriteString(id)
|
||||
sb.WriteRune(')')
|
||||
}
|
||||
}
|
||||
if s.kind == ProfileSetting || s.profileID != "" {
|
||||
writeScopeWithID(ProfileSetting, s.profileID)
|
||||
if s.kind != ProfileSetting {
|
||||
sb.WriteRune('/')
|
||||
}
|
||||
}
|
||||
if s.kind == UserSetting {
|
||||
writeScopeWithID(UserSetting, s.userID)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func parseScopeAndID(s string) (scope Scope, id string, err error) {
|
||||
name, params, ok := extractScopeAndParams(s)
|
||||
if !ok {
|
||||
return 0, "", fmt.Errorf("%q is not a valid scope string", s)
|
||||
}
|
||||
if err := scope.UnmarshalText([]byte(name)); err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
return scope, params, nil
|
||||
}
|
||||
|
||||
func extractScopeAndParams(s string) (name, params string, ok bool) {
|
||||
paramsStart := strings.Index(s, "(")
|
||||
if paramsStart == -1 {
|
||||
return s, "", true
|
||||
}
|
||||
paramsEnd := strings.LastIndex(s, ")")
|
||||
if paramsEnd < paramsStart {
|
||||
return "", "", false
|
||||
}
|
||||
return s[0:paramsStart], s[paramsStart+1 : paramsEnd], true
|
||||
}
|
||||
550
util/syspolicy/setting/policy_scope_test.go
Normal file
550
util/syspolicy/setting/policy_scope_test.go
Normal file
@@ -0,0 +1,550 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
)
|
||||
|
||||
func TestPolicyScopeIsApplicableSetting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope PolicyScope
|
||||
setting *Definition
|
||||
wantApplicable bool
|
||||
}{
|
||||
{
|
||||
name: "DeviceScope/DeviceSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/ProfileSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantApplicable: false,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/UserSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantApplicable: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/DeviceSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/ProfileSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/UserSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantApplicable: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/DeviceSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope/ProfileSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope/UserSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotApplicable := tt.scope.IsApplicableSetting(tt.setting)
|
||||
if gotApplicable != tt.wantApplicable {
|
||||
t.Fatalf("got %v, want %v", gotApplicable, tt.wantApplicable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyScopeIsConfigurableSetting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope PolicyScope
|
||||
setting *Definition
|
||||
wantConfigurable bool
|
||||
}{
|
||||
{
|
||||
name: "DeviceScope/DeviceSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/ProfileSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/UserSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/DeviceSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantConfigurable: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/ProfileSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/UserSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope/DeviceSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantConfigurable: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/ProfileSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantConfigurable: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/UserSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotConfigurable := tt.scope.IsConfigurableSetting(tt.setting)
|
||||
if gotConfigurable != tt.wantConfigurable {
|
||||
t.Fatalf("got %v, want %v", gotConfigurable, tt.wantConfigurable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyScopeIsWithinOf(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopeA PolicyScope
|
||||
scopeB PolicyScope
|
||||
wantBWithinOfA bool
|
||||
wantBStrictlyWithinOfA bool
|
||||
}{
|
||||
{
|
||||
name: "DeviceScope/DeviceScope",
|
||||
scopeA: DeviceScope,
|
||||
scopeB: DeviceScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/CurrentProfileScope",
|
||||
scopeA: DeviceScope,
|
||||
scopeB: CurrentProfileScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/UserScope",
|
||||
scopeA: DeviceScope,
|
||||
scopeB: CurrentUserScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/DeviceScope",
|
||||
scopeA: CurrentProfileScope,
|
||||
scopeB: DeviceScope,
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/ProfileScope",
|
||||
scopeA: CurrentProfileScope,
|
||||
scopeB: CurrentProfileScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/UserScope",
|
||||
scopeA: CurrentProfileScope,
|
||||
scopeB: CurrentUserScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope/DeviceScope",
|
||||
scopeA: CurrentUserScope,
|
||||
scopeB: DeviceScope,
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/ProfileScope",
|
||||
scopeA: CurrentUserScope,
|
||||
scopeB: CurrentProfileScope,
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/UserScope",
|
||||
scopeA: CurrentUserScope,
|
||||
scopeB: CurrentUserScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope(1234)/UserScope(1234)",
|
||||
scopeA: UserScopeOf("1234"),
|
||||
scopeB: UserScopeOf("1234"),
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope(1234)/UserScope(5678)",
|
||||
scopeA: UserScopeOf("1234"),
|
||||
scopeB: UserScopeOf("5678"),
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope(A)/UserScope(A/1234)",
|
||||
scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"},
|
||||
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"},
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope(A)/UserScope(B/1234)",
|
||||
scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"},
|
||||
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "B"},
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope(1234)/UserScope(A/1234)",
|
||||
scopeA: PolicyScope{kind: UserSetting, userID: "1234"},
|
||||
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"},
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope(1234)/UserScope(A/5678)",
|
||||
scopeA: PolicyScope{kind: UserSetting, userID: "1234"},
|
||||
scopeB: PolicyScope{kind: UserSetting, userID: "5678", profileID: "A"},
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotWithinOf := tt.scopeB.IsWithinOf(tt.scopeA)
|
||||
if gotWithinOf != tt.wantBWithinOfA {
|
||||
t.Fatalf("WithinOf: got %v, want %v", gotWithinOf, tt.wantBWithinOfA)
|
||||
}
|
||||
|
||||
gotStrictlyWithinOf := tt.scopeB.IsStrictlyWithinOf(tt.scopeA)
|
||||
if gotStrictlyWithinOf != tt.wantBStrictlyWithinOfA {
|
||||
t.Fatalf("StrictlyWithinOf: got %v, want %v", gotStrictlyWithinOf, tt.wantBStrictlyWithinOfA)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyScopeMarshalUnmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in any
|
||||
wantJSON string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "null-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{},
|
||||
wantJSON: `{"Scope":"Device"}`,
|
||||
},
|
||||
{
|
||||
name: "null-scope-omit-zero",
|
||||
in: &struct {
|
||||
Scope PolicyScope `json:",omitzero"`
|
||||
}{},
|
||||
wantJSON: `{}`,
|
||||
},
|
||||
{
|
||||
name: "device-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{DeviceScope},
|
||||
wantJSON: `{"Scope":"Device"}`,
|
||||
},
|
||||
{
|
||||
name: "current-profile-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{CurrentProfileScope},
|
||||
wantJSON: `{"Scope":"Profile"}`,
|
||||
},
|
||||
{
|
||||
name: "current-user-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{CurrentUserScope},
|
||||
wantJSON: `{"Scope":"User"}`,
|
||||
},
|
||||
{
|
||||
name: "specific-user-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{UserScopeOf("_")},
|
||||
wantJSON: `{"Scope":"User(_)"}`,
|
||||
},
|
||||
{
|
||||
name: "specific-user-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{UserScopeOf("S-1-5-21-3698941153-1525015703-2649197413-1001")},
|
||||
wantJSON: `{"Scope":"User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`,
|
||||
},
|
||||
{
|
||||
name: "specific-profile-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{PolicyScope{kind: ProfileSetting, profileID: "1234"}},
|
||||
wantJSON: `{"Scope":"Profile(1234)"}`,
|
||||
},
|
||||
{
|
||||
name: "specific-profile-and-user-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{PolicyScope{
|
||||
kind: UserSetting,
|
||||
profileID: "1234",
|
||||
userID: "S-1-5-21-3698941153-1525015703-2649197413-1001",
|
||||
}},
|
||||
wantJSON: `{"Scope":"Profile(1234)/User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotJSON, err := jsonv2.Marshal(tt.in)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(gotJSON) != tt.wantJSON {
|
||||
t.Fatalf("Marshal got %s, want %s", gotJSON, tt.wantJSON)
|
||||
}
|
||||
wantBack := tt.in
|
||||
gotBack := reflect.New(reflect.TypeOf(tt.in).Elem()).Interface()
|
||||
err = jsonv2.Unmarshal(gotJSON, gotBack)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotBack, wantBack) {
|
||||
t.Fatalf("Unmarshal got %+v, want %+v", gotBack, wantBack)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyScopeUnmarshalSpecial(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
json string
|
||||
want any
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
json: "{}",
|
||||
want: &struct {
|
||||
Scope PolicyScope
|
||||
}{},
|
||||
},
|
||||
{
|
||||
name: "too-many-scopes",
|
||||
json: `{"Scope":"Device/Profile/User"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "user/profile", // incorrect order
|
||||
json: `{"Scope":"User/Profile"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "profile-user-no-params",
|
||||
json: `{"Scope":"Profile/User"}`,
|
||||
want: &struct {
|
||||
Scope PolicyScope
|
||||
}{CurrentUserScope},
|
||||
},
|
||||
{
|
||||
name: "unknown-scope",
|
||||
json: `{"Scope":"Unknown"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "unknown-scope/unknown-scope",
|
||||
json: `{"Scope":"Unknown/Unknown"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "device-scope/unknown-scope",
|
||||
json: `{"Scope":"Device/Unknown"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "unknown-scope/device-scope",
|
||||
json: `{"Scope":"Unknown/Device"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "slash",
|
||||
json: `{"Scope":"/"}`,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &struct {
|
||||
Scope PolicyScope
|
||||
}{}
|
||||
err := jsonv2.Unmarshal([]byte(tt.json), got)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("Marshal error: got %v, want %v", err, tt.wantError)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Fatalf("Unmarshal got %+v, want %+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestExtractScopeAndParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
scope string
|
||||
params string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
s: "",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "scope-only",
|
||||
s: "device",
|
||||
scope: "device",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "scope-with-params",
|
||||
s: "user(1234)",
|
||||
scope: "user",
|
||||
params: "1234",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "params-empty-scope",
|
||||
s: "(1234)",
|
||||
scope: "",
|
||||
params: "1234",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "params-with-brackets",
|
||||
s: "test()())))())",
|
||||
scope: "test",
|
||||
params: ")())))()",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "no-closing-bracket",
|
||||
s: "user(1234",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "open-before-close",
|
||||
s: ")user(1234",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "brackets-only",
|
||||
s: ")(",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "closing-bracket",
|
||||
s: ")",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "opening-bracket",
|
||||
s: ")",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
scope, params, ok := extractScopeAndParams(tt.s)
|
||||
if ok != tt.wantOk {
|
||||
t.Logf("OK: got %v; want %v", ok, tt.wantOk)
|
||||
}
|
||||
if scope != tt.scope {
|
||||
t.Logf("Scope: got %q; want %q", scope, tt.scope)
|
||||
}
|
||||
if params != tt.params {
|
||||
t.Logf("Params: got %v; want %v", params, tt.params)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
47
util/syspolicy/setting/raw_item.go
Normal file
47
util/syspolicy/setting/raw_item.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
// RawItem contains a raw policy setting as read from a policy store, or an
|
||||
// error if the requested setting could not be read from the store. As a special
|
||||
// case, it may also hold a value of the [Visibility], [PreferenceOption],
|
||||
// or [time.Duration] types. While the policy store interface does not support
|
||||
// these types natively, and the values of these types have to be unmarshalled
|
||||
// or converted from strings, these setting types predate the typed policy
|
||||
// hierarchies, and must be supported at this layer.
|
||||
type RawItem struct {
|
||||
value any
|
||||
err *Error
|
||||
origin *Origin // or nil
|
||||
}
|
||||
|
||||
// RawItemOf returns [RawItem] with the specified value.
|
||||
func RawItemOf(value any) RawItem {
|
||||
return RawItemWith(value, nil, nil)
|
||||
}
|
||||
|
||||
// RawItemWith returns an [RawItem] with the specified value, error and origin.
|
||||
func RawItemWith(value any, err *Error, origin *Origin) RawItem {
|
||||
return RawItem{value, err, origin}
|
||||
}
|
||||
|
||||
// Value returns the value of an untyped policy setting,
|
||||
// or nil if the policy setting is not configured.
|
||||
func (i RawItem) Value() any {
|
||||
return i.value
|
||||
}
|
||||
|
||||
// Error returns the error that occurred when reading the policy setting,
|
||||
// or nil if no error occurred.
|
||||
func (i RawItem) Error() error {
|
||||
if i.err != nil {
|
||||
return i.err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Origin returns an optional [Origin] indicating the policy settings is configured.
|
||||
func (i RawItem) Origin() *Origin {
|
||||
return i.origin
|
||||
}
|
||||
352
util/syspolicy/setting/setting.go
Normal file
352
util/syspolicy/setting/setting.go
Normal file
@@ -0,0 +1,352 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package setting contain types for policy settings.
|
||||
package setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
)
|
||||
|
||||
// Scope indicates the broadest scope at which a policy setting may apply,
|
||||
// and the narrowest scope at which it may be configured.
|
||||
type Scope int8
|
||||
|
||||
const (
|
||||
// DeviceSetting indicates a policy setting that applies to a device, regardless of
|
||||
// which OS user or Tailscale profile is currently active, if any.
|
||||
// It can only be configured at a [DeviceScope].
|
||||
DeviceSetting Scope = iota
|
||||
// ProfileSetting indicates a policy setting that applies to a Tailscale profile.
|
||||
// It can only be configured for a specific profile or at a [DeviceScope],
|
||||
// in which case it applies to all profiles on the device.
|
||||
ProfileSetting
|
||||
// UserSetting indicates a policy setting that applies to users.
|
||||
// It can be configured for a user, profile, or the entire device.
|
||||
UserSetting
|
||||
|
||||
// MaxSettingScope is the maximum possible [Scope] value.
|
||||
MaxSettingScope = UserSetting
|
||||
)
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (s Scope) String() string {
|
||||
switch s {
|
||||
case DeviceSetting:
|
||||
return "Device"
|
||||
case ProfileSetting:
|
||||
return "Profile"
|
||||
case UserSetting:
|
||||
return "User"
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (s Scope) MarshalText() (text []byte, err error) {
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
func (s *Scope) UnmarshalText(text []byte) error {
|
||||
switch strings.ToLower(string(text)) {
|
||||
case "device":
|
||||
*s = DeviceSetting
|
||||
case "profile":
|
||||
*s = ProfileSetting
|
||||
case "user":
|
||||
*s = UserSetting
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid scope", string(text))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Type is a policy setting value type.
|
||||
// Except for [InvalidValue], which represents an invalid policy setting type,
|
||||
// and [PreferenceOptionValue], [VisibilityValue], and [DurationValue],
|
||||
// which have special handling due to their legacy status in the package,
|
||||
// SettingTypes represent the raw value types readable from policy stores.
|
||||
type Type int
|
||||
|
||||
const (
|
||||
// InvalidValue indicates an invalid policy setting value type.
|
||||
InvalidValue Type = iota
|
||||
// BooleanValue indicates a policy setting whose underlying type in the
|
||||
// [source.Store] is a bool.
|
||||
BooleanValue
|
||||
// IntegerValue indicates a policy setting whose underlying type in the
|
||||
// [source.Store] is a uint64.
|
||||
IntegerValue
|
||||
// StringValue indicates a policy setting whose underlying type in the
|
||||
// [source.Store] is a string.
|
||||
StringValue
|
||||
// StringListValue indicates a policy setting whose underlying type in the
|
||||
// [source.Store] is a []string.
|
||||
StringListValue
|
||||
// PreferenceOptionValue indicates a three-state policy setting whose
|
||||
// underlying type in the [source.Store] is a string, but the actual value
|
||||
// is a [PreferenceOption].
|
||||
PreferenceOptionValue
|
||||
// VisibilityValue indicates a two-state boolean-like policy setting whose
|
||||
// underlying type in the [source.Store] is a string, but the actual value
|
||||
// is a [Visibility].
|
||||
VisibilityValue
|
||||
// DurationValue indicates an interval/period/duration policy setting whose
|
||||
// underlying type in the [source.Store] is a string, but the actual value
|
||||
// is a [time.Duration].
|
||||
DurationValue
|
||||
)
|
||||
|
||||
// String returns a string representation of t.
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case InvalidValue:
|
||||
return "Invalid"
|
||||
case BooleanValue:
|
||||
return "Boolean"
|
||||
case IntegerValue:
|
||||
return "Integer"
|
||||
case StringValue:
|
||||
return "String"
|
||||
case StringListValue:
|
||||
return "StringList"
|
||||
case PreferenceOptionValue:
|
||||
return "PreferenceOption"
|
||||
case VisibilityValue:
|
||||
return "Visibility"
|
||||
case DurationValue:
|
||||
return "Duration"
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// ValueType is a constraint that allows Go types corresponding to [Type].
|
||||
type ValueType interface {
|
||||
bool | uint64 | string | []string | Visibility | PreferenceOption | time.Duration
|
||||
}
|
||||
|
||||
// Definition defines policy key, scope and value type.
|
||||
type Definition struct {
|
||||
key Key
|
||||
scope Scope
|
||||
typ Type
|
||||
platforms PlatformList
|
||||
}
|
||||
|
||||
// NewDefinition returns a new [Definition] with the specified
|
||||
// key, scope, type and supported platforms (see [PlatformList]).
|
||||
func NewDefinition(k Key, s Scope, t Type, platforms ...string) *Definition {
|
||||
return &Definition{key: k, scope: s, typ: t, platforms: platforms}
|
||||
}
|
||||
|
||||
// Key returns a policy setting's identifier.
|
||||
func (d *Definition) Key() Key {
|
||||
if d == nil {
|
||||
return ""
|
||||
}
|
||||
return d.key
|
||||
}
|
||||
|
||||
// Scope reports the broadest [Scope] the policy setting may apply to.
|
||||
func (d *Definition) Scope() Scope {
|
||||
if d == nil {
|
||||
return 0
|
||||
}
|
||||
return d.scope
|
||||
}
|
||||
|
||||
// Type reports the underlying value type of the policy setting.
|
||||
func (d *Definition) Type() Type {
|
||||
if d == nil {
|
||||
return InvalidValue
|
||||
}
|
||||
return d.typ
|
||||
}
|
||||
|
||||
// IsSupported reports whether the policy setting is supported on the current OS.
|
||||
func (d *Definition) IsSupported() bool {
|
||||
if d == nil {
|
||||
return false
|
||||
}
|
||||
return d.platforms.HasCurrent()
|
||||
}
|
||||
|
||||
// SupportedPlatforms reports platforms on which the policy setting is supported.
|
||||
// An empty [PlatformList] indicates that s is available on all platforms.
|
||||
func (d *Definition) SupportedPlatforms() PlatformList {
|
||||
if d == nil {
|
||||
return nil
|
||||
}
|
||||
return d.platforms
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (d *Definition) String() string {
|
||||
if d == nil {
|
||||
return "(nil)"
|
||||
}
|
||||
return fmt.Sprintf("%v(%q, %v)", d.scope, d.key, d.typ)
|
||||
}
|
||||
|
||||
// Equal reports whether d and d2 have the same key, type and scope.
|
||||
// It does not check whether both s and s2 are supported on the same platforms.
|
||||
func (d *Definition) Equal(d2 *Definition) bool {
|
||||
if d == d2 {
|
||||
return true
|
||||
}
|
||||
if d == nil || d2 == nil {
|
||||
return false
|
||||
}
|
||||
return d.key == d2.key && d.typ == d2.typ && d.scope == d2.scope
|
||||
}
|
||||
|
||||
// DefinitionMap is a map of setting [Definition] by [Key].
|
||||
type DefinitionMap map[Key]*Definition
|
||||
|
||||
var (
|
||||
definitions lazy.SyncValue[DefinitionMap]
|
||||
|
||||
definitionsMu sync.Mutex
|
||||
definitionsList []*Definition
|
||||
definitionsUsed bool
|
||||
)
|
||||
|
||||
// Register registers a policy setting with the specified key, scope, and value type.
|
||||
// All policy settings must be registered before any of them can be used.
|
||||
// Register panics if called after invoking any syspolicy functions that use the
|
||||
// registered policy definitions, such as functions that read the policy.
|
||||
func Register(k Key, s Scope, t Type, platforms ...string) {
|
||||
RegisterDefinition(NewDefinition(k, s, t, platforms...))
|
||||
}
|
||||
|
||||
// RegisterDefinition is like [Register], but accepts a [Definition].
|
||||
func RegisterDefinition(d *Definition) {
|
||||
definitionsMu.Lock()
|
||||
defer definitionsMu.Unlock()
|
||||
registerLocked(d)
|
||||
}
|
||||
|
||||
func registerLocked(d *Definition) {
|
||||
if definitionsUsed {
|
||||
panic("policy definitions are already in use")
|
||||
}
|
||||
definitionsList = append(definitionsList, d)
|
||||
}
|
||||
|
||||
func settingDefinitions() (DefinitionMap, error) {
|
||||
return definitions.GetErr(func() (DefinitionMap, error) {
|
||||
lazyinit.Do()
|
||||
definitionsMu.Lock()
|
||||
defer definitionsMu.Unlock()
|
||||
definitionsUsed = true
|
||||
return DefinitionMapOf(definitionsList)
|
||||
})
|
||||
}
|
||||
|
||||
// DefinitionMapOf returns a [DefinitionMap] with the specified settings,
|
||||
// or an error if any settings have the same key but different type or scope.
|
||||
func DefinitionMapOf(settings []*Definition) (DefinitionMap, error) {
|
||||
m := make(DefinitionMap, len(settings))
|
||||
for _, s := range settings {
|
||||
if existing, exists := m[s.key]; exists {
|
||||
if existing.Equal(s) {
|
||||
// Ignore duplicate setting definitions if they match. It is acceptable
|
||||
// if the same policy setting was registered more than once
|
||||
// (e.g. by the syspolicy package itself and by iOS/Android code).
|
||||
existing.platforms.mergeFrom(s.platforms)
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("duplicate policy definition: %q", s.key)
|
||||
}
|
||||
m[s.key] = s
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// SetDefinitionsForTest allows to register the specified setting definitions
|
||||
// for the test duration. It is not concurrency-safe, but unlike [Register],
|
||||
// it does not panic and can be called anytime.
|
||||
// It returns an error if ds contains two different settings with the same [Key].
|
||||
func SetDefinitionsForTest(tb lazy.TB, ds ...*Definition) error {
|
||||
m, err := DefinitionMapOf(ds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
definitions.SetForTest(tb, m, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefinitionOf returns a setting definition by key,
|
||||
// or [ErrNoSuchKey] if the specified key does not exist,
|
||||
// or an error if there are conflicting policy definitions.
|
||||
func DefinitionOf(k Key) (*Definition, error) {
|
||||
ds, err := settingDefinitions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d, ok := ds[k]; ok {
|
||||
return d, nil
|
||||
}
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
|
||||
// Definitions returns all registered setting definitions,
|
||||
// or an error if different policies were registered under the same name.
|
||||
func Definitions() ([]*Definition, error) {
|
||||
ds, err := settingDefinitions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res := make([]*Definition, 0, len(ds))
|
||||
for _, d := range ds {
|
||||
res = append(res, d)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// PlatformList is a list of OSes.
|
||||
// An empty list indicates that all possible platforms are supported.
|
||||
type PlatformList []string
|
||||
|
||||
// Has reports whether the list contains the target platform.
|
||||
func (l PlatformList) Has(target string) bool {
|
||||
if len(l) == 0 {
|
||||
return true
|
||||
}
|
||||
return slices.ContainsFunc(l, func(os string) bool {
|
||||
return strings.EqualFold(os, target)
|
||||
})
|
||||
}
|
||||
|
||||
// HasCurrent is like Has, but for the current platform.
|
||||
func (l PlatformList) HasCurrent() bool {
|
||||
return l.Has(internal.OS())
|
||||
}
|
||||
|
||||
// mergeFrom merges l2 into l. Since an empty list indicates no platform restrictions,
|
||||
// if either l or l2 is empty, the merged result in l will also be empty.
|
||||
func (l *PlatformList) mergeFrom(l2 PlatformList) {
|
||||
switch {
|
||||
case len(*l) == 0:
|
||||
// No-op. An empty list indicates no platform restrictions.
|
||||
case len(l2) == 0:
|
||||
// Merging with an empty list results in an empty list.
|
||||
*l = l2
|
||||
default:
|
||||
// Append, sort and dedup.
|
||||
*l = append(*l, l2...)
|
||||
slices.Sort(*l)
|
||||
*l = slices.Compact(*l)
|
||||
}
|
||||
}
|
||||
344
util/syspolicy/setting/setting_test.go
Normal file
344
util/syspolicy/setting/setting_test.go
Normal file
@@ -0,0 +1,344 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
func TestSettingDefinition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setting *Definition
|
||||
osOverride string
|
||||
wantKey Key
|
||||
wantScope Scope
|
||||
wantType Type
|
||||
wantIsSupported bool
|
||||
wantSupportedPlatforms PlatformList
|
||||
wantString string
|
||||
}{
|
||||
{
|
||||
name: "Nil",
|
||||
setting: nil,
|
||||
wantKey: "",
|
||||
wantScope: 0,
|
||||
wantType: InvalidValue,
|
||||
wantIsSupported: false,
|
||||
wantString: "(nil)",
|
||||
},
|
||||
{
|
||||
name: "Device/Invalid",
|
||||
setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, InvalidValue),
|
||||
wantKey: "TestDevicePolicySetting",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: InvalidValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Device("TestDevicePolicySetting", Invalid)`,
|
||||
},
|
||||
{
|
||||
name: "Device/Integer",
|
||||
setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue),
|
||||
wantKey: "TestDevicePolicySetting",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: IntegerValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Device("TestDevicePolicySetting", Integer)`,
|
||||
},
|
||||
{
|
||||
name: "Profile/String",
|
||||
setting: NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue),
|
||||
wantKey: "TestProfilePolicySetting",
|
||||
wantScope: ProfileSetting,
|
||||
wantType: StringValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Profile("TestProfilePolicySetting", String)`,
|
||||
},
|
||||
{
|
||||
name: "Device/StringList",
|
||||
setting: NewDefinition("AllowedSuggestedExitNodes", DeviceSetting, StringListValue),
|
||||
wantKey: "AllowedSuggestedExitNodes",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: StringListValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Device("AllowedSuggestedExitNodes", StringList)`,
|
||||
},
|
||||
{
|
||||
name: "Device/PreferenceOption",
|
||||
setting: NewDefinition("AdvertiseExitNode", DeviceSetting, PreferenceOptionValue),
|
||||
wantKey: "AdvertiseExitNode",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: PreferenceOptionValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Device("AdvertiseExitNode", PreferenceOption)`,
|
||||
},
|
||||
{
|
||||
name: "User/Boolean",
|
||||
setting: NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue),
|
||||
wantKey: "TestUserPolicySetting",
|
||||
wantScope: UserSetting,
|
||||
wantType: BooleanValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `User("TestUserPolicySetting", Boolean)`,
|
||||
},
|
||||
{
|
||||
name: "User/Visibility",
|
||||
setting: NewDefinition("AdminConsole", UserSetting, VisibilityValue),
|
||||
wantKey: "AdminConsole",
|
||||
wantScope: UserSetting,
|
||||
wantType: VisibilityValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `User("AdminConsole", Visibility)`,
|
||||
},
|
||||
{
|
||||
name: "User/Duration",
|
||||
setting: NewDefinition("KeyExpirationNotice", UserSetting, DurationValue),
|
||||
wantKey: "KeyExpirationNotice",
|
||||
wantScope: UserSetting,
|
||||
wantType: DurationValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `User("KeyExpirationNotice", Duration)`,
|
||||
},
|
||||
{
|
||||
name: "SupportedSetting",
|
||||
setting: NewDefinition("DesktopPolicySetting", DeviceSetting, StringValue, "macos", "windows"),
|
||||
osOverride: "windows",
|
||||
wantKey: "DesktopPolicySetting",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: StringValue,
|
||||
wantIsSupported: true,
|
||||
wantSupportedPlatforms: PlatformList{"macos", "windows"},
|
||||
wantString: `Device("DesktopPolicySetting", String)`,
|
||||
},
|
||||
{
|
||||
name: "UnsupportedSetting",
|
||||
setting: NewDefinition("AndroidPolicySetting", DeviceSetting, StringValue, "android"),
|
||||
osOverride: "macos",
|
||||
wantKey: "AndroidPolicySetting",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: StringValue,
|
||||
wantIsSupported: false,
|
||||
wantSupportedPlatforms: PlatformList{"android"},
|
||||
wantString: `Device("AndroidPolicySetting", String)`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.osOverride != "" {
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
}
|
||||
if !tt.setting.Equal(tt.setting) {
|
||||
t.Errorf("the setting should be equal to itself")
|
||||
}
|
||||
if tt.setting != nil && !tt.setting.Equal(ptr.To(*tt.setting)) {
|
||||
t.Errorf("the setting should be equal to its shallow copy")
|
||||
}
|
||||
if gotKey := tt.setting.Key(); gotKey != tt.wantKey {
|
||||
t.Errorf("Key: got %q, want %q", gotKey, tt.wantKey)
|
||||
}
|
||||
if gotScope := tt.setting.Scope(); gotScope != tt.wantScope {
|
||||
t.Errorf("Scope: got %v, want %v", gotScope, tt.wantScope)
|
||||
}
|
||||
if gotType := tt.setting.Type(); gotType != tt.wantType {
|
||||
t.Errorf("Type: got %v, want %v", gotType, tt.wantType)
|
||||
}
|
||||
if gotIsSupported := tt.setting.IsSupported(); gotIsSupported != tt.wantIsSupported {
|
||||
t.Errorf("IsSupported: got %v, want %v", gotIsSupported, tt.wantIsSupported)
|
||||
}
|
||||
if gotSupportedPlatforms := tt.setting.SupportedPlatforms(); !slices.Equal(gotSupportedPlatforms, tt.wantSupportedPlatforms) {
|
||||
t.Errorf("SupportedPlatforms: got %v, want %v", gotSupportedPlatforms, tt.wantSupportedPlatforms)
|
||||
}
|
||||
if gotString := tt.setting.String(); gotString != tt.wantString {
|
||||
t.Errorf("String: got %v, want %v", gotString, tt.wantString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterSettingDefinition(t *testing.T) {
|
||||
const testPolicySettingKey Key = "TestPolicySetting"
|
||||
tests := []struct {
|
||||
name string
|
||||
key Key
|
||||
wantEq *Definition
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "GetRegistered",
|
||||
key: "TestPolicySetting",
|
||||
wantEq: NewDefinition(testPolicySettingKey, DeviceSetting, StringValue),
|
||||
},
|
||||
{
|
||||
name: "GetNonRegistered",
|
||||
key: "OtherPolicySetting",
|
||||
wantEq: nil,
|
||||
wantErr: ErrNoSuchKey,
|
||||
},
|
||||
}
|
||||
|
||||
resetSettingDefinitions(t)
|
||||
Register(testPolicySettingKey, DeviceSetting, StringValue)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, gotErr := DefinitionOf(tt.key)
|
||||
if gotErr != tt.wantErr {
|
||||
t.Errorf("gotErr %v, wantErr %v", gotErr, tt.wantErr)
|
||||
}
|
||||
if !got.Equal(tt.wantEq) {
|
||||
t.Errorf("got %v, want %v", got, tt.wantEq)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterAfterUsePanics(t *testing.T) {
|
||||
resetSettingDefinitions(t)
|
||||
|
||||
Register("TestPolicySetting", DeviceSetting, StringValue)
|
||||
DefinitionOf("TestPolicySetting")
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if gotPanic, wantPanic := recover(), "policy definitions are already in use"; gotPanic != wantPanic {
|
||||
t.Errorf("gotPanic: %q, wantPanic: %q", gotPanic, wantPanic)
|
||||
}
|
||||
}()
|
||||
|
||||
Register("TestPolicySetting", DeviceSetting, StringValue)
|
||||
}()
|
||||
}
|
||||
|
||||
func TestRegisterDuplicateSettings(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
settings []*Definition
|
||||
wantEq *Definition
|
||||
wantErrStr string
|
||||
}{
|
||||
{
|
||||
name: "NoConflict/Exact",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
|
||||
},
|
||||
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
|
||||
},
|
||||
{
|
||||
name: "NoConflict/MergeOS-First",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"),
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
|
||||
},
|
||||
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
|
||||
},
|
||||
{
|
||||
name: "NoConflict/MergeOS-Second",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"),
|
||||
},
|
||||
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
|
||||
},
|
||||
{
|
||||
name: "NoConflict/MergeOS-Both",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos"),
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "windows"),
|
||||
},
|
||||
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos", "windows"),
|
||||
},
|
||||
{
|
||||
name: "Conflict/Scope",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
|
||||
NewDefinition("TestPolicySetting", UserSetting, StringValue),
|
||||
},
|
||||
wantEq: nil,
|
||||
wantErrStr: `duplicate policy definition: "TestPolicySetting"`,
|
||||
},
|
||||
{
|
||||
name: "Conflict/Type",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", UserSetting, StringValue),
|
||||
NewDefinition("TestPolicySetting", UserSetting, IntegerValue),
|
||||
},
|
||||
wantEq: nil,
|
||||
wantErrStr: `duplicate policy definition: "TestPolicySetting"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resetSettingDefinitions(t)
|
||||
for _, s := range tt.settings {
|
||||
Register(s.Key(), s.Scope(), s.Type(), s.SupportedPlatforms()...)
|
||||
}
|
||||
got, err := DefinitionOf("TestPolicySetting")
|
||||
var gotErrStr string
|
||||
if err != nil {
|
||||
gotErrStr = err.Error()
|
||||
}
|
||||
if gotErrStr != tt.wantErrStr {
|
||||
t.Fatalf("ErrStr: got %q, want %q", gotErrStr, tt.wantErrStr)
|
||||
}
|
||||
if !got.Equal(tt.wantEq) {
|
||||
t.Errorf("Definition got %v, want %v", got, tt.wantEq)
|
||||
}
|
||||
if !slices.Equal(got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms()) {
|
||||
t.Errorf("SupportedPlatforms got %v, want %v", got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSettingDefinitions(t *testing.T) {
|
||||
definitions := []*Definition{
|
||||
NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue),
|
||||
NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue),
|
||||
NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue),
|
||||
NewDefinition("TestStringListPolicySetting", DeviceSetting, StringListValue),
|
||||
}
|
||||
if err := SetDefinitionsForTest(t, definitions...); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
|
||||
cmp := func(l, r *Definition) int {
|
||||
return strings.Compare(string(l.Key()), string(r.Key()))
|
||||
}
|
||||
want := append([]*Definition{}, definitions...)
|
||||
slices.SortFunc(want, cmp)
|
||||
|
||||
got, err := Definitions()
|
||||
if err != nil {
|
||||
t.Fatalf("Definitions failed: %v", err)
|
||||
}
|
||||
slices.SortFunc(got, cmp)
|
||||
|
||||
if !slices.Equal(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func resetSettingDefinitions(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
definitionsMu.Lock()
|
||||
definitionsList = nil
|
||||
definitions = lazy.SyncValue[DefinitionMap]{}
|
||||
definitionsUsed = false
|
||||
definitionsMu.Unlock()
|
||||
})
|
||||
|
||||
definitionsMu.Lock()
|
||||
definitionsList = nil
|
||||
definitions = lazy.SyncValue[DefinitionMap]{}
|
||||
definitionsUsed = false
|
||||
definitionsMu.Unlock()
|
||||
}
|
||||
153
util/syspolicy/setting/snapshot.go
Normal file
153
util/syspolicy/setting/snapshot.go
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
// Snapshot is an immutable collection of [RawItem]s, representing
|
||||
// a set of policy settings applied at a specific moment in time.
|
||||
// A nil pointer to [Snapshot] is valid.
|
||||
type Snapshot struct {
|
||||
m map[Key]RawItem
|
||||
sig deephash.Sum // of m
|
||||
summary Summary
|
||||
}
|
||||
|
||||
// NewSnapshot returns a new [Snapshot] with the specified items and options.
|
||||
func NewSnapshot(items map[Key]RawItem, opts ...SummaryOption) *Snapshot {
|
||||
return &Snapshot{m: items, sig: deephash.Hash(&items), summary: SummaryWith(opts...)}
|
||||
}
|
||||
|
||||
type keyItemPair struct {
|
||||
Key Key
|
||||
Item RawItem
|
||||
}
|
||||
|
||||
// All returns an iterator over [[Key], [RawItem]] key-value pairs in b. The
|
||||
// iteration order is not specified and is not guaranteed to be the same from
|
||||
// one call to the next.
|
||||
func (s *Snapshot) All() []keyItemPair {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
// TODO(nickkhyl): return iter.Seq2[[Key], [RawItem]] in Go 1.23,
|
||||
// and remove [keyItemPair].
|
||||
items := make([]keyItemPair, 0, len(s.m))
|
||||
for k, i := range s.m {
|
||||
items = append(items, keyItemPair{k, i})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// Get returns the value of the policy setting with the specified key
|
||||
// or nil if it does not exist or could not be read.
|
||||
func (s *Snapshot) Get(k Key) any {
|
||||
v, _ := s.GetErr(k)
|
||||
return v
|
||||
}
|
||||
|
||||
// GetErr returns the value of the policy setting with the specified key,
|
||||
// [ErrNotConfigured] if it does not exist, or an error returned by
|
||||
// the policy Store if the policy setting could not be read.
|
||||
func (s *Snapshot) GetErr(k Key) (any, error) {
|
||||
if s != nil {
|
||||
if s, ok := s.m[k]; ok {
|
||||
return s.Value(), s.Error()
|
||||
}
|
||||
}
|
||||
return nil, ErrNotConfigured
|
||||
}
|
||||
|
||||
// GetSetting returns the untyped policy setting with the specified key and true
|
||||
// if a policy setting with such key has been configured;
|
||||
// otherwise, it returns zero, false.
|
||||
func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) {
|
||||
setting, ok = s.m[k]
|
||||
return setting, ok
|
||||
}
|
||||
|
||||
// Equal reports whether s and s2 are equal.
|
||||
func (s *Snapshot) Equal(s2 *Snapshot) bool {
|
||||
if !s.EqualItems(s2) {
|
||||
return false
|
||||
}
|
||||
return s.Summary() == s2.Summary()
|
||||
}
|
||||
|
||||
// EqualItems reports whether items in s and s2 are equal.
|
||||
func (s *Snapshot) EqualItems(s2 *Snapshot) bool {
|
||||
if s == s2 {
|
||||
return true
|
||||
}
|
||||
if s.Len() != s2.Len() {
|
||||
return false
|
||||
}
|
||||
if s.Len() == 0 {
|
||||
return true
|
||||
}
|
||||
return s.sig == s2.sig
|
||||
}
|
||||
|
||||
// Keys return an iterator over keys in s. The iteration order is not specified
|
||||
// and is not guaranteed to be the same from one call to the next.
|
||||
func (s *Snapshot) Keys() []Key {
|
||||
if s.m == nil {
|
||||
return nil
|
||||
}
|
||||
// TODO(nickkhyl): return iter.Seq[Key] in Go 1.23.
|
||||
return xmaps.Keys(s.m)
|
||||
}
|
||||
|
||||
// Len reports the number of [RawItem]s in s.
|
||||
func (s *Snapshot) Len() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return len(s.m)
|
||||
}
|
||||
|
||||
// Summary returns information about s as a whole rather than about specific [RawItem]s in it.
|
||||
func (s *Snapshot) Summary() Summary {
|
||||
if s == nil {
|
||||
return Summary{}
|
||||
}
|
||||
return s.summary
|
||||
}
|
||||
|
||||
// MergeSnapshots returns a [Snapshot] that contains all [RawItem]s
|
||||
// from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope].
|
||||
// If there's a conflict between policy settings in the two snapshots,
|
||||
// the policy settings from the snapshot with the broader scope take precedence.
|
||||
// In other words, policy settings configured for the [DeviceScope] win
|
||||
// over policy settings configured for a user scope.
|
||||
func MergeSnapshots(snapshot1, snapshot2 *Snapshot) *Snapshot {
|
||||
scope1, ok1 := snapshot1.Summary().Scope().GetOk()
|
||||
scope2, ok2 := snapshot2.Summary().Scope().GetOk()
|
||||
if ok1 && ok2 && scope2.IsStrictlyWithinOf(scope1) {
|
||||
// Swap snapshots if snapshot1 has higher precedence than snapshot2.
|
||||
snapshot1, snapshot2 = snapshot2, snapshot1
|
||||
}
|
||||
if snapshot2.Len() == 0 {
|
||||
return snapshot1
|
||||
}
|
||||
summaryOpts := make([]SummaryOption, 0, 2)
|
||||
if scope, ok := snapshot1.Summary().Scope().GetOk(); ok {
|
||||
// Use the scope from snapshot1, if present, which is the more specific snapshot.
|
||||
summaryOpts = append(summaryOpts, scope)
|
||||
}
|
||||
if snapshot1.Len() == 0 {
|
||||
if origin, ok := snapshot2.Summary().Origin().GetOk(); ok {
|
||||
// Use the origin from snapshot2 if snapshot1 is empty.
|
||||
summaryOpts = append(summaryOpts, origin)
|
||||
}
|
||||
return &Snapshot{snapshot2.m, snapshot2.sig, SummaryWith(summaryOpts...)}
|
||||
}
|
||||
m := make(map[Key]RawItem, snapshot1.Len()+snapshot2.Len())
|
||||
xmaps.Copy(m, snapshot1.m)
|
||||
xmaps.Copy(m, snapshot2.m) // snapshot2 has higher precedence
|
||||
return &Snapshot{m, deephash.Hash(&m), SummaryWith(summaryOpts...)}
|
||||
}
|
||||
372
util/syspolicy/setting/snapshot_test.go
Normal file
372
util/syspolicy/setting/snapshot_test.go
Normal file
@@ -0,0 +1,372 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMergeSnapshots(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s1, s2 *Snapshot
|
||||
want *Snapshot
|
||||
}{
|
||||
{
|
||||
name: "both-nil",
|
||||
s1: nil,
|
||||
s2: nil,
|
||||
want: NewSnapshot(map[Key]RawItem{}),
|
||||
},
|
||||
{
|
||||
name: "both-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}),
|
||||
s2: NewSnapshot(map[Key]RawItem{}),
|
||||
want: NewSnapshot(map[Key]RawItem{}),
|
||||
},
|
||||
{
|
||||
name: "first-nil",
|
||||
s1: nil,
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "first-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "second-nil",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
s2: nil,
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "second-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "no-conflicts",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "with-conflicts",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "with-scope-first-wins",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, DeviceScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-second-wins",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-both-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-first-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true}}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-second-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MergeSnapshots(tt.s1, tt.s2)
|
||||
if !got.Equal(tt.want) {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
b1, b2 *Snapshot
|
||||
wantEqual bool
|
||||
wantEqualItems bool
|
||||
}{
|
||||
{
|
||||
name: "nil-nil",
|
||||
b1: nil,
|
||||
b2: nil,
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "nil-empty",
|
||||
b1: nil,
|
||||
b2: NewSnapshot(map[Key]RawItem{}),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "empty-nil",
|
||||
b1: NewSnapshot(map[Key]RawItem{}),
|
||||
b2: nil,
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "empty-empty",
|
||||
b1: NewSnapshot(map[Key]RawItem{}),
|
||||
b2: NewSnapshot(map[Key]RawItem{}),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "first-nil",
|
||||
b1: nil,
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
{
|
||||
name: "first-empty",
|
||||
b1: NewSnapshot(map[Key]RawItem{}),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
{
|
||||
name: "second-nil",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
b2: nil,
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
{
|
||||
name: "second-empty",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
b2: NewSnapshot(map[Key]RawItem{}),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
{
|
||||
name: "same-items-same-order-no-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "same-items-same-order-same-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "same-items-different-order-same-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting3": {value: false},
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
}, DeviceScope),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "same-items-same-order-different-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, CurrentUserScope),
|
||||
wantEqual: false,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "different-items-same-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}, DeviceScope),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotEqual := tt.b1.Equal(tt.b2); gotEqual != tt.wantEqual {
|
||||
t.Errorf("WantEqual: got %v, want %v", gotEqual, tt.wantEqual)
|
||||
}
|
||||
if gotEqualItems := tt.b1.EqualItems(tt.b2); gotEqualItems != tt.wantEqualItems {
|
||||
t.Errorf("WantEqualItems: got %v, want %v", gotEqualItems, tt.wantEqualItems)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
84
util/syspolicy/setting/summary.go
Normal file
84
util/syspolicy/setting/summary.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
"github.com/go-json-experiment/json/jsontext"
|
||||
"tailscale.com/types/opt"
|
||||
)
|
||||
|
||||
// Summary is an immutable [PolicyScope] and [Origin].
|
||||
type Summary struct {
|
||||
data summary
|
||||
}
|
||||
|
||||
type summary struct {
|
||||
Scope opt.Value[PolicyScope] `json:",omitzero"`
|
||||
Origin opt.Value[Origin] `json:",omitzero"`
|
||||
}
|
||||
|
||||
// SummaryWith returns a [Summary] with the specified options.
|
||||
func SummaryWith(opts ...SummaryOption) Summary {
|
||||
var summary Summary
|
||||
for _, o := range opts {
|
||||
o.applySummaryOption(&summary)
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
// Scope reports the [PolicyScope] in s.
|
||||
func (s Summary) Scope() opt.Value[PolicyScope] {
|
||||
return s.data.Scope
|
||||
}
|
||||
|
||||
// Origin reports the [Origin] in s.
|
||||
func (s Summary) Origin() opt.Value[Origin] {
|
||||
return s.data.Origin
|
||||
}
|
||||
|
||||
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
|
||||
func (s Summary) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
|
||||
return jsonv2.MarshalEncode(out, &s.data, opts)
|
||||
}
|
||||
|
||||
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
|
||||
func (s *Summary) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
|
||||
return jsonv2.UnmarshalDecode(in, &s.data, opts)
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler].
|
||||
func (s Summary) MarshalJSON() ([]byte, error) {
|
||||
return jsonv2.Marshal(s) // uses MarshalJSONV2
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler].
|
||||
func (s *Summary) UnmarshalJSON(b []byte) error {
|
||||
return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
|
||||
}
|
||||
|
||||
// SummaryOption is an option that configures [Summary]
|
||||
// The following are allowed options:
|
||||
//
|
||||
// - [Summary]
|
||||
// - [PolicyScope]
|
||||
// - [Origin]
|
||||
type SummaryOption interface {
|
||||
applySummaryOption(summary *Summary)
|
||||
}
|
||||
|
||||
func (s PolicyScope) applySummaryOption(summary *Summary) {
|
||||
summary.data.Scope.Set(s)
|
||||
}
|
||||
|
||||
func (o Origin) applySummaryOption(summary *Summary) {
|
||||
summary.data.Origin.Set(o)
|
||||
if !summary.data.Scope.IsSet() {
|
||||
summary.data.Scope.Set(o.Scope())
|
||||
}
|
||||
}
|
||||
|
||||
func (s Summary) applySummaryOption(summary *Summary) {
|
||||
*summary = s
|
||||
}
|
||||
132
util/syspolicy/setting/types.go
Normal file
132
util/syspolicy/setting/types.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
)
|
||||
|
||||
// PreferenceOption is a policy that governs whether a boolean variable
|
||||
// is forcibly assigned an administrator-defined value, or allowed to receive
|
||||
// a user-defined value.
|
||||
type PreferenceOption int
|
||||
|
||||
const (
|
||||
ShowChoiceByPolicy PreferenceOption = iota
|
||||
NeverByPolicy
|
||||
AlwaysByPolicy
|
||||
)
|
||||
|
||||
// Show returns if the UI option that controls the choice administered by this
|
||||
// policy should be shown. Currently this is true if and only if the policy is
|
||||
// [ShowChoiceByPolicy].
|
||||
func (p PreferenceOption) Show() bool {
|
||||
return p == ShowChoiceByPolicy
|
||||
}
|
||||
|
||||
// ShouldEnable checks if the choice administered by this policy should be
|
||||
// enabled. If the administrator has chosen a setting, the administrator's
|
||||
// setting is returned, otherwise userChoice is returned.
|
||||
func (p PreferenceOption) ShouldEnable(userChoice bool) bool {
|
||||
switch p {
|
||||
case NeverByPolicy:
|
||||
return false
|
||||
case AlwaysByPolicy:
|
||||
return true
|
||||
default:
|
||||
return userChoice
|
||||
}
|
||||
}
|
||||
|
||||
// IsAlways reports whether the preference should always be enabled.
|
||||
func (p PreferenceOption) IsAlways() bool {
|
||||
return p == AlwaysByPolicy
|
||||
}
|
||||
|
||||
// IsNever reports whether the preference should always be disabled.
|
||||
func (p PreferenceOption) IsNever() bool {
|
||||
return p == NeverByPolicy
|
||||
}
|
||||
|
||||
// WillOverride checks if the choice administered by the policy is different
|
||||
// from the user's choice.
|
||||
func (p PreferenceOption) WillOverride(userChoice bool) bool {
|
||||
return p.ShouldEnable(userChoice) != userChoice
|
||||
}
|
||||
|
||||
// String returns a string representation of p.
|
||||
func (p PreferenceOption) String() string {
|
||||
switch p {
|
||||
case AlwaysByPolicy:
|
||||
return "always"
|
||||
case NeverByPolicy:
|
||||
return "never"
|
||||
default:
|
||||
return "user-decides"
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (p *PreferenceOption) MarshalText() (text []byte, err error) {
|
||||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
func (p *PreferenceOption) UnmarshalText(text []byte) error {
|
||||
switch string(text) {
|
||||
case "always":
|
||||
*p = AlwaysByPolicy
|
||||
case "never":
|
||||
*p = NeverByPolicy
|
||||
default:
|
||||
*p = ShowChoiceByPolicy
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Visibility is a policy that controls whether or not a particular
|
||||
// component of a user interface is to be shown.
|
||||
type Visibility byte
|
||||
|
||||
var (
|
||||
_ encoding.TextMarshaler = (*Visibility)(nil)
|
||||
_ encoding.TextUnmarshaler = (*Visibility)(nil)
|
||||
)
|
||||
|
||||
const (
|
||||
VisibleByPolicy Visibility = 'v'
|
||||
HiddenByPolicy Visibility = 'h'
|
||||
)
|
||||
|
||||
// Show reports whether the UI option administered by this policy should be shown.
|
||||
// Currently this is true if the policy is not [hiddenByPolicy].
|
||||
func (p Visibility) Show() bool {
|
||||
return p != HiddenByPolicy
|
||||
}
|
||||
|
||||
// String returns a string representation of p.
|
||||
func (p Visibility) String() string {
|
||||
switch p {
|
||||
case 'h':
|
||||
return "hide"
|
||||
default:
|
||||
return "show"
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (p Visibility) MarshalText() (text []byte, err error) {
|
||||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
func (p *Visibility) UnmarshalText(text []byte) error {
|
||||
switch string(text) {
|
||||
case "hide":
|
||||
*p = HiddenByPolicy
|
||||
default:
|
||||
*p = VisibleByPolicy
|
||||
}
|
||||
return nil
|
||||
}
|
||||
393
util/syspolicy/source/policy_reader.go
Normal file
393
util/syspolicy/source/policy_reader.go
Normal file
@@ -0,0 +1,393 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/internal/metrics"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// Reader reads all configured policy settings from a given [Store].
|
||||
// It registers a change callback with the [Store] and maintains the current version
|
||||
// of the [setting.Snapshot] by lazily re-reading policy settings from the [Store]
|
||||
// whenever a new snapshot is requested
|
||||
// It is safe for concurrent use.
|
||||
type Reader struct {
|
||||
store Store
|
||||
origin *setting.Origin
|
||||
settings []*setting.Definition
|
||||
unregisterChangeNotifier func()
|
||||
doneCh chan struct{} // closed when policyCache is closed.
|
||||
|
||||
mu sync.RWMutex
|
||||
closing bool
|
||||
upToDate bool
|
||||
lastPolicy *setting.Snapshot
|
||||
sessions set.HandleSet[*ReadingSession]
|
||||
}
|
||||
|
||||
// newReader returns a new [Reader] that reads policy settings from a given [Store].
|
||||
// The returned reader takes ownership of the store. If the store implements [io.Closer],
|
||||
// the returned reader will close the store when it is closed.
|
||||
func newReader(store Store, origin *setting.Origin) (*Reader, error) {
|
||||
settings, err := setting.Definitions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if expirable, ok := store.(Expirable); ok {
|
||||
select {
|
||||
case <-expirable.Done():
|
||||
return nil, ErrStoreClosed
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
reader := &Reader{store: store, origin: origin, settings: settings, doneCh: make(chan struct{})}
|
||||
if changeable, ok := store.(Changeable); ok {
|
||||
// We should subscribe to policy change notifications first before reading
|
||||
// the policy settings from the store. This way we won't miss any notifications.
|
||||
if reader.unregisterChangeNotifier, err = changeable.RegisterChangeCallback(reader.onPolicyChange); err != nil {
|
||||
// Errors registering policy change callbacks are non-fatal.
|
||||
// TODO(nickkhyl): implement a background policy refresh every X minutes?
|
||||
loggerx.Errorf("failed to register %v policy change callback: %v\n", origin, err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := reader.reload(true); err != nil {
|
||||
if reader.unregisterChangeNotifier != nil {
|
||||
reader.unregisterChangeNotifier()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if expirable, ok := store.(Expirable); ok {
|
||||
if waitCh := expirable.Done(); waitCh != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-waitCh:
|
||||
reader.Close()
|
||||
case <-reader.doneCh:
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// GetSettings returns the current [*setting.Snapshot],
|
||||
// re-reading it from from the underlying [Store] only if the policy
|
||||
// has changed since it was read last. It never fails and returns
|
||||
// the previous version of the policy settings if a read attempt fails.
|
||||
func (r *Reader) GetSettings() *setting.Snapshot {
|
||||
r.mu.RLock()
|
||||
if r.upToDate {
|
||||
r.mu.RUnlock()
|
||||
return r.lastPolicy
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
policy, err := r.reload(false)
|
||||
if err != nil {
|
||||
// If the policy could not be reloaded at all, we'll return the last cached version of it.
|
||||
// On the contrary, errors specific to individual policy items are always propagated to the callers.
|
||||
loggerx.Errorf("failed to reload %v policy: %v\n", r.origin, err)
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
// ReadSettings reads policy settings from the underlying [Store] even if no
|
||||
// changes were detected. It returns the new [*setting.Snapshot], nil on
|
||||
// success, or nil, error in case of failure.
|
||||
func (r *Reader) ReadSettings() (*setting.Snapshot, error) {
|
||||
b, err := r.reload(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// reload is like [Reader.ReadSettings], but allows specifying whether to re-read
|
||||
// an unchanged policy, and returns the last [*setting.Snapshot] if the read fails.
|
||||
func (r *Reader) reload(force bool) (*setting.Snapshot, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.upToDate && !force {
|
||||
return r.lastPolicy, nil
|
||||
}
|
||||
|
||||
if lockable, ok := r.store.(Lockable); ok {
|
||||
if err := lockable.Lock(); err != nil {
|
||||
return r.lastPolicy, err
|
||||
}
|
||||
defer lockable.Unlock()
|
||||
}
|
||||
|
||||
r.upToDate = true
|
||||
|
||||
metrics.Reset(r.origin)
|
||||
|
||||
var m map[setting.Key]setting.RawItem
|
||||
if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 {
|
||||
m = make(map[setting.Key]setting.RawItem, lastPolicyCount)
|
||||
}
|
||||
for _, s := range r.settings {
|
||||
if !r.origin.Scope().IsConfigurableSetting(s) {
|
||||
// Skip settings that cannot be configured in the current scope.
|
||||
continue
|
||||
}
|
||||
|
||||
val, err := readPolicySettingValue(r.store, s)
|
||||
if err != nil && (errors.Is(err, setting.ErrNoSuchKey) || errors.Is(err, setting.ErrNotConfigured)) {
|
||||
metrics.ReportNotConfigured(r.origin, s)
|
||||
continue
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
metrics.ReportConfigured(r.origin, s, val)
|
||||
} else {
|
||||
metrics.ReportError(r.origin, s, err)
|
||||
}
|
||||
|
||||
// If there's an error reading a single policy, such as a value type mismatch,
|
||||
// we'll wrap the error to preserve its text and return it
|
||||
// whenever someone attempts to fetch the value.
|
||||
mak.Set(&m, s.Key(), setting.RawItemWith(val, setting.WrapError(err), r.origin))
|
||||
}
|
||||
|
||||
newPolicy := setting.NewSnapshot(m, setting.SummaryWith(r.origin))
|
||||
if r.lastPolicy == nil || !newPolicy.EqualItems(r.lastPolicy) {
|
||||
r.lastPolicy = newPolicy
|
||||
}
|
||||
return r.lastPolicy, nil
|
||||
}
|
||||
|
||||
// ReadingSession is like [Reader], but with a channel that's written
|
||||
// to when there's a policy change, and closed when the session is terminated.
|
||||
type ReadingSession struct {
|
||||
reader *Reader
|
||||
policyChangedCh chan struct{} // 1-buffered channel
|
||||
handle set.Handle // in the reader.sessions
|
||||
closeInternal func()
|
||||
}
|
||||
|
||||
// OpenSession opens and returns a new session to r, allowing the caller
|
||||
// to get notified whenever a policy change is reported by the [source.Store],
|
||||
// or an [ErrStoreClosed] if the reader has already been closed.
|
||||
func (r *Reader) OpenSession() (*ReadingSession, error) {
|
||||
session := &ReadingSession{
|
||||
reader: r,
|
||||
policyChangedCh: make(chan struct{}, 1),
|
||||
}
|
||||
session.closeInternal = sync.OnceFunc(func() { close(session.policyChangedCh) })
|
||||
r.mu.Lock()
|
||||
if !r.closing {
|
||||
session.handle = r.sessions.Add(session)
|
||||
r.mu.Unlock()
|
||||
return session, nil
|
||||
}
|
||||
r.mu.Unlock()
|
||||
return nil, ErrStoreClosed
|
||||
}
|
||||
|
||||
// GetSettings is like [Reader.GetSettings].
|
||||
func (s *ReadingSession) GetSettings() *setting.Snapshot {
|
||||
return s.reader.GetSettings()
|
||||
}
|
||||
|
||||
// ReadSettings is like [Reader.ReadSettings].
|
||||
func (s *ReadingSession) ReadSettings() (*setting.Snapshot, error) {
|
||||
return s.reader.ReadSettings()
|
||||
}
|
||||
|
||||
// PolicyChanged returns a channel that's written to when
|
||||
// there's a policy change, closed when the session is terminated.
|
||||
func (s *ReadingSession) PolicyChanged() <-chan struct{} {
|
||||
return s.policyChangedCh
|
||||
}
|
||||
|
||||
// Close unregisters this session with the [Reader].
|
||||
func (s *ReadingSession) Close() {
|
||||
s.reader.mu.Lock()
|
||||
delete(s.reader.sessions, s.handle)
|
||||
s.closeInternal()
|
||||
s.reader.mu.Unlock()
|
||||
}
|
||||
|
||||
// onPolicyChange handles a policy change notification from the [Store],
|
||||
// invalidating the current [setting.Snapshot] in r,
|
||||
// and notifying the active [ReadingSession]s.
|
||||
func (r *Reader) onPolicyChange() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.upToDate = false
|
||||
for _, s := range r.sessions {
|
||||
select {
|
||||
case s.policyChangedCh <- struct{}{}:
|
||||
// Notified.
|
||||
default:
|
||||
// 1-buffered channel is full, meaning that another policy change
|
||||
// notification is already en route.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the store reader and the underlying store.
|
||||
func (r *Reader) Close() error {
|
||||
r.mu.Lock()
|
||||
if r.closing {
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
r.closing = true
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.unregisterChangeNotifier != nil {
|
||||
r.unregisterChangeNotifier()
|
||||
r.unregisterChangeNotifier = nil
|
||||
}
|
||||
|
||||
if closer, ok := r.store.(io.Closer); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.store = nil
|
||||
|
||||
close(r.doneCh)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, c := range r.sessions {
|
||||
c.closeInternal()
|
||||
}
|
||||
r.sessions = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the reader is closed.
|
||||
func (r *Reader) Done() <-chan struct{} {
|
||||
return r.doneCh
|
||||
}
|
||||
|
||||
// ReadableSource is a [Source] open for reading.
|
||||
type ReadableSource struct {
|
||||
*Source
|
||||
*ReadingSession
|
||||
}
|
||||
|
||||
// Close closes the underlying [ReadingSession].
|
||||
func (s ReadableSource) Close() {
|
||||
s.ReadingSession.Close()
|
||||
}
|
||||
|
||||
// ReadableSources is a slice of [ReadableSource].
|
||||
type ReadableSources []ReadableSource
|
||||
|
||||
// Contains reports whether s contains the specified source.
|
||||
func (s ReadableSources) Contains(source *Source) bool {
|
||||
return s.IndexOf(source) != -1
|
||||
}
|
||||
|
||||
// IndexOf returns position of the specified source in s, or -1
|
||||
// if the source does not exist.
|
||||
func (s ReadableSources) IndexOf(source *Source) int {
|
||||
return slices.IndexFunc(s, func(rs ReadableSource) bool {
|
||||
return rs.Source == source
|
||||
})
|
||||
}
|
||||
|
||||
// InsertionIndexOf returns the position at which source can be inserted
|
||||
// to maintain the sorted order of the readableSources.
|
||||
// The return value is unspecified if s is not sorted on entry to InsertionIndexOf.
|
||||
func (s ReadableSources) InsertionIndexOf(source *Source) int {
|
||||
low, high := 0, len(s)
|
||||
for low < high {
|
||||
mid := (low + high) / 2
|
||||
if s[mid].Compare(source) <= 0 {
|
||||
low = mid + 1
|
||||
} else {
|
||||
high = mid
|
||||
}
|
||||
}
|
||||
return low
|
||||
}
|
||||
|
||||
// StableSort sorts the readableSources by the precedence, so that policy settings
|
||||
// from sources with higher precedence (e.g., [DeviceScope]) will be merged last,
|
||||
// overriding any policy settings with the same keys configured in sources with
|
||||
// lower precedence (e.g., [CurrentUserScope]).
|
||||
func (s *ReadableSources) StableSort() {
|
||||
sort.SliceStable(*s, func(i, j int) bool {
|
||||
return (*s)[i].Source.Compare((*s)[j].Source) < 0
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAt closes and deletes the i-th source from s.
|
||||
func (s *ReadableSources) DeleteAt(i int) {
|
||||
(*s)[i].Close()
|
||||
*s = slices.Delete(*s, i, i+1)
|
||||
}
|
||||
|
||||
// Close closes and deletes all sources in s.
|
||||
func (s *ReadableSources) Close() {
|
||||
for _, s := range *s {
|
||||
s.Close()
|
||||
}
|
||||
*s = nil
|
||||
}
|
||||
|
||||
func readPolicySettingValue(store Store, s *setting.Definition) (value any, err error) {
|
||||
switch key := s.Key(); s.Type() {
|
||||
case setting.BooleanValue:
|
||||
return store.ReadBoolean(key)
|
||||
case setting.IntegerValue:
|
||||
return store.ReadUInt64(key)
|
||||
case setting.StringValue:
|
||||
return store.ReadString(key)
|
||||
case setting.StringListValue:
|
||||
return store.ReadStringArray(key)
|
||||
case setting.PreferenceOptionValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value setting.PreferenceOption
|
||||
if err = value.UnmarshalText([]byte(s)); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return setting.ShowChoiceByPolicy, err
|
||||
case setting.VisibilityValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value setting.Visibility
|
||||
if err = value.UnmarshalText([]byte(s)); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return setting.VisibleByPolicy, err
|
||||
case setting.DurationValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value time.Duration
|
||||
if value, err = time.ParseDuration(s); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unsupported setting type: %v", setting.ErrTypeMismatch, s.Type())
|
||||
}
|
||||
}
|
||||
291
util/syspolicy/source/policy_reader_test.go
Normal file
291
util/syspolicy/source/policy_reader_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestReaderLifecycle(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origin *setting.Origin
|
||||
definitions []*setting.Definition
|
||||
wantReads []TestExpectedReads
|
||||
initStrings []TestSetting[string]
|
||||
initUInt64s []TestSetting[uint64]
|
||||
initWant *setting.Snapshot
|
||||
addStrings []TestSetting[string]
|
||||
addStringLists []TestSetting[[]string]
|
||||
newWant *setting.Snapshot
|
||||
}{
|
||||
{
|
||||
name: "read-all-settings-once",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
|
||||
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "re-read-all-settings-when-the-policy-changes",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
|
||||
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")},
|
||||
addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})},
|
||||
newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/device",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DeviceSetting", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/profile",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentProfileScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
// Device settings cannot be configured at the profile scope and should not be read.
|
||||
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/user",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
// Device and profile settings cannot be configured at the profile scope and should not be read.
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-stringy-settings",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initStrings: []TestSetting[string]{
|
||||
TestSettingOf("DurationValue", "2h30m"),
|
||||
TestSettingOf("PreferenceOptionValue", "always"),
|
||||
TestSettingOf("VisibilityValue", "show"),
|
||||
},
|
||||
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "read-erroneous-stringy-settings",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DurationValue1", setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition("DurationValue2", setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.UserSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.UserSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DurationValue1", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "DurationValue2", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initStrings: []TestSetting[string]{
|
||||
TestSettingOf("DurationValue1", "soon"),
|
||||
TestSettingWithError[string]("DurationValue2", setting.NewError("bang!")),
|
||||
TestSettingOf("PreferenceOptionValue", "sometimes"),
|
||||
},
|
||||
initUInt64s: []TestSetting[uint64]{
|
||||
TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch
|
||||
},
|
||||
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"DurationValue1": setting.RawItemWith(nil, setting.NewError("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"DurationValue2": setting.RawItemWith(nil, setting.NewError("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewError("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setting.SetDefinitionsForTest(t, tt.definitions...)
|
||||
store := NewTestStore(t)
|
||||
store.SetStrings(tt.initStrings...)
|
||||
store.SetUInt64s(tt.initUInt64s...)
|
||||
|
||||
reader, err := newReader(store, tt.origin)
|
||||
if err != nil {
|
||||
t.Fatalf("newReader failed: %v", err)
|
||||
}
|
||||
|
||||
if got := reader.GetSettings(); tt.initWant != nil && !got.Equal(tt.initWant) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
|
||||
// Should not result in new reads as there were no changes.
|
||||
N := 100
|
||||
for range N {
|
||||
reader.GetSettings()
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
store.ResetCounters()
|
||||
|
||||
got, err := reader.ReadSettings()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadSettings failed: %v", err)
|
||||
}
|
||||
|
||||
if tt.initWant != nil && !got.Equal(tt.initWant) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
|
||||
}
|
||||
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
store.ResetCounters()
|
||||
|
||||
if len(tt.addStrings) != 0 || len(tt.addStringLists) != 0 {
|
||||
store.SetStrings(tt.addStrings...)
|
||||
store.SetStringLists(tt.addStringLists...)
|
||||
|
||||
// As the settings have changed, GetSettings needs to re-read them.
|
||||
if got, want := reader.GetSettings(), cmp.Or(tt.newWant, tt.initWant); !got.Equal(want) {
|
||||
t.Errorf("New Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-reader.Done():
|
||||
t.Fatalf("the reader is closed")
|
||||
default:
|
||||
}
|
||||
|
||||
store.Close()
|
||||
|
||||
<-reader.Done()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadingSession(t *testing.T) {
|
||||
setting.SetDefinitionsForTest(t, setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue))
|
||||
store := NewTestStore(t)
|
||||
|
||||
origin := setting.NewOrigin(setting.DeviceScope)
|
||||
reader, err := newReader(store, origin)
|
||||
if err != nil {
|
||||
t.Fatalf("newReader failed: %v", err)
|
||||
}
|
||||
session, err := reader.OpenSession()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open a reading session: %v", err)
|
||||
}
|
||||
t.Cleanup(session.Close)
|
||||
|
||||
if got, want := session.GetSettings(), setting.NewSnapshot(nil, origin); !got.Equal(want) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
select {
|
||||
case _, ok := <-session.PolicyChanged():
|
||||
if ok {
|
||||
t.Fatalf("the policy changed notification was sent prematurely")
|
||||
} else {
|
||||
t.Fatalf("the session was closed prematurely")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
store.SetStrings(TestSettingOf("StringValue", "S1"))
|
||||
_, ok := <-session.PolicyChanged()
|
||||
if !ok {
|
||||
t.Fatalf("the session was closed prematurely")
|
||||
}
|
||||
|
||||
want := setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"StringValue": setting.RawItemWith("S1", nil, origin),
|
||||
}, origin)
|
||||
if got := session.GetSettings(); !got.Equal(want) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
store.Close()
|
||||
if _, ok = <-session.PolicyChanged(); ok {
|
||||
t.Fatalf("the session must be closed")
|
||||
}
|
||||
}
|
||||
146
util/syspolicy/source/policy_store.go
Normal file
146
util/syspolicy/source/policy_store.go
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// ErrStoreClosed is an error returned when attempting to use a [Store] after it
|
||||
// has been closed.
|
||||
var ErrStoreClosed = errors.New("the policy store has been closed")
|
||||
|
||||
// Store provides methods to read system policy settings from OS-specific storage.
|
||||
// Implementations must be concurrency-safe, and may also implement
|
||||
// [Lockable], [Changeable], [Expirable] and [io.Closer].
|
||||
//
|
||||
// If a [Store] implementation also implements [io.Closer],
|
||||
// it will be called by the package to release the resources
|
||||
// when the store is no longer needed.
|
||||
type Store interface {
|
||||
// ReadString returns the value of a [setting.StringValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
|
||||
ReadString(key setting.Key) (string, error)
|
||||
// ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
|
||||
ReadUInt64(key setting.Key) (uint64, error)
|
||||
// ReadBoolean returns the value of a [setting.BooleanValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
|
||||
ReadBoolean(key setting.Key) (bool, error)
|
||||
// ReadStringArray returns the value of a [setting.StringListValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an [setting.ErrTypeMismatch] if the policy setting is not of a string list type.
|
||||
ReadStringArray(key setting.Key) ([]string, error)
|
||||
}
|
||||
|
||||
// Lockable is an optional interface that [Store] implementations may support.
|
||||
// Locking a [Store] is not mandatory as [Store] must be concurrency-safe,
|
||||
// but is recommended to avoid issues where consecutive read calls for related
|
||||
// policies might return inconsistent results if a policy change occurs between
|
||||
// the calls.
|
||||
type Lockable interface {
|
||||
|
||||
// Lock acquires a read lock on the policy store,
|
||||
// ensuring the store's state remains unchanged while locked.
|
||||
// Multiple readers can hold the lock simultaneously.
|
||||
// It should return nil if the store does not support locking,
|
||||
// or an error if the store cannot be locked.
|
||||
Lock() error
|
||||
// Unlock unlocks the policy store.
|
||||
// It is a runtime error if the store is not locked on entry to Unlock.
|
||||
Unlock()
|
||||
}
|
||||
|
||||
// Changeable is an optional interface that [Store] implementations may support.
|
||||
type Changeable interface {
|
||||
// RegisterChangeCallback adds a function that will be called
|
||||
// whenever there's a policy change in the [Store].
|
||||
// The returned function can be used to unregister the callback.
|
||||
RegisterChangeCallback(callback func()) (unregister func(), err error)
|
||||
}
|
||||
|
||||
// Expirable is an optional interface that [Store] implementations may support.
|
||||
type Expirable interface {
|
||||
// Done returns a channel that is closed when the policy [Store] should no longer be used.
|
||||
// It should return nil if the store never expires.
|
||||
Done() <-chan struct{}
|
||||
}
|
||||
|
||||
// Source represents a named source of policy settings for a given scope.
|
||||
type Source struct {
|
||||
name string
|
||||
scope setting.PolicyScope
|
||||
store Store
|
||||
origin *setting.Origin
|
||||
|
||||
lazyReader lazy.SyncValue[*Reader]
|
||||
}
|
||||
|
||||
// NewSource returns a new [Source] with the specified name, scope, and store.
|
||||
func NewSource(name string, scope setting.PolicyScope, store Store) *Source {
|
||||
return &Source{name: name, scope: scope, store: store, origin: setting.NewNamedOrigin(name, scope)}
|
||||
}
|
||||
|
||||
// Name reports the name of the policy source.
|
||||
func (s *Source) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Scope reports the management scope of the policy source.
|
||||
func (s *Source) Scope() setting.PolicyScope {
|
||||
return s.scope
|
||||
}
|
||||
|
||||
// Store returns the [Store] that can be used to read policy settings from this source.
|
||||
func (s *Source) Store() Store {
|
||||
return s.store
|
||||
}
|
||||
|
||||
// Reader returns a [Reader] that reads from this source's [Store].
|
||||
func (s *Source) Reader() (*Reader, error) {
|
||||
return s.lazyReader.GetErr(func() (*Reader, error) {
|
||||
return newReader(s.store, s.origin)
|
||||
})
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (s *Source) String() string {
|
||||
if s.Name() != "" {
|
||||
return fmt.Sprintf("%s (%v)", s.Name(), s.Scope())
|
||||
}
|
||||
return s.Scope().String()
|
||||
}
|
||||
|
||||
// Compare returns an integer comparing [Source] s and s2
|
||||
// by their precedence, following the "last-wins" model.
|
||||
// The result will be:
|
||||
//
|
||||
// -1 if policy settings from s should be processed before policy settings from s2;
|
||||
// +1 if policy settings from s should be processed after policy settings from s2, overriding s2;
|
||||
// 0 if the relative processing order of policy settings in s and s2 is unspecified.
|
||||
func (s *Source) Compare(s2 *Source) int {
|
||||
return cmp.Compare(s2.Scope().Kind(), s.Scope().Kind())
|
||||
}
|
||||
|
||||
// Close closes the [Source] and the underlying [Store].
|
||||
func (s *Source) Close() error {
|
||||
// The [Reader], if any, owns the [Store].
|
||||
if reader, _ := s.lazyReader.GetErr(func() (*Reader, error) { return nil, ErrStoreClosed }); reader != nil {
|
||||
return reader.Close()
|
||||
}
|
||||
// Otherwise, it is our responsibility to close it.
|
||||
if closer, ok := s.store.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
438
util/syspolicy/source/policy_store_windows.go
Normal file
438
util/syspolicy/source/policy_store_windows.go
Normal file
@@ -0,0 +1,438 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/winutil/gp"
|
||||
)
|
||||
|
||||
const (
|
||||
softwareKeyName = "Software"
|
||||
tsPoliciesSubkey = `Policies\Tailscale`
|
||||
tsIPNSubkey = "Tailscale IPN" // the legacy key we need to fallback to
|
||||
)
|
||||
|
||||
var (
|
||||
// [PlatformPolicyStore] implements [Store].
|
||||
_ Store = (*PlatformPolicyStore)(nil)
|
||||
)
|
||||
|
||||
// PlatformPolicyStore implements [Store] by providing read access to the Registry-based
|
||||
// Tailscale policies, such as those configured via Group Policy or MDM. It is
|
||||
// recommended to lock it when reading multiple policy values in a row. It also
|
||||
// allows subscribing to notifications when there's a policy change.
|
||||
type PlatformPolicyStore struct {
|
||||
scope gp.Scope // [gp.MachinePolicy] or [gp.UserPolicy]
|
||||
|
||||
// The softwareKey can be HKLM\Software, HKCU\Software, or
|
||||
// HKU\{SID}\Software. Anything below the Software subkey, including
|
||||
// Software\Policies, may not yet exist or could be deleted throughout the
|
||||
// [PlatformPolicyStore]'s lifespan, invalidating the handle. We also prefer
|
||||
// to always use a real registry key (rather than a predefined HKLM or HKCU)
|
||||
// to simplify bookkeeping (predefined keys should never be closed).
|
||||
// Finally, this will allow us to watch for any registry changes directly
|
||||
// should we need this in the future in addition to gp.ChangeWatcher.
|
||||
softwareKey registry.Key
|
||||
watcher *gp.ChangeWatcher
|
||||
|
||||
done chan struct{} // done is closed when Close call completes
|
||||
|
||||
// The policyLock can be locked by the caller when reading multiple policy settings
|
||||
// to prevent the Group Policy Client service from modifying policies while
|
||||
// they are being read.
|
||||
//
|
||||
// When both policyLock and mu need to be taken, mu must be taken before policyLock.
|
||||
policyLock *gp.PolicyLock
|
||||
|
||||
mu sync.RWMutex
|
||||
tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked.
|
||||
cbs set.HandleSet[func()] // policy change callbacks
|
||||
lockCnt int
|
||||
locked sync.WaitGroup
|
||||
closing bool
|
||||
readable bool
|
||||
}
|
||||
|
||||
type registryValueGetter[T any] func(key registry.Key, name setting.Key) (T, error)
|
||||
|
||||
// NewMachinePlatformPolicyStore returns a new [PlatformPolicyStore] for the machine.
|
||||
func NewMachinePlatformPolicyStore() (*PlatformPolicyStore, error) {
|
||||
softwareKey, err := registry.OpenKey(registry.LOCAL_MACHINE, softwareKeyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
|
||||
}
|
||||
return newPlatformPolicyStore(gp.MachinePolicy, softwareKey, 0)
|
||||
}
|
||||
|
||||
// NewUserPlatformPolicyStore returns a new [PlatformPolicyStore] for the user specified by its token.
|
||||
// User's profile must be loaded, and the token handle must have [windows.TOKEN_QUERY]
|
||||
// access. The caller retains ownership of the token.
|
||||
func NewUserPlatformPolicyStore(token windows.Token) (*PlatformPolicyStore, error) {
|
||||
var err error
|
||||
var softwareKey registry.Key
|
||||
if token != 0 {
|
||||
var user *windows.Tokenuser
|
||||
if user, err = token.GetTokenUser(); err != nil {
|
||||
return nil, fmt.Errorf("failed to get token user: %w", err)
|
||||
}
|
||||
userSid := user.User.Sid
|
||||
softwareKey, err = registry.OpenKey(registry.USERS, userSid.String()+`\`+softwareKeyName, windows.KEY_READ)
|
||||
} else {
|
||||
softwareKey, err = registry.OpenKey(registry.CURRENT_USER, softwareKeyName, windows.KEY_READ)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
|
||||
}
|
||||
return newPlatformPolicyStore(gp.UserPolicy, softwareKey, token)
|
||||
}
|
||||
|
||||
func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, token windows.Token) (_ *PlatformPolicyStore, err error) {
|
||||
store := &PlatformPolicyStore{
|
||||
scope: scope,
|
||||
softwareKey: softwareKey,
|
||||
done: make(chan struct{}),
|
||||
readable: true,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
store.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
switch scope {
|
||||
case gp.MachinePolicy:
|
||||
store.policyLock = gp.NewMachinePolicyLock()
|
||||
case gp.UserPolicy:
|
||||
if store.policyLock, err = gp.NewUserPolicyLock(token); err != nil {
|
||||
return nil, fmt.Errorf("failed to create a user policy lock: %w", err)
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// Lock locks the policy store, preventing the system from modifying the policies
|
||||
// while they are being read. It is a read lock that may be acquired by multiple goroutines.
|
||||
// Each Lock call must be balanced by exactly one Unlock call.
|
||||
func (ps *PlatformPolicyStore) Lock() (err error) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
if ps.closing {
|
||||
return ErrStoreClosed
|
||||
}
|
||||
|
||||
ps.lockCnt += 1
|
||||
if ps.lockCnt != 1 {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.lockCnt -= 1
|
||||
}
|
||||
}()
|
||||
|
||||
// Ensure ps remains open while the lock is held.
|
||||
ps.locked.Add(1)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.locked.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
// Acquire the GP lock to prevent the system from modifying policy settings
|
||||
// while they are being read.
|
||||
if err := ps.policyLock.Lock(); err != nil {
|
||||
if errors.Is(err, gp.ErrInvalidLockState) {
|
||||
return ErrStoreClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.policyLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Keep the Tailscale's registry keys open for the duration of the lock.
|
||||
keyNames := tailscaleKeyNamesFor(ps.scope)
|
||||
ps.tsKeys = make([]registry.Key, 0, len(keyNames))
|
||||
for _, keyName := range keyNames {
|
||||
var tsKey registry.Key
|
||||
tsKey, err = registry.OpenKey(ps.softwareKey, keyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
ps.tsKeys = append(ps.tsKeys, tsKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock decrements the lock counter and unlocks the policy store once the counter reaches 0.
|
||||
// It panics if ps is not locked on entry to Unlock.
|
||||
func (ps *PlatformPolicyStore) Unlock() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
ps.lockCnt -= 1
|
||||
if ps.lockCnt < 0 {
|
||||
panic("negative lockCnt")
|
||||
} else if ps.lockCnt != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range ps.tsKeys {
|
||||
key.Close()
|
||||
}
|
||||
ps.tsKeys = nil
|
||||
ps.policyLock.Unlock()
|
||||
ps.locked.Done()
|
||||
}
|
||||
|
||||
// RegisterChangeCallback adds a function that will be called whenever there's a policy change.
|
||||
// It returns a function that needs to be called to unregister the specified callback or an error.
|
||||
// The error is [ErrStoreClosed] if ps has already been closed.
|
||||
func (ps *PlatformPolicyStore) RegisterChangeCallback(cb func()) (unregister func(), err error) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
if ps.closing {
|
||||
return nil, ErrStoreClosed
|
||||
}
|
||||
|
||||
handle := ps.cbs.Add(cb)
|
||||
if len(ps.cbs) == 1 {
|
||||
if ps.watcher, err = gp.NewChangeWatcher(ps.scope, ps.onChange); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return func() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
delete(ps.cbs, handle)
|
||||
if len(ps.cbs) == 0 {
|
||||
if ps.watcher != nil {
|
||||
ps.watcher.Close()
|
||||
ps.watcher = nil
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ps *PlatformPolicyStore) onChange() {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
if ps.closing {
|
||||
return
|
||||
}
|
||||
for _, callback := range ps.cbs {
|
||||
go callback()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadString retrieves a string policy with the specified name.
|
||||
// It returns [ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadString(name setting.Key) (val string, err error) {
|
||||
return getPolicyValue(ps, canonicalizeValueName(name),
|
||||
func(key registry.Key, name setting.Key) (string, error) {
|
||||
val, _, err := key.GetStringValue(string(name))
|
||||
return val, err
|
||||
})
|
||||
}
|
||||
|
||||
// ReadUInt64 retrieves an integer policy with the specified name.
|
||||
// It returns [ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadUInt64(name setting.Key) (uint64, error) {
|
||||
return getPolicyValue(ps, canonicalizeValueName(name),
|
||||
func(key registry.Key, name setting.Key) (uint64, error) {
|
||||
val, _, err := key.GetIntegerValue(string(name))
|
||||
return val, err
|
||||
})
|
||||
}
|
||||
|
||||
// ReadBoolean retrieves a boolean policy with the specified name.
|
||||
// It returns [ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadBoolean(name setting.Key) (bool, error) {
|
||||
return getPolicyValue(ps, canonicalizeValueName(name),
|
||||
func(key registry.Key, name setting.Key) (bool, error) {
|
||||
val, _, err := key.GetIntegerValue(string(name))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return val != 0, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ReadString retrieves a multi-string policy with the specified name.
|
||||
// It returns [ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadStringArray(name setting.Key) ([]string, error) {
|
||||
return getPolicyValue(ps, name,
|
||||
func(key registry.Key, name setting.Key) ([]string, error) {
|
||||
val, _, err := key.GetStringsValue(string(canonicalizeValueName(name)))
|
||||
if err != registry.ErrNotExist {
|
||||
return val, err
|
||||
}
|
||||
|
||||
// The idiomatic way to store multiple string values in Group Policy
|
||||
// and MDM for Windows is to have multiple REG_SZ (or REG_EXPAND_SZ)
|
||||
// values under a subkey rather than in a single REG_MULTI_SZ value.
|
||||
//
|
||||
// See the Group Policy: Registry Extension Encoding specification,
|
||||
// and specifically the ListElement and ListBox types.
|
||||
// https://web.archive.org/web/20240721033657/https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GPREG/%5BMS-GPREG%5D.pdf
|
||||
valKey, err := registry.OpenKey(key, string(canonicalizeKeyName(name)), windows.KEY_READ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valNames, err := valKey.ReadValueNames(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val = make([]string, 0, len(valNames))
|
||||
for _, name := range valNames {
|
||||
switch item, _, err := valKey.GetStringValue(name); {
|
||||
case err == registry.ErrNotExist:
|
||||
continue
|
||||
case err != nil:
|
||||
return nil, err
|
||||
default:
|
||||
val = append(val, item)
|
||||
}
|
||||
}
|
||||
return val, nil
|
||||
})
|
||||
}
|
||||
|
||||
func canonicalizeKeyName(name setting.Key) setting.Key {
|
||||
return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `\`))
|
||||
}
|
||||
|
||||
func canonicalizeValueName(name setting.Key) setting.Key {
|
||||
return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `_`))
|
||||
}
|
||||
|
||||
func getPolicyValue[T any](ps *PlatformPolicyStore, name setting.Key, getter registryValueGetter[T]) (T, error) {
|
||||
var zero T
|
||||
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
if !ps.readable {
|
||||
return zero, setting.ErrNotConfigured
|
||||
}
|
||||
|
||||
if ps.tsKeys != nil {
|
||||
// A non-nil tsKeys indicates that ps has been locked.
|
||||
// It may be empty if Tailscale policy keys do not exist.
|
||||
for _, tsKey := range ps.tsKeys {
|
||||
val, err := getter(tsKey, name)
|
||||
if err == nil || err != registry.ErrNotExist {
|
||||
return val, err
|
||||
}
|
||||
}
|
||||
return zero, setting.ErrNotConfigured
|
||||
}
|
||||
|
||||
// The ps has not been locked, so we don't have any pre-opened keys.
|
||||
for _, tsKeyName := range tailscaleKeyNamesFor(ps.scope) {
|
||||
var tsKey registry.Key
|
||||
tsKey, err := registry.OpenKey(ps.softwareKey, tsKeyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
continue
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
defer tsKey.Close()
|
||||
|
||||
val, err := getter(tsKey, name)
|
||||
if err == nil || err != registry.ErrNotExist {
|
||||
return val, err
|
||||
}
|
||||
}
|
||||
|
||||
return zero, setting.ErrNotConfigured
|
||||
}
|
||||
|
||||
// Close closes the policy store and releases any associated resources.
|
||||
// It cancels pending locks and prevents any new lock attempts,
|
||||
// but waits for existing locks to be released.
|
||||
func (ps *PlatformPolicyStore) Close() error {
|
||||
// Request to close the Group Policy read lock.
|
||||
// Existing held locks will remain valid, but any new or pending locks
|
||||
// will fail. In certain scenarios, the corresponding write lock may be held
|
||||
// by the Group Policy service for extended periods (minutes rather than
|
||||
// seconds or milliseconds). In such cases, we prefer not to wait that long
|
||||
// if the ps is being closed anyway.
|
||||
if ps.policyLock != nil {
|
||||
ps.policyLock.Close()
|
||||
}
|
||||
|
||||
// Signal to the external code that ps should no longer be used.
|
||||
close(ps.done)
|
||||
|
||||
// Mark ps as closing to fast-fail any new lock attempts.
|
||||
// Callers that have already locked it can finish their reading.
|
||||
ps.mu.Lock()
|
||||
if ps.closing {
|
||||
ps.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
ps.closing = true
|
||||
if ps.watcher != nil {
|
||||
ps.watcher.Close()
|
||||
ps.watcher = nil
|
||||
}
|
||||
ps.mu.Unlock()
|
||||
|
||||
// Wait for any outstanding locks to be released.
|
||||
ps.locked.Wait()
|
||||
|
||||
// Deny any further read attempts and release remaining resources.
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
ps.cbs = nil
|
||||
ps.policyLock = nil
|
||||
ps.readable = false
|
||||
if ps.softwareKey != 0 {
|
||||
ps.softwareKey.Close()
|
||||
ps.softwareKey = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the Close method is called.
|
||||
func (ps *PlatformPolicyStore) Done() <-chan struct{} {
|
||||
return ps.done
|
||||
}
|
||||
|
||||
func tailscaleKeyNamesFor(scope gp.Scope) []string {
|
||||
switch scope {
|
||||
case gp.MachinePolicy:
|
||||
// If a computer-side policy value does not exist under Software\Policies\Tailscale,
|
||||
// we need to fallback and use the legacy Software\Tailscale IPN key.
|
||||
return []string{tsPoliciesSubkey, tsIPNSubkey}
|
||||
case gp.UserPolicy:
|
||||
// However, we've never used the legacy key with user-side policies,
|
||||
// and we should never do so. Unlike HKLM\Software\Tailscale IPN,
|
||||
// its HKCU counterpart is user-writable.
|
||||
return []string{tsPoliciesSubkey}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
298
util/syspolicy/source/policy_store_windows_test.go
Normal file
298
util/syspolicy/source/policy_store_windows_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"tailscale.com/util/cibuild"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/winutil"
|
||||
"tailscale.com/util/winutil/gp"
|
||||
)
|
||||
|
||||
type testPolicyValue struct {
|
||||
name setting.Key
|
||||
value any
|
||||
}
|
||||
|
||||
func TestLockUnlockPolicyStore(t *testing.T) {
|
||||
store, err := NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
t.Fatalf("NewMachinePolicyStore failed: %v", err)
|
||||
}
|
||||
|
||||
t.Run("One-Goroutine", func(t *testing.T) {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("store.Lock(): got %v; want nil", err)
|
||||
return
|
||||
}
|
||||
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
|
||||
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
|
||||
}
|
||||
store.Unlock()
|
||||
})
|
||||
|
||||
// Lock the store N times from different goroutines.
|
||||
const N = 100
|
||||
var unlocked atomic.Int32
|
||||
t.Run("N-Goroutines", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(N)
|
||||
for range N {
|
||||
go func() {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("store.Lock(): got %v; want nil", err)
|
||||
return
|
||||
}
|
||||
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
|
||||
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
|
||||
}
|
||||
wg.Done()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
unlocked.Add(1)
|
||||
store.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait until the store is locked N times.
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
// Close the store. The call should wait for all held locks to be released.
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
if locked := unlocked.Load(); locked != N {
|
||||
t.Errorf("locked.Load(): got %v; want %v", locked, N)
|
||||
}
|
||||
|
||||
// Any further attempts to lock it should fail.
|
||||
if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) {
|
||||
t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPolicyStore(t *testing.T) {
|
||||
if !winutil.IsCurrentProcessElevated() {
|
||||
t.Skipf("test requires running as elevated user")
|
||||
}
|
||||
tests := []struct {
|
||||
name setting.Key
|
||||
newValue any
|
||||
legacyValue any
|
||||
want any
|
||||
}{
|
||||
{name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"},
|
||||
{name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"},
|
||||
{name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""},
|
||||
{name: "BoolPolicy_True", newValue: true, want: true},
|
||||
{name: "BoolPolicy_False", newValue: false, want: false},
|
||||
{name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64
|
||||
{name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)},
|
||||
{name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
|
||||
{name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}},
|
||||
}
|
||||
|
||||
runTests := func(t *testing.T, userStore bool, token windows.Token) {
|
||||
var hive registry.Key
|
||||
if userStore {
|
||||
hive = registry.CURRENT_USER
|
||||
} else {
|
||||
hive = registry.LOCAL_MACHINE
|
||||
}
|
||||
|
||||
// Write policy values to the registry.
|
||||
newValues := make([]testPolicyValue, 0, len(tests))
|
||||
for _, tt := range tests {
|
||||
if tt.newValue != nil {
|
||||
newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue})
|
||||
}
|
||||
}
|
||||
policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey
|
||||
cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues)
|
||||
if err != nil {
|
||||
t.Fatalf("createTestPolicyValues failed: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// Write legacy policy values to the registry.
|
||||
legacyValues := make([]testPolicyValue, 0, len(tests))
|
||||
for _, tt := range tests {
|
||||
if tt.legacyValue != nil {
|
||||
legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue})
|
||||
}
|
||||
}
|
||||
legacyKeyName := softwareKeyName + `\` + tsIPNSubkey
|
||||
cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues)
|
||||
if err != nil {
|
||||
t.Fatalf("createTestPolicyValues failed: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
var store *PlatformPolicyStore
|
||||
if userStore {
|
||||
store, err = NewUserPlatformPolicyStore(token)
|
||||
} else {
|
||||
store, err = NewMachinePlatformPolicyStore()
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NewXPolicyStore failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := store.Close(); err != nil {
|
||||
t.Errorf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry.
|
||||
testReadValues := func(t *testing.T, withLocks bool) {
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.name), func(t *testing.T) {
|
||||
if userStore && tt.newValue == nil {
|
||||
t.Skip("there is no legacy policies for users")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
if withLocks {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("failed to acquire the lock: %v", err)
|
||||
}
|
||||
defer store.Unlock()
|
||||
}
|
||||
|
||||
var got any
|
||||
var err error
|
||||
switch tt.want.(type) {
|
||||
case string:
|
||||
got, err = store.ReadString(tt.name)
|
||||
case uint64:
|
||||
got, err = store.ReadUInt64(tt.name)
|
||||
case bool:
|
||||
got, err = store.ReadBoolean(tt.name)
|
||||
case []string:
|
||||
got, err = store.ReadStringArray(tt.name)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("got %v; want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
t.Run("NoLock", func(t *testing.T) {
|
||||
testReadValues(t, false)
|
||||
})
|
||||
|
||||
t.Run("WithLock", func(t *testing.T) {
|
||||
testReadValues(t, true)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("MachineStore", func(t *testing.T) {
|
||||
runTests(t, false, 0)
|
||||
})
|
||||
|
||||
t.Run("CurrentUserStore", func(t *testing.T) {
|
||||
runTests(t, true, 0)
|
||||
})
|
||||
|
||||
t.Run("UserStoreWithToken", func(t *testing.T) {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
|
||||
t.Fatalf("OpenProcessToken: %v", err)
|
||||
}
|
||||
defer token.Close()
|
||||
runTests(t, true, token)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicyStoreChangeNotifications(t *testing.T) {
|
||||
if cibuild.On() {
|
||||
t.Skipf("test requires running on a real Windows environment")
|
||||
}
|
||||
store, err := NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
t.Fatalf("NewMachinePolicyStore failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := store.Close(); err != nil {
|
||||
t.Errorf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
unregister, err := store.RegisterChangeCallback(func() { close(done) })
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterChangeCallback failed: %v", err)
|
||||
}
|
||||
t.Cleanup(unregister)
|
||||
|
||||
// RefreshMachinePolicy is a non-blocking call.
|
||||
if err := gp.RefreshMachinePolicy(true); err != nil {
|
||||
t.Fatalf("RefreshMachinePolicy failed: %v", err)
|
||||
}
|
||||
|
||||
// We should receive a policy change notification when
|
||||
// the Group Policy service completes policy processing.
|
||||
// Otherwise, the test will eventually time out.
|
||||
<-done
|
||||
}
|
||||
|
||||
func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) {
|
||||
key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
doCleanup := func() {
|
||||
for _, v := range values {
|
||||
key.DeleteValue(string(v.name))
|
||||
}
|
||||
key.Close()
|
||||
if !existing {
|
||||
registry.DeleteKey(hive, keyName)
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
doCleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, v := range values {
|
||||
switch value := v.value.(type) {
|
||||
case string:
|
||||
err = key.SetStringValue(string(v.name), value)
|
||||
case uint32:
|
||||
err = key.SetDWordValue(string(v.name), value)
|
||||
case uint64:
|
||||
err = key.SetQWordValue(string(v.name), value)
|
||||
case bool:
|
||||
if value {
|
||||
err = key.SetDWordValue(string(v.name), 1)
|
||||
} else {
|
||||
err = key.SetDWordValue(string(v.name), 0)
|
||||
}
|
||||
case []string:
|
||||
err = key.SetStringsValue(string(v.name), value)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return doCleanup, nil
|
||||
}
|
||||
446
util/syspolicy/source/test_store.go
Normal file
446
util/syspolicy/source/test_store.go
Normal file
@@ -0,0 +1,446 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
var _ Store = (*TestStore)(nil)
|
||||
|
||||
// TestValueType is a constraint that allows types supported by [TestStore].
|
||||
type TestValueType interface {
|
||||
bool | uint64 | string | []string
|
||||
}
|
||||
|
||||
// TestSetting is a policy setting in a [TestStore].
|
||||
type TestSetting[T TestValueType] struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Error is the error to be returned by the [TestStore] when reading
|
||||
// a policy setting with the specified key.
|
||||
Error error
|
||||
// Value is the value to be returned by the [TestStore] when reading
|
||||
// a policy setting with the specified key.
|
||||
// It is only used if the Error is nil.
|
||||
Value T
|
||||
}
|
||||
|
||||
// TestSettingOf returns a [TestSetting] representing a policy setting
|
||||
// configured with the specified key and value.
|
||||
func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] {
|
||||
return TestSetting[T]{Key: key, Value: value}
|
||||
}
|
||||
|
||||
// TestSettingWithError returns a [TestSetting] representing a policy setting
|
||||
// with the specified key and error.
|
||||
func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] {
|
||||
return TestSetting[T]{Key: key, Error: err}
|
||||
}
|
||||
|
||||
// testReadOperation describes a single policy setting read operation.
|
||||
type testReadOperation struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Type is a value type of a read operation.
|
||||
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
||||
Type setting.Type
|
||||
}
|
||||
|
||||
// TestExpectedReads is the number of read operations with the specified details.
|
||||
type TestExpectedReads struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Type is a value type of a read operation.
|
||||
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
||||
Type setting.Type
|
||||
// NumTimes is how many times a setting with the specified key and type should have been read.
|
||||
NumTimes int
|
||||
}
|
||||
|
||||
func (r TestExpectedReads) operation() testReadOperation {
|
||||
return testReadOperation{r.Key, r.Type}
|
||||
}
|
||||
|
||||
// TestStore is a [Store] that can be used in tests.
|
||||
type TestStore struct {
|
||||
tb internal.TB
|
||||
|
||||
done chan struct{}
|
||||
|
||||
storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock].
|
||||
storeLockCount atomic.Int32
|
||||
|
||||
mu sync.RWMutex
|
||||
suspendCount int // change callback are suspended if > 0
|
||||
mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
|
||||
cbs set.HandleSet[func()]
|
||||
|
||||
readsMu sync.Mutex
|
||||
reads map[testReadOperation]int // how many times a policy setting was read
|
||||
}
|
||||
|
||||
// NewTestStore returns a new [TestStore].
|
||||
// The tb will be used to report coding errors detected by the [TestStore].
|
||||
func NewTestStore(tb internal.TB) *TestStore {
|
||||
m := make(map[setting.Key]any)
|
||||
return &TestStore{
|
||||
tb: tb,
|
||||
done: make(chan struct{}),
|
||||
mr: m,
|
||||
mw: m,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
|
||||
// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
|
||||
func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore {
|
||||
m := make(map[setting.Key]any)
|
||||
store := &TestStore{
|
||||
tb: tb,
|
||||
done: make(chan struct{}),
|
||||
mr: m,
|
||||
mw: m,
|
||||
}
|
||||
switch settings := any(settings).(type) {
|
||||
case []TestSetting[bool]:
|
||||
store.SetBooleans(settings...)
|
||||
case []TestSetting[uint64]:
|
||||
store.SetUInt64s(settings...)
|
||||
case []TestSetting[string]:
|
||||
store.SetStrings(settings...)
|
||||
case []TestSetting[[]string]:
|
||||
store.SetStringLists(settings...)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
// Lock implements [Store].
|
||||
func (s *TestStore) Lock() error {
|
||||
s.storeLock.RLock()
|
||||
s.storeLockCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock implements [Store].
|
||||
func (s *TestStore) Unlock() {
|
||||
if s.storeLockCount.Add(-1) < 0 {
|
||||
s.tb.Fatal("negative storeLockCount")
|
||||
}
|
||||
s.storeLock.RUnlock()
|
||||
}
|
||||
|
||||
// RegisterChangeCallback implements [Store].
|
||||
func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
handle := s.cbs.Add(callback)
|
||||
return func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.cbs, handle)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadString implements [Store].
|
||||
func (s *TestStore) ReadString(key setting.Key) (string, error) {
|
||||
defer s.recordRead(key, setting.StringValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return "", setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return "", err
|
||||
}
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ReadUInt64 implements [Store].
|
||||
func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) {
|
||||
defer s.recordRead(key, setting.IntegerValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return 0, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return 0, err
|
||||
}
|
||||
u64, ok := v.(uint64)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return u64, nil
|
||||
}
|
||||
|
||||
// ReadBoolean implements [Store].
|
||||
func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) {
|
||||
defer s.recordRead(key, setting.BooleanValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return false, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return false, err
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// ReadStringArray implements [Store].
|
||||
func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) {
|
||||
defer s.recordRead(key, setting.StringListValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return nil, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return nil, err
|
||||
}
|
||||
slice, ok := v.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (s *TestStore) recordRead(key setting.Key, typ setting.Type) {
|
||||
s.readsMu.Lock()
|
||||
op := testReadOperation{key, typ}
|
||||
num := s.reads[op]
|
||||
num++
|
||||
mak.Set(&s.reads, op, num)
|
||||
s.readsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *TestStore) ResetCounters() {
|
||||
s.readsMu.Lock()
|
||||
clear(s.reads)
|
||||
s.readsMu.Unlock()
|
||||
}
|
||||
|
||||
// ReadsMustEqual fails the test if the actual reads differs from the specified reads.
|
||||
func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
s.readsMu.Lock()
|
||||
defer s.readsMu.Unlock()
|
||||
s.readsMustContainLocked(reads...)
|
||||
s.readMustNoExtraLocked(reads...)
|
||||
}
|
||||
|
||||
// ReadsMustContain fails the test if the specified reads have not been made,
|
||||
// or have been made a different number of times. It permits other values to be
|
||||
// read in addition to the ones being tested.
|
||||
func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
s.readsMu.Lock()
|
||||
defer s.readsMu.Unlock()
|
||||
s.readsMustContainLocked(reads...)
|
||||
}
|
||||
|
||||
func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
for _, r := range reads {
|
||||
if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes {
|
||||
s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
rs := make(set.Set[testReadOperation])
|
||||
for i := range reads {
|
||||
rs.Add(reads[i].operation())
|
||||
}
|
||||
for ro, num := range s.reads {
|
||||
if !rs.Contains(ro) {
|
||||
s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Suspend suspends the store, batching changes and notifications
|
||||
// until [TestStore.Resume] is called the same number of times as Suspend.
|
||||
func (s *TestStore) Suspend() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.suspendCount++; s.suspendCount == 1 {
|
||||
s.mw = xmaps.Clone(s.mr)
|
||||
}
|
||||
}
|
||||
|
||||
// Resume resumes the store, applying the changes and invoking
|
||||
// the change callbacks.
|
||||
func (s *TestStore) Resume() {
|
||||
s.storeLock.Lock()
|
||||
s.mu.Lock()
|
||||
switch s.suspendCount--; {
|
||||
case s.suspendCount == 0:
|
||||
s.mr = s.mw
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
case s.suspendCount < 0:
|
||||
s.tb.Fatal("negative suspendCount")
|
||||
default:
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// SetBooleans sets the specified boolean settings in s.
|
||||
func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetUInt64s sets the specified integer settings in s.
|
||||
func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetStrings sets the specified string settings in s.
|
||||
func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetStrings sets the specified string list settings in s.
|
||||
func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// Delete deletes the specified settings from s.
|
||||
func (s *TestStore) Delete(keys ...setting.Key) {
|
||||
s.storeLock.Lock()
|
||||
for _, key := range keys {
|
||||
s.mu.Lock()
|
||||
delete(s.mw, key)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// Clear deletes all settings from s.
|
||||
func (s *TestStore) Clear() {
|
||||
s.storeLock.Lock()
|
||||
s.mu.Lock()
|
||||
clear(s.mw)
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
func (s *TestStore) notifyPolicyChanged() {
|
||||
s.mu.RLock()
|
||||
if s.suspendCount != 0 {
|
||||
s.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
cbs := xmaps.Values(s.cbs)
|
||||
s.mu.RUnlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(cbs))
|
||||
for _, cb := range cbs {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cb()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Close closes s, notifying its users that it has expired.
|
||||
func (s *TestStore) Close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.done != nil {
|
||||
close(s.done)
|
||||
s.done = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Done implements [Store].
|
||||
func (s *TestStore) Done() <-chan struct{} {
|
||||
return s.done
|
||||
}
|
||||
@@ -1,122 +1,83 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package syspolicy provides functions to retrieve system settings of a device.
|
||||
// Package syspolicy facilitates retrieval of the current policy settings
|
||||
// applied to the device or user and receiving notifications when the policy
|
||||
// changes.
|
||||
//
|
||||
// It provides functions that return specific policy settings by their unique
|
||||
// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString],
|
||||
// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration].
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotConfigured is returned when the requested policy setting is not configured.
|
||||
ErrNotConfigured = setting.ErrNotConfigured
|
||||
// ErrTypeMismatch is returned when there's a type mismatch between the actual type
|
||||
// of the setting value and the expected type.
|
||||
ErrTypeMismatch = setting.ErrTypeMismatch
|
||||
// ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting
|
||||
// has been registered with the specified key.
|
||||
//
|
||||
// Until 2024-08-02, this error was also returned by a [Handler] when the specified
|
||||
// key did not have a value set. While the package maintains compatibility with this
|
||||
// usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
|
||||
// [source.Store] implementations.
|
||||
ErrNoSuchKey = setting.ErrNoSuchKey
|
||||
)
|
||||
|
||||
// GetString returns a string policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetString(key Key, defaultValue string) (string, error) {
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadString(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
}
|
||||
|
||||
// GetUint64 returns a numeric policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetUint64(key Key, defaultValue uint64) (uint64, error) {
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadUInt64(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
}
|
||||
|
||||
// GetBoolean returns a boolean policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetBoolean(key Key, defaultValue bool) (bool, error) {
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadBoolean(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
}
|
||||
|
||||
// GetStringArray returns a multi-string policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetStringArray(key Key, defaultValue []string) ([]string, error) {
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadStringArray(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
}
|
||||
|
||||
// PreferenceOption is a policy that governs whether a boolean variable
|
||||
// is forcibly assigned an administrator-defined value, or allowed to receive
|
||||
// a user-defined value.
|
||||
type PreferenceOption int
|
||||
|
||||
const (
|
||||
showChoiceByPolicy PreferenceOption = iota
|
||||
neverByPolicy
|
||||
alwaysByPolicy
|
||||
type (
|
||||
// PreferenceOption is a policy that governs whether a boolean variable
|
||||
// is forcibly assigned an administrator-defined value, or allowed to receive
|
||||
// a user-defined value.
|
||||
PreferenceOption = setting.PreferenceOption
|
||||
// Visibility is a policy that controls whether or not a particular
|
||||
// component of a user interface is to be shown.
|
||||
Visibility = setting.Visibility
|
||||
)
|
||||
|
||||
// Show returns if the UI option that controls the choice administered by this
|
||||
// policy should be shown. Currently this is true if and only if the policy is
|
||||
// showChoiceByPolicy.
|
||||
func (p PreferenceOption) Show() bool {
|
||||
return p == showChoiceByPolicy
|
||||
}
|
||||
|
||||
// ShouldEnable checks if the choice administered by this policy should be
|
||||
// enabled. If the administrator has chosen a setting, the administrator's
|
||||
// setting is returned, otherwise userChoice is returned.
|
||||
func (p PreferenceOption) ShouldEnable(userChoice bool) bool {
|
||||
switch p {
|
||||
case neverByPolicy:
|
||||
return false
|
||||
case alwaysByPolicy:
|
||||
return true
|
||||
default:
|
||||
return userChoice
|
||||
}
|
||||
}
|
||||
|
||||
// WillOverride checks if the choice administered by the policy is different
|
||||
// from the user's choice.
|
||||
func (p PreferenceOption) WillOverride(userChoice bool) bool {
|
||||
return p.ShouldEnable(userChoice) != userChoice
|
||||
}
|
||||
|
||||
// GetPreferenceOption loads a policy from the registry that can be
|
||||
// managed by an enterprise policy management system and allows administrative
|
||||
// overrides of users' choices in a way that we do not want tailcontrol to have
|
||||
// the authority to set. It describes user-decides/always/never options, where
|
||||
// "always" and "never" remove the user's ability to make a selection. If not
|
||||
// present or set to a different value, "user-decides" is the default.
|
||||
func GetPreferenceOption(name Key) (PreferenceOption, error) {
|
||||
opt, err := GetString(name, "user-decides")
|
||||
if err != nil {
|
||||
return showChoiceByPolicy, err
|
||||
}
|
||||
switch opt {
|
||||
case "always":
|
||||
return alwaysByPolicy, nil
|
||||
case "never":
|
||||
return neverByPolicy, nil
|
||||
default:
|
||||
return showChoiceByPolicy, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Visibility is a policy that controls whether or not a particular
|
||||
// component of a user interface is to be shown.
|
||||
type Visibility byte
|
||||
|
||||
const (
|
||||
visibleByPolicy Visibility = 'v'
|
||||
hiddenByPolicy Visibility = 'h'
|
||||
)
|
||||
|
||||
// Show reports whether the UI option administered by this policy should be shown.
|
||||
// Currently this is true if and only if the policy is visibleByPolicy.
|
||||
func (p Visibility) Show() bool {
|
||||
return p == visibleByPolicy
|
||||
func GetPreferenceOption(name Key) (setting.PreferenceOption, error) {
|
||||
return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy)
|
||||
}
|
||||
|
||||
// GetVisibility loads a policy from the registry that can be managed
|
||||
@@ -124,17 +85,8 @@ func (p Visibility) Show() bool {
|
||||
// for UI elements. The registry value should be a string set to "show" (return
|
||||
// true) or "hide" (return true). If not present or set to a different value,
|
||||
// "show" (return false) is the default.
|
||||
func GetVisibility(name Key) (Visibility, error) {
|
||||
opt, err := GetString(name, "show")
|
||||
if err != nil {
|
||||
return visibleByPolicy, err
|
||||
}
|
||||
switch opt {
|
||||
case "hide":
|
||||
return hiddenByPolicy, nil
|
||||
default:
|
||||
return visibleByPolicy, nil
|
||||
}
|
||||
func GetVisibility(name Key) (setting.Visibility, error) {
|
||||
return getCurrentPolicySettingValue(name, setting.VisibleByPolicy)
|
||||
}
|
||||
|
||||
// GetDuration loads a policy from the registry that can be managed
|
||||
@@ -143,15 +95,48 @@ func GetVisibility(name Key) (Visibility, error) {
|
||||
// understands. If the registry value is "" or can not be processed,
|
||||
// defaultValue is returned instead.
|
||||
func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) {
|
||||
opt, err := GetString(name, "")
|
||||
if opt == "" || err != nil {
|
||||
return defaultValue, err
|
||||
d, err := getCurrentPolicySettingValue(name, defaultValue)
|
||||
if err != nil {
|
||||
return d, err
|
||||
}
|
||||
v, err := time.ParseDuration(opt)
|
||||
if err != nil || v < 0 {
|
||||
if d < 0 {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, nil
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// getCurrentPolicySettingValue returns the value of the policy setting
|
||||
// specified by its key from the [rsop.Policy] of the [CurrentScope]. It
|
||||
// returns def if the policy setting is not configured, or an error if it has
|
||||
// an error or could not be converted to the specified type T.
|
||||
func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) {
|
||||
resultant, err := rsop.PolicyFor(setting.CurrentScope())
|
||||
if err != nil {
|
||||
return def, err
|
||||
}
|
||||
value, err := resultant.Get().GetErr(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) {
|
||||
return def, nil
|
||||
}
|
||||
return def, err
|
||||
}
|
||||
if res, ok := value.(T); ok {
|
||||
return res, nil
|
||||
}
|
||||
return convertPolicySettingValueTo(value, def)
|
||||
}
|
||||
|
||||
func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) {
|
||||
// Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string
|
||||
// if someone requests a string instead of the actual setting's value.
|
||||
// TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests.
|
||||
if reflect.TypeFor[T]().Kind() == reflect.String {
|
||||
if str, ok := value.(fmt.Stringer); ok {
|
||||
return any(str.String()).(T), nil
|
||||
}
|
||||
}
|
||||
return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def)
|
||||
}
|
||||
|
||||
// SelectControlURL returns the ControlURL to use based on a value in
|
||||
|
||||
@@ -5,16 +5,24 @@ package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/internal/metrics"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
// testHandler encompasses all data types returned when testing any of the syspolicy
|
||||
// methods that involve getting a policy value.
|
||||
// For keys and the corresponding values, check policy_keys.go.
|
||||
type testHandler struct {
|
||||
t *testing.T
|
||||
t testing.TB
|
||||
key Key
|
||||
s string
|
||||
u64 uint64
|
||||
@@ -28,7 +36,10 @@ var someOtherError = errors.New("error other than not found")
|
||||
|
||||
func (th *testHandler) ReadString(key string) (string, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadString(%q) want %q", key, th.key)
|
||||
// The syspolicy package now reads and caches all registered policy settings.
|
||||
// Therefore, it is expected to call the handler requesting all policies
|
||||
// rather than just the specific ones we asked for.
|
||||
return "", ErrNotConfigured
|
||||
}
|
||||
th.calls++
|
||||
return th.s, th.err
|
||||
@@ -36,7 +47,10 @@ func (th *testHandler) ReadString(key string) (string, error) {
|
||||
|
||||
func (th *testHandler) ReadUInt64(key string) (uint64, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadUint64(%q) want %q", key, th.key)
|
||||
// The syspolicy package now reads and caches all registered policy settings.
|
||||
// Therefore, it is expected to call the handler requesting all policies
|
||||
// rather than just the specific ones we asked for.
|
||||
return 0, ErrNotConfigured
|
||||
}
|
||||
th.calls++
|
||||
return th.u64, th.err
|
||||
@@ -44,7 +58,10 @@ func (th *testHandler) ReadUInt64(key string) (uint64, error) {
|
||||
|
||||
func (th *testHandler) ReadBoolean(key string) (bool, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadBool(%q) want %q", key, th.key)
|
||||
// The syspolicy package now reads and caches all registered policy settings.
|
||||
// Therefore, it is expected to call the handler requesting all policies
|
||||
// rather than just the specific ones we asked for.
|
||||
return false, ErrNotConfigured
|
||||
}
|
||||
th.calls++
|
||||
return th.b, th.err
|
||||
@@ -52,7 +69,10 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) {
|
||||
|
||||
func (th *testHandler) ReadStringArray(key string) ([]string, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadStringArray(%q) want %q", key, th.key)
|
||||
// The syspolicy package now reads and caches all registered policy settings.
|
||||
// Therefore, it is expected to call the handler requesting all policies
|
||||
// rather than just the specific ones we asked for.
|
||||
return nil, ErrNotConfigured
|
||||
}
|
||||
th.calls++
|
||||
return th.sArr, th.err
|
||||
@@ -67,23 +87,28 @@ func TestGetString(t *testing.T) {
|
||||
defaultValue string
|
||||
wantValue string
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "hide",
|
||||
wantValue: "hide",
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: EnableServerMode,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non-blank default",
|
||||
key: EnableServerMode,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
defaultValue: "test",
|
||||
wantValue: "test",
|
||||
wantError: nil,
|
||||
@@ -93,11 +118,17 @@ func TestGetString(t *testing.T) {
|
||||
key: NetworkDevicesVisibility,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_NetworkDevices_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@@ -105,12 +136,21 @@ func TestGetString(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetString(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,7 +167,7 @@ func TestGetUint64(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: KeyExpirationNoticeTime,
|
||||
key: LogSCMInteractions,
|
||||
handlerValue: 1,
|
||||
wantValue: 1,
|
||||
},
|
||||
@@ -135,14 +175,14 @@ func TestGetUint64(t *testing.T) {
|
||||
name: "read non-existing value",
|
||||
key: LogSCMInteractions,
|
||||
handlerValue: 0,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: 0,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non-zero default",
|
||||
key: LogSCMInteractions,
|
||||
defaultValue: 2,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: 2,
|
||||
},
|
||||
{
|
||||
@@ -155,14 +195,21 @@ func TestGetUint64(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
// None of the policy settings tested here are integers.
|
||||
// In fact, we don't have any integer policies as of 2024-07-29.
|
||||
// However, we can register each of them as an integer policy setting
|
||||
// for the duration of the test, providing us with something to test against.
|
||||
if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, WrapHandler(&testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
u64: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
}))
|
||||
value, err := GetUint64(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
@@ -181,32 +228,43 @@ func TestGetBoolean(t *testing.T) {
|
||||
defaultValue bool
|
||||
wantValue bool
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: FlushDNSOnSessionUnlock,
|
||||
handlerValue: true,
|
||||
wantValue: true,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: LogSCMInteractions,
|
||||
handlerValue: false,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: false,
|
||||
},
|
||||
{
|
||||
name: "reading value returns other error",
|
||||
key: FlushDNSOnSessionUnlock,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError,
|
||||
wantError: someOtherError, // expect error...
|
||||
defaultValue: true,
|
||||
wantValue: false,
|
||||
wantValue: true, // ...AND default value if the handler fails.
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@@ -214,12 +272,21 @@ func TestGetBoolean(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetBoolean(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -232,42 +299,61 @@ func TestGetPreferenceOption(t *testing.T) {
|
||||
handlerError error
|
||||
wantValue PreferenceOption
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "always by policy",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "always",
|
||||
wantValue: alwaysByPolicy,
|
||||
wantValue: setting.AlwaysByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "never by policy",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "never",
|
||||
wantValue: neverByPolicy,
|
||||
wantValue: setting.NeverByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "use default",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "",
|
||||
wantValue: showChoiceByPolicy,
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: EnableIncomingConnections,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: showChoiceByPolicy,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
},
|
||||
{
|
||||
name: "other error is returned",
|
||||
key: EnableIncomingConnections,
|
||||
handlerError: someOtherError,
|
||||
wantValue: showChoiceByPolicy,
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@@ -275,12 +361,21 @@ func TestGetPreferenceOption(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
option, err := GetPreferenceOption(tt.key)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if option != tt.wantValue {
|
||||
t.Errorf("option=%v, want %v", option, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -293,38 +388,53 @@ func TestGetVisibility(t *testing.T) {
|
||||
handlerError error
|
||||
wantValue Visibility
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "hidden by policy",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "hide",
|
||||
wantValue: hiddenByPolicy,
|
||||
wantValue: setting.HiddenByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "visibility default",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "show",
|
||||
wantValue: visibleByPolicy,
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "show",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: visibleByPolicy,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
},
|
||||
{
|
||||
name: "other error is returned",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "show",
|
||||
handlerError: someOtherError,
|
||||
wantValue: visibleByPolicy,
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@@ -332,12 +442,21 @@ func TestGetVisibility(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
visibility, err := GetVisibility(tt.key)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if visibility != tt.wantValue {
|
||||
t.Errorf("visibility=%v, want %v", visibility, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -351,6 +470,7 @@ func TestGetDuration(t *testing.T) {
|
||||
defaultValue time.Duration
|
||||
wantValue time.Duration
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
@@ -358,25 +478,34 @@ func TestGetDuration(t *testing.T) {
|
||||
handlerValue: "2h",
|
||||
wantValue: 2 * time.Hour,
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid duration value",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerValue: "-20",
|
||||
wantValue: 24 * time.Hour,
|
||||
wantError: errors.New(`time: missing unit in duration "-20"`),
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: 24 * time.Hour,
|
||||
defaultValue: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value different default",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: 0 * time.Second,
|
||||
defaultValue: 0 * time.Second,
|
||||
},
|
||||
@@ -387,11 +516,17 @@ func TestGetDuration(t *testing.T) {
|
||||
wantValue: 24 * time.Hour,
|
||||
wantError: someOtherError,
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@@ -399,12 +534,21 @@ func TestGetDuration(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
duration, err := GetDuration(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if fmt.Sprint(err) != fmt.Sprint(tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if duration != tt.wantValue {
|
||||
t.Errorf("duration=%v, want %v", duration, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -418,23 +562,28 @@ func TestGetStringArray(t *testing.T) {
|
||||
defaultValue []string
|
||||
wantValue []string
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerValue: []string{"foo", "bar"},
|
||||
wantValue: []string{"foo", "bar"},
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non nil default",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
defaultValue: []string{"foo", "bar"},
|
||||
wantValue: []string{"foo", "bar"},
|
||||
wantError: nil,
|
||||
@@ -444,11 +593,17 @@ func TestGetStringArray(t *testing.T) {
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@@ -456,16 +611,47 @@ func TestGetStringArray(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetStringArray(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if !slices.Equal(tt.wantValue, value) {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetString(b *testing.B) {
|
||||
loggerx.SetForTest(b, logger.Discard, logger.Discard)
|
||||
setWellKnownSettingsForTest(b)
|
||||
|
||||
store := source.NewTestStore(b)
|
||||
wantControlURL := "https://login.tailscale.com"
|
||||
store.SetStrings(source.TestSetting[string]{Key: ControlURL, Value: wantControlURL})
|
||||
|
||||
_, err := rsop.RegisterStoreForTest(b, "Test Store", setting.DeviceScope, store)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com")
|
||||
if gotControlURL != wantControlURL {
|
||||
b.Fatalf("got %v; want %v", gotControlURL, wantControlURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectControlURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
reg, disk, want string
|
||||
@@ -497,3 +683,13 @@ func TestSelectControlURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func errorsMatchForTest(got, want error) bool {
|
||||
if got == nil && want == nil {
|
||||
return true
|
||||
}
|
||||
if got == nil || want == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(got, want) || got.Error() == want.Error()
|
||||
}
|
||||
|
||||
93
util/syspolicy/syspolicy_windows.go
Normal file
93
util/syspolicy/syspolicy_windows.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/user"
|
||||
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// On Windows, we should automatically register the Registry-based policy
|
||||
// store for the device. If we are running in a user's security context
|
||||
// (e.g., we're the GUI), we should also register the Registry policy store for
|
||||
// the user. In the future, we should register (and unregister) user policy
|
||||
// stores whenever a user connects to the local backend. This ensures the
|
||||
// backend is aware of the user's policy settings and can send them to the
|
||||
// GUI/CLI/Web clients on demand or whenever they change.
|
||||
//
|
||||
// Other platforms, such as macOS, iOS and Android, should register their
|
||||
// platform-specific policy stores via [RegisterStore] (or [RegisterHandler]
|
||||
// until they implement the [Store] interface).
|
||||
//
|
||||
// External code, such as the ipnlocal package, may choose to register
|
||||
// additional policy stores, such as config files and policies received from
|
||||
// the control plane.
|
||||
lazyinit.Defer(func() error {
|
||||
// Do not register or use default policy stores during tests.
|
||||
// Each test should set up its own necessary configurations.
|
||||
if testenv.InTest() {
|
||||
return nil
|
||||
}
|
||||
return configureSyspolicy(nil)
|
||||
})
|
||||
}
|
||||
|
||||
// configureSyspolicy configures syspolicy for use on Windows,
|
||||
// either in test or regular builds depending on whether tb has a non-nil value.
|
||||
func configureSyspolicy(tb internal.TB) error {
|
||||
const localSystemSID = "S-1-5-18"
|
||||
// Always create and register a machine policy store that reads
|
||||
// policy settings from the HKEY_LOCAL_MACHINE registry hive.
|
||||
machineStore, err := source.NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create the machine policy store: %v", err)
|
||||
}
|
||||
if tb == nil {
|
||||
_, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore)
|
||||
} else {
|
||||
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check whether the current process is running as Local System or not.
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Uid == localSystemSID {
|
||||
return nil
|
||||
}
|
||||
// If it's not a Local System's process (e.g., the GUI and not the tailscaled service),
|
||||
// we should create and use a policy store for the current user that reads
|
||||
// policy settings from that user's registry hive (HKEY_CURRENT_USER).
|
||||
userStore, err := source.NewUserPlatformPolicyStore(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create the current user's policy store: %v", err)
|
||||
}
|
||||
if tb == nil {
|
||||
_, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore)
|
||||
} else {
|
||||
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// And also set [CurrentUserScope] as the [CurrentScope], so [GetString],
|
||||
// [GetVisibility] and similar functions would be returning a merged result
|
||||
// of the machine's and user's policies.
|
||||
if !setting.SetCurrentScope(setting.CurrentUserScope) {
|
||||
return errors.New("current scope already set")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -189,6 +189,7 @@ func (l *PolicyLock) lockSlow() (err error) {
|
||||
select {
|
||||
case resultCh <- policyLockResult{handle, err}:
|
||||
// lockSlow has received the result.
|
||||
break send_result
|
||||
default:
|
||||
select {
|
||||
case <-closing:
|
||||
|
||||
Reference in New Issue
Block a user