Compare commits
1 Commits
lp
...
andrew/key
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc4048014e |
@@ -40,7 +40,6 @@ import (
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/tkatype"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// defaultLocalClient is the default LocalClient when using the legacy
|
||||
@@ -815,33 +814,6 @@ func (lc *LocalClient) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn
|
||||
return decodeJSON[*ipn.Prefs](body)
|
||||
}
|
||||
|
||||
// GetEffectivePolicy returns the effective policy for the specified scope.
|
||||
func (lc *LocalClient) GetEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) {
|
||||
scopeID, err := scope.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err := lc.get200(ctx, "/localapi/v0/policy/"+string(scopeID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodeJSON[*setting.Snapshot](body)
|
||||
}
|
||||
|
||||
// ReloadEffectivePolicy reloads the effective policy for the specified scope
|
||||
// by reading and merging policy settings from all applicable policy sources.
|
||||
func (lc *LocalClient) ReloadEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) {
|
||||
scopeID, err := scope.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err := lc.send(ctx, "POST", "/localapi/v0/policy/"+string(scopeID), 200, http.NoBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodeJSON[*setting.Snapshot](body)
|
||||
}
|
||||
|
||||
// GetDNSOSConfig returns the system DNS configuration for the current device.
|
||||
// That is, it returns the DNS configuration that the system would use if Tailscale weren't being used.
|
||||
func (lc *LocalClient) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig, error) {
|
||||
|
||||
@@ -164,16 +164,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
tailscale.com/util/slicesx from tailscale.com/cmd/derper+
|
||||
tailscale.com/util/syspolicy from tailscale.com/ipn
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/testenv from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/usermetric from tailscale.com/health
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/version from tailscale.com/derp+
|
||||
tailscale.com/version/distro from tailscale.com/envknob+
|
||||
@@ -194,7 +189,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
|
||||
golang.org/x/crypto/sha3 from crypto/internal/mlkem768+
|
||||
W golang.org/x/exp/constraints from tailscale.com/util/winutil
|
||||
golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting+
|
||||
golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting
|
||||
L golang.org/x/net/bpf from github.com/mdlayher/netlink+
|
||||
golang.org/x/net/dns/dnsmessage from net+
|
||||
golang.org/x/net/http/httpguts from net/http
|
||||
@@ -255,7 +250,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
encoding/pem from crypto/tls+
|
||||
errors from bufio+
|
||||
expvar from github.com/prometheus/client_golang/prometheus+
|
||||
flag from tailscale.com/cmd/derper+
|
||||
flag from tailscale.com/cmd/derper
|
||||
fmt from compress/flate+
|
||||
go/token from google.golang.org/protobuf/internal/strs
|
||||
hash from crypto+
|
||||
@@ -289,7 +284,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
os from crypto/rand+
|
||||
os/exec from github.com/coreos/go-iptables/iptables+
|
||||
os/signal from tailscale.com/cmd/derper
|
||||
W os/user from tailscale.com/util/winutil+
|
||||
W os/user from tailscale.com/util/winutil
|
||||
path from github.com/prometheus/client_golang/prometheus/internal+
|
||||
path/filepath from crypto/x509+
|
||||
reflect from crypto/x509+
|
||||
|
||||
@@ -659,6 +659,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/control/controlclient from tailscale.com/ipn/ipnlocal+
|
||||
tailscale.com/control/controlhttp from tailscale.com/control/controlclient
|
||||
tailscale.com/control/controlknobs from tailscale.com/control/controlclient+
|
||||
tailscale.com/control/keyfallback from tailscale.com/control/controlclient
|
||||
tailscale.com/derp from tailscale.com/derp/derphttp+
|
||||
tailscale.com/derp/derphttp from tailscale.com/ipn/localapi+
|
||||
tailscale.com/disco from tailscale.com/derp+
|
||||
@@ -812,11 +813,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/util/slicesx from tailscale.com/appc+
|
||||
tailscale.com/util/syspolicy from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/util/systemd from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/testenv from tailscale.com/control/controlclient+
|
||||
@@ -826,7 +824,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
|
||||
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+
|
||||
|
||||
@@ -1896,182 +1896,6 @@ spec:
|
||||
Value is the taint value the toleration matches to.
|
||||
If the operator is Exists, the value should be empty, otherwise just a regular string.
|
||||
type: string
|
||||
topologySpreadConstraints:
|
||||
description: |-
|
||||
Proxy Pod's topology spread constraints.
|
||||
By default Tailscale Kubernetes operator does not apply any topology spread constraints.
|
||||
https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/
|
||||
type: array
|
||||
items:
|
||||
description: TopologySpreadConstraint specifies how to spread matching pods among the given topology.
|
||||
type: object
|
||||
required:
|
||||
- maxSkew
|
||||
- topologyKey
|
||||
- whenUnsatisfiable
|
||||
properties:
|
||||
labelSelector:
|
||||
description: |-
|
||||
LabelSelector is used to find matching pods.
|
||||
Pods that match this label selector are counted to determine the number of pods
|
||||
in their corresponding topology domain.
|
||||
type: object
|
||||
properties:
|
||||
matchExpressions:
|
||||
description: matchExpressions is a list of label selector requirements. The requirements are ANDed.
|
||||
type: array
|
||||
items:
|
||||
description: |-
|
||||
A label selector requirement is a selector that contains values, a key, and an operator that
|
||||
relates the key and values.
|
||||
type: object
|
||||
required:
|
||||
- key
|
||||
- operator
|
||||
properties:
|
||||
key:
|
||||
description: key is the label key that the selector applies to.
|
||||
type: string
|
||||
operator:
|
||||
description: |-
|
||||
operator represents a key's relationship to a set of values.
|
||||
Valid operators are In, NotIn, Exists and DoesNotExist.
|
||||
type: string
|
||||
values:
|
||||
description: |-
|
||||
values is an array of string values. If the operator is In or NotIn,
|
||||
the values array must be non-empty. If the operator is Exists or DoesNotExist,
|
||||
the values array must be empty. This array is replaced during a strategic
|
||||
merge patch.
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
x-kubernetes-list-type: atomic
|
||||
x-kubernetes-list-type: atomic
|
||||
matchLabels:
|
||||
description: |-
|
||||
matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels
|
||||
map is equivalent to an element of matchExpressions, whose key field is "key", the
|
||||
operator is "In", and the values array contains only "value". The requirements are ANDed.
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
x-kubernetes-map-type: atomic
|
||||
matchLabelKeys:
|
||||
description: |-
|
||||
MatchLabelKeys is a set of pod label keys to select the pods over which
|
||||
spreading will be calculated. The keys are used to lookup values from the
|
||||
incoming pod labels, those key-value labels are ANDed with labelSelector
|
||||
to select the group of existing pods over which spreading will be calculated
|
||||
for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector.
|
||||
MatchLabelKeys cannot be set when LabelSelector isn't set.
|
||||
Keys that don't exist in the incoming pod labels will
|
||||
be ignored. A null or empty list means only match against labelSelector.
|
||||
|
||||
This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default).
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
x-kubernetes-list-type: atomic
|
||||
maxSkew:
|
||||
description: |-
|
||||
MaxSkew describes the degree to which pods may be unevenly distributed.
|
||||
When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference
|
||||
between the number of matching pods in the target topology and the global minimum.
|
||||
The global minimum is the minimum number of matching pods in an eligible domain
|
||||
or zero if the number of eligible domains is less than MinDomains.
|
||||
For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same
|
||||
labelSelector spread as 2/2/1:
|
||||
In this case, the global minimum is 1.
|
||||
| zone1 | zone2 | zone3 |
|
||||
| P P | P P | P |
|
||||
- if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2;
|
||||
scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2)
|
||||
violate MaxSkew(1).
|
||||
- if MaxSkew is 2, incoming pod can be scheduled onto any zone.
|
||||
When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence
|
||||
to topologies that satisfy it.
|
||||
It's a required field. Default value is 1 and 0 is not allowed.
|
||||
type: integer
|
||||
format: int32
|
||||
minDomains:
|
||||
description: |-
|
||||
MinDomains indicates a minimum number of eligible domains.
|
||||
When the number of eligible domains with matching topology keys is less than minDomains,
|
||||
Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed.
|
||||
And when the number of eligible domains with matching topology keys equals or greater than minDomains,
|
||||
this value has no effect on scheduling.
|
||||
As a result, when the number of eligible domains is less than minDomains,
|
||||
scheduler won't schedule more than maxSkew Pods to those domains.
|
||||
If value is nil, the constraint behaves as if MinDomains is equal to 1.
|
||||
Valid values are integers greater than 0.
|
||||
When value is not nil, WhenUnsatisfiable must be DoNotSchedule.
|
||||
|
||||
For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same
|
||||
labelSelector spread as 2/2/2:
|
||||
| zone1 | zone2 | zone3 |
|
||||
| P P | P P | P P |
|
||||
The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0.
|
||||
In this situation, new pod with the same labelSelector cannot be scheduled,
|
||||
because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones,
|
||||
it will violate MaxSkew.
|
||||
type: integer
|
||||
format: int32
|
||||
nodeAffinityPolicy:
|
||||
description: |-
|
||||
NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector
|
||||
when calculating pod topology spread skew. Options are:
|
||||
- Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations.
|
||||
- Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations.
|
||||
|
||||
If this value is nil, the behavior is equivalent to the Honor policy.
|
||||
This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag.
|
||||
type: string
|
||||
nodeTaintsPolicy:
|
||||
description: |-
|
||||
NodeTaintsPolicy indicates how we will treat node taints when calculating
|
||||
pod topology spread skew. Options are:
|
||||
- Honor: nodes without taints, along with tainted nodes for which the incoming pod
|
||||
has a toleration, are included.
|
||||
- Ignore: node taints are ignored. All nodes are included.
|
||||
|
||||
If this value is nil, the behavior is equivalent to the Ignore policy.
|
||||
This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag.
|
||||
type: string
|
||||
topologyKey:
|
||||
description: |-
|
||||
TopologyKey is the key of node labels. Nodes that have a label with this key
|
||||
and identical values are considered to be in the same topology.
|
||||
We consider each <key, value> as a "bucket", and try to put balanced number
|
||||
of pods into each bucket.
|
||||
We define a domain as a particular instance of a topology.
|
||||
Also, we define an eligible domain as a domain whose nodes meet the requirements of
|
||||
nodeAffinityPolicy and nodeTaintsPolicy.
|
||||
e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology.
|
||||
And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology.
|
||||
It's a required field.
|
||||
type: string
|
||||
whenUnsatisfiable:
|
||||
description: |-
|
||||
WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy
|
||||
the spread constraint.
|
||||
- DoNotSchedule (default) tells the scheduler not to schedule it.
|
||||
- ScheduleAnyway tells the scheduler to schedule the pod in any location,
|
||||
but giving higher precedence to topologies that would help reduce the
|
||||
skew.
|
||||
A constraint is considered "Unsatisfiable" for an incoming pod
|
||||
if and only if every possible node assignment for that pod would violate
|
||||
"MaxSkew" on some topology.
|
||||
For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same
|
||||
labelSelector spread as 3/1/1:
|
||||
| zone1 | zone2 | zone3 |
|
||||
| P P P | P | P |
|
||||
If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled
|
||||
to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies
|
||||
MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler
|
||||
won't make it *more* imbalanced.
|
||||
It's a required field.
|
||||
type: string
|
||||
tailscale:
|
||||
description: |-
|
||||
TailscaleConfig contains options to configure the tailscale-specific
|
||||
|
||||
@@ -2323,182 +2323,6 @@ spec:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
topologySpreadConstraints:
|
||||
description: |-
|
||||
Proxy Pod's topology spread constraints.
|
||||
By default Tailscale Kubernetes operator does not apply any topology spread constraints.
|
||||
https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/
|
||||
items:
|
||||
description: TopologySpreadConstraint specifies how to spread matching pods among the given topology.
|
||||
properties:
|
||||
labelSelector:
|
||||
description: |-
|
||||
LabelSelector is used to find matching pods.
|
||||
Pods that match this label selector are counted to determine the number of pods
|
||||
in their corresponding topology domain.
|
||||
properties:
|
||||
matchExpressions:
|
||||
description: matchExpressions is a list of label selector requirements. The requirements are ANDed.
|
||||
items:
|
||||
description: |-
|
||||
A label selector requirement is a selector that contains values, a key, and an operator that
|
||||
relates the key and values.
|
||||
properties:
|
||||
key:
|
||||
description: key is the label key that the selector applies to.
|
||||
type: string
|
||||
operator:
|
||||
description: |-
|
||||
operator represents a key's relationship to a set of values.
|
||||
Valid operators are In, NotIn, Exists and DoesNotExist.
|
||||
type: string
|
||||
values:
|
||||
description: |-
|
||||
values is an array of string values. If the operator is In or NotIn,
|
||||
the values array must be non-empty. If the operator is Exists or DoesNotExist,
|
||||
the values array must be empty. This array is replaced during a strategic
|
||||
merge patch.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
x-kubernetes-list-type: atomic
|
||||
required:
|
||||
- key
|
||||
- operator
|
||||
type: object
|
||||
type: array
|
||||
x-kubernetes-list-type: atomic
|
||||
matchLabels:
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: |-
|
||||
matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels
|
||||
map is equivalent to an element of matchExpressions, whose key field is "key", the
|
||||
operator is "In", and the values array contains only "value". The requirements are ANDed.
|
||||
type: object
|
||||
type: object
|
||||
x-kubernetes-map-type: atomic
|
||||
matchLabelKeys:
|
||||
description: |-
|
||||
MatchLabelKeys is a set of pod label keys to select the pods over which
|
||||
spreading will be calculated. The keys are used to lookup values from the
|
||||
incoming pod labels, those key-value labels are ANDed with labelSelector
|
||||
to select the group of existing pods over which spreading will be calculated
|
||||
for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector.
|
||||
MatchLabelKeys cannot be set when LabelSelector isn't set.
|
||||
Keys that don't exist in the incoming pod labels will
|
||||
be ignored. A null or empty list means only match against labelSelector.
|
||||
|
||||
This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default).
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
x-kubernetes-list-type: atomic
|
||||
maxSkew:
|
||||
description: |-
|
||||
MaxSkew describes the degree to which pods may be unevenly distributed.
|
||||
When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference
|
||||
between the number of matching pods in the target topology and the global minimum.
|
||||
The global minimum is the minimum number of matching pods in an eligible domain
|
||||
or zero if the number of eligible domains is less than MinDomains.
|
||||
For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same
|
||||
labelSelector spread as 2/2/1:
|
||||
In this case, the global minimum is 1.
|
||||
| zone1 | zone2 | zone3 |
|
||||
| P P | P P | P |
|
||||
- if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2;
|
||||
scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2)
|
||||
violate MaxSkew(1).
|
||||
- if MaxSkew is 2, incoming pod can be scheduled onto any zone.
|
||||
When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence
|
||||
to topologies that satisfy it.
|
||||
It's a required field. Default value is 1 and 0 is not allowed.
|
||||
format: int32
|
||||
type: integer
|
||||
minDomains:
|
||||
description: |-
|
||||
MinDomains indicates a minimum number of eligible domains.
|
||||
When the number of eligible domains with matching topology keys is less than minDomains,
|
||||
Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed.
|
||||
And when the number of eligible domains with matching topology keys equals or greater than minDomains,
|
||||
this value has no effect on scheduling.
|
||||
As a result, when the number of eligible domains is less than minDomains,
|
||||
scheduler won't schedule more than maxSkew Pods to those domains.
|
||||
If value is nil, the constraint behaves as if MinDomains is equal to 1.
|
||||
Valid values are integers greater than 0.
|
||||
When value is not nil, WhenUnsatisfiable must be DoNotSchedule.
|
||||
|
||||
For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same
|
||||
labelSelector spread as 2/2/2:
|
||||
| zone1 | zone2 | zone3 |
|
||||
| P P | P P | P P |
|
||||
The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0.
|
||||
In this situation, new pod with the same labelSelector cannot be scheduled,
|
||||
because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones,
|
||||
it will violate MaxSkew.
|
||||
format: int32
|
||||
type: integer
|
||||
nodeAffinityPolicy:
|
||||
description: |-
|
||||
NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector
|
||||
when calculating pod topology spread skew. Options are:
|
||||
- Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations.
|
||||
- Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations.
|
||||
|
||||
If this value is nil, the behavior is equivalent to the Honor policy.
|
||||
This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag.
|
||||
type: string
|
||||
nodeTaintsPolicy:
|
||||
description: |-
|
||||
NodeTaintsPolicy indicates how we will treat node taints when calculating
|
||||
pod topology spread skew. Options are:
|
||||
- Honor: nodes without taints, along with tainted nodes for which the incoming pod
|
||||
has a toleration, are included.
|
||||
- Ignore: node taints are ignored. All nodes are included.
|
||||
|
||||
If this value is nil, the behavior is equivalent to the Ignore policy.
|
||||
This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag.
|
||||
type: string
|
||||
topologyKey:
|
||||
description: |-
|
||||
TopologyKey is the key of node labels. Nodes that have a label with this key
|
||||
and identical values are considered to be in the same topology.
|
||||
We consider each <key, value> as a "bucket", and try to put balanced number
|
||||
of pods into each bucket.
|
||||
We define a domain as a particular instance of a topology.
|
||||
Also, we define an eligible domain as a domain whose nodes meet the requirements of
|
||||
nodeAffinityPolicy and nodeTaintsPolicy.
|
||||
e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology.
|
||||
And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology.
|
||||
It's a required field.
|
||||
type: string
|
||||
whenUnsatisfiable:
|
||||
description: |-
|
||||
WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy
|
||||
the spread constraint.
|
||||
- DoNotSchedule (default) tells the scheduler not to schedule it.
|
||||
- ScheduleAnyway tells the scheduler to schedule the pod in any location,
|
||||
but giving higher precedence to topologies that would help reduce the
|
||||
skew.
|
||||
A constraint is considered "Unsatisfiable" for an incoming pod
|
||||
if and only if every possible node assignment for that pod would violate
|
||||
"MaxSkew" on some topology.
|
||||
For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same
|
||||
labelSelector spread as 3/1/1:
|
||||
| zone1 | zone2 | zone3 |
|
||||
| P P P | P | P |
|
||||
If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled
|
||||
to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies
|
||||
MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler
|
||||
won't make it *more* imbalanced.
|
||||
It's a required field.
|
||||
type: string
|
||||
required:
|
||||
- maxSkew
|
||||
- topologyKey
|
||||
- whenUnsatisfiable
|
||||
type: object
|
||||
type: array
|
||||
type: object
|
||||
type: object
|
||||
tailscale:
|
||||
|
||||
@@ -432,148 +432,6 @@ func TestTailnetTargetIPAnnotation(t *testing.T) {
|
||||
expectMissing[corev1.Secret](t, fc, "operator-ns", fullName)
|
||||
}
|
||||
|
||||
func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) {
|
||||
fc := fake.NewFakeClient()
|
||||
ft := &fakeTSClient{}
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
clock := tstest.NewClock(tstest.ClockOpts{})
|
||||
sr := &ServiceReconciler{
|
||||
Client: fc,
|
||||
ssr: &tailscaleSTSReconciler{
|
||||
Client: fc,
|
||||
tsClient: ft,
|
||||
defaultTags: []string{"tag:k8s"},
|
||||
operatorNamespace: "operator-ns",
|
||||
proxyImage: "tailscale/tailscale",
|
||||
},
|
||||
logger: zl.Sugar(),
|
||||
clock: clock,
|
||||
recorder: record.NewFakeRecorder(100),
|
||||
}
|
||||
tailnetTargetIP := "invalid-ip"
|
||||
mustCreate(t, fc, &corev1.Service{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "test",
|
||||
Namespace: "default",
|
||||
|
||||
UID: types.UID("1234-UID"),
|
||||
Annotations: map[string]string{
|
||||
AnnotationTailnetTargetIP: tailnetTargetIP,
|
||||
},
|
||||
},
|
||||
Spec: corev1.ServiceSpec{
|
||||
ClusterIP: "10.20.30.40",
|
||||
Type: corev1.ServiceTypeLoadBalancer,
|
||||
LoadBalancerClass: ptr.To("tailscale"),
|
||||
},
|
||||
})
|
||||
|
||||
expectReconciled(t, sr, "default", "test")
|
||||
|
||||
t0 := conditionTime(clock)
|
||||
|
||||
want := &corev1.Service{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "test",
|
||||
Namespace: "default",
|
||||
UID: types.UID("1234-UID"),
|
||||
Annotations: map[string]string{
|
||||
AnnotationTailnetTargetIP: tailnetTargetIP,
|
||||
},
|
||||
},
|
||||
Spec: corev1.ServiceSpec{
|
||||
ClusterIP: "10.20.30.40",
|
||||
Type: corev1.ServiceTypeLoadBalancer,
|
||||
LoadBalancerClass: ptr.To("tailscale"),
|
||||
},
|
||||
Status: corev1.ServiceStatus{
|
||||
Conditions: []metav1.Condition{{
|
||||
Type: string(tsapi.ProxyReady),
|
||||
Status: metav1.ConditionFalse,
|
||||
LastTransitionTime: t0,
|
||||
Reason: reasonProxyInvalid,
|
||||
Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "invalid-ip" could not be parsed as a valid IP Address, error: ParseAddr("invalid-ip"): unable to parse IP`,
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
expectEqual(t, fc, want, nil)
|
||||
}
|
||||
|
||||
func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) {
|
||||
fc := fake.NewFakeClient()
|
||||
ft := &fakeTSClient{}
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
clock := tstest.NewClock(tstest.ClockOpts{})
|
||||
sr := &ServiceReconciler{
|
||||
Client: fc,
|
||||
ssr: &tailscaleSTSReconciler{
|
||||
Client: fc,
|
||||
tsClient: ft,
|
||||
defaultTags: []string{"tag:k8s"},
|
||||
operatorNamespace: "operator-ns",
|
||||
proxyImage: "tailscale/tailscale",
|
||||
},
|
||||
logger: zl.Sugar(),
|
||||
clock: clock,
|
||||
recorder: record.NewFakeRecorder(100),
|
||||
}
|
||||
tailnetTargetIP := "999.999.999.999"
|
||||
mustCreate(t, fc, &corev1.Service{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "test",
|
||||
Namespace: "default",
|
||||
|
||||
UID: types.UID("1234-UID"),
|
||||
Annotations: map[string]string{
|
||||
AnnotationTailnetTargetIP: tailnetTargetIP,
|
||||
},
|
||||
},
|
||||
Spec: corev1.ServiceSpec{
|
||||
ClusterIP: "10.20.30.40",
|
||||
Type: corev1.ServiceTypeLoadBalancer,
|
||||
LoadBalancerClass: ptr.To("tailscale"),
|
||||
},
|
||||
})
|
||||
|
||||
expectReconciled(t, sr, "default", "test")
|
||||
|
||||
t0 := conditionTime(clock)
|
||||
|
||||
want := &corev1.Service{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "test",
|
||||
Namespace: "default",
|
||||
UID: types.UID("1234-UID"),
|
||||
Annotations: map[string]string{
|
||||
AnnotationTailnetTargetIP: tailnetTargetIP,
|
||||
},
|
||||
},
|
||||
Spec: corev1.ServiceSpec{
|
||||
ClusterIP: "10.20.30.40",
|
||||
Type: corev1.ServiceTypeLoadBalancer,
|
||||
LoadBalancerClass: ptr.To("tailscale"),
|
||||
},
|
||||
Status: corev1.ServiceStatus{
|
||||
Conditions: []metav1.Condition{{
|
||||
Type: string(tsapi.ProxyReady),
|
||||
Status: metav1.ConditionFalse,
|
||||
LastTransitionTime: t0,
|
||||
Reason: reasonProxyInvalid,
|
||||
Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "999.999.999.999" could not be parsed as a valid IP Address, error: ParseAddr("999.999.999.999"): IPv4 field has value >255`,
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
expectEqual(t, fc, want, nil)
|
||||
}
|
||||
|
||||
func TestAnnotations(t *testing.T) {
|
||||
fc := fake.NewFakeClient()
|
||||
ft := &fakeTSClient{}
|
||||
|
||||
@@ -718,7 +718,6 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet,
|
||||
ss.Spec.Template.Spec.NodeSelector = wantsPod.NodeSelector
|
||||
ss.Spec.Template.Spec.Affinity = wantsPod.Affinity
|
||||
ss.Spec.Template.Spec.Tolerations = wantsPod.Tolerations
|
||||
ss.Spec.Template.Spec.TopologySpreadConstraints = wantsPod.TopologySpreadConstraints
|
||||
|
||||
// Update containers.
|
||||
updateContainer := func(overlay *tsapi.Container, base corev1.Container) corev1.Container {
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
appsv1 "k8s.io/api/apps/v1"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
"k8s.io/apimachinery/pkg/api/resource"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"sigs.k8s.io/yaml"
|
||||
tsapi "tailscale.com/k8s-operator/apis/v1alpha1"
|
||||
"tailscale.com/types/ptr"
|
||||
@@ -74,16 +73,6 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) {
|
||||
NodeSelector: map[string]string{"beta.kubernetes.io/os": "linux"},
|
||||
Affinity: &corev1.Affinity{NodeAffinity: &corev1.NodeAffinity{RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{}}},
|
||||
Tolerations: []corev1.Toleration{{Key: "", Operator: "Exists"}},
|
||||
TopologySpreadConstraints: []corev1.TopologySpreadConstraint{
|
||||
{
|
||||
WhenUnsatisfiable: "DoNotSchedule",
|
||||
TopologyKey: "kubernetes.io/hostname",
|
||||
MaxSkew: 3,
|
||||
LabelSelector: &metav1.LabelSelector{
|
||||
MatchLabels: map[string]string{"foo": "bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
TailscaleContainer: &tsapi.Container{
|
||||
SecurityContext: &corev1.SecurityContext{
|
||||
Privileged: ptr.To(true),
|
||||
@@ -170,7 +159,6 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) {
|
||||
wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector
|
||||
wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity
|
||||
wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations
|
||||
wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints
|
||||
wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext
|
||||
wantSS.Spec.Template.Spec.InitContainers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleInitContainer.SecurityContext
|
||||
wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources
|
||||
@@ -213,7 +201,6 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) {
|
||||
wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector
|
||||
wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity
|
||||
wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations
|
||||
wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints
|
||||
wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext
|
||||
wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources
|
||||
wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, []corev1.EnvVar{{Name: "foo", Value: "bar"}, {Name: "TS_USERSPACE", Value: "true"}, {Name: "bar"}}...)
|
||||
|
||||
@@ -358,14 +358,9 @@ func validateService(svc *corev1.Service) []string {
|
||||
violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q does not appear to be a valid MagicDNS name", AnnotationTailnetTargetFQDN, fqdn))
|
||||
}
|
||||
}
|
||||
if ipStr := svc.Annotations[AnnotationTailnetTargetIP]; ipStr != "" {
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q could not be parsed as a valid IP Address, error: %s", AnnotationTailnetTargetIP, ipStr, err))
|
||||
} else if !ip.IsValid() {
|
||||
violations = append(violations, fmt.Sprintf("parsed IP address in annotation %s: %q is not valid", AnnotationTailnetTargetIP, ipStr))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(irbekrm): validate that tailscale.com/tailnet-ip annotation is a
|
||||
// valid IP address (tailscale/tailscale#13671).
|
||||
|
||||
svcName := nameForService(svc)
|
||||
if err := dnsname.ValidLabel(svcName); err != nil {
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Tailscale LOPOWER
|
||||
|
||||
"Little Opinionated Proxy Over Wireguard-encrypted Routes"
|
||||
|
||||
**STATUS**: in-development alpha (as of 2024-11-03)
|
||||
|
||||
## Background
|
||||
|
||||
Some small devices such as ESP32 microcontrollers [support WireGuard](https://github.com/ciniml/WireGuard-ESP32-Arduino) but are too small to run Tailscale.
|
||||
|
||||
Tailscale LOPOWER is a proxy that you run nearby that bridges a low-power WireGuard-speaking device on one side to Tailscale on the other side. That way network traffic from the low-powered device never hits the network unencrypted but is still able to communicate to/from other Tailscale devices on your Tailnet.
|
||||
|
||||
## Diagram
|
||||
|
||||
<img src="./lopower.svg">
|
||||
|
||||
## Features
|
||||
|
||||
* Runs separate Wireguard server with separate keys (unknown to the Tailscale control plane) that proxy on to Tailscale
|
||||
* Outputs WireGuard-standard configuration to enrolls devices, including in QR code form.
|
||||
* embeds `tsnet`, with an identity on which the device(s) behind the proxy appear on your Tailnet
|
||||
* optional IPv4 support. IPv6 is always enabled, as it never conflicts with anything. But IPv4 (or CGNAT) might already be in use on your client's network.
|
||||
* includes a DNS server (at `fd7a:115c:a1e0:9909::1` by default and optionally also at `10.90.0.1`) to serve both MagicDNS names as well as forwarding non-Tailscale DNS names onwards
|
||||
* if IPv4 is disabled, MagicDNS `A` records are filtered out, and only `AAAA` records are served.
|
||||
|
||||
## Limitations
|
||||
|
||||
* this runs in userspace using gVisor's netstack. That means it's portable (and doesn't require kernel/system configuration), but that does mean it doesn't operate at a packet level but rather it stitches together two separate TCP (or UDP) flows and doesn't support IP protocols such as SCTP or other things that aren't TCP or UDP.
|
||||
* the standard WireGuard configuration doesn't support specifying DNS search domains, so resolving bare names like the `go` in `http://go/foo` won't work and you need to resolve names using the fully qualified `go.your-tailnet.ts.net` names.
|
||||
* since it's based on userspace tsnet mode, it doesn't pick up your system DNS configuration (yet?) and instead resolves non-tailnet DNS names using either your "Override DNS" tailnet settings for the global DNS resolver, or else defaults to `8.8.8.8` and `1.1.1.1` (using DoH) if that isn't set.
|
||||
|
||||
## TODO
|
||||
|
||||
* provisioning more than one low-powered device is possible, but requires manual config file edits. It should be possible to enroll multiple devices (including QR code support) easily.
|
||||
* incoming connections (from Tailscale to `lopower`) don't yet forward to the low-powered devices. When there's only one low-powered device, the mapping policy is obvious. When there are multiple, it's not as obvious. Maybe the answer is supporting [4via6 subnet routers](https://tailscale.com/kb/1201/4via6-subnets).
|
||||
|
||||
## Installing
|
||||
|
||||
* git clone this repo, switch to `lp` branch, `go install ./cmd/lopower` and see `lopower --help`.
|
||||
@@ -1,872 +0,0 @@
|
||||
// The lopower server is a "Little Opinionated Proxy Over
|
||||
// Wireguard-Encrypted Route". It bridges a static WireGuard
|
||||
// client into a Tailscale network.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
qrcode "github.com/skip2/go-qrcode"
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"golang.org/x/sys/unix"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/tsnet"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/wgengine/wgcfg"
|
||||
)
|
||||
|
||||
var (
|
||||
wgListenPort = flag.Int("wg-port", 51820, "port number to listen on for WireGuard from the client")
|
||||
confDir = flag.String("dir", filepath.Join(os.Getenv("HOME"), ".config/lopower"), "directory to store configuration in")
|
||||
wgPubHost = flag.String("wg-host", "0.0.0.1", "IP address of lopower's WireGuard server that's accessible from the client")
|
||||
qrListenAddr = flag.String("qr-listen", "127.0.0.1:8014", "HTTP address to serve a QR code for client's WireGuard configuration, or empty for none")
|
||||
printConfig = flag.Bool("print-config", true, "print the client's WireGuard configuration to stdout on startup")
|
||||
includeV4 = flag.Bool("include-v4", true, "include IPv4 (CGNAT) in the WireGuard configuration; incompatible with some carriers. IPv6 is always included.")
|
||||
verbosePackets = flag.Bool("verbose-packets", false, "log packet contents")
|
||||
)
|
||||
|
||||
type config struct {
|
||||
PrivKey key.NodePrivate // the proxy server's key
|
||||
Peers []Peer
|
||||
|
||||
// V4 and V6 are the local IPs.
|
||||
V4 netip.Addr
|
||||
V6 netip.Addr
|
||||
|
||||
// CIDRs are used to allocate IPs to peers.
|
||||
V4CIDR netip.Prefix
|
||||
V6CIDR netip.Prefix
|
||||
}
|
||||
|
||||
// IsLocalIP reports whether ip is one of the local IPs.
|
||||
func (c *config) IsLocalIP(ip netip.Addr) bool {
|
||||
return ip.IsValid() && (ip == c.V4 || ip == c.V6)
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
PrivKey key.NodePrivate // e.g. proxy client's
|
||||
V4 netip.Addr
|
||||
V6 netip.Addr
|
||||
}
|
||||
|
||||
func (lp *lpServer) storeConfigLocked() {
|
||||
path := filepath.Join(lp.dir, "config.json")
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
log.Fatalf("os.MkdirAll(%q): %v", filepath.Dir(path), err)
|
||||
}
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
log.Fatalf("os.OpenFile(%q): %v", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
must.Do(json.NewEncoder(f).Encode(lp.c))
|
||||
if err := f.Close(); err != nil {
|
||||
log.Fatalf("f.Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (lp *lpServer) loadConfig() {
|
||||
path := filepath.Join(lp.dir, "config.json")
|
||||
f, err := os.Open(path)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
var cfg *config
|
||||
must.Do(json.NewDecoder(f).Decode(&cfg))
|
||||
if len(cfg.Peers) > 0 { // as early version didn't set this
|
||||
lp.mu.Lock()
|
||||
defer lp.mu.Unlock()
|
||||
lp.c = cfg
|
||||
}
|
||||
return
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
log.Fatalf("os.OpenFile(%q): %v", path, err)
|
||||
}
|
||||
const defaultV4CIDR = "10.90.0.0/24"
|
||||
const defaultV6CIDR = "fd7a:115c:a1e0:9909::/64" // 9909 = above QWERTY "LOPO"(wer)
|
||||
c := &config{
|
||||
PrivKey: key.NewNode(),
|
||||
V4CIDR: netip.MustParsePrefix(defaultV4CIDR),
|
||||
V6CIDR: netip.MustParsePrefix(defaultV6CIDR),
|
||||
}
|
||||
c.V4 = c.V4CIDR.Addr().Next()
|
||||
c.V6 = c.V6CIDR.Addr().Next()
|
||||
c.Peers = append(c.Peers, Peer{
|
||||
PrivKey: key.NewNode(),
|
||||
V4: c.V4.Next(),
|
||||
V6: c.V6.Next(),
|
||||
})
|
||||
|
||||
lp.mu.Lock()
|
||||
defer lp.mu.Unlock()
|
||||
lp.c = c
|
||||
lp.storeConfigLocked()
|
||||
return
|
||||
}
|
||||
|
||||
func (lp *lpServer) reconfig() {
|
||||
lp.mu.Lock()
|
||||
wc := &wgcfg.Config{
|
||||
Name: "lopower0",
|
||||
PrivateKey: lp.c.PrivKey,
|
||||
ListenPort: uint16(*wgListenPort),
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(lp.c.V4, 32),
|
||||
netip.PrefixFrom(lp.c.V6, 128),
|
||||
},
|
||||
}
|
||||
for _, p := range lp.c.Peers {
|
||||
wc.Peers = append(wc.Peers, wgcfg.Peer{
|
||||
PublicKey: p.PrivKey.Public(),
|
||||
AllowedIPs: []netip.Prefix{
|
||||
netip.PrefixFrom(p.V4, 32),
|
||||
netip.PrefixFrom(p.V6, 128),
|
||||
},
|
||||
})
|
||||
}
|
||||
lp.mu.Unlock()
|
||||
must.Do(wgcfg.ReconfigDevice(lp.d, wc, log.Printf))
|
||||
}
|
||||
|
||||
func newLP(ctx context.Context) *lpServer {
|
||||
logf := log.Printf
|
||||
deviceLogger := &device.Logger{
|
||||
Verbosef: logger.Discard,
|
||||
Errorf: logf,
|
||||
}
|
||||
lp := &lpServer{
|
||||
ctx: ctx,
|
||||
dir: *confDir,
|
||||
readCh: make(chan *stack.PacketBuffer, 16),
|
||||
}
|
||||
lp.loadConfig()
|
||||
lp.initNetstack(ctx)
|
||||
nst := &nsTUN{
|
||||
lp: lp,
|
||||
closeCh: make(chan struct{}),
|
||||
evChan: make(chan tun.Event),
|
||||
}
|
||||
|
||||
wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger)
|
||||
lp.d = wgdev
|
||||
must.Do(wgdev.Up())
|
||||
lp.reconfig()
|
||||
|
||||
if *printConfig {
|
||||
log.Printf("Device Wireguard config is:\n%s", lp.wgConfigForQR())
|
||||
}
|
||||
|
||||
lp.startTSNet(ctx)
|
||||
return lp
|
||||
}
|
||||
|
||||
type lpServer struct {
|
||||
dir string
|
||||
tsnet *tsnet.Server
|
||||
d *device.Device
|
||||
ns *stack.Stack
|
||||
ctx context.Context // canceled on shutdown
|
||||
linkEP *channel.Endpoint
|
||||
readCh chan *stack.PacketBuffer // from gvisor/dns server => out to network
|
||||
|
||||
// protocolConns tracks the number of active connections for each connection.
|
||||
// It is used to add and remove protocol addresses from netstack as needed.
|
||||
protocolConns syncs.Map[tcpip.ProtocolAddress, *atomic.Int32]
|
||||
|
||||
mu sync.Mutex // protects following
|
||||
c *config
|
||||
}
|
||||
|
||||
// MaxPacketSize is the maximum size (in bytes)
|
||||
// of a packet that can be injected into lpServer.
|
||||
const MaxPacketSize = device.MaxContentSize
|
||||
const nicID = 1
|
||||
|
||||
func (lp *lpServer) initNetstack(ctx context.Context) error {
|
||||
ns := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
udp.NewProtocol,
|
||||
},
|
||||
})
|
||||
lp.ns = ns
|
||||
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
|
||||
if tcpipErr := ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt); tcpipErr != nil {
|
||||
return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr)
|
||||
}
|
||||
lp.linkEP = channel.New(512, 1280, "")
|
||||
if tcpipProblem := ns.CreateNIC(nicID, lp.linkEP); tcpipProblem != nil {
|
||||
return fmt.Errorf("CreateNIC: %v", tcpipProblem)
|
||||
}
|
||||
ns.SetPromiscuousMode(nicID, true)
|
||||
|
||||
lp.mu.Lock()
|
||||
v4, v6 := lp.c.V4, lp.c.V6
|
||||
lp.mu.Unlock()
|
||||
prefix := tcpip.AddrFrom4Slice(v4.AsSlice()).WithPrefix()
|
||||
if *includeV4 {
|
||||
if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: prefix,
|
||||
}, stack.AddressProperties{}); tcpProb != nil {
|
||||
return errors.New(tcpProb.String())
|
||||
}
|
||||
}
|
||||
prefix = tcpip.AddrFrom16Slice(v6.AsSlice()).WithPrefix()
|
||||
if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
AddressWithPrefix: prefix,
|
||||
}, stack.AddressProperties{}); tcpProb != nil {
|
||||
return errors.New(tcpProb.String())
|
||||
}
|
||||
|
||||
ipv4Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 4)), tcpip.MaskFromBytes(make([]byte, 4)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create IPv4 subnet: %v", err)
|
||||
}
|
||||
ipv6Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 16)), tcpip.MaskFromBytes(make([]byte, 16)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create IPv6 subnet: %v", err)
|
||||
}
|
||||
|
||||
routes := []tcpip.Route{{
|
||||
Destination: ipv4Subnet,
|
||||
NIC: nicID,
|
||||
}, {
|
||||
Destination: ipv6Subnet,
|
||||
NIC: nicID,
|
||||
}}
|
||||
if !*includeV4 {
|
||||
routes = routes[1:]
|
||||
}
|
||||
|
||||
ns.SetRouteTable(routes)
|
||||
|
||||
const tcpReceiveBufferSize = 0 // default
|
||||
const maxInFlightConnectionAttempts = 8192
|
||||
tcpFwd := tcp.NewForwarder(ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, lp.acceptTCP)
|
||||
udpFwd := udp.NewForwarder(ns, lp.acceptUDP)
|
||||
ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) {
|
||||
return tcpFwd.HandlePacket(tei, pb)
|
||||
})
|
||||
ns.SetTransportProtocolHandler(udp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) {
|
||||
return udpFwd.HandlePacket(tei, pb)
|
||||
})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
pkt := lp.linkEP.ReadContext(ctx)
|
||||
if pkt == nil {
|
||||
if ctx.Err() != nil {
|
||||
// Return without logging.
|
||||
log.Printf("linkEP.ReadContext: %v", ctx.Err())
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
size := pkt.Size()
|
||||
if size > MaxPacketSize || size == 0 {
|
||||
pkt.DecRef()
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case lp.readCh <- pkt:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr {
|
||||
switch s.Len() {
|
||||
case 4:
|
||||
return netip.AddrFrom4(s.As4())
|
||||
case 16:
|
||||
return netip.AddrFrom16(s.As16()).Unmap()
|
||||
}
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (lp *lpServer) trackProtocolAddr(destIP netip.Addr) (untrack func()) {
|
||||
pa := tcpip.ProtocolAddress{
|
||||
AddressWithPrefix: tcpip.AddrFromSlice(destIP.AsSlice()).WithPrefix(),
|
||||
}
|
||||
if destIP.Is4() {
|
||||
pa.Protocol = ipv4.ProtocolNumber
|
||||
} else if destIP.Is6() {
|
||||
pa.Protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
|
||||
addrConns, _ := lp.protocolConns.LoadOrInit(pa, func() *atomic.Int32 { return new(atomic.Int32) })
|
||||
if addrConns.Add(1) == 1 {
|
||||
lp.ns.AddProtocolAddress(nicID, pa, stack.AddressProperties{
|
||||
PEB: stack.CanBePrimaryEndpoint, // zero value default
|
||||
ConfigType: stack.AddressConfigStatic, // zero value default
|
||||
})
|
||||
}
|
||||
return func() {
|
||||
if addrConns.Add(-1) == 0 {
|
||||
lp.ns.RemoveAddress(nicID, pa.AddressWithPrefix.Address)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (lp *lpServer) acceptUDP(r *udp.ForwarderRequest) {
|
||||
log.Printf("acceptUDP: %v", r.ID())
|
||||
destIP := netaddrIPFromNetstackIP(r.ID().LocalAddress)
|
||||
untrack := lp.trackProtocolAddr(destIP)
|
||||
var wq waiter.Queue
|
||||
ep, udpErr := r.CreateEndpoint(&wq)
|
||||
if udpErr != nil {
|
||||
log.Printf("CreateEndpoint: %v", udpErr)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer untrack()
|
||||
defer ep.Close()
|
||||
reqDetails := r.ID()
|
||||
|
||||
clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
|
||||
destPort := reqDetails.LocalPort
|
||||
if !clientRemoteIP.IsValid() {
|
||||
log.Printf("acceptUDP: invalid remote IP %v", reqDetails.RemoteAddress)
|
||||
return
|
||||
}
|
||||
|
||||
randPort := rand.IntN(65536-1024) + 1024
|
||||
v4, v6 := lp.tsnet.TailscaleIPs()
|
||||
var listenAddr netip.Addr
|
||||
if destIP.Is4() {
|
||||
listenAddr = v4
|
||||
} else {
|
||||
listenAddr = v6
|
||||
}
|
||||
backendConn, err := lp.tsnet.ListenPacket("udp", fmt.Sprintf("%s:%d", listenAddr, randPort))
|
||||
if err != nil {
|
||||
log.Printf("ListenPacket: %v", err)
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
clientConn := gonet.NewUDPConn(&wq, ep)
|
||||
defer clientConn.Close()
|
||||
errCh := make(chan error, 2)
|
||||
go func() (err error) {
|
||||
defer func() { errCh <- err }()
|
||||
var buf [64]byte
|
||||
for {
|
||||
n, _, err := backendConn.ReadFrom(buf[:])
|
||||
if err != nil {
|
||||
log.Printf("UDP read: %v", err)
|
||||
return err
|
||||
}
|
||||
_, err = clientConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}()
|
||||
dstAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", destIP, destPort))
|
||||
if err != nil {
|
||||
log.Printf("ResolveUDPAddr: %v", err)
|
||||
return
|
||||
}
|
||||
go func() (err error) {
|
||||
defer func() { errCh <- err }()
|
||||
var buf [2048]byte
|
||||
for {
|
||||
n, err := clientConn.Read(buf[:])
|
||||
if err != nil {
|
||||
log.Printf("UDP read: %v", err)
|
||||
return err
|
||||
}
|
||||
_, err = backendConn.WriteTo(buf[:n], dstAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}()
|
||||
err = <-errCh
|
||||
if err != nil {
|
||||
log.Printf("io.Copy: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) {
|
||||
log.Printf("acceptTCP: %v", r.ID())
|
||||
reqDetails := r.ID()
|
||||
destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress)
|
||||
clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
|
||||
destPort := reqDetails.LocalPort
|
||||
if !clientRemoteIP.IsValid() {
|
||||
log.Printf("acceptTCP: invalid remote IP %v", reqDetails.RemoteAddress)
|
||||
r.Complete(true) // sends a RST
|
||||
return
|
||||
}
|
||||
untrack := lp.trackProtocolAddr(destIP)
|
||||
defer untrack()
|
||||
|
||||
var wq waiter.Queue
|
||||
ep, tcpErr := r.CreateEndpoint(&wq)
|
||||
if tcpErr != nil {
|
||||
log.Printf("CreateEndpoint: %v", tcpErr)
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
defer ep.Close()
|
||||
ep.SocketOptions().SetKeepAlive(true)
|
||||
|
||||
if destPort == 53 && lp.c.IsLocalIP(destIP) {
|
||||
tc := gonet.NewTCPConn(&wq, ep)
|
||||
defer tc.Close()
|
||||
r.Complete(false) // accept TCP connection
|
||||
lp.handleTCPDNSQuery(tc, netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort))
|
||||
return
|
||||
}
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
c, err := lp.tsnet.Dial(dialCtx, "tcp", fmt.Sprintf("%s:%d", destIP, destPort))
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("Dial(%s:%d): %v", destIP, destPort, err)
|
||||
r.Complete(true) // sends a RST
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
tc := gonet.NewTCPConn(&wq, ep)
|
||||
defer tc.Close()
|
||||
r.Complete(false) // accept TCP connection
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() { _, err := io.Copy(tc, c); errc <- err }()
|
||||
go func() { _, err := io.Copy(c, tc); errc <- err }()
|
||||
err = <-errc
|
||||
if err != nil {
|
||||
log.Printf("io.Copy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (lp *lpServer) wgConfigForQR() string {
|
||||
var b strings.Builder
|
||||
|
||||
p := lp.c.Peers[0]
|
||||
privHex, _ := p.PrivKey.MarshalText()
|
||||
privHex = bytes.TrimPrefix(privHex, []byte("privkey:"))
|
||||
priv := make([]byte, 32)
|
||||
got, err := hex.Decode(priv, privHex)
|
||||
if err != nil || got != 32 {
|
||||
log.Printf("marshal text was: %q", privHex)
|
||||
log.Fatalf("bad private key: %v, % bytes", err, got)
|
||||
}
|
||||
privb64 := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
fmt.Fprintf(&b, "[Interface]\nPrivateKey = %s\n", privb64)
|
||||
fmt.Fprintf(&b, "Address = %v,%v\n", p.V6, p.V4)
|
||||
|
||||
pubBin, _ := lp.c.PrivKey.Public().MarshalBinary()
|
||||
if len(pubBin) != 34 {
|
||||
log.Fatalf("bad pubkey length: %d", len(pubBin))
|
||||
}
|
||||
pubBin = pubBin[2:] // trim off "np"
|
||||
pubb64 := base64.StdEncoding.EncodeToString(pubBin)
|
||||
|
||||
fmt.Fprintf(&b, "\n[Peer]\nPublicKey = %v\n", pubb64)
|
||||
if *includeV4 {
|
||||
fmt.Fprintf(&b, "AllowedIPs = %v/32,%v/128,%v,%v\n", lp.c.V4, lp.c.V6, tsaddr.TailscaleULARange(), tsaddr.CGNATRange())
|
||||
} else {
|
||||
fmt.Fprintf(&b, "AllowedIPs = %v/128,%v\n", lp.c.V6, tsaddr.TailscaleULARange())
|
||||
}
|
||||
fmt.Fprintf(&b, "Endpoint = %v\n", net.JoinHostPort(*wgPubHost, fmt.Sprint(*wgListenPort)))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (lp *lpServer) serveQR() {
|
||||
ln, err := net.Listen("tcp", *qrListenAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("qr: %v", err)
|
||||
}
|
||||
log.Printf("# Serving QR code at http://%s/", ln.Addr())
|
||||
hs := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
conf := lp.wgConfigForQR()
|
||||
v, err := qrcode.Encode(conf, qrcode.Medium, 512)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(v)
|
||||
}),
|
||||
}
|
||||
if err := hs.Serve(ln); err != nil {
|
||||
log.Fatalf("qr: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type nsTUN struct {
|
||||
lp *lpServer
|
||||
closeCh chan struct{}
|
||||
evChan chan tun.Event
|
||||
}
|
||||
|
||||
func (t *nsTUN) File() *os.File {
|
||||
panic("nsTUN.File() called, which makes no sense")
|
||||
}
|
||||
|
||||
func (t *nsTUN) Close() error {
|
||||
close(t.closeCh)
|
||||
close(t.evChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read reads packets from gvisor (or the DNS server) to send out to the network.
|
||||
func (t *nsTUN) Read(out [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case <-t.closeCh:
|
||||
return 0, io.EOF
|
||||
case resPacket := <-t.lp.readCh:
|
||||
defer resPacket.DecRef()
|
||||
pkt := out[0][offset:]
|
||||
n := copy(pkt, resPacket.NetworkHeader().Slice())
|
||||
n += copy(pkt[n:], resPacket.TransportHeader().Slice())
|
||||
n += copy(pkt[n:], resPacket.Data().AsRange().ToSlice())
|
||||
if *verbosePackets {
|
||||
log.Printf("[v] nsTUN.Read (out): % 02x", pkt[:n])
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write accepts incoming packets. The packets begin at buffs[:][offset:],
|
||||
// like wireguard-go/tun.Device.Write. Write is called per-peer via
|
||||
// wireguard-go/device.Peer.RoutineSequentialReceiver, so it MUST be
|
||||
// thread-safe.
|
||||
func (t *nsTUN) Write(buffs [][]byte, offset int) (int, error) {
|
||||
var pkt packet.Parsed
|
||||
for _, buff := range buffs {
|
||||
raw := buff[offset:]
|
||||
pkt.Decode(raw)
|
||||
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(slices.Clone(raw)),
|
||||
})
|
||||
if *verbosePackets {
|
||||
log.Printf("[v] nsTUN.Write (in): % 02x", raw)
|
||||
}
|
||||
if pkt.IPProto == ipproto.UDP && pkt.Dst.Port() == 53 && t.lp.c.IsLocalIP(pkt.Dst.Addr()) {
|
||||
// Handle DNS queries before sending to gvisor.
|
||||
t.lp.handleDNSUDPQuery(raw)
|
||||
continue
|
||||
}
|
||||
if pkt.IPVersion == 4 {
|
||||
t.lp.linkEP.InjectInbound(ipv4.ProtocolNumber, packetBuf)
|
||||
} else if pkt.IPVersion == 6 {
|
||||
t.lp.linkEP.InjectInbound(ipv6.ProtocolNumber, packetBuf)
|
||||
}
|
||||
}
|
||||
return len(buffs), nil
|
||||
}
|
||||
|
||||
func (t *nsTUN) Flush() error { return nil }
|
||||
func (t *nsTUN) MTU() (int, error) { return 1500, nil }
|
||||
func (t *nsTUN) Name() (string, error) { return "nstun", nil }
|
||||
func (t *nsTUN) Events() <-chan tun.Event { return t.evChan }
|
||||
func (t *nsTUN) BatchSize() int { return 1 }
|
||||
|
||||
func (lp *lpServer) startTSNet(ctx context.Context) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ts := &tsnet.Server{
|
||||
Dir: filepath.Join(lp.dir, "tsnet"),
|
||||
Hostname: hostname,
|
||||
UserLogf: log.Printf,
|
||||
Ephemeral: false,
|
||||
}
|
||||
lp.tsnet = ts
|
||||
ts.PreStart = func() error {
|
||||
dnsMgr := ts.Sys().DNSManager.Get()
|
||||
dnsMgr.SetForceAAAA(true)
|
||||
|
||||
// Force fallback resolvers to Google and Cloudflare as an ultimate
|
||||
// fallback in case the Tailnet DNS servers are not set/forced. Normally
|
||||
// tailscaled would resort to using the OS DNS resolvers, but
|
||||
// tsnet/userspace binaries don't do that (yet?), so this is the
|
||||
// "Opionated" part of the "LOPOWER" name. The opinion is just using
|
||||
// big providers known to work. (Normally stock tailscaled never
|
||||
// makes such opinions and never defaults to any big provider, unless
|
||||
// you're already running on that big provider's network so have
|
||||
// already indicated you're fine with them.))
|
||||
dnsMgr.SetForceFallbackResolvers([]*dnstype.Resolver{
|
||||
{Addr: "8.8.8.8"},
|
||||
{Addr: "1.1.1.1"},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := ts.Up(ctx); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// filteredDNSQuery wraps the MagicDNS server response but filters out A record responses
|
||||
// for *.ts.net if IPv4 is not enabled. This is so the e.g. a phone on a CGNAT-using
|
||||
// network doesn't prefer the "A" record over AAAA when dialing and dial into the
|
||||
// the carrier's CGNAT range into of the AAAA record into the Tailscale IPv6 ULA range.
|
||||
func (lp *lpServer) filteredDNSQuery(ctx context.Context, q []byte, family string, from netip.AddrPort) ([]byte, error) {
|
||||
m, ok := lp.tsnet.Sys().DNSManager.GetOK()
|
||||
if !ok {
|
||||
return nil, errors.New("DNSManager not ready")
|
||||
}
|
||||
origRes, err := m.Query(ctx, q, family, from)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if *includeV4 {
|
||||
return origRes, nil
|
||||
}
|
||||
|
||||
// Filter out *.ts.net A records.
|
||||
|
||||
var msg dnsmessage.Message
|
||||
if err := msg.Unpack(origRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newAnswers := msg.Answers[:0]
|
||||
for _, a := range msg.Answers {
|
||||
name := a.Header.Name.String()
|
||||
if a.Header.Type == dnsmessage.TypeA && strings.HasSuffix(name, ".ts.net.") {
|
||||
// Drop.
|
||||
continue
|
||||
}
|
||||
newAnswers = append(newAnswers, a)
|
||||
}
|
||||
|
||||
if len(newAnswers) == len(msg.Answers) {
|
||||
// Nothing was filtered. No need to reencode it.
|
||||
return origRes, nil
|
||||
}
|
||||
|
||||
msg.Answers = newAnswers
|
||||
return msg.Pack()
|
||||
}
|
||||
|
||||
func (lp *lpServer) handleTCPDNSQuery(c net.Conn, src netip.AddrPort) {
|
||||
defer c.Close()
|
||||
var lenBuf [2]byte
|
||||
for {
|
||||
c.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
_, err := io.ReadFull(c, lenBuf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n := binary.BigEndian.Uint16(lenBuf[:])
|
||||
buf := make([]byte, n)
|
||||
c.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
_, err = io.ReadFull(c, buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res, err := lp.filteredDNSQuery(context.Background(), buf, "tcp", src)
|
||||
if err != nil {
|
||||
log.Printf("TCP DNS query error: %v", err)
|
||||
return
|
||||
}
|
||||
binary.BigEndian.PutUint16(lenBuf[:], uint16(len(res)))
|
||||
c.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
_, err = c.Write(lenBuf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
_, err = c.Write(res)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// caller owns the raw memory.
|
||||
func (lp *lpServer) handleDNSUDPQuery(raw []byte) {
|
||||
var pkt packet.Parsed
|
||||
pkt.Decode(raw)
|
||||
if pkt.IPProto != ipproto.UDP || pkt.Dst.Port() != 53 || !lp.c.IsLocalIP(pkt.Dst.Addr()) {
|
||||
panic("caller error")
|
||||
}
|
||||
|
||||
dnsRes, err := lp.filteredDNSQuery(context.Background(), pkt.Payload(), "udp", pkt.Src)
|
||||
if err != nil {
|
||||
log.Printf("DNS query error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ipLayer := mkIPLayer(layers.IPProtocolUDP, pkt.Dst.Addr(), pkt.Src.Addr())
|
||||
udpLayer := &layers.UDP{
|
||||
SrcPort: 53,
|
||||
DstPort: layers.UDPPort(pkt.Src.Port()),
|
||||
}
|
||||
|
||||
resPkt, err := mkPacket(ipLayer, udpLayer, gopacket.Payload(dnsRes))
|
||||
if err != nil {
|
||||
log.Printf("mkPacket: %v", err)
|
||||
return
|
||||
}
|
||||
pktBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(resPkt),
|
||||
})
|
||||
select {
|
||||
case lp.readCh <- pktBuf:
|
||||
case <-lp.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
type serializableNetworkLayer interface {
|
||||
gopacket.SerializableLayer
|
||||
gopacket.NetworkLayer
|
||||
}
|
||||
|
||||
func mkIPLayer(proto layers.IPProtocol, src, dst netip.Addr) serializableNetworkLayer {
|
||||
if src.Is4() {
|
||||
return &layers.IPv4{
|
||||
Protocol: proto,
|
||||
SrcIP: src.AsSlice(),
|
||||
DstIP: dst.AsSlice(),
|
||||
}
|
||||
}
|
||||
if src.Is6() {
|
||||
return &layers.IPv6{
|
||||
NextHeader: proto,
|
||||
SrcIP: src.AsSlice(),
|
||||
DstIP: dst.AsSlice(),
|
||||
}
|
||||
}
|
||||
panic("invalid src IP")
|
||||
}
|
||||
|
||||
// mkPacket is a serializes a number of layers into a packet.
|
||||
//
|
||||
// It's a convenience wrapper around gopacket.SerializeLayers
|
||||
// that does some things automatically:
|
||||
//
|
||||
// * layers.IPv4/IPv6 Version is set to 4/6 if not already set
|
||||
// * layers.IPv4/IPv6 TTL/HopLimit is set to 64 if not already set
|
||||
// * the TCP/UDP/ICMPv6 checksum is set based on the network layer
|
||||
//
|
||||
// The provided layers in ll must be sorted from lowest (e.g. *layers.Ethernet)
|
||||
// to highest. (Depending on the need, the first layer will be either *layers.Ethernet
|
||||
// or *layers.IPv4/IPv6).
|
||||
func mkPacket(ll ...gopacket.SerializableLayer) ([]byte, error) {
|
||||
var nl gopacket.NetworkLayer
|
||||
for _, la := range ll {
|
||||
switch la := la.(type) {
|
||||
case *layers.IPv4:
|
||||
nl = la
|
||||
if la.Version == 0 {
|
||||
la.Version = 4
|
||||
}
|
||||
if la.TTL == 0 {
|
||||
la.TTL = 64
|
||||
}
|
||||
case *layers.IPv6:
|
||||
nl = la
|
||||
if la.Version == 0 {
|
||||
la.Version = 6
|
||||
}
|
||||
if la.HopLimit == 0 {
|
||||
la.HopLimit = 64
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, la := range ll {
|
||||
switch la := la.(type) {
|
||||
case *layers.TCP:
|
||||
la.SetNetworkLayerForChecksum(nl)
|
||||
case *layers.UDP:
|
||||
la.SetNetworkLayerForChecksum(nl)
|
||||
case *layers.ICMPv6:
|
||||
la.SetNetworkLayerForChecksum(nl)
|
||||
}
|
||||
}
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}
|
||||
if err := gopacket.SerializeLayers(buf, opts, ll...); err != nil {
|
||||
return nil, fmt.Errorf("serializing packet: %v", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
log.Printf("lopower starting")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
lp := newLP(ctx)
|
||||
|
||||
if *qrListenAddr != "" {
|
||||
go lp.serveQR()
|
||||
}
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, unix.SIGTERM, os.Interrupt)
|
||||
<-sigCh
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 114 KiB |
@@ -185,12 +185,10 @@ change in the future.
|
||||
logoutCmd,
|
||||
switchCmd,
|
||||
configureCmd,
|
||||
syspolicyCmd,
|
||||
netcheckCmd,
|
||||
ipCmd,
|
||||
dnsCmd,
|
||||
statusCmd,
|
||||
metricsCmd,
|
||||
pingCmd,
|
||||
ncCmd,
|
||||
sshCmd,
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/peterbourgon/ff/v3/ffcli"
|
||||
"tailscale.com/atomicfile"
|
||||
)
|
||||
|
||||
var metricsCmd = &ffcli.Command{
|
||||
Name: "metrics",
|
||||
ShortHelp: "Show Tailscale metrics",
|
||||
LongHelp: strings.TrimSpace(`
|
||||
|
||||
The 'tailscale metrics' command shows Tailscale user-facing metrics (as opposed
|
||||
to internal metrics printed by 'tailscale debug metrics').
|
||||
|
||||
For more information about Tailscale metrics, refer to
|
||||
https://tailscale.com/s/client-metrics
|
||||
|
||||
`),
|
||||
ShortUsage: "tailscale metrics <subcommand> [flags]",
|
||||
UsageFunc: usageFuncNoDefaultValues,
|
||||
Exec: runMetricsNoSubcommand,
|
||||
Subcommands: []*ffcli.Command{
|
||||
{
|
||||
Name: "print",
|
||||
ShortUsage: "tailscale metrics print",
|
||||
Exec: runMetricsPrint,
|
||||
ShortHelp: "Prints current metric values in the Prometheus text exposition format",
|
||||
},
|
||||
{
|
||||
Name: "write",
|
||||
ShortUsage: "tailscale metrics write <path>",
|
||||
Exec: runMetricsWrite,
|
||||
ShortHelp: "Writes metric values to a file",
|
||||
LongHelp: strings.TrimSpace(`
|
||||
|
||||
The 'tailscale metrics write' command writes metric values to a text file provided as its
|
||||
only argument. It's meant to be used alongside Prometheus node exporter, allowing Tailscale
|
||||
metrics to be consumed and exported by the textfile collector.
|
||||
|
||||
As an example, to export Tailscale metrics on an Ubuntu system running node exporter, you
|
||||
can regularly run 'tailscale metrics write /var/lib/prometheus/node-exporter/tailscaled.prom'
|
||||
using cron or a systemd timer.
|
||||
|
||||
`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// runMetricsNoSubcommand prints metric values if no subcommand is specified.
|
||||
func runMetricsNoSubcommand(ctx context.Context, args []string) error {
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("tailscale metrics: unknown subcommand: %s", args[0])
|
||||
}
|
||||
|
||||
return runMetricsPrint(ctx, args)
|
||||
}
|
||||
|
||||
// runMetricsPrint prints metric values to stdout.
|
||||
func runMetricsPrint(ctx context.Context, args []string) error {
|
||||
out, err := localClient.UserMetrics(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
Stdout.Write(out)
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMetricsWrite writes metric values to a file.
|
||||
func runMetricsWrite(ctx context.Context, args []string) error {
|
||||
if len(args) != 1 {
|
||||
return errors.New("usage: tailscale metrics write <path>")
|
||||
}
|
||||
path := args[0]
|
||||
out, err := localClient.UserMetrics(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return atomicfile.WriteFile(path, out, 0644)
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/peterbourgon/ff/v3/ffcli"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
var syspolicyArgs struct {
|
||||
json bool // JSON output mode
|
||||
}
|
||||
|
||||
var syspolicyCmd = &ffcli.Command{
|
||||
Name: "syspolicy",
|
||||
ShortHelp: "Diagnose the MDM and system policy configuration",
|
||||
LongHelp: "The 'tailscale syspolicy' command provides tools for diagnosing the MDM and system policy configuration.",
|
||||
ShortUsage: "tailscale syspolicy <subcommand>",
|
||||
UsageFunc: usageFuncNoDefaultValues,
|
||||
Subcommands: []*ffcli.Command{
|
||||
{
|
||||
Name: "list",
|
||||
ShortUsage: "tailscale syspolicy list",
|
||||
Exec: runSysPolicyList,
|
||||
ShortHelp: "Prints effective policy settings",
|
||||
LongHelp: "The 'tailscale syspolicy list' subcommand displays the effective policy settings and their sources (e.g., MDM or environment variables).",
|
||||
FlagSet: (func() *flag.FlagSet {
|
||||
fs := newFlagSet("syspolicy list")
|
||||
fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format")
|
||||
return fs
|
||||
})(),
|
||||
},
|
||||
{
|
||||
Name: "reload",
|
||||
ShortUsage: "tailscale syspolicy reload",
|
||||
Exec: runSysPolicyReload,
|
||||
ShortHelp: "Forces a reload of policy settings, even if no changes are detected, and prints the result",
|
||||
LongHelp: "The 'tailscale syspolicy reload' subcommand forces a reload of policy settings, even if no changes are detected, and prints the result.",
|
||||
FlagSet: (func() *flag.FlagSet {
|
||||
fs := newFlagSet("syspolicy reload")
|
||||
fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format")
|
||||
return fs
|
||||
})(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func runSysPolicyList(ctx context.Context, args []string) error {
|
||||
policy, err := localClient.GetEffectivePolicy(ctx, setting.DefaultScope())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
printPolicySettings(policy)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func runSysPolicyReload(ctx context.Context, args []string) error {
|
||||
policy, err := localClient.ReloadEffectivePolicy(ctx, setting.DefaultScope())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
printPolicySettings(policy)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printPolicySettings(policy *setting.Snapshot) {
|
||||
if syspolicyArgs.json {
|
||||
json, err := json.MarshalIndent(policy, "", "\t")
|
||||
if err != nil {
|
||||
errf("syspolicy marshalling error: %v", err)
|
||||
} else {
|
||||
outln(string(json))
|
||||
}
|
||||
return
|
||||
}
|
||||
if policy.Len() == 0 {
|
||||
outln("No policy settings")
|
||||
return
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "Name\tOrigin\tValue\tError")
|
||||
fmt.Fprintln(w, "----\t------\t-----\t-----")
|
||||
for _, k := range slices.Sorted(policy.Keys()) {
|
||||
setting, _ := policy.GetSetting(k)
|
||||
var origin string
|
||||
if o := setting.Origin(); o != nil {
|
||||
origin = o.String()
|
||||
}
|
||||
if err := setting.Error(); err != nil {
|
||||
fmt.Fprintf(w, "%s\t%s\t\t{%s}\n", k, origin, err)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\t\n", k, origin, setting.Value())
|
||||
}
|
||||
}
|
||||
w.Flush()
|
||||
|
||||
fmt.Println()
|
||||
return
|
||||
}
|
||||
@@ -174,18 +174,14 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
|
||||
tailscale.com/util/syspolicy from tailscale.com/ipn
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli
|
||||
tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli
|
||||
tailscale.com/util/usermetric from tailscale.com/health
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
W 💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/version from tailscale.com/client/web+
|
||||
tailscale.com/version/distro from tailscale.com/client/web+
|
||||
|
||||
@@ -250,6 +250,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/control/controlclient from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/control/controlhttp from tailscale.com/control/controlclient
|
||||
tailscale.com/control/controlknobs from tailscale.com/control/controlclient+
|
||||
tailscale.com/control/keyfallback from tailscale.com/control/controlclient
|
||||
tailscale.com/derp from tailscale.com/derp/derphttp+
|
||||
tailscale.com/derp/derphttp from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/disco from tailscale.com/derp+
|
||||
@@ -401,11 +402,8 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
|
||||
tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/util/systemd from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+
|
||||
@@ -415,7 +413,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
|
||||
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+
|
||||
|
||||
@@ -788,6 +788,7 @@ func runDebugServer(mux *http.ServeMux, addr string) {
|
||||
}
|
||||
|
||||
func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) {
|
||||
tfs, _ := sys.DriveForLocal.GetOK()
|
||||
ret, err := netstack.Create(logf,
|
||||
sys.Tun.Get(),
|
||||
sys.Engine.Get(),
|
||||
@@ -795,6 +796,7 @@ func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) {
|
||||
sys.Dialer.Get(),
|
||||
sys.DNSManager.Get(),
|
||||
sys.ProxyMapper(),
|
||||
tfs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -42,7 +42,6 @@ type testAttempt struct {
|
||||
testName string // "TestFoo"
|
||||
outcome string // "pass", "fail", "skip"
|
||||
logs bytes.Buffer
|
||||
start, end time.Time
|
||||
isMarkedFlaky bool // set if the test is marked as flaky
|
||||
issueURL string // set if the test is marked as flaky
|
||||
|
||||
@@ -133,17 +132,11 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te
|
||||
}
|
||||
pkg := goOutput.Package
|
||||
pkgTests := resultMap[pkg]
|
||||
if pkgTests == nil {
|
||||
pkgTests = make(map[string]*testAttempt)
|
||||
resultMap[pkg] = pkgTests
|
||||
}
|
||||
if goOutput.Test == "" {
|
||||
switch goOutput.Action {
|
||||
case "start":
|
||||
pkgTests[""] = &testAttempt{start: goOutput.Time}
|
||||
case "fail", "pass", "skip":
|
||||
for _, test := range pkgTests {
|
||||
if test.testName != "" && test.outcome == "" {
|
||||
if test.outcome == "" {
|
||||
test.outcome = "fail"
|
||||
ch <- test
|
||||
}
|
||||
@@ -151,13 +144,15 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te
|
||||
ch <- &testAttempt{
|
||||
pkg: goOutput.Package,
|
||||
outcome: goOutput.Action,
|
||||
start: pkgTests[""].start,
|
||||
end: goOutput.Time,
|
||||
pkgFinished: true,
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if pkgTests == nil {
|
||||
pkgTests = make(map[string]*testAttempt)
|
||||
resultMap[pkg] = pkgTests
|
||||
}
|
||||
testName := goOutput.Test
|
||||
if test, _, isSubtest := strings.Cut(goOutput.Test, "/"); isSubtest {
|
||||
testName = test
|
||||
@@ -173,10 +168,8 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te
|
||||
pkgTests[testName] = &testAttempt{
|
||||
pkg: pkg,
|
||||
testName: testName,
|
||||
start: goOutput.Time,
|
||||
}
|
||||
case "skip", "pass", "fail":
|
||||
pkgTests[testName].end = goOutput.Time
|
||||
pkgTests[testName].outcome = goOutput.Action
|
||||
ch <- pkgTests[testName]
|
||||
case "output":
|
||||
@@ -220,7 +213,7 @@ func main() {
|
||||
firstRun.tests = append(firstRun.tests, &packageTests{Pattern: pkg})
|
||||
}
|
||||
toRun := []*nextRun{firstRun}
|
||||
printPkgOutcome := func(pkg, outcome string, attempt int, runtime time.Duration) {
|
||||
printPkgOutcome := func(pkg, outcome string, attempt int) {
|
||||
if outcome == "skip" {
|
||||
fmt.Printf("?\t%s [skipped/no tests] \n", pkg)
|
||||
return
|
||||
@@ -232,10 +225,10 @@ func main() {
|
||||
outcome = "FAIL"
|
||||
}
|
||||
if attempt > 1 {
|
||||
fmt.Printf("%s\t%s\t%.3fs\t[attempt=%d]\n", outcome, pkg, runtime.Seconds(), attempt)
|
||||
fmt.Printf("%s\t%s [attempt=%d]\n", outcome, pkg, attempt)
|
||||
return
|
||||
}
|
||||
fmt.Printf("%s\t%s\t%.3fs\n", outcome, pkg, runtime.Seconds())
|
||||
fmt.Printf("%s\t%s\n", outcome, pkg)
|
||||
}
|
||||
|
||||
// Check for -coverprofile argument and filter it out
|
||||
@@ -314,7 +307,7 @@ func main() {
|
||||
// when a package times out.
|
||||
failed = true
|
||||
}
|
||||
printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt, tr.end.Sub(tr.start))
|
||||
printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt)
|
||||
continue
|
||||
}
|
||||
if testingVerbose || tr.outcome == "fail" {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
@@ -77,10 +76,7 @@ func TestFlakeRun(t *testing.T) {
|
||||
t.Fatalf("go run . %s: %s with output:\n%s", testfile, err, out)
|
||||
}
|
||||
|
||||
// Replace the unpredictable timestamp with "0.00s".
|
||||
out = regexp.MustCompile(`\t\d+\.\d\d\ds\t`).ReplaceAll(out, []byte("\t0.00s\t"))
|
||||
|
||||
want := []byte("ok\t" + testfile + "\t0.00s\t[attempt=2]")
|
||||
want := []byte("ok\t" + testfile + " [attempt=2]")
|
||||
if !bytes.Contains(out, want) {
|
||||
t.Fatalf("wanted output containing %q but got:\n%s", want, out)
|
||||
}
|
||||
|
||||
@@ -150,7 +150,6 @@ func runEsbuildServe(buildOptions esbuild.BuildOptions) {
|
||||
log.Fatalf("Cannot start esbuild server: %v", err)
|
||||
}
|
||||
log.Printf("Listening on http://%s:%d\n", result.Host, result.Port)
|
||||
select {}
|
||||
}
|
||||
|
||||
func runEsbuild(buildOptions esbuild.BuildOptions) esbuild.BuildResult {
|
||||
|
||||
@@ -115,7 +115,7 @@ func newIPN(jsConfig js.Value) map[string]any {
|
||||
}
|
||||
sys.Set(eng)
|
||||
|
||||
ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper())
|
||||
ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
|
||||
if err != nil {
|
||||
log.Fatalf("netstack.Create: %v", err)
|
||||
}
|
||||
|
||||
@@ -29,9 +29,11 @@ import (
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/control/keyfallback"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/hostinfo"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/logtail"
|
||||
"tailscale.com/net/dnscache"
|
||||
@@ -87,9 +89,10 @@ type Direct struct {
|
||||
|
||||
dialPlan ControlDialPlanner // can be nil
|
||||
|
||||
mu sync.Mutex // mutex guards the following fields
|
||||
serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now
|
||||
serverNoiseKey key.MachinePublic
|
||||
mu sync.Mutex // mutex guards the following fields
|
||||
serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now
|
||||
serverNoiseKey key.MachinePublic
|
||||
usedFallbackNoiseKey bool // true if we used the baked-in fallback key
|
||||
|
||||
sfGroup singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation.
|
||||
noiseClient *NoiseClient
|
||||
@@ -498,6 +501,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
|
||||
tryingNewKey := c.tryingNewKey
|
||||
serverKey := c.serverLegacyKey
|
||||
serverNoiseKey := c.serverNoiseKey
|
||||
usedFallback := c.usedFallbackNoiseKey
|
||||
authKey, isWrapped, wrappedSig, wrappedKey := tka.DecodeWrappedAuthkey(c.authKey, c.logf)
|
||||
hi := c.hostInfoLocked()
|
||||
backendLogID := hi.BackendLogID
|
||||
@@ -528,7 +532,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
|
||||
}
|
||||
|
||||
c.logf("doLogin(regen=%v, hasUrl=%v)", regen, opt.URL != "")
|
||||
if serverKey.IsZero() {
|
||||
if serverKey.IsZero() || usedFallback {
|
||||
keys, err := loadServerPubKeys(ctx, c.httpc, c.serverURL)
|
||||
if err != nil && c.interceptedDial != nil && c.interceptedDial.Load() {
|
||||
c.health.SetUnhealthy(macOSScreenTime, nil)
|
||||
@@ -536,13 +540,21 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
|
||||
c.health.SetHealthy(macOSScreenTime)
|
||||
}
|
||||
if err != nil {
|
||||
return regen, opt.URL, nil, err
|
||||
if k2, err := c.getFallbackServerPubKeys(); err == nil {
|
||||
keys = k2
|
||||
usedFallback = true
|
||||
} else {
|
||||
return regen, opt.URL, nil, err
|
||||
}
|
||||
} else {
|
||||
usedFallback = false
|
||||
c.logf("control server key from %s: ts2021=%s", c.serverURL, keys.PublicKey.ShortString())
|
||||
}
|
||||
c.logf("control server key from %s: ts2021=%s, legacy=%v", c.serverURL, keys.PublicKey.ShortString(), keys.LegacyPublicKey.ShortString())
|
||||
|
||||
c.mu.Lock()
|
||||
c.serverLegacyKey = keys.LegacyPublicKey
|
||||
c.serverNoiseKey = keys.PublicKey
|
||||
c.usedFallbackNoiseKey = usedFallback
|
||||
c.mu.Unlock()
|
||||
serverKey = keys.LegacyPublicKey
|
||||
serverNoiseKey = keys.PublicKey
|
||||
@@ -751,6 +763,22 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
|
||||
return false, resp.AuthURL, nil, nil
|
||||
}
|
||||
|
||||
func (c *Direct) getFallbackServerPubKeys() (*tailcfg.OverTLSPublicKeyResponse, error) {
|
||||
// If we saw an error, try to use the fallback key if
|
||||
// we're dialing the default control server.
|
||||
if ipn.IsLoginServerSynonym(c.serverURL) {
|
||||
return nil, errors.New("not using default control server")
|
||||
}
|
||||
|
||||
kf, err := keyfallback.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logf("using fallback server key: ts2021=%s", kf.PublicKey.ShortString())
|
||||
return kf, nil
|
||||
}
|
||||
|
||||
// newEndpoints acquires c.mu and sets the local port and endpoints and reports
|
||||
// whether they've changed.
|
||||
//
|
||||
|
||||
4
control/keyfallback/control-key.json
Normal file
4
control/keyfallback/control-key.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"legacyPublicKey": "mkey:9e5156a4c65121306dd2d8ed8f92cb8d738e2533011344b522c5d28409bc4970",
|
||||
"publicKey": "mkey:7d2792f9c98d753d2042471536801949104c247f95eac770f8fb321595e2173b"
|
||||
}
|
||||
32
control/keyfallback/keyfallback.go
Normal file
32
control/keyfallback/keyfallback.go
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package keyfallback contains a fallback mechanism for starting up Tailscale
|
||||
// when the control server cannot be reached to obtain the primary Noise key.
|
||||
//
|
||||
// The data is backed by a JSON file `control-key.json` that is updated by
|
||||
// `update.go`:
|
||||
//
|
||||
// (cd control/keyfallback; go run update.go)
|
||||
package keyfallback
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Get returns the fallback control server public key that was baked into the
|
||||
// binary at compile time. It is only valid for the main Tailscale control
|
||||
// server instance.
|
||||
func Get() (*tailcfg.OverTLSPublicKeyResponse, error) {
|
||||
out := &tailcfg.OverTLSPublicKeyResponse{}
|
||||
if err := json.Unmarshal(controlKeyJSON, out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
//go:embed control-key.json
|
||||
var controlKeyJSON []byte
|
||||
77
control/keyfallback/keyfallback_test.go
Normal file
77
control/keyfallback/keyfallback_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package keyfallback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest/nettest"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
func TestHasValidControlKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
keys, err := Get()
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if keys.PublicKey.IsZero() {
|
||||
t.Fatalf("zero key")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeyIsUpToDate fetches the control key from the control server and
|
||||
// compares it to the baked-in key, to verify that it's up-to-date. If the
|
||||
// control server is unreachable, the test is skipped.
|
||||
func TestKeyIsUpToDate(t *testing.T) {
|
||||
nettest.SkipIfNoNetwork(t)
|
||||
|
||||
// Optimistically fetch the control key and check if it's up to date,
|
||||
// but ignore if we don't have network access (e.g. running tests on an
|
||||
// airplane).
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
keyURL := fmt.Sprintf("%v/key?v=%d", ipn.DefaultControlURL, tailcfg.CurrentCapabilityVersion)
|
||||
req := must.Get(http.NewRequestWithContext(ctx, "GET", keyURL, nil))
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Logf("fetch control key: %v", err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
t.Fatalf("fetch control key: bad status; got %v, want 200", res.Status)
|
||||
}
|
||||
b, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read control key: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the key is up to date and matches the baked-in key.
|
||||
out := &tailcfg.OverTLSPublicKeyResponse{}
|
||||
if err := json.Unmarshal(b, out); err != nil {
|
||||
t.Fatalf("unmarshal control key: %v", err)
|
||||
}
|
||||
|
||||
keys, err := Get()
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(keys, out) {
|
||||
t.Errorf("control key is out of date")
|
||||
t.Logf("old key: %v", keys)
|
||||
t.Logf("new key: %v", out)
|
||||
}
|
||||
}
|
||||
47
control/keyfallback/update.go
Normal file
47
control/keyfallback/update.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func main() {
|
||||
keyURL := fmt.Sprintf("%v/key?v=%d", ipn.DefaultControlURL, tailcfg.CurrentCapabilityVersion)
|
||||
res, err := http.Get(keyURL)
|
||||
if err != nil {
|
||||
log.Fatalf("fetch control key: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
b, err := io.ReadAll(io.LimitReader(res.Body, 64<<10))
|
||||
if err != nil {
|
||||
log.Fatalf("read control key: %v", err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
log.Fatalf("fetch control key: bad status; got %v, want 200", res.Status)
|
||||
}
|
||||
|
||||
// Unmarshal to make sure it's valid.
|
||||
var out tailcfg.OverTLSPublicKeyResponse
|
||||
if err := json.Unmarshal(b, &out); err != nil {
|
||||
log.Fatalf("unmarshal control key: %v", err)
|
||||
}
|
||||
if out.PublicKey.IsZero() {
|
||||
log.Fatalf("control key is zero")
|
||||
}
|
||||
|
||||
if err := os.WriteFile("control-key.json", b, 0644); err != nil {
|
||||
log.Fatalf("write control key: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -32,8 +32,6 @@ type ConfigVAlpha struct {
|
||||
AdvertiseRoutes []netip.Prefix `json:",omitempty"`
|
||||
DisableSNAT opt.Bool `json:",omitempty"`
|
||||
|
||||
AppConnector *AppConnectorPrefs `json:",omitempty"` // advertise app connector; defaults to false (if nil or explicitly set to false)
|
||||
|
||||
NetfilterMode *string `json:",omitempty"` // "on", "off", "nodivert"
|
||||
NoStatefulFiltering opt.Bool `json:",omitempty"`
|
||||
|
||||
@@ -139,9 +137,5 @@ func (c *ConfigVAlpha) ToPrefs() (MaskedPrefs, error) {
|
||||
mp.AutoUpdate = *c.AutoUpdate
|
||||
mp.AutoUpdateSet = AutoUpdatePrefsMask{ApplySet: true, CheckSet: true}
|
||||
}
|
||||
if c.AppConnector != nil {
|
||||
mp.AppConnector = *c.AppConnector
|
||||
mp.AppConnectorSet = true
|
||||
}
|
||||
return mp, nil
|
||||
}
|
||||
|
||||
@@ -332,10 +332,12 @@ func handleC2NPostureIdentityGet(b *LocalBackend, w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
if choice.ShouldEnable(b.Prefs().PostureChecking()) {
|
||||
res.SerialNumbers, err = posture.GetSerialNumbers(b.logf)
|
||||
sns, err := posture.GetSerialNumbers(b.logf)
|
||||
if err != nil {
|
||||
b.logf("c2n: GetSerialNumbers returned error: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
res.SerialNumbers = sns
|
||||
|
||||
// TODO(tailscale/corp#21371, 2024-07-10): once this has landed in a stable release
|
||||
// and looks good in client metrics, remove this parameter and always report MAC
|
||||
|
||||
@@ -399,6 +399,11 @@ type metrics struct {
|
||||
// approvedRoutes is a metric that reports the number of network routes served by the local node and approved
|
||||
// by the control server.
|
||||
approvedRoutes *usermetric.Gauge
|
||||
|
||||
// primaryRoutes is a metric that reports the number of primary network routes served by the local node.
|
||||
// A route being a primary route implies that the route is currently served by this node, and not by another
|
||||
// subnet router in a high availability configuration.
|
||||
primaryRoutes *usermetric.Gauge
|
||||
}
|
||||
|
||||
// clientGen is a func that creates a control plane client.
|
||||
@@ -449,6 +454,8 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
|
||||
"tailscaled_advertised_routes", "Number of advertised network routes (e.g. by a subnet router)"),
|
||||
approvedRoutes: sys.UserMetricsRegistry().NewGauge(
|
||||
"tailscaled_approved_routes", "Number of approved network routes (e.g. by a subnet router)"),
|
||||
primaryRoutes: sys.UserMetricsRegistry().NewGauge(
|
||||
"tailscaled_primary_routes", "Number of network routes for which this node is a primary router (in high availability configuration)"),
|
||||
}
|
||||
|
||||
b := &LocalBackend{
|
||||
@@ -479,7 +486,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
|
||||
mConn.SetNetInfoCallback(b.setNetInfo)
|
||||
|
||||
if sys.InitialConfig != nil {
|
||||
if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil {
|
||||
if err := b.setConfigLocked(sys.InitialConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -712,8 +719,8 @@ func (b *LocalBackend) SetDirectFileRoot(dir string) {
|
||||
// It returns (false, nil) if not running in declarative mode, (true, nil) on
|
||||
// success, or (false, error) on failure.
|
||||
func (b *LocalBackend) ReloadConfig() (ok bool, err error) {
|
||||
unlock := b.lockAndGetUnlock()
|
||||
defer unlock()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.conf == nil {
|
||||
return false, nil
|
||||
}
|
||||
@@ -721,21 +728,18 @@ func (b *LocalBackend) ReloadConfig() (ok bool, err error) {
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := b.setConfigLockedOnEntry(conf, unlock); err != nil {
|
||||
if err := b.setConfigLocked(conf); err != nil {
|
||||
return false, fmt.Errorf("error setting config: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// initPrefsFromConfig initializes the backend's prefs from the provided config.
|
||||
// This should only be called once, at startup. For updates at runtime, use
|
||||
// [LocalBackend.setConfigLocked].
|
||||
func (b *LocalBackend) initPrefsFromConfig(conf *conffile.Config) error {
|
||||
// TODO(maisem,bradfitz): combine this with setConfigLocked. This is called
|
||||
// before anything is running, so there's no need to lock and we don't
|
||||
// update any subsystems. At runtime, we both need to lock and update
|
||||
// subsystems with the new prefs.
|
||||
func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error {
|
||||
|
||||
// TODO(irbekrm): notify the relevant components to consume any prefs
|
||||
// updates. Currently only initial configfile settings are applied
|
||||
// immediately.
|
||||
p := b.pm.CurrentPrefs().AsStruct()
|
||||
mp, err := conf.Parsed.ToPrefs()
|
||||
if err != nil {
|
||||
@@ -745,14 +749,13 @@ func (b *LocalBackend) initPrefsFromConfig(conf *conffile.Config) error {
|
||||
if err := b.pm.SetPrefs(p.View(), ipn.NetworkProfile{}); err != nil {
|
||||
return err
|
||||
}
|
||||
b.setStaticEndpointsFromConfigLocked(conf)
|
||||
b.conf = conf
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LocalBackend) setStaticEndpointsFromConfigLocked(conf *conffile.Config) {
|
||||
defer func() {
|
||||
b.conf = conf
|
||||
}()
|
||||
|
||||
if conf.Parsed.StaticEndpoints == nil && (b.conf == nil || b.conf.Parsed.StaticEndpoints == nil) {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure that magicsock conn has the up to date static wireguard
|
||||
@@ -766,22 +769,6 @@ func (b *LocalBackend) setStaticEndpointsFromConfigLocked(conf *conffile.Config)
|
||||
ms.SetStaticEndpoints(views.SliceOf(conf.Parsed.StaticEndpoints))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setConfigLockedOnEntry uses the provided config to update the backend's prefs
|
||||
// and other state.
|
||||
func (b *LocalBackend) setConfigLockedOnEntry(conf *conffile.Config, unlock unlockOnce) error {
|
||||
defer unlock()
|
||||
p := b.pm.CurrentPrefs().AsStruct()
|
||||
mp, err := conf.Parsed.ToPrefs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing config to prefs: %w", err)
|
||||
}
|
||||
p.ApplyEdits(&mp)
|
||||
b.setStaticEndpointsFromConfigLocked(conf)
|
||||
b.setPrefsLockedOnEntry(p, unlock)
|
||||
|
||||
b.conf = conf
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4194,11 +4181,7 @@ func (b *LocalBackend) authReconfig() {
|
||||
disableSubnetsIfPAC := nm.HasCap(tailcfg.NodeAttrDisableSubnetsIfPAC)
|
||||
userDialUseRoutes := nm.HasCap(tailcfg.NodeAttrUserDialUseRoutes)
|
||||
dohURL, dohURLOK := exitNodeCanProxyDNS(nm, b.peers, prefs.ExitNodeID())
|
||||
var forceAAAA bool
|
||||
if dm, ok := b.sys.DNSManager.GetOK(); ok {
|
||||
forceAAAA = dm.GetForceAAAA()
|
||||
}
|
||||
dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, forceAAAA, b.logf, version.OS())
|
||||
dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, b.logf, version.OS())
|
||||
// If the current node is an app connector, ensure the app connector machine is started
|
||||
b.reconfigAppConnectorLocked(nm, prefs)
|
||||
b.mu.Unlock()
|
||||
@@ -4298,7 +4281,7 @@ func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs,
|
||||
//
|
||||
// The versionOS is a Tailscale-style version ("iOS", "macOS") and not
|
||||
// a runtime.GOOS.
|
||||
func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired, forceAAAA bool, logf logger.Logf, versionOS string) *dns.Config {
|
||||
func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config {
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -4361,7 +4344,7 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.
|
||||
// https://github.com/tailscale/tailscale/issues/1152
|
||||
// tracks adding the right capability reporting to
|
||||
// enable AAAA in MagicDNS.
|
||||
if addr.Addr().Is6() && have4 && !forceAAAA {
|
||||
if addr.Addr().Is6() && have4 {
|
||||
continue
|
||||
}
|
||||
ips = append(ips, addr.Addr())
|
||||
@@ -5494,6 +5477,7 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) {
|
||||
// If there is no netmap, the client is going into a "turned off"
|
||||
// state so reset the metrics.
|
||||
b.metrics.approvedRoutes.Set(0)
|
||||
b.metrics.primaryRoutes.Set(0)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -5522,6 +5506,7 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) {
|
||||
}
|
||||
}
|
||||
b.metrics.approvedRoutes.Set(approved)
|
||||
b.metrics.primaryRoutes.Set(float64(tsaddr.WithoutExitRoute(nm.SelfNode.PrimaryRoutes()).Len()))
|
||||
}
|
||||
for _, p := range nm.Peers {
|
||||
addNode(p)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -33,7 +32,6 @@ import (
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/hostinfo"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/conffile"
|
||||
"tailscale.com/ipn/ipnauth"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
"tailscale.com/net/netcheck"
|
||||
@@ -56,8 +54,6 @@ import (
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/wgcfg"
|
||||
@@ -434,25 +430,16 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
func newTestLocalBackend(t testing.TB) *LocalBackend {
|
||||
return newTestLocalBackendWithSys(t, new(tsd.System))
|
||||
}
|
||||
|
||||
// newTestLocalBackendWithSys creates a new LocalBackend with the given tsd.System.
|
||||
// If the state store or engine are not set in sys, they will be set to a new
|
||||
// in-memory store and fake userspace engine, respectively.
|
||||
func newTestLocalBackendWithSys(t testing.TB, sys *tsd.System) *LocalBackend {
|
||||
var logf logger.Logf = logger.Discard
|
||||
if _, ok := sys.StateStore.GetOK(); !ok {
|
||||
sys.Set(new(mem.Store))
|
||||
}
|
||||
if _, ok := sys.Engine.GetOK(); !ok {
|
||||
eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry())
|
||||
if err != nil {
|
||||
t.Fatalf("NewFakeUserspaceEngine: %v", err)
|
||||
}
|
||||
t.Cleanup(eng.Close)
|
||||
sys.Set(eng)
|
||||
sys := new(tsd.System)
|
||||
store := new(mem.Store)
|
||||
sys.Set(store)
|
||||
eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry())
|
||||
if err != nil {
|
||||
t.Fatalf("NewFakeUserspaceEngine: %v", err)
|
||||
}
|
||||
t.Cleanup(eng.Close)
|
||||
sys.Set(eng)
|
||||
lb, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("NewLocalBackend: %v", err)
|
||||
@@ -1298,7 +1285,7 @@ func TestDNSConfigForNetmapForExitNodeConfigs(t *testing.T) {
|
||||
}
|
||||
|
||||
prefs := &ipn.Prefs{ExitNodeID: tc.exitNode, CorpDNS: true}
|
||||
got := dnsConfigForNetmap(nm, peersMap(tc.peers), prefs.View(), false, false, t.Logf, "")
|
||||
got := dnsConfigForNetmap(nm, peersMap(tc.peers), prefs.View(), false, t.Logf, "")
|
||||
if !resolversEqual(t, got.DefaultResolvers, tc.wantDefaultResolvers) {
|
||||
t.Errorf("DefaultResolvers: got %#v, want %#v", got.DefaultResolvers, tc.wantDefaultResolvers)
|
||||
}
|
||||
@@ -1572,6 +1559,94 @@ func dnsResponse(domain, address string) []byte {
|
||||
return must.Get(b.Finish())
|
||||
}
|
||||
|
||||
type errorSyspolicyHandler struct {
|
||||
t *testing.T
|
||||
err error
|
||||
key syspolicy.Key
|
||||
allowKeys map[syspolicy.Key]*string
|
||||
}
|
||||
|
||||
func (h *errorSyspolicyHandler) ReadString(key string) (string, error) {
|
||||
sk := syspolicy.Key(key)
|
||||
if _, ok := h.allowKeys[sk]; !ok {
|
||||
h.t.Errorf("ReadString: %q is not in list of permitted keys", h.key)
|
||||
}
|
||||
if sk == h.key {
|
||||
return "", h.err
|
||||
}
|
||||
return "", syspolicy.ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (h *errorSyspolicyHandler) ReadUInt64(key string) (uint64, error) {
|
||||
h.t.Errorf("ReadUInt64(%q) unexpectedly called", key)
|
||||
return 0, syspolicy.ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (h *errorSyspolicyHandler) ReadBoolean(key string) (bool, error) {
|
||||
h.t.Errorf("ReadBoolean(%q) unexpectedly called", key)
|
||||
return false, syspolicy.ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (h *errorSyspolicyHandler) ReadStringArray(key string) ([]string, error) {
|
||||
h.t.Errorf("ReadStringArray(%q) unexpectedly called", key)
|
||||
return nil, syspolicy.ErrNoSuchKey
|
||||
}
|
||||
|
||||
type mockSyspolicyHandler struct {
|
||||
t *testing.T
|
||||
// stringPolicies is the collection of policies that we expect to see
|
||||
// queried by the current test. If the policy is expected but unset, then
|
||||
// use nil, otherwise use a string equal to the policy's desired value.
|
||||
stringPolicies map[syspolicy.Key]*string
|
||||
// stringArrayPolicies is the collection of policies that we expected to see
|
||||
// queries by the current test, that return policy string arrays.
|
||||
stringArrayPolicies map[syspolicy.Key][]string
|
||||
// failUnknownPolicies is set if policies other than those in stringPolicies
|
||||
// (uint64 or bool policies are not supported by mockSyspolicyHandler yet)
|
||||
// should be considered a test failure if they are queried.
|
||||
failUnknownPolicies bool
|
||||
}
|
||||
|
||||
func (h *mockSyspolicyHandler) ReadString(key string) (string, error) {
|
||||
if s, ok := h.stringPolicies[syspolicy.Key(key)]; ok {
|
||||
if s == nil {
|
||||
return "", syspolicy.ErrNoSuchKey
|
||||
}
|
||||
return *s, nil
|
||||
}
|
||||
if h.failUnknownPolicies {
|
||||
h.t.Errorf("ReadString(%q) unexpectedly called", key)
|
||||
}
|
||||
return "", syspolicy.ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (h *mockSyspolicyHandler) ReadUInt64(key string) (uint64, error) {
|
||||
if h.failUnknownPolicies {
|
||||
h.t.Errorf("ReadUInt64(%q) unexpectedly called", key)
|
||||
}
|
||||
return 0, syspolicy.ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (h *mockSyspolicyHandler) ReadBoolean(key string) (bool, error) {
|
||||
if h.failUnknownPolicies {
|
||||
h.t.Errorf("ReadBoolean(%q) unexpectedly called", key)
|
||||
}
|
||||
return false, syspolicy.ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (h *mockSyspolicyHandler) ReadStringArray(key string) ([]string, error) {
|
||||
if h.failUnknownPolicies {
|
||||
h.t.Errorf("ReadStringArray(%q) unexpectedly called", key)
|
||||
}
|
||||
if s, ok := h.stringArrayPolicies[syspolicy.Key(key)]; ok {
|
||||
if s == nil {
|
||||
return []string{}, syspolicy.ErrNoSuchKey
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
return nil, syspolicy.ErrNoSuchKey
|
||||
}
|
||||
|
||||
func TestSetExitNodeIDPolicy(t *testing.T) {
|
||||
pfx := netip.MustParsePrefix
|
||||
tests := []struct {
|
||||
@@ -1781,18 +1856,23 @@ func TestSetExitNodeIDPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b := newTestBackend(t)
|
||||
|
||||
policyStore := source.NewTestStoreOf(t,
|
||||
source.TestSettingOf(syspolicy.ExitNodeID, test.exitNodeID),
|
||||
source.TestSettingOf(syspolicy.ExitNodeIP, test.exitNodeIP),
|
||||
)
|
||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
||||
|
||||
msh := &mockSyspolicyHandler{
|
||||
t: t,
|
||||
stringPolicies: map[syspolicy.Key]*string{
|
||||
syspolicy.ExitNodeID: nil,
|
||||
syspolicy.ExitNodeIP: nil,
|
||||
},
|
||||
}
|
||||
if test.exitNodeIDKey {
|
||||
msh.stringPolicies[syspolicy.ExitNodeID] = &test.exitNodeID
|
||||
}
|
||||
if test.exitNodeIPKey {
|
||||
msh.stringPolicies[syspolicy.ExitNodeIP] = &test.exitNodeIP
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, msh)
|
||||
if test.nm == nil {
|
||||
test.nm = new(netmap.NetworkMap)
|
||||
}
|
||||
@@ -1914,13 +1994,13 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) {
|
||||
report: report,
|
||||
},
|
||||
}
|
||||
|
||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
||||
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
|
||||
syspolicy.ExitNodeID, "auto:any",
|
||||
))
|
||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
||||
|
||||
msh := &mockSyspolicyHandler{
|
||||
t: t,
|
||||
stringPolicies: map[syspolicy.Key]*string{
|
||||
syspolicy.ExitNodeID: ptr.To("auto:any"),
|
||||
},
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, msh)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b := newTestLocalBackend(t)
|
||||
@@ -1969,11 +2049,13 @@ func TestAutoExitNodeSetNetInfoCallback(t *testing.T) {
|
||||
}
|
||||
cc = newClient(t, opts)
|
||||
b.cc = cc
|
||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
||||
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
|
||||
syspolicy.ExitNodeID, "auto:any",
|
||||
))
|
||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
||||
msh := &mockSyspolicyHandler{
|
||||
t: t,
|
||||
stringPolicies: map[syspolicy.Key]*string{
|
||||
syspolicy.ExitNodeID: ptr.To("auto:any"),
|
||||
},
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, msh)
|
||||
peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes())
|
||||
peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes())
|
||||
selfNode := tailcfg.Node{
|
||||
@@ -2078,11 +2160,13 @@ func TestSetControlClientStatusAutoExitNode(t *testing.T) {
|
||||
DERPMap: derpMap,
|
||||
}
|
||||
b := newTestLocalBackend(t)
|
||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
||||
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
|
||||
syspolicy.ExitNodeID, "auto:any",
|
||||
))
|
||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
||||
msh := &mockSyspolicyHandler{
|
||||
t: t,
|
||||
stringPolicies: map[syspolicy.Key]*string{
|
||||
syspolicy.ExitNodeID: ptr.To("auto:any"),
|
||||
},
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, msh)
|
||||
b.netMap = nm
|
||||
b.lastSuggestedExitNode = peer1.StableID()
|
||||
b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, report)
|
||||
@@ -2316,16 +2400,17 @@ func TestApplySysPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
settings := make([]source.TestSetting[string], 0, len(tt.stringPolicies))
|
||||
for p, v := range tt.stringPolicies {
|
||||
settings = append(settings, source.TestSettingOf(p, v))
|
||||
msh := &mockSyspolicyHandler{
|
||||
t: t,
|
||||
stringPolicies: make(map[syspolicy.Key]*string, len(tt.stringPolicies)),
|
||||
}
|
||||
policyStore := source.NewTestStoreOf(t, settings...)
|
||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
||||
for p, v := range tt.stringPolicies {
|
||||
v := v // construct a unique pointer for each policy value
|
||||
msh.stringPolicies[p] = &v
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, msh)
|
||||
|
||||
t.Run("unit", func(t *testing.T) {
|
||||
prefs := tt.prefs.Clone()
|
||||
@@ -2461,19 +2546,35 @@ func TestPreferencePolicyInfo(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for _, pp := range preferencePolicies {
|
||||
t.Run(string(pp.key), func(t *testing.T) {
|
||||
s := source.TestSetting[string]{
|
||||
Key: pp.key,
|
||||
Error: tt.policyError,
|
||||
Value: tt.policyValue,
|
||||
var h syspolicy.Handler
|
||||
|
||||
allPolicies := make(map[syspolicy.Key]*string, len(preferencePolicies)+1)
|
||||
allPolicies[syspolicy.ControlURL] = nil
|
||||
for _, pp := range preferencePolicies {
|
||||
allPolicies[pp.key] = nil
|
||||
}
|
||||
policyStore := source.NewTestStoreOf(t, s)
|
||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
||||
|
||||
if tt.policyError != nil {
|
||||
h = &errorSyspolicyHandler{
|
||||
t: t,
|
||||
err: tt.policyError,
|
||||
key: pp.key,
|
||||
allowKeys: allPolicies,
|
||||
}
|
||||
} else {
|
||||
msh := &mockSyspolicyHandler{
|
||||
t: t,
|
||||
stringPolicies: allPolicies,
|
||||
failUnknownPolicies: true,
|
||||
}
|
||||
msh.stringPolicies[pp.key] = &tt.policyValue
|
||||
h = msh
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, h)
|
||||
|
||||
prefs := defaultPrefs.AsStruct()
|
||||
pp.set(prefs, tt.initialValue)
|
||||
@@ -3724,16 +3825,15 @@ func TestShouldAutoExitNode(t *testing.T) {
|
||||
expectedBool: false,
|
||||
},
|
||||
}
|
||||
|
||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
|
||||
syspolicy.ExitNodeID, tt.exitNodeIDPolicyValue,
|
||||
))
|
||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
||||
|
||||
msh := &mockSyspolicyHandler{
|
||||
t: t,
|
||||
stringPolicies: map[syspolicy.Key]*string{
|
||||
syspolicy.ExitNodeID: ptr.To(tt.exitNodeIDPolicyValue),
|
||||
},
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, msh)
|
||||
got := shouldAutoExitNode()
|
||||
if got != tt.expectedBool {
|
||||
t.Fatalf("expected %v got %v for %v policy value", tt.expectedBool, got, tt.exitNodeIDPolicyValue)
|
||||
@@ -3871,13 +3971,17 @@ func TestFillAllowedSuggestions(t *testing.T) {
|
||||
want: []tailcfg.StableNodeID{"ABC", "def", "gHiJ"},
|
||||
},
|
||||
}
|
||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
|
||||
syspolicy.AllowedSuggestedExitNodes, tt.allowPolicy,
|
||||
))
|
||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
||||
mh := mockSyspolicyHandler{
|
||||
t: t,
|
||||
}
|
||||
if tt.allowPolicy != nil {
|
||||
mh.stringArrayPolicies = map[syspolicy.Key][]string{
|
||||
syspolicy.AllowedSuggestedExitNodes: tt.allowPolicy,
|
||||
}
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, &mh)
|
||||
|
||||
got := fillAllowedSuggestions()
|
||||
if got == nil {
|
||||
@@ -4434,35 +4538,3 @@ func TestLoginNotifications(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigFileReload tests that the LocalBackend reloads its configuration
|
||||
// when the configuration file changes.
|
||||
func TestConfigFileReload(t *testing.T) {
|
||||
cfg1 := `{"Hostname": "foo", "Version": "alpha0"}`
|
||||
f := filepath.Join(t.TempDir(), "cfg")
|
||||
must.Do(os.WriteFile(f, []byte(cfg1), 0600))
|
||||
sys := new(tsd.System)
|
||||
sys.InitialConfig = must.Get(conffile.Load(f))
|
||||
lb := newTestLocalBackendWithSys(t, sys)
|
||||
must.Do(lb.Start(ipn.Options{}))
|
||||
|
||||
lb.mu.Lock()
|
||||
hn := lb.hostinfo.Hostname
|
||||
lb.mu.Unlock()
|
||||
if hn != "foo" {
|
||||
t.Fatalf("got %q; want %q", hn, "foo")
|
||||
}
|
||||
|
||||
cfg2 := `{"Hostname": "bar", "Version": "alpha0"}`
|
||||
must.Do(os.WriteFile(f, []byte(cfg2), 0600))
|
||||
if !must.Get(lb.ReloadConfig()) {
|
||||
t.Fatal("reload failed")
|
||||
}
|
||||
|
||||
lb.mu.Lock()
|
||||
hn = lb.hostinfo.Hostname
|
||||
lb.mu.Unlock()
|
||||
if hn != "bar" {
|
||||
t.Fatalf("got %q; want %q", hn, "bar")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,8 +62,7 @@ import (
|
||||
"tailscale.com/util/osdiag"
|
||||
"tailscale.com/util/progresstracking"
|
||||
"tailscale.com/util/rands"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/testenv"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/wgengine/magicsock"
|
||||
)
|
||||
@@ -78,7 +77,6 @@ var handler = map[string]localAPIHandler{
|
||||
"cert/": (*Handler).serveCert,
|
||||
"file-put/": (*Handler).serveFilePut,
|
||||
"files/": (*Handler).serveFiles,
|
||||
"policy/": (*Handler).servePolicy,
|
||||
"profiles/": (*Handler).serveProfiles,
|
||||
|
||||
// The other /localapi/v0/NAME handlers are exact matches and contain only NAME
|
||||
@@ -572,9 +570,15 @@ func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
clientmetric.WritePrometheusExpositionFormat(w)
|
||||
}
|
||||
|
||||
// serveUserMetrics returns user-facing metrics in Prometheus text
|
||||
// exposition format.
|
||||
// TODO(kradalby): Remove this once we have landed on a final set of
|
||||
// metrics to export to clients and consider the metrics stable.
|
||||
var debugUsermetricsEndpoint = envknob.RegisterBool("TS_DEBUG_USER_METRICS")
|
||||
|
||||
func (h *Handler) serveUserMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if !testenv.InTest() && !debugUsermetricsEndpoint() {
|
||||
http.Error(w, "usermetrics debug flag not enabled", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
h.b.UserMetricsRegistry().Handler(w, r)
|
||||
}
|
||||
|
||||
@@ -1335,53 +1339,6 @@ func (h *Handler) servePrefs(w http.ResponseWriter, r *http.Request) {
|
||||
e.Encode(prefs)
|
||||
}
|
||||
|
||||
func (h *Handler) servePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.PermitRead {
|
||||
http.Error(w, "policy access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/policy/")
|
||||
if !ok {
|
||||
http.Error(w, "misconfigured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var scope setting.PolicyScope
|
||||
if suffix == "" {
|
||||
scope = setting.DefaultScope()
|
||||
} else if err := scope.UnmarshalText([]byte(suffix)); err != nil {
|
||||
http.Error(w, fmt.Sprintf("%q is not a valid scope", suffix), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := rsop.PolicyFor(scope)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var effectivePolicy *setting.Snapshot
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
effectivePolicy = policy.Get()
|
||||
case "POST":
|
||||
effectivePolicy, err = policy.Reload()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
default:
|
||||
http.Error(w, "unsupported method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w)
|
||||
e.SetIndent("", "\t")
|
||||
e.Encode(effectivePolicy)
|
||||
}
|
||||
|
||||
type resJSON struct {
|
||||
Error string `json:",omitempty"`
|
||||
}
|
||||
|
||||
@@ -13,27 +13,19 @@ import (
|
||||
"time"
|
||||
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
"tailscale.com/kube/kubeapi"
|
||||
"tailscale.com/kube/kubeclient"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// TODO(irbekrm): should we bump this? should we have retries? See tailscale/tailscale#13024
|
||||
const timeout = 5 * time.Second
|
||||
|
||||
// Store is an ipn.StateStore that uses a Kubernetes Secret for persistence.
|
||||
type Store struct {
|
||||
client kubeclient.Client
|
||||
canPatch bool
|
||||
secretName string
|
||||
|
||||
// memory holds the latest tailscale state. Writes write state to a kube Secret and memory, Reads read from
|
||||
// memory.
|
||||
memory mem.Store
|
||||
}
|
||||
|
||||
// New returns a new Store that persists to the named Secret.
|
||||
// New returns a new Store that persists to the named secret.
|
||||
func New(_ logger.Logf, secretName string) (*Store, error) {
|
||||
c, err := kubeclient.New()
|
||||
if err != nil {
|
||||
@@ -47,16 +39,11 @@ func New(_ logger.Logf, secretName string) (*Store, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &Store{
|
||||
return &Store{
|
||||
client: c,
|
||||
canPatch: canPatch,
|
||||
secretName: secretName,
|
||||
}
|
||||
// Load latest state from kube Secret if it already exists.
|
||||
if err := s.loadState(); err != nil && err != ipn.ErrStateNotExist {
|
||||
return nil, fmt.Errorf("error loading state from kube Secret: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) SetDialer(d func(ctx context.Context, network, address string) (net.Conn, error)) {
|
||||
@@ -67,17 +54,37 @@ func (s *Store) String() string { return "kube.Store" }
|
||||
|
||||
// ReadState implements the StateStore interface.
|
||||
func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) {
|
||||
return s.memory.ReadState(ipn.StateKey(sanitizeKey(id)))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
secret, err := s.client.GetSecret(ctx, s.secretName)
|
||||
if err != nil {
|
||||
if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 {
|
||||
return nil, ipn.ErrStateNotExist
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
b, ok := secret.Data[sanitizeKey(id)]
|
||||
if !ok {
|
||||
return nil, ipn.ErrStateNotExist
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func sanitizeKey(k ipn.StateKey) string {
|
||||
// The only valid characters in a Kubernetes secret key are alphanumeric, -,
|
||||
// _, and .
|
||||
return strings.Map(func(r rune) rune {
|
||||
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' {
|
||||
return r
|
||||
}
|
||||
return '_'
|
||||
}, string(k))
|
||||
}
|
||||
|
||||
// WriteState implements the StateStore interface.
|
||||
func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
s.memory.WriteState(ipn.StateKey(sanitizeKey(id)), bs)
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
func (s *Store) WriteState(id ipn.StateKey, bs []byte) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
secret, err := s.client.GetSecret(ctx, s.secretName)
|
||||
@@ -130,29 +137,3 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) loadState() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
secret, err := s.client.GetSecret(ctx, s.secretName)
|
||||
if err != nil {
|
||||
if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 {
|
||||
return ipn.ErrStateNotExist
|
||||
}
|
||||
return err
|
||||
}
|
||||
s.memory.LoadFromMap(secret.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitizeKey(k ipn.StateKey) string {
|
||||
// The only valid characters in a Kubernetes secret key are alphanumeric, -,
|
||||
// _, and .
|
||||
return strings.Map(func(r rune) rune {
|
||||
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' {
|
||||
return r
|
||||
}
|
||||
return '_'
|
||||
}, string(k))
|
||||
}
|
||||
|
||||
@@ -9,10 +9,8 @@ import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/mak"
|
||||
)
|
||||
|
||||
// New returns a new Store.
|
||||
@@ -30,7 +28,6 @@ type Store struct {
|
||||
func (s *Store) String() string { return "mem.Store" }
|
||||
|
||||
// ReadState implements the StateStore interface.
|
||||
// It returns ipn.ErrStateNotExist if the state does not exist.
|
||||
func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -42,7 +39,6 @@ func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) {
|
||||
}
|
||||
|
||||
// WriteState implements the StateStore interface.
|
||||
// It never returns an error.
|
||||
func (s *Store) WriteState(id ipn.StateKey, bs []byte) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -53,19 +49,6 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromMap loads the in-memory cache from the provided map.
|
||||
// Any existing content is cleared, and the provided map is
|
||||
// copied into the cache.
|
||||
func (s *Store) LoadFromMap(m map[string][]byte) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
xmaps.Clear(s.cache)
|
||||
for k, v := range m {
|
||||
mak.Set(&s.cache, ipn.StateKey(k), v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// LoadFromJSON attempts to unmarshal json content into the
|
||||
// in-memory cache.
|
||||
func (s *Store) LoadFromJSON(data []byte) error {
|
||||
|
||||
@@ -381,7 +381,6 @@ _Appears in:_
|
||||
| `nodeName` _string_ | Proxy Pod's node name.<br />https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | |
|
||||
| `nodeSelector` _object (keys:string, values:string)_ | Proxy Pod's node selector.<br />By default Tailscale Kubernetes operator does not apply any node<br />selector.<br />https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | |
|
||||
| `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | Proxy Pod's tolerations.<br />By default Tailscale Kubernetes operator does not apply any<br />tolerations.<br />https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | |
|
||||
| `topologySpreadConstraints` _[TopologySpreadConstraint](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#topologyspreadconstraint-v1-core) array_ | Proxy Pod's topology spread constraints.<br />By default Tailscale Kubernetes operator does not apply any topology spread constraints.<br />https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ | | |
|
||||
|
||||
|
||||
#### ProxyClass
|
||||
|
||||
@@ -154,11 +154,7 @@ type Pod struct {
|
||||
// https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling
|
||||
// +optional
|
||||
Tolerations []corev1.Toleration `json:"tolerations,omitempty"`
|
||||
// Proxy Pod's topology spread constraints.
|
||||
// By default Tailscale Kubernetes operator does not apply any topology spread constraints.
|
||||
// https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/
|
||||
// +optional
|
||||
TopologySpreadConstraints []corev1.TopologySpreadConstraint `json:"topologySpreadConstraints,omitempty"`
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
|
||||
@@ -392,13 +392,6 @@ func (in *Pod) DeepCopyInto(out *Pod) {
|
||||
(*in)[i].DeepCopyInto(&(*out)[i])
|
||||
}
|
||||
}
|
||||
if in.TopologySpreadConstraints != nil {
|
||||
in, out := &in.TopologySpreadConstraints, &out.TopologySpreadConstraints
|
||||
*out = make([]corev1.TopologySpreadConstraint, len(*in))
|
||||
for i := range *in {
|
||||
(*in)[i].DeepCopyInto(&(*out)[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Pod.
|
||||
|
||||
@@ -73,13 +73,13 @@ See also the dependencies in the [Tailscale CLI][].
|
||||
- [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE))
|
||||
- [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE))
|
||||
- [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE))
|
||||
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.28.0:LICENSE))
|
||||
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE))
|
||||
- [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE))
|
||||
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE))
|
||||
- [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.26.0:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.25.0:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.19.0:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE))
|
||||
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE))
|
||||
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE))
|
||||
- [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE))
|
||||
|
||||
@@ -65,15 +65,15 @@ Windows][]. See also the dependencies in the [Tailscale CLI][].
|
||||
- [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE))
|
||||
- [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE))
|
||||
- [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE))
|
||||
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.28.0:LICENSE))
|
||||
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE))
|
||||
- [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE))
|
||||
- [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE))
|
||||
- [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE))
|
||||
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE))
|
||||
- [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.26.0:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.25.0:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.19.0:LICENSE))
|
||||
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE))
|
||||
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE))
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE))
|
||||
- [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2))
|
||||
- [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3))
|
||||
- [gopkg.in/Knetic/govaluate.v3](https://pkg.go.dev/gopkg.in/Knetic/govaluate.v3) ([MIT](https://github.com/Knetic/govaluate/blob/v3.0.0/LICENSE))
|
||||
|
||||
@@ -63,9 +63,7 @@ type Manager struct {
|
||||
mu sync.Mutex // guards following
|
||||
// config is the last configuration we successfully compiled or nil if there
|
||||
// was any failure applying the last configuration.
|
||||
config *Config
|
||||
forceAAAA bool // whether client wants MagicDNS AAAA even if unsure of host's IPv6 status
|
||||
forceFallbackResolvers []*dnstype.Resolver
|
||||
config *Config
|
||||
}
|
||||
|
||||
// NewManagers created a new manager from the given config.
|
||||
@@ -130,28 +128,6 @@ func (m *Manager) GetBaseConfig() (OSConfig, error) {
|
||||
return m.os.GetBaseConfig()
|
||||
}
|
||||
|
||||
func (m *Manager) GetForceAAAA() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.forceAAAA
|
||||
}
|
||||
|
||||
func (m *Manager) SetForceAAAA(v bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.forceAAAA = v
|
||||
}
|
||||
|
||||
// SetForceFallbackResolvers sets the resolvers to use to override
|
||||
// the fallback resolvers if the control plane doesn't send any.
|
||||
//
|
||||
// It takes ownership of the provided slice.
|
||||
func (m *Manager) SetForceFallbackResolvers(resolvers []*dnstype.Resolver) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.forceFallbackResolvers = resolvers
|
||||
}
|
||||
|
||||
// setLocked sets the DNS configuration.
|
||||
//
|
||||
// m.mu must be held.
|
||||
@@ -170,10 +146,6 @@ func (m *Manager) setLocked(cfg Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := rcfg.Routes["."]; !ok && len(m.forceFallbackResolvers) > 0 {
|
||||
rcfg.Routes["."] = m.forceFallbackResolvers
|
||||
}
|
||||
|
||||
m.logf("Resolvercfg: %v", logger.ArgWriter(func(w *bufio.Writer) {
|
||||
rcfg.WriteToBufioWriter(w)
|
||||
}))
|
||||
|
||||
@@ -57,7 +57,6 @@ func (m *resolvdManager) SetDNS(config OSConfig) error {
|
||||
|
||||
if len(newSearch) > 1 {
|
||||
newResolvConf = append(newResolvConf, []byte(strings.Join(newSearch, " "))...)
|
||||
newResolvConf = append(newResolvConf, '\n')
|
||||
}
|
||||
|
||||
err = m.fs.WriteFile(resolvConf, newResolvConf, 0644)
|
||||
@@ -124,6 +123,6 @@ func (m resolvdManager) readResolvConf() (config OSConfig, err error) {
|
||||
}
|
||||
|
||||
func removeSearchLines(orig []byte) []byte {
|
||||
re := regexp.MustCompile(`(?ms)^search\s+.+$`)
|
||||
re := regexp.MustCompile(`(?m)^search\s+.+$`)
|
||||
return re.ReplaceAll(orig, []byte(""))
|
||||
}
|
||||
|
||||
@@ -27,7 +27,6 @@ import (
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/dnstype"
|
||||
)
|
||||
|
||||
@@ -277,8 +276,6 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
|
||||
tb.Fatal("cannot skip both UDP and TCP servers")
|
||||
}
|
||||
|
||||
logf := tstest.WhileTestRunningLogger(tb)
|
||||
|
||||
tcpResponse := make([]byte, len(response)+2)
|
||||
binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
|
||||
copy(tcpResponse[2:], response)
|
||||
@@ -332,13 +329,13 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
|
||||
// Read the length header, then the buffer
|
||||
var length uint16
|
||||
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
||||
logf("error reading length header: %v", err)
|
||||
tb.Logf("error reading length header: %v", err)
|
||||
return
|
||||
}
|
||||
req := make([]byte, length)
|
||||
n, err := io.ReadFull(conn, req)
|
||||
if err != nil {
|
||||
logf("error reading query: %v", err)
|
||||
tb.Logf("error reading query: %v", err)
|
||||
return
|
||||
}
|
||||
req = req[:n]
|
||||
@@ -346,7 +343,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
|
||||
|
||||
// Write response
|
||||
if _, err := conn.Write(tcpResponse); err != nil {
|
||||
logf("error writing response: %v", err)
|
||||
tb.Logf("error writing response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -370,7 +367,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
|
||||
handleUDP := func(addr netip.AddrPort, req []byte) {
|
||||
onRequest(false, req)
|
||||
if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil {
|
||||
logf("error writing response: %v", err)
|
||||
tb.Logf("error writing response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,7 +390,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
|
||||
tb.Cleanup(func() {
|
||||
tcpLn.Close()
|
||||
udpLn.Close()
|
||||
logf("waiting for listeners to finish...")
|
||||
tb.Logf("waiting for listeners to finish...")
|
||||
wg.Wait()
|
||||
})
|
||||
return
|
||||
@@ -453,8 +450,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
|
||||
}
|
||||
|
||||
func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
|
||||
logf := tstest.WhileTestRunningLogger(tb)
|
||||
netMon, err := netmon.New(logf)
|
||||
netMon, err := netmon.New(tb.Logf)
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
@@ -462,7 +458,7 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports
|
||||
var dialer tsdial.Dialer
|
||||
dialer.SetNetMon(netMon)
|
||||
|
||||
fwd := newForwarder(logf, netMon, nil, &dialer, new(health.Tracker), nil)
|
||||
fwd := newForwarder(tb.Logf, netMon, nil, &dialer, new(health.Tracker), nil)
|
||||
if modify != nil {
|
||||
modify(fwd)
|
||||
}
|
||||
|
||||
@@ -392,11 +392,10 @@ type probePlan map[string][]probe
|
||||
// sortRegions returns the regions of dm first sorted
|
||||
// from fastest to slowest (based on the 'last' report),
|
||||
// end in regions that have no data.
|
||||
func sortRegions(dm *tailcfg.DERPMap, last *Report, preferredDERP int) (prev []*tailcfg.DERPRegion) {
|
||||
func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) {
|
||||
prev = make([]*tailcfg.DERPRegion, 0, len(dm.Regions))
|
||||
for _, reg := range dm.Regions {
|
||||
// include an otherwise avoid region if it is the current preferred region
|
||||
if reg.Avoid && reg.RegionID != preferredDERP {
|
||||
if reg.Avoid {
|
||||
continue
|
||||
}
|
||||
prev = append(prev, reg)
|
||||
@@ -421,19 +420,9 @@ func sortRegions(dm *tailcfg.DERPMap, last *Report, preferredDERP int) (prev []*
|
||||
// a full report, all regions are scanned.)
|
||||
const numIncrementalRegions = 3
|
||||
|
||||
// makeProbePlan generates the probe plan for a DERPMap, given the most recent
|
||||
// report and the current home DERP. preferredDERP is passed independently of
|
||||
// last (report) because last is currently nil'd to indicate a desire for a full
|
||||
// netcheck.
|
||||
//
|
||||
// TODO(raggi,jwhited): refactor the callers and this function to be more clear
|
||||
// about full vs. incremental netchecks, and remove the need for the history
|
||||
// hiding. This was avoided in an incremental change due to exactly this kind of
|
||||
// distant coupling.
|
||||
// TODO(raggi): change from "preferred DERP" from a historical report to "home
|
||||
// DERP" as in what DERP is the current home connection, this would further
|
||||
// reduce flap events.
|
||||
func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, preferredDERP int) (plan probePlan) {
|
||||
// makeProbePlan generates the probe plan for a DERPMap, given the most
|
||||
// recent report and whether IPv6 is configured on an interface.
|
||||
func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (plan probePlan) {
|
||||
if last == nil || len(last.RegionLatency) == 0 {
|
||||
return makeProbePlanInitial(dm, ifState)
|
||||
}
|
||||
@@ -444,34 +433,9 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, pre
|
||||
had4 := len(last.RegionV4Latency) > 0
|
||||
had6 := len(last.RegionV6Latency) > 0
|
||||
hadBoth := have6if && had4 && had6
|
||||
// #13969 ensure that the home region is always probed.
|
||||
// If a netcheck has unstable latency, such as a user with large amounts of
|
||||
// bufferbloat or a highly congested connection, there are cases where a full
|
||||
// netcheck may observe a one-off high latency to the current home DERP. Prior
|
||||
// to the forced inclusion of the home DERP, this would result in an
|
||||
// incremental netcheck following such an event to cause a home DERP move, with
|
||||
// restoration back to the home DERP on the next full netcheck ~5 minutes later
|
||||
// - which is highly disruptive when it causes shifts in geo routed subnet
|
||||
// routers. By always including the home DERP in the incremental netcheck, we
|
||||
// ensure that the home DERP is always probed, even if it observed a recenet
|
||||
// poor latency sample. This inclusion enables the latency history checks in
|
||||
// home DERP selection to still take effect.
|
||||
// planContainsHome indicates whether the home DERP has been added to the probePlan,
|
||||
// if there is no prior home, then there's no home to additionally include.
|
||||
planContainsHome := preferredDERP == 0
|
||||
for ri, reg := range sortRegions(dm, last, preferredDERP) {
|
||||
regIsHome := reg.RegionID == preferredDERP
|
||||
if ri >= numIncrementalRegions {
|
||||
// planned at least numIncrementalRegions regions and that includes the
|
||||
// last home region (or there was none), plan complete.
|
||||
if planContainsHome {
|
||||
break
|
||||
}
|
||||
// planned at least numIncrementalRegions regions, but not the home region,
|
||||
// check if this is the home region, if not, skip it.
|
||||
if !regIsHome {
|
||||
continue
|
||||
}
|
||||
for ri, reg := range sortRegions(dm, last) {
|
||||
if ri == numIncrementalRegions {
|
||||
break
|
||||
}
|
||||
var p4, p6 []probe
|
||||
do4 := have4if
|
||||
@@ -482,7 +446,7 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, pre
|
||||
tries := 1
|
||||
isFastestTwo := ri < 2
|
||||
|
||||
if isFastestTwo || regIsHome {
|
||||
if isFastestTwo {
|
||||
tries = 2
|
||||
} else if hadBoth {
|
||||
// For dual stack machines, make the 3rd & slower nodes alternate
|
||||
@@ -493,15 +457,14 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, pre
|
||||
do4, do6 = false, true
|
||||
}
|
||||
}
|
||||
if !regIsHome && !isFastestTwo && !had6 {
|
||||
if !isFastestTwo && !had6 {
|
||||
do6 = false
|
||||
}
|
||||
|
||||
if regIsHome {
|
||||
if reg.RegionID == last.PreferredDERP {
|
||||
// But if we already had a DERP home, try extra hard to
|
||||
// make sure it's there so we don't flip flop around.
|
||||
tries = 4
|
||||
planContainsHome = true
|
||||
}
|
||||
|
||||
for try := 0; try < tries; try++ {
|
||||
@@ -826,10 +789,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
|
||||
c.curState = rs
|
||||
last := c.last
|
||||
|
||||
// Extract preferredDERP from the last report, if available. This will be used
|
||||
// in captive portal detection and DERP flapping suppression. Ideally this would
|
||||
// be the current active home DERP rather than the last report preferred DERP,
|
||||
// but only the latter is presently available.
|
||||
// Even if we're doing a non-incremental update, we may want to try our
|
||||
// preferred DERP region for captive portal detection. Save that, if we
|
||||
// have it.
|
||||
var preferredDERP int
|
||||
if last != nil {
|
||||
preferredDERP = last.PreferredDERP
|
||||
@@ -886,7 +848,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
|
||||
|
||||
var plan probePlan
|
||||
if opts == nil || !opts.OnlyTCP443 {
|
||||
plan = makeProbePlan(dm, ifState, last, preferredDERP)
|
||||
plan = makeProbePlan(dm, ifState, last)
|
||||
}
|
||||
|
||||
// If we're doing a full probe, also check for a captive portal. We
|
||||
|
||||
@@ -357,15 +357,6 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) {
|
||||
wantPrevLen: 3,
|
||||
wantDERP: 2, // moved to d2 since d1 is gone
|
||||
},
|
||||
{
|
||||
name: "preferred_derp_hysteresis_no_switch_pct",
|
||||
steps: []step{
|
||||
{0 * time.Second, report("d1", 34*time.Millisecond, "d2", 35*time.Millisecond)},
|
||||
{1 * time.Second, report("d1", 34*time.Millisecond, "d2", 23*time.Millisecond)},
|
||||
},
|
||||
wantPrevLen: 2,
|
||||
wantDERP: 1, // diff is 11ms, but d2 is greater than 2/3s of d1
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -599,40 +590,6 @@ func TestMakeProbePlan(t *testing.T) {
|
||||
"region-3-v4": []probe{p("3a", 4)},
|
||||
},
|
||||
},
|
||||
{
|
||||
// #13969: ensure that the prior/current home region is always included in
|
||||
// probe plans, so that we don't flap between regions due to a single major
|
||||
// netcheck having excluded the home region due to a spuriously high sample.
|
||||
name: "ensure_home_region_inclusion",
|
||||
dm: basicMap,
|
||||
have6if: true,
|
||||
last: &Report{
|
||||
RegionLatency: map[int]time.Duration{
|
||||
1: 50 * time.Millisecond,
|
||||
2: 20 * time.Millisecond,
|
||||
3: 30 * time.Millisecond,
|
||||
4: 40 * time.Millisecond,
|
||||
},
|
||||
RegionV4Latency: map[int]time.Duration{
|
||||
1: 50 * time.Millisecond,
|
||||
2: 20 * time.Millisecond,
|
||||
},
|
||||
RegionV6Latency: map[int]time.Duration{
|
||||
3: 30 * time.Millisecond,
|
||||
4: 40 * time.Millisecond,
|
||||
},
|
||||
PreferredDERP: 1,
|
||||
},
|
||||
want: probePlan{
|
||||
"region-1-v4": []probe{p("1a", 4), p("1a", 4, 60*ms), p("1a", 4, 220*ms), p("1a", 4, 330*ms)},
|
||||
"region-1-v6": []probe{p("1a", 6), p("1a", 6, 60*ms), p("1a", 6, 220*ms), p("1a", 6, 330*ms)},
|
||||
"region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)},
|
||||
"region-2-v6": []probe{p("2a", 6), p("2b", 6, 24*ms)},
|
||||
"region-3-v4": []probe{p("3a", 4), p("3b", 4, 36*ms)},
|
||||
"region-3-v6": []probe{p("3a", 6), p("3b", 6, 36*ms)},
|
||||
"region-4-v4": []probe{p("4a", 4)},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -640,11 +597,7 @@ func TestMakeProbePlan(t *testing.T) {
|
||||
HaveV6: tt.have6if,
|
||||
HaveV4: !tt.no4,
|
||||
}
|
||||
preferredDERP := 0
|
||||
if tt.last != nil {
|
||||
preferredDERP = tt.last.PreferredDERP
|
||||
}
|
||||
got := makeProbePlan(tt.dm, ifState, tt.last, preferredDERP)
|
||||
got := makeProbePlan(tt.dm, ifState, tt.last)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("unexpected plan; got:\n%v\nwant:\n%v\n", got, tt.want)
|
||||
}
|
||||
@@ -817,7 +770,7 @@ func TestSortRegions(t *testing.T) {
|
||||
report.RegionLatency[3] = time.Second * time.Duration(6)
|
||||
report.RegionLatency[4] = time.Second * time.Duration(0)
|
||||
report.RegionLatency[5] = time.Second * time.Duration(2)
|
||||
sortedMap := sortRegions(unsortedMap, report, 0)
|
||||
sortedMap := sortRegions(unsortedMap, report)
|
||||
|
||||
// Sorting by latency this should result in rid: 5, 2, 1, 3
|
||||
// rid 4 with latency 0 should be at the end
|
||||
|
||||
@@ -81,12 +81,6 @@ const (
|
||||
addrTypeNotSupported replyCode = 8
|
||||
)
|
||||
|
||||
// UDP conn default buffer size and read timeout.
|
||||
const (
|
||||
bufferSize = 8 * 1024
|
||||
readTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Server is a SOCKS5 proxy server.
|
||||
type Server struct {
|
||||
// Logf optionally specifies the logger to use.
|
||||
@@ -149,8 +143,7 @@ type Conn struct {
|
||||
clientConn net.Conn
|
||||
request *request
|
||||
|
||||
udpClientAddr net.Addr
|
||||
udpTargetConns map[socksAddr]net.Conn
|
||||
udpClientAddr net.Addr
|
||||
}
|
||||
|
||||
// Run starts the new connection.
|
||||
@@ -283,6 +276,15 @@ func (c *Conn) handleUDP() error {
|
||||
}
|
||||
defer clientUDPConn.Close()
|
||||
|
||||
serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
|
||||
if err != nil {
|
||||
res := errorResponse(generalFailure)
|
||||
buf, _ := res.marshal()
|
||||
c.clientConn.Write(buf)
|
||||
return err
|
||||
}
|
||||
defer serverUDPConn.Close()
|
||||
|
||||
bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -303,32 +305,25 @@ func (c *Conn) handleUDP() error {
|
||||
}
|
||||
c.clientConn.Write(buf)
|
||||
|
||||
return c.transferUDP(c.clientConn, clientUDPConn)
|
||||
return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
|
||||
}
|
||||
|
||||
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error {
|
||||
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
const bufferSize = 8 * 1024
|
||||
const readTimeout = 5 * time.Second
|
||||
|
||||
// client -> target
|
||||
go func() {
|
||||
defer cancel()
|
||||
|
||||
c.udpTargetConns = make(map[socksAddr]net.Conn)
|
||||
// close all target udp connections when the client connection is closed
|
||||
defer func() {
|
||||
for _, conn := range c.udpTargetConns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
err := c.handleUDPRequest(ctx, clientConn, buf)
|
||||
err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
@@ -342,6 +337,29 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
|
||||
}
|
||||
}()
|
||||
|
||||
// target -> client
|
||||
go func() {
|
||||
defer cancel()
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
c.logf("udp transfer: handle udp response fail: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// A UDP association terminates when the TCP connection that the UDP
|
||||
// ASSOCIATE request arrived on terminates. RFC1928
|
||||
_, err := io.Copy(io.Discard, associatedTCP)
|
||||
@@ -351,50 +369,11 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) getOrDialTargetConn(
|
||||
ctx context.Context,
|
||||
clientConn net.PacketConn,
|
||||
targetAddr socksAddr,
|
||||
) (net.Conn, error) {
|
||||
conn, exist := c.udpTargetConns[targetAddr]
|
||||
if exist {
|
||||
return conn, nil
|
||||
}
|
||||
conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.udpTargetConns[targetAddr] = conn
|
||||
|
||||
// target -> client
|
||||
go func() {
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
err := c.handleUDPResponse(clientConn, targetAddr, conn, buf)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
c.logf("udp transfer: handle udp response fail: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Conn) handleUDPRequest(
|
||||
ctx context.Context,
|
||||
clientConn net.PacketConn,
|
||||
targetConn net.PacketConn,
|
||||
buf []byte,
|
||||
readTimeout time.Duration,
|
||||
) error {
|
||||
// add a deadline for the read to avoid blocking forever
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
@@ -407,35 +386,38 @@ func (c *Conn) handleUDPRequest(
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse udp request: %w", err)
|
||||
}
|
||||
|
||||
targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr)
|
||||
targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial target %s fail: %w", req.addr, err)
|
||||
c.logf("resolve target addr fail: %v", err)
|
||||
}
|
||||
|
||||
nn, err := targetConn.Write(data)
|
||||
nn, err := targetConn.WriteTo(data, targetAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to target %s fail: %w", req.addr, err)
|
||||
return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
|
||||
}
|
||||
if nn != len(data) {
|
||||
return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
|
||||
return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) handleUDPResponse(
|
||||
targetConn net.PacketConn,
|
||||
clientConn net.PacketConn,
|
||||
targetAddr socksAddr,
|
||||
targetConn net.Conn,
|
||||
buf []byte,
|
||||
readTimeout time.Duration,
|
||||
) error {
|
||||
// add a deadline for the read to avoid blocking forever
|
||||
_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
n, err := targetConn.Read(buf)
|
||||
n, addr, err := targetConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read from target: %w", err)
|
||||
}
|
||||
hdr := udpRequest{addr: targetAddr}
|
||||
host, port, err := splitHostPort(addr.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
|
||||
pkt, err := hdr.marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal udp request: %w", err)
|
||||
@@ -645,15 +627,10 @@ func (s socksAddr) marshal() ([]byte, error) {
|
||||
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
|
||||
return pkt, nil
|
||||
}
|
||||
|
||||
func (s socksAddr) hostPort() string {
|
||||
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
|
||||
}
|
||||
|
||||
func (s socksAddr) String() string {
|
||||
return s.hostPort()
|
||||
}
|
||||
|
||||
// response contains the contents of
|
||||
// a response packet sent from the proxy
|
||||
// to the client.
|
||||
|
||||
@@ -169,25 +169,12 @@ func TestReadPassword(t *testing.T) {
|
||||
|
||||
func TestUDP(t *testing.T) {
|
||||
// backend UDP server which we'll use SOCKS5 to connect to
|
||||
newUDPEchoServer := func() net.PacketConn {
|
||||
listener, err := net.ListenPacket("udp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go udpEchoServer(listener)
|
||||
return listener
|
||||
listener, err := net.ListenPacket("udp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const echoServerNumber = 3
|
||||
echoServerListener := make([]net.PacketConn, echoServerNumber)
|
||||
for i := 0; i < echoServerNumber; i++ {
|
||||
echoServerListener[i] = newUDPEchoServer()
|
||||
}
|
||||
defer func() {
|
||||
for i := 0; i < echoServerNumber; i++ {
|
||||
_ = echoServerListener[i].Close()
|
||||
}
|
||||
}()
|
||||
backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
|
||||
go udpEchoServer(listener)
|
||||
|
||||
// SOCKS5 server
|
||||
socks5, err := net.Listen("tcp", ":0")
|
||||
@@ -197,93 +184,84 @@ func TestUDP(t *testing.T) {
|
||||
socks5Port := socks5.Addr().(*net.TCPAddr).Port
|
||||
go socks5Server(socks5)
|
||||
|
||||
// make a socks5 udpAssociate conn
|
||||
newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) {
|
||||
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf) // server hello
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired {
|
||||
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
|
||||
}
|
||||
|
||||
targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
|
||||
targetAddrPkt, err := targetAddr.marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n, err = conn.Read(buf) // server response
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) {
|
||||
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
|
||||
}
|
||||
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return conn, udpProxySocksAddr
|
||||
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf) // server hello
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
|
||||
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
|
||||
}
|
||||
|
||||
conn, udpProxySocksAddr := newUdpAssociateConn()
|
||||
defer conn.Close()
|
||||
targetAddr := socksAddr{
|
||||
addrType: domainName,
|
||||
addr: "localhost",
|
||||
port: uint16(backendServerPort),
|
||||
}
|
||||
targetAddrPkt, err := targetAddr.marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) {
|
||||
udpPayload, err := (&udpRequest{addr: addr}).marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
udpPayload = append(udpPayload, body...)
|
||||
_, err = socks5UDPConn.Write(udpPayload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, err := socks5UDPConn.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, responseBody, err = parseUDPRequest(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return responseBody
|
||||
n, err = conn.Read(buf) // server response
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
|
||||
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
|
||||
}
|
||||
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr)
|
||||
udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer socks5UDPConn.Close()
|
||||
|
||||
for i := 0; i < echoServerNumber; i++ {
|
||||
port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port
|
||||
addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)}
|
||||
requestBody := []byte(fmt.Sprintf("Test %d", i))
|
||||
responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody)
|
||||
if !bytes.Equal(requestBody, responseBody) {
|
||||
t.Fatalf("got: %q want: %q", responseBody, requestBody)
|
||||
}
|
||||
udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
udpPayload = append(udpPayload, []byte("Test")...)
|
||||
_, err = udpConn.Write(udpPayload) // send udp package
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
n, _, err = udpConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(responseBody) != "Test" {
|
||||
t.Fatalf("got: %q want: Test", responseBody)
|
||||
}
|
||||
err = udpConn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,13 +279,7 @@ func setNetMon(netMon *netmon.Monitor) {
|
||||
if ifName == "" {
|
||||
return
|
||||
}
|
||||
// DefaultRouteInterface and Interface are gathered at different points in time.
|
||||
// Check for existence first, to avoid a nil pointer dereference.
|
||||
iface, ok := state.Interface[ifName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ifIndex := iface.Index
|
||||
ifIndex := state.Interface[ifName].Index
|
||||
sockStats.mu.Lock()
|
||||
defer sockStats.mu.Unlock()
|
||||
// Ignore changes to unknown interfaces -- it would require
|
||||
|
||||
@@ -213,14 +213,24 @@ type Wrapper struct {
|
||||
}
|
||||
|
||||
type metrics struct {
|
||||
inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[usermetric.DropLabels]
|
||||
outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[usermetric.DropLabels]
|
||||
inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel]
|
||||
outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel]
|
||||
}
|
||||
|
||||
func registerMetrics(reg *usermetric.Registry) *metrics {
|
||||
return &metrics{
|
||||
inboundDroppedPacketsTotal: reg.DroppedPacketsInbound(),
|
||||
outboundDroppedPacketsTotal: reg.DroppedPacketsOutbound(),
|
||||
inboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel](
|
||||
reg,
|
||||
"tailscaled_inbound_dropped_packets_total",
|
||||
"counter",
|
||||
"Counts the number of dropped packets received by the node from other peers",
|
||||
),
|
||||
outboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel](
|
||||
reg,
|
||||
"tailscaled_outbound_dropped_packets_total",
|
||||
"counter",
|
||||
"Counts the number of packets dropped while being sent to other peers",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -861,22 +871,7 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf
|
||||
return res, gro
|
||||
}
|
||||
}
|
||||
if resp := t.filtRunOut(p, pc); resp != filter.Accept {
|
||||
return resp, gro
|
||||
}
|
||||
|
||||
if t.PostFilterPacketOutboundToWireGuard != nil {
|
||||
if res := t.PostFilterPacketOutboundToWireGuard(p, t); res.IsDrop() {
|
||||
return res, gro
|
||||
}
|
||||
}
|
||||
return filter.Accept, gro
|
||||
}
|
||||
|
||||
// filtRunOut runs the outbound packet filter on p.
|
||||
// It uses pc to determine if the packet is to a jailed peer and should be
|
||||
// filtered with the jailed filter.
|
||||
func (t *Wrapper) filtRunOut(p *packet.Parsed, pc *peerConfigTable) filter.Response {
|
||||
// If the outbound packet is to a jailed peer, use our jailed peer
|
||||
// packet filter.
|
||||
var filt *filter.Filter
|
||||
@@ -886,17 +881,23 @@ func (t *Wrapper) filtRunOut(p *packet.Parsed, pc *peerConfigTable) filter.Respo
|
||||
filt = t.filter.Load()
|
||||
}
|
||||
if filt == nil {
|
||||
return filter.Drop
|
||||
return filter.Drop, gro
|
||||
}
|
||||
|
||||
if filt.RunOut(p, t.filterFlags) != filter.Accept {
|
||||
metricPacketOutDropFilter.Add(1)
|
||||
t.metrics.outboundDroppedPacketsTotal.Add(usermetric.DropLabels{
|
||||
Reason: usermetric.ReasonACL,
|
||||
t.metrics.outboundDroppedPacketsTotal.Add(dropPacketLabel{
|
||||
Reason: DropReasonACL,
|
||||
}, 1)
|
||||
return filter.Drop
|
||||
return filter.Drop, gro
|
||||
}
|
||||
return filter.Accept
|
||||
|
||||
if t.PostFilterPacketOutboundToWireGuard != nil {
|
||||
if res := t.PostFilterPacketOutboundToWireGuard(p, t); res.IsDrop() {
|
||||
return res, gro
|
||||
}
|
||||
}
|
||||
return filter.Accept, gro
|
||||
}
|
||||
|
||||
// noteActivity records that there was a read or write at the current time.
|
||||
@@ -1060,11 +1061,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
|
||||
p := parsedPacketPool.Get().(*packet.Parsed)
|
||||
defer parsedPacketPool.Put(p)
|
||||
p.Decode(pkt)
|
||||
response, _ := t.filterPacketOutboundToWireGuard(p, pc, nil)
|
||||
if response != filter.Accept {
|
||||
metricPacketOutDrop.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
invertGSOChecksum(pkt, gso)
|
||||
pc.snat(p)
|
||||
@@ -1162,8 +1158,8 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
|
||||
|
||||
if outcome != filter.Accept {
|
||||
metricPacketInDropFilter.Add(1)
|
||||
t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{
|
||||
Reason: usermetric.ReasonACL,
|
||||
t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{
|
||||
Reason: DropReasonACL,
|
||||
}, 1)
|
||||
|
||||
// Tell them, via TSMP, we're dropping them due to the ACL.
|
||||
@@ -1243,8 +1239,8 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
|
||||
t.noteActivity()
|
||||
_, err := t.tdevWrite(buffs, offset)
|
||||
if err != nil {
|
||||
t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{
|
||||
Reason: usermetric.ReasonError,
|
||||
t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{
|
||||
Reason: DropReasonError,
|
||||
}, int64(len(buffs)))
|
||||
}
|
||||
return len(buffs), err
|
||||
@@ -1486,6 +1482,20 @@ var (
|
||||
metricPacketOutDropSelfDisco = clientmetric.NewCounter("tstun_out_to_wg_drop_self_disco")
|
||||
)
|
||||
|
||||
type DropReason string
|
||||
|
||||
const (
|
||||
DropReasonACL DropReason = "acl"
|
||||
DropReasonError DropReason = "error"
|
||||
)
|
||||
|
||||
type dropPacketLabel struct {
|
||||
// Reason indicates what we have done with the packet, and has the following values:
|
||||
// - acl (rejected packets because of ACL)
|
||||
// - error (rejected packets because of an error)
|
||||
Reason DropReason
|
||||
}
|
||||
|
||||
func (t *Wrapper) InstallCaptureHook(cb capture.Callback) {
|
||||
t.captureHook.Store(cb)
|
||||
}
|
||||
|
||||
@@ -441,13 +441,13 @@ func TestFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
var metricInboundDroppedPacketsACL, metricInboundDroppedPacketsErr, metricOutboundDroppedPacketsACL int64
|
||||
if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok {
|
||||
if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok {
|
||||
metricInboundDroppedPacketsACL = m.Value()
|
||||
}
|
||||
if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonError}).(*expvar.Int); ok {
|
||||
if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonError}).(*expvar.Int); ok {
|
||||
metricInboundDroppedPacketsErr = m.Value()
|
||||
}
|
||||
if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok {
|
||||
if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok {
|
||||
metricOutboundDroppedPacketsACL = m.Value()
|
||||
}
|
||||
|
||||
|
||||
105
safeweb/http.go
105
safeweb/http.go
@@ -74,74 +74,25 @@ import (
|
||||
crand "crypto/rand"
|
||||
"fmt"
|
||||
"log"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/csrf"
|
||||
)
|
||||
|
||||
// CSP is the value of a Content-Security-Policy header. Keys are CSP
|
||||
// directives (like "default-src") and values are source expressions (like
|
||||
// "'self'" or "https://tailscale.com"). A nil slice value is allowed for some
|
||||
// directives like "upgrade-insecure-requests" that don't expect a list of
|
||||
// source definitions.
|
||||
type CSP map[string][]string
|
||||
|
||||
// DefaultCSP is the recommended CSP to use when not loading resources from
|
||||
// other domains and not embedding the current website. If you need to tweak
|
||||
// the CSP, it is recommended to extend DefaultCSP instead of writing your own
|
||||
// from scratch.
|
||||
func DefaultCSP() CSP {
|
||||
return CSP{
|
||||
"default-src": {"self"}, // origin is the only valid source for all content types
|
||||
"frame-ancestors": {"none"}, // disallow framing of the page
|
||||
"form-action": {"self"}, // disallow form submissions to other origins
|
||||
"base-uri": {"self"}, // disallow base URIs from other origins
|
||||
// TODO(awly): consider upgrade-insecure-requests in SecureContext
|
||||
// instead, as this is deprecated.
|
||||
"block-all-mixed-content": nil, // disallow mixed content when serving over HTTPS
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the values for a given directive. Empty values are allowed, if the
|
||||
// directive doesn't expect any (like "upgrade-insecure-requests").
|
||||
func (csp CSP) Set(directive string, values ...string) {
|
||||
csp[directive] = values
|
||||
}
|
||||
|
||||
// Add adds a source expression to an existing directive.
|
||||
func (csp CSP) Add(directive, value string) {
|
||||
csp[directive] = append(csp[directive], value)
|
||||
}
|
||||
|
||||
// Del deletes a directive and all its values.
|
||||
func (csp CSP) Del(directive string) {
|
||||
delete(csp, directive)
|
||||
}
|
||||
|
||||
func (csp CSP) String() string {
|
||||
keys := slices.Collect(maps.Keys(csp))
|
||||
slices.Sort(keys)
|
||||
var s strings.Builder
|
||||
for _, k := range keys {
|
||||
s.WriteString(k)
|
||||
for _, v := range csp[k] {
|
||||
// Special values like 'self', 'none', 'unsafe-inline', etc., must
|
||||
// be quoted. Do it implicitly as a convenience here.
|
||||
if !strings.Contains(v, ".") && len(v) > 1 && v[0] != '\'' && v[len(v)-1] != '\'' {
|
||||
v = "'" + v + "'"
|
||||
}
|
||||
s.WriteString(" " + v)
|
||||
}
|
||||
s.WriteString("; ")
|
||||
}
|
||||
return strings.TrimSpace(s.String())
|
||||
}
|
||||
// The default Content-Security-Policy header.
|
||||
var defaultCSP = strings.Join([]string{
|
||||
`default-src 'self'`, // origin is the only valid source for all content types
|
||||
`script-src 'self'`, // disallow inline javascript
|
||||
`frame-ancestors 'none'`, // disallow framing of the page
|
||||
`form-action 'self'`, // disallow form submissions to other origins
|
||||
`base-uri 'self'`, // disallow base URIs from other origins
|
||||
`block-all-mixed-content`, // disallow mixed content when serving over HTTPS
|
||||
`object-src 'self'`, // disallow embedding of resources from other origins
|
||||
}, "; ")
|
||||
|
||||
// The default Strict-Transport-Security header. This header tells the browser
|
||||
// to exclusively use HTTPS for all requests to the origin for the next year.
|
||||
@@ -179,9 +130,6 @@ type Config struct {
|
||||
// startup.
|
||||
CSRFSecret []byte
|
||||
|
||||
// CSP is the Content-Security-Policy header to return with BrowserMux
|
||||
// responses.
|
||||
CSP CSP
|
||||
// CSPAllowInlineStyles specifies whether to include `style-src:
|
||||
// unsafe-inline` in the Content-Security-Policy header to permit the use of
|
||||
// inline CSS.
|
||||
@@ -220,10 +168,6 @@ func (c *Config) setDefaults() error {
|
||||
}
|
||||
}
|
||||
|
||||
if c.CSP == nil {
|
||||
c.CSP = DefaultCSP()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -255,20 +199,16 @@ func NewServer(config Config) (*Server, error) {
|
||||
if config.CookiesSameSiteLax {
|
||||
sameSite = csrf.SameSiteLaxMode
|
||||
}
|
||||
if config.CSPAllowInlineStyles {
|
||||
if _, ok := config.CSP["style-src"]; ok {
|
||||
config.CSP.Add("style-src", "unsafe-inline")
|
||||
} else {
|
||||
config.CSP.Set("style-src", "self", "unsafe-inline")
|
||||
}
|
||||
}
|
||||
s := &Server{
|
||||
Config: config,
|
||||
csp: config.CSP.String(),
|
||||
csp: defaultCSP,
|
||||
// only set Secure flag on CSRF cookies if we are in a secure context
|
||||
// as otherwise the browser will reject the cookie
|
||||
csrfProtect: csrf.Protect(config.CSRFSecret, csrf.Secure(config.SecureContext), csrf.SameSite(sameSite)),
|
||||
}
|
||||
if config.CSPAllowInlineStyles {
|
||||
s.csp = defaultCSP + `; style-src 'self' 'unsafe-inline'`
|
||||
}
|
||||
s.h = cmp.Or(config.HTTPServer, &http.Server{})
|
||||
if s.h.Handler != nil {
|
||||
return nil, fmt.Errorf("use safeweb.Config.APIMux and safeweb.Config.BrowserMux instead of http.Server.Handler")
|
||||
@@ -285,27 +225,12 @@ const (
|
||||
browserHandler
|
||||
)
|
||||
|
||||
func (h handlerType) String() string {
|
||||
switch h {
|
||||
case browserHandler:
|
||||
return "browser"
|
||||
case apiHandler:
|
||||
return "api"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// checkHandlerType returns either apiHandler or browserHandler, depending on
|
||||
// whether apiPattern or browserPattern is more specific (i.e. which pattern
|
||||
// contains more pathname components). If they are equally specific, it returns
|
||||
// unknownHandler.
|
||||
func checkHandlerType(apiPattern, browserPattern string) handlerType {
|
||||
apiPattern, browserPattern = path.Clean(apiPattern), path.Clean(browserPattern)
|
||||
c := cmp.Compare(strings.Count(apiPattern, "/"), strings.Count(browserPattern, "/"))
|
||||
if apiPattern == "/" || browserPattern == "/" {
|
||||
c = cmp.Compare(len(apiPattern), len(browserPattern))
|
||||
}
|
||||
c := cmp.Compare(strings.Count(path.Clean(apiPattern), "/"), strings.Count(path.Clean(browserPattern), "/"))
|
||||
switch {
|
||||
case c > 0:
|
||||
return apiHandler
|
||||
|
||||
@@ -241,26 +241,18 @@ func TestCSRFProtection(t *testing.T) {
|
||||
func TestContentSecurityPolicyHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
csp CSP
|
||||
apiRoute bool
|
||||
wantCSP string
|
||||
wantCSP bool
|
||||
}{
|
||||
{
|
||||
name: "default CSP",
|
||||
wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`,
|
||||
},
|
||||
{
|
||||
name: "custom CSP",
|
||||
csp: CSP{
|
||||
"default-src": {"'self'", "https://tailscale.com"},
|
||||
"upgrade-insecure-requests": nil,
|
||||
},
|
||||
wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`,
|
||||
name: "default routes get CSP headers",
|
||||
apiRoute: false,
|
||||
wantCSP: true,
|
||||
},
|
||||
{
|
||||
name: "`/api/*` routes do not get CSP headers",
|
||||
apiRoute: true,
|
||||
wantCSP: "",
|
||||
wantCSP: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -273,9 +265,9 @@ func TestContentSecurityPolicyHeader(t *testing.T) {
|
||||
var s *Server
|
||||
var err error
|
||||
if tt.apiRoute {
|
||||
s, err = NewServer(Config{APIMux: h, CSP: tt.csp})
|
||||
s, err = NewServer(Config{APIMux: h})
|
||||
} else {
|
||||
s, err = NewServer(Config{BrowserMux: h, CSP: tt.csp})
|
||||
s, err = NewServer(Config{BrowserMux: h})
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -287,8 +279,8 @@ func TestContentSecurityPolicyHeader(t *testing.T) {
|
||||
s.h.Handler.ServeHTTP(w, req)
|
||||
resp := w.Result()
|
||||
|
||||
if got := resp.Header.Get("Content-Security-Policy"); got != tt.wantCSP {
|
||||
t.Fatalf("content security policy want: %q; got: %q", tt.wantCSP, got)
|
||||
if (resp.Header.Get("Content-Security-Policy") == "") == tt.wantCSP {
|
||||
t.Fatalf("content security policy want: %v; got: %v", tt.wantCSP, resp.Header.Get("Content-Security-Policy"))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -405,7 +397,7 @@ func TestCSPAllowInlineStyles(t *testing.T) {
|
||||
csp := resp.Header.Get("Content-Security-Policy")
|
||||
allowsStyles := strings.Contains(csp, "style-src 'self' 'unsafe-inline'")
|
||||
if allowsStyles != allow {
|
||||
t.Fatalf("CSP inline styles want: %v, got: %v in %q", allow, allowsStyles, csp)
|
||||
t.Fatalf("CSP inline styles want: %v; got: %v", allow, allowsStyles)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -535,13 +527,13 @@ func TestGetMoreSpecificPattern(t *testing.T) {
|
||||
{
|
||||
desc: "same prefix",
|
||||
a: "/foo/bar/quux",
|
||||
b: "/foo/bar/", // path.Clean will strip the trailing slash.
|
||||
b: "/foo/bar/",
|
||||
want: apiHandler,
|
||||
},
|
||||
{
|
||||
desc: "almost same prefix, but not a path component",
|
||||
a: "/goat/sheep/cheese",
|
||||
b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash.
|
||||
b: "/goat/sheepcheese/",
|
||||
want: apiHandler,
|
||||
},
|
||||
{
|
||||
@@ -562,12 +554,6 @@ func TestGetMoreSpecificPattern(t *testing.T) {
|
||||
b: "///////",
|
||||
want: unknownHandler,
|
||||
},
|
||||
{
|
||||
desc: "root-level",
|
||||
a: "/latest",
|
||||
b: "/", // path.Clean will NOT strip the trailing slash.
|
||||
want: apiHandler,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
got := checkHandlerType(tt.a, tt.b)
|
||||
|
||||
@@ -149,8 +149,7 @@ type CapabilityVersion int
|
||||
// - 104: 2024-08-03: SelfNodeV6MasqAddrForThisPeer now works
|
||||
// - 105: 2024-08-05: Fixed SSH behavior on systems that use busybox (issue #12849)
|
||||
// - 106: 2024-09-03: fix panic regression from cryptokey routing change (65fe0ba7b5)
|
||||
// - 107: 2024-10-30: add App Connector to conffile (PR #13942)
|
||||
const CurrentCapabilityVersion CapabilityVersion = 107
|
||||
const CurrentCapabilityVersion CapabilityVersion = 106
|
||||
|
||||
type StableID string
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ
|
||||
cc = "cc"
|
||||
targetOS = cmp.Or(env.Get("GOOS", ""), nativeGOOS)
|
||||
targetArch = cmp.Or(env.Get("GOARCH", ""), nativeGOARCH)
|
||||
buildFlags = []string{}
|
||||
buildFlags = []string{"-trimpath"}
|
||||
cgoCflags = []string{"-O3", "-std=gnu11", "-g"}
|
||||
cgoLdflags []string
|
||||
ldflags []string
|
||||
@@ -47,10 +47,6 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ
|
||||
subcommand = argv[1]
|
||||
}
|
||||
|
||||
if subcommand != "test" {
|
||||
buildFlags = append(buildFlags, "-trimpath")
|
||||
}
|
||||
|
||||
switch subcommand {
|
||||
case "build", "env", "install", "run", "test", "list":
|
||||
default:
|
||||
|
||||
@@ -163,6 +163,7 @@ GOTOOLCHAIN=local (was <nil>)
|
||||
TS_LINK_FAIL_REFLECT=0 (was <nil>)`,
|
||||
wantArgv: []string{
|
||||
"gocross", "test",
|
||||
"-trimpath",
|
||||
"-tags=tailscale_go,osusergo,netgo",
|
||||
"-ldflags", "-X tailscale.com/version.longStamp=1.2.3-long -X tailscale.com/version.shortStamp=1.2.3 -X tailscale.com/version.gitCommitStamp=abcd -X tailscale.com/version.extraGitCommitStamp=defg '-extldflags=-static'",
|
||||
"-race",
|
||||
|
||||
@@ -121,17 +121,11 @@ type Server struct {
|
||||
// field at zero unless you know what you are doing.
|
||||
Port uint16
|
||||
|
||||
// PreStart is an optional hook to run just before LocalBackend.Start,
|
||||
// to reconfigure internals. If it returns an error, Server.Start
|
||||
// will return that error, wrapper.
|
||||
PreStart func() error
|
||||
|
||||
getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
lb *ipnlocal.LocalBackend
|
||||
sys *tsd.System
|
||||
netstack *netstack.Impl
|
||||
netMon *netmon.Monitor
|
||||
rootPath string // the state directory
|
||||
@@ -524,7 +518,6 @@ func (s *Server) start() (reterr error) {
|
||||
}
|
||||
|
||||
sys := new(tsd.System)
|
||||
s.sys = sys
|
||||
if err := s.startLogger(&closePool, sys.HealthTracker(), tsLogf); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -553,7 +546,7 @@ func (s *Server) start() (reterr error) {
|
||||
sys.HealthTracker().SetMetricsRegistry(sys.UserMetricsRegistry())
|
||||
|
||||
// TODO(oxtoacart): do we need to support Taildrive on tsnet, and if so, how?
|
||||
ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper())
|
||||
ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("netstack.Create: %w", err)
|
||||
}
|
||||
@@ -621,13 +614,6 @@ func (s *Server) start() (reterr error) {
|
||||
prefs.ControlURL = s.ControlURL
|
||||
prefs.RunWebClient = s.RunWebClient
|
||||
authKey := s.getAuthKey()
|
||||
|
||||
if f := s.PreStart; f != nil {
|
||||
if err := f(); err != nil {
|
||||
return fmt.Errorf("PreStart: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = lb.Start(ipn.Options{
|
||||
UpdatePrefs: prefs,
|
||||
AuthKey: authKey,
|
||||
@@ -1241,13 +1227,6 @@ func (s *Server) CapturePcap(ctx context.Context, pcapFile string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sys returns a handle to the Tailscale subsystems of this node.
|
||||
//
|
||||
// This is not a stable API, nor are the APIs of the returned subsystems.
|
||||
func (s *Server) Sys() *tsd.System {
|
||||
return s.sys
|
||||
}
|
||||
|
||||
type listenKey struct {
|
||||
network string
|
||||
host netip.Addr // or zero value for unspecified
|
||||
|
||||
@@ -1080,6 +1080,13 @@ func TestUserMetrics(t *testing.T) {
|
||||
t.Errorf("metrics1, tailscaled_health_messages: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
// The node is the primary subnet router for 2 routes:
|
||||
// - 192.0.2.0/24
|
||||
// - 192.0.5.1/32
|
||||
if got, want := parsedMetrics1["tailscaled_primary_routes"], wantRoutes; got != want {
|
||||
t.Errorf("metrics1, tailscaled_primary_routes: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
// Verify that the amount of data recorded in bytes is higher or equal to the
|
||||
// 10 megabytes sent.
|
||||
inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`]
|
||||
@@ -1124,6 +1131,11 @@ func TestUserMetrics(t *testing.T) {
|
||||
t.Errorf("metrics2, tailscaled_health_messages: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
// The node is the primary subnet router for 0 routes
|
||||
if got, want := parsedMetrics2["tailscaled_primary_routes"], 0.0; got != want {
|
||||
t.Errorf("metrics2, tailscaled_primary_routes: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
// Verify that the amount of data recorded in bytes is higher or equal than the
|
||||
// 10 megabytes sent.
|
||||
outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`]
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/gopacket/pcapgo"
|
||||
@@ -280,28 +279,10 @@ type Network struct {
|
||||
|
||||
svcs set.Set[NetworkService]
|
||||
|
||||
latency time.Duration // latency applied to interface writes
|
||||
lossRate float64 // chance of packet loss (0.0 to 1.0)
|
||||
|
||||
// ...
|
||||
err error // carried error
|
||||
}
|
||||
|
||||
// SetLatency sets the simulated network latency for this network.
|
||||
func (n *Network) SetLatency(d time.Duration) {
|
||||
n.latency = d
|
||||
}
|
||||
|
||||
// SetPacketLoss sets the packet loss rate for this network 0.0 (no loss) to 1.0 (total loss).
|
||||
func (n *Network) SetPacketLoss(rate float64) {
|
||||
if rate < 0 {
|
||||
rate = 0
|
||||
} else if rate > 1 {
|
||||
rate = 1
|
||||
}
|
||||
n.lossRate = rate
|
||||
}
|
||||
|
||||
// SetBlackholedIPv4 sets whether the network should blackhole all IPv4 traffic
|
||||
// out to the Internet. (DHCP etc continues to work on the LAN.)
|
||||
func (n *Network) SetBlackholedIPv4(v bool) {
|
||||
@@ -380,8 +361,6 @@ func (s *Server) initFromConfig(c *Config) error {
|
||||
wanIP4: conf.wanIP4,
|
||||
lanIP4: conf.lanIP4,
|
||||
breakWAN4: conf.breakWAN4,
|
||||
latency: conf.latency,
|
||||
lossRate: conf.lossRate,
|
||||
nodesByIP4: map[netip.Addr]*node{},
|
||||
nodesByMAC: map[MAC]*node{},
|
||||
logf: logger.WithPrefix(s.logf, fmt.Sprintf("[net-%v] ", conf.mac)),
|
||||
|
||||
@@ -3,10 +3,7 @@
|
||||
|
||||
package vnet
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
import "testing"
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -21,16 +18,6 @@ func TestConfig(t *testing.T) {
|
||||
c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "latency-and-loss",
|
||||
setup: func(c *Config) {
|
||||
n1 := c.AddNetwork("2.1.1.1", "192.168.1.1/24", EasyNAT, NATPMP)
|
||||
n1.SetLatency(time.Second)
|
||||
n1.SetPacketLoss(0.1)
|
||||
c.AddNode(n1)
|
||||
c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "indirect",
|
||||
setup: func(c *Config) {
|
||||
|
||||
@@ -515,8 +515,6 @@ type network struct {
|
||||
wanIP4 netip.Addr // router's LAN IPv4, if any
|
||||
lanIP4 netip.Prefix // router's LAN IP + CIDR (e.g. 192.168.2.1/24)
|
||||
breakWAN4 bool // break WAN IPv4 connectivity
|
||||
latency time.Duration // latency applied to interface writes
|
||||
lossRate float64 // probability of dropping a packet (0.0 to 1.0)
|
||||
nodesByIP4 map[netip.Addr]*node // by LAN IPv4
|
||||
nodesByMAC map[MAC]*node
|
||||
logf func(format string, args ...any)
|
||||
@@ -979,7 +977,7 @@ func (n *network) writeEth(res []byte) bool {
|
||||
for mac, nw := range n.writers.All() {
|
||||
if mac != srcMAC {
|
||||
num++
|
||||
n.conditionedWrite(nw, res)
|
||||
nw.write(res)
|
||||
}
|
||||
}
|
||||
return num > 0
|
||||
@@ -989,7 +987,7 @@ func (n *network) writeEth(res []byte) bool {
|
||||
return false
|
||||
}
|
||||
if nw, ok := n.writers.Load(dstMAC); ok {
|
||||
n.conditionedWrite(nw, res)
|
||||
nw.write(res)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1002,23 +1000,6 @@ func (n *network) writeEth(res []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *network) conditionedWrite(nw networkWriter, packet []byte) {
|
||||
if n.lossRate > 0 && rand.Float64() < n.lossRate {
|
||||
// packet lost
|
||||
return
|
||||
}
|
||||
if n.latency > 0 {
|
||||
// copy the packet as there's no guarantee packet is owned long enough.
|
||||
// TODO(raggi): this could be optimized substantially if necessary,
|
||||
// a pool of buffers and a cheaper delay mechanism are both obvious improvements.
|
||||
var pkt = make([]byte, len(packet))
|
||||
copy(pkt, packet)
|
||||
time.AfterFunc(n.latency, func() { nw.write(pkt) })
|
||||
} else {
|
||||
nw.write(packet)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
macAllNodes = MAC{0: 0x33, 1: 0x33, 5: 0x01}
|
||||
macAllRouters = MAC{0: 0x33, 1: 0x33, 5: 0x02}
|
||||
|
||||
@@ -14,7 +14,6 @@ class Config: Codable {
|
||||
var mac = "52:cc:cc:cc:cc:01"
|
||||
var ethermac = "52:cc:cc:cc:ce:01"
|
||||
var port: UInt32 = 51009
|
||||
var sharedDir: String?
|
||||
|
||||
// The virtual machines ID. Also double as the directory name under which
|
||||
// we will store configuration, block device, etc.
|
||||
|
||||
@@ -141,18 +141,5 @@ struct TailMacConfigHelper {
|
||||
func createKeyboardConfiguration() -> VZKeyboardConfiguration {
|
||||
return VZMacKeyboardConfiguration()
|
||||
}
|
||||
|
||||
func createDirectoryShareConfiguration(tag: String) -> VZDirectorySharingDeviceConfiguration? {
|
||||
guard let dir = config.sharedDir else { return nil }
|
||||
|
||||
let sharedDir = VZSharedDirectory(url: URL(fileURLWithPath: dir), readOnly: false)
|
||||
let share = VZSingleDirectoryShare(directory: sharedDir)
|
||||
|
||||
// Create the VZVirtioFileSystemDeviceConfiguration and assign it a unique tag.
|
||||
let sharingConfiguration = VZVirtioFileSystemDeviceConfiguration(tag: tag)
|
||||
sharingConfiguration.share = share
|
||||
|
||||
return sharingConfiguration
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,12 +19,10 @@ var config: Config = Config()
|
||||
extension HostCli {
|
||||
struct Run: ParsableCommand {
|
||||
@Option var id: String
|
||||
@Option var share: String?
|
||||
|
||||
mutating func run() {
|
||||
print("Running vm with identifier \(id)")
|
||||
config = Config(id)
|
||||
config.sharedDir = share
|
||||
print("Running vm with identifier \(id) and sharedDir \(share ?? "<none>")")
|
||||
_ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,13 +95,6 @@ class VMController: NSObject, VZVirtualMachineDelegate {
|
||||
virtualMachineConfiguration.keyboards = [helper.createKeyboardConfiguration()]
|
||||
virtualMachineConfiguration.socketDevices = [helper.createSocketDeviceConfiguration()]
|
||||
|
||||
if let dir = config.sharedDir, let shareConfig = helper.createDirectoryShareConfiguration(tag: "vmshare") {
|
||||
print("Sharing \(dir) as vmshare. Use: mount_virtiofs vmshare <path> in the guest to mount.")
|
||||
virtualMachineConfiguration.directorySharingDevices = [shareConfig]
|
||||
} else {
|
||||
print("No shared directory created. \(config.sharedDir ?? "none") was requested.")
|
||||
}
|
||||
|
||||
try! virtualMachineConfiguration.validate()
|
||||
try! virtualMachineConfiguration.validateSaveRestoreSupport()
|
||||
|
||||
|
||||
@@ -95,16 +95,12 @@ extension Tailmac {
|
||||
extension Tailmac {
|
||||
struct Run: ParsableCommand {
|
||||
@Option(help: "The vm identifier") var id: String
|
||||
@Option(help: "Optional share directory") var share: String?
|
||||
@Flag(help: "Tail the TailMac log output instead of returning immediatly") var tail
|
||||
|
||||
mutating func run() {
|
||||
let process = Process()
|
||||
let stdOutPipe = Pipe()
|
||||
|
||||
let executablePath = CommandLine.arguments[0]
|
||||
let executableDirectory = (executablePath as NSString).deletingLastPathComponent
|
||||
let appPath = executableDirectory + "/Host.app/Contents/MacOS/Host"
|
||||
let appPath = "./Host.app/Contents/MacOS/Host"
|
||||
|
||||
process.executableURL = URL(
|
||||
fileURLWithPath: appPath,
|
||||
@@ -113,15 +109,10 @@ extension Tailmac {
|
||||
)
|
||||
|
||||
if !FileManager.default.fileExists(atPath: appPath) {
|
||||
fatalError("Could not find Host.app at \(appPath). This must be co-located with the tailmac utility")
|
||||
fatalError("Could not find Host.app. This must be co-located with the tailmac utility")
|
||||
}
|
||||
|
||||
var args = ["run", "--id", id]
|
||||
if let share {
|
||||
args.append("--share")
|
||||
args.append(share)
|
||||
}
|
||||
process.arguments = args
|
||||
process.arguments = ["run", "--id", id]
|
||||
|
||||
do {
|
||||
process.standardOutput = stdOutPipe
|
||||
@@ -130,18 +121,26 @@ extension Tailmac {
|
||||
fatalError("Unable to launch the vm process")
|
||||
}
|
||||
|
||||
// This doesn't print until we exit which is not ideal, but at least we
|
||||
// get the output
|
||||
if tail != 0 {
|
||||
// (jonathan)TODO: How do we get the process output in real time?
|
||||
// The child process only seems to flush to stdout on completion
|
||||
let outHandle = stdOutPipe.fileHandleForReading
|
||||
outHandle.readabilityHandler = { handle in
|
||||
let data = handle.availableData
|
||||
|
||||
let queue = OperationQueue()
|
||||
NotificationCenter.default.addObserver(
|
||||
forName: NSNotification.Name.NSFileHandleDataAvailable,
|
||||
object: outHandle, queue: queue)
|
||||
{
|
||||
notification -> Void in
|
||||
let data = outHandle.availableData
|
||||
if data.count > 0 {
|
||||
if let str = String(data: data, encoding: String.Encoding.utf8) {
|
||||
print(str)
|
||||
}
|
||||
}
|
||||
outHandle.waitForDataInBackgroundAndNotify()
|
||||
}
|
||||
outHandle.waitForDataInBackgroundAndNotify()
|
||||
process.waitUntilExit()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func ValueOf[T any](v T) Value[T] {
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (o Value[T]) String() string {
|
||||
func (o *Value[T]) String() string {
|
||||
if !o.set {
|
||||
return fmt.Sprintf("(empty[%T])", o.value)
|
||||
}
|
||||
|
||||
122
util/syspolicy/caching_handler.go
Normal file
122
util/syspolicy/caching_handler.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
|
||||
// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
|
||||
// otherwise the actual error is returned and the next read for that key will retry using the handler.
|
||||
type CachingHandler struct {
|
||||
mu sync.Mutex
|
||||
strings map[string]string
|
||||
uint64s map[string]uint64
|
||||
bools map[string]bool
|
||||
strArrs map[string][]string
|
||||
notFound map[string]bool
|
||||
handler Handler
|
||||
}
|
||||
|
||||
// NewCachingHandler creates a CachingHandler given a handler.
|
||||
func NewCachingHandler(handler Handler) *CachingHandler {
|
||||
return &CachingHandler{
|
||||
handler: handler,
|
||||
strings: make(map[string]string),
|
||||
uint64s: make(map[string]uint64),
|
||||
bools: make(map[string]bool),
|
||||
strArrs: make(map[string][]string),
|
||||
notFound: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadString reads the policy settings value string given the key.
|
||||
// ReadString first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadString(key string) (string, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.strings[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return "", ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadString(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return "", err
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ch.strings[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadUInt64 reads the policy settings uint64 value given the key.
|
||||
// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.uint64s[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return 0, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadUInt64(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return 0, err
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ch.uint64s[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadBoolean reads the policy settings boolean value given the key.
|
||||
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.bools[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return false, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadBoolean(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return false, err
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
ch.bools[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadBoolean reads the policy settings boolean value given the key.
|
||||
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.strArrs[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadStringArray(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.strArrs[key] = val
|
||||
return val, nil
|
||||
}
|
||||
262
util/syspolicy/caching_handler_test.go
Normal file
262
util/syspolicy/caching_handler_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandlerReadString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue string
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue string
|
||||
wantErr error
|
||||
strings map[string]string
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
strings: map[string]string{"test": "foo"},
|
||||
wantValue: "foo",
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: "foo",
|
||||
wantValue: "foo",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
s: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.strings != nil {
|
||||
cache.strings = tt.strings
|
||||
}
|
||||
got, err := cache.ReadString(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadString(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerReadUint64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue uint64
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue uint64
|
||||
wantErr error
|
||||
uint64s map[string]uint64
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
uint64s: map[string]uint64{"test": 1},
|
||||
wantValue: 1,
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: 1,
|
||||
wantValue: 1,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
u64: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.uint64s != nil {
|
||||
cache.uint64s = tt.uint64s
|
||||
}
|
||||
got, err := cache.ReadUInt64(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadUInt64(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestHandlerReadBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue bool
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue bool
|
||||
wantErr error
|
||||
bools map[string]bool
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
bools: map[string]bool{"test": true},
|
||||
wantValue: true,
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: true,
|
||||
wantValue: true,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
b: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.bools != nil {
|
||||
cache.bools = tt.bools
|
||||
}
|
||||
got, err := cache.ReadBoolean(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadBoolean(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,17 +4,16 @@
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// TODO(nickkhyl): delete this file once other repos are updated.
|
||||
var (
|
||||
handlerUsed atomic.Bool
|
||||
handler Handler = defaultHandler{}
|
||||
)
|
||||
|
||||
// Handler reads system policies from OS-specific storage.
|
||||
//
|
||||
// Deprecated: implementing a [source.Store] should be preferred.
|
||||
type Handler interface {
|
||||
// ReadString reads the policy setting's string value for the given key.
|
||||
// It should return ErrNoSuchKey if the key does not have a value set.
|
||||
@@ -30,88 +29,55 @@ type Handler interface {
|
||||
ReadStringArray(key string) ([]string, error)
|
||||
}
|
||||
|
||||
// RegisterHandler wraps and registers the specified handler as the device's
|
||||
// policy [source.Store] for the program's lifetime.
|
||||
//
|
||||
// Deprecated: using [RegisterStore] should be preferred.
|
||||
// ErrNoSuchKey is returned by a Handler when the specified key does not have a
|
||||
// value set.
|
||||
var ErrNoSuchKey = errors.New("no such key")
|
||||
|
||||
// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple.
|
||||
type defaultHandler struct{}
|
||||
|
||||
func (defaultHandler) ReadString(_ string) (string, error) {
|
||||
return "", ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadUInt64(_ string) (uint64, error) {
|
||||
return 0, ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadBoolean(_ string) (bool, error) {
|
||||
return false, ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadStringArray(_ string) ([]string, error) {
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
|
||||
// markHandlerInUse is called before handler methods are called.
|
||||
func markHandlerInUse() {
|
||||
handlerUsed.Store(true)
|
||||
}
|
||||
|
||||
// RegisterHandler initializes the policy handler and ensures registration will happen once.
|
||||
func RegisterHandler(h Handler) {
|
||||
rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h))
|
||||
// Technically this assignment is not concurrency safe, but in the
|
||||
// event that there was any risk of a data race, we will panic due to
|
||||
// the CompareAndSwap failing.
|
||||
handler = h
|
||||
if !handlerUsed.CompareAndSwap(false, true) {
|
||||
panic("handler was already used before registration")
|
||||
}
|
||||
}
|
||||
|
||||
// TB is a subset of testing.TB that we use to set up test helpers.
|
||||
// It's defined here to avoid pulling in the testing package.
|
||||
type TB = internal.TB
|
||||
type TB interface {
|
||||
Helper()
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// SetHandlerForTest wraps and sets the specified handler as the device's policy
|
||||
// [source.Store] for the duration of tb.
|
||||
//
|
||||
// Deprecated: using [MustRegisterStoreForTest] should be preferred.
|
||||
func SetHandlerForTest(tb TB, h Handler) {
|
||||
RegisterWellKnownSettingsForTest(tb)
|
||||
MustRegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.DefaultScope(), WrapHandler(h))
|
||||
}
|
||||
|
||||
var _ source.Store = (*handlerStore)(nil)
|
||||
|
||||
// handlerStore is a [source.Store] that calls the underlying [Handler].
|
||||
//
|
||||
// TODO(nickkhyl): remove it when the corp and android repos are updated.
|
||||
type handlerStore struct {
|
||||
h Handler
|
||||
}
|
||||
|
||||
// WrapHandler returns a [source.Store] that wraps the specified [Handler].
|
||||
func WrapHandler(h Handler) source.Store {
|
||||
return handlerStore{h}
|
||||
}
|
||||
|
||||
// Lock implements [source.Lockable].
|
||||
func (s handlerStore) Lock() error {
|
||||
if lockable, ok := s.h.(source.Lockable); ok {
|
||||
return lockable.Lock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock implements [source.Lockable].
|
||||
func (s handlerStore) Unlock() {
|
||||
if lockable, ok := s.h.(source.Lockable); ok {
|
||||
lockable.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterChangeCallback implements [source.Changeable].
|
||||
func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
|
||||
if changeable, ok := s.h.(source.Changeable); ok {
|
||||
return changeable.RegisterChangeCallback(callback)
|
||||
}
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
// ReadString implements [source.Store].
|
||||
func (s handlerStore) ReadString(key setting.Key) (string, error) {
|
||||
return s.h.ReadString(string(key))
|
||||
}
|
||||
|
||||
// ReadUInt64 implements [source.Store].
|
||||
func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) {
|
||||
return s.h.ReadUInt64(string(key))
|
||||
}
|
||||
|
||||
// ReadBoolean implements [source.Store].
|
||||
func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) {
|
||||
return s.h.ReadBoolean(string(key))
|
||||
}
|
||||
|
||||
// ReadStringArray implements [source.Store].
|
||||
func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) {
|
||||
return s.h.ReadStringArray(string(key))
|
||||
}
|
||||
|
||||
// Done implements [source.Expirable].
|
||||
func (s handlerStore) Done() <-chan struct{} {
|
||||
if expirable, ok := s.h.(source.Expirable); ok {
|
||||
return expirable.Done()
|
||||
}
|
||||
return nil
|
||||
tb.Helper()
|
||||
oldHandler := handler
|
||||
handler = h
|
||||
tb.Cleanup(func() { handler = oldHandler })
|
||||
}
|
||||
|
||||
19
util/syspolicy/handler_test.go
Normal file
19
util/syspolicy/handler_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultHandlerReadValues(t *testing.T) {
|
||||
var h defaultHandler
|
||||
|
||||
got, err := h.ReadString(string(AdminConsoleVisibility))
|
||||
if got != "" || err != ErrNoSuchKey {
|
||||
t.Fatalf("got %v err %v", got, err)
|
||||
}
|
||||
result, err := h.ReadUInt64(string(LogSCMInteractions))
|
||||
if result != 0 || err != ErrNoSuchKey {
|
||||
t.Fatalf("got %v err %v", result, err)
|
||||
}
|
||||
}
|
||||
105
util/syspolicy/handler_windows.go
Normal file
105
util/syspolicy/handler_windows.go
Normal file
@@ -0,0 +1,105 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/winutil"
|
||||
)
|
||||
|
||||
var (
|
||||
windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors")
|
||||
windowsAny = clientmetric.NewGauge("windows_syspolicy_any")
|
||||
)
|
||||
|
||||
type windowsHandler struct{}
|
||||
|
||||
func init() {
|
||||
RegisterHandler(NewCachingHandler(windowsHandler{}))
|
||||
|
||||
keyList := []struct {
|
||||
isSet func(Key) bool
|
||||
keys []Key
|
||||
}{
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadString(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: stringKeys,
|
||||
},
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadBoolean(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: boolKeys,
|
||||
},
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadUInt64(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: uint64Keys,
|
||||
},
|
||||
}
|
||||
|
||||
var anySet bool
|
||||
for _, l := range keyList {
|
||||
for _, k := range l.keys {
|
||||
if !l.isSet(k) {
|
||||
continue
|
||||
}
|
||||
clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1)
|
||||
anySet = true
|
||||
}
|
||||
}
|
||||
if anySet {
|
||||
windowsAny.Set(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadString(key string) (string, error) {
|
||||
s, err := winutil.GetPolicyString(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
|
||||
return s, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadUInt64(key string) (uint64, error) {
|
||||
value, err := winutil.GetPolicyInteger(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadBoolean(key string) (bool, error) {
|
||||
value, err := winutil.GetPolicyInteger(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value != 0, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadStringArray(key string) ([]string, error) {
|
||||
value, err := winutil.GetPolicyStringArray(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
@@ -284,7 +284,7 @@ func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) {
|
||||
}
|
||||
|
||||
func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric {
|
||||
name := strings.ReplaceAll(string(key), string(setting.KeyPathSeparator), "_")
|
||||
name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_")
|
||||
return newMetric([]string{name, metricScopeName(scope), suffix}, typ)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,24 +3,10 @@
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
import "tailscale.com/util/syspolicy/setting"
|
||||
|
||||
// Key is a string that uniquely identifies a policy and must remain unchanged
|
||||
// once established and documented for a given policy setting. It may contain
|
||||
// alphanumeric characters and zero or more [KeyPathSeparator]s to group
|
||||
// individual policy settings into categories.
|
||||
type Key = setting.Key
|
||||
|
||||
// The const block below lists known policy keys.
|
||||
// When adding a key to this list, remember to add a corresponding
|
||||
// [setting.Definition] to [implicitDefinitions] below.
|
||||
// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder.
|
||||
|
||||
const (
|
||||
// Keys with a string value
|
||||
ControlURL Key = "LoginURL" // default ""; if blank, ipn uses ipn.DefaultControlURL.
|
||||
@@ -77,9 +63,6 @@ const (
|
||||
// SuggestedExitNodeVisibility controls the visibility of suggested exit nodes in the client GUI.
|
||||
// When this system policy is set to 'hide', an exit node suggestion won't be presented to the user as part of the exit nodes picker.
|
||||
SuggestedExitNodeVisibility Key = "SuggestedExitNode"
|
||||
// OnboardingFlowVisibility controls the visibility of the onboarding flow in the client GUI.
|
||||
// When this system policy is set to 'hide', the onboarding flow is never shown to the user.
|
||||
OnboardingFlowVisibility Key = "OnboardingFlow"
|
||||
|
||||
// Keys with a string value formatted for use with time.ParseDuration().
|
||||
KeyExpirationNoticeTime Key = "KeyExpirationNotice" // default 24 hours
|
||||
@@ -127,91 +110,3 @@ const (
|
||||
// AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes.
|
||||
AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes"
|
||||
)
|
||||
|
||||
// implicitDefinitions is a list of [setting.Definition] that will be registered
|
||||
// automatically when the policy setting definitions are first used by the syspolicy package hierarchy.
|
||||
// This includes the first time a policy needs to be read from any source.
|
||||
var implicitDefinitions = []*setting.Definition{
|
||||
// Device policy settings (can only be configured on a per-device basis):
|
||||
setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(AuthKey, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(MachineCertificateSubject, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue),
|
||||
|
||||
// User policy settings (can be configured on a user- or device-basis):
|
||||
setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(OnboardingFlowVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
}
|
||||
|
||||
func init() {
|
||||
internal.Init.MustDefer(func() error {
|
||||
// Avoid implicit [setting.Definition] registration during tests.
|
||||
// Each test should control which policy settings to register.
|
||||
// Use [setting.SetDefinitionsForTest] to specify necessary definitions,
|
||||
// or [setWellKnownSettingsForTest] to set implicit definitions for the test duration.
|
||||
if testenv.InTest() {
|
||||
return nil
|
||||
}
|
||||
for _, d := range implicitDefinitions {
|
||||
setting.RegisterDefinition(d)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap]
|
||||
|
||||
// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key,
|
||||
// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist
|
||||
// among implicit policy definitions.
|
||||
func WellKnownSettingDefinition(k Key) (*setting.Definition, error) {
|
||||
m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) {
|
||||
return setting.DefinitionMapOf(implicitDefinitions)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d, ok := m[k]; ok {
|
||||
return d, nil
|
||||
}
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
|
||||
// RegisterWellKnownSettingsForTest registers all implicit setting definitions
|
||||
// for the duration of the test.
|
||||
func RegisterWellKnownSettingsForTest(tb TB) {
|
||||
tb.Helper()
|
||||
err := setting.SetDefinitionsForTest(tb, implicitDefinitions...)
|
||||
if err != nil {
|
||||
tb.Fatalf("Failed to register well-known settings: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestKnownKeysRegistered(t *testing.T) {
|
||||
keyConsts, err := listStringConsts[Key]("policy_keys.go")
|
||||
if err != nil {
|
||||
t.Fatalf("listStringConsts failed: %v", err)
|
||||
}
|
||||
|
||||
m, err := setting.DefinitionMapOf(implicitDefinitions)
|
||||
if err != nil {
|
||||
t.Fatalf("definitionMapOf failed: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range keyConsts {
|
||||
t.Run(string(key), func(t *testing.T) {
|
||||
d := m[key]
|
||||
if d == nil {
|
||||
t.Fatalf("%q was not registered", key)
|
||||
}
|
||||
if d.Key() != key {
|
||||
t.Fatalf("d.Key got: %s, want %s", d.Key(), key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotAWellKnownSetting(t *testing.T) {
|
||||
d, err := WellKnownSettingDefinition("TestSettingDoesNotExist")
|
||||
if d != nil || err == nil {
|
||||
t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey)
|
||||
}
|
||||
}
|
||||
|
||||
func listStringConsts[T ~string](filename string) (map[string]T, error) {
|
||||
fset := token.NewFileSet()
|
||||
src, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := parser.ParseFile(fset, filename, src, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
consts := make(map[string]T)
|
||||
typeName := reflect.TypeFor[T]().Name()
|
||||
for _, d := range f.Decls {
|
||||
g, ok := d.(*ast.GenDecl)
|
||||
if !ok || g.Tok != token.CONST {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, s := range g.Specs {
|
||||
vs, ok := s.(*ast.ValueSpec)
|
||||
if !ok || len(vs.Names) != len(vs.Values) {
|
||||
continue
|
||||
}
|
||||
if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, n := range vs.Names {
|
||||
lit, ok := vs.Values[i].(*ast.BasicLit)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i]))
|
||||
}
|
||||
val, err := strconv.Unquote(lit.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value)
|
||||
}
|
||||
consts[n.Name] = T(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return consts, nil
|
||||
}
|
||||
38
util/syspolicy/policy_keys_windows.go
Normal file
38
util/syspolicy/policy_keys_windows.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
var stringKeys = []Key{
|
||||
ControlURL,
|
||||
LogTarget,
|
||||
Tailnet,
|
||||
ExitNodeID,
|
||||
ExitNodeIP,
|
||||
EnableIncomingConnections,
|
||||
EnableServerMode,
|
||||
ExitNodeAllowLANAccess,
|
||||
EnableTailscaleDNS,
|
||||
EnableTailscaleSubnets,
|
||||
AdminConsoleVisibility,
|
||||
NetworkDevicesVisibility,
|
||||
TestMenuVisibility,
|
||||
UpdateMenuVisibility,
|
||||
RunExitNodeVisibility,
|
||||
PreferencesMenuVisibility,
|
||||
ExitNodeMenuVisibility,
|
||||
AutoUpdateVisibility,
|
||||
ResetToDefaultsVisibility,
|
||||
KeyExpirationNoticeTime,
|
||||
PostureChecking,
|
||||
ManagedByOrganizationName,
|
||||
ManagedByCaption,
|
||||
ManagedByURL,
|
||||
}
|
||||
|
||||
var boolKeys = []Key{
|
||||
LogSCMInteractions,
|
||||
FlushDNSOnSessionUnlock,
|
||||
}
|
||||
|
||||
var uint64Keys = []Key{}
|
||||
@@ -10,4 +10,4 @@ package setting
|
||||
type Key string
|
||||
|
||||
// KeyPathSeparator allows logical grouping of policy settings into categories.
|
||||
const KeyPathSeparator = '/'
|
||||
const KeyPathSeparator = "/"
|
||||
|
||||
@@ -5,11 +5,7 @@ package setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
"github.com/go-json-experiment/json/jsontext"
|
||||
"tailscale.com/types/opt"
|
||||
"tailscale.com/types/structs"
|
||||
)
|
||||
|
||||
@@ -21,15 +17,10 @@ import (
|
||||
// or converted from strings, these setting types predate the typed policy
|
||||
// hierarchies, and must be supported at this layer.
|
||||
type RawItem struct {
|
||||
_ structs.Incomparable
|
||||
data rawItemJSON
|
||||
}
|
||||
|
||||
// rawItemJSON holds JSON-marshallable data for [RawItem].
|
||||
type rawItemJSON struct {
|
||||
Value RawValue `json:",omitzero"`
|
||||
Error *ErrorText `json:",omitzero"` // or nil
|
||||
Origin *Origin `json:",omitzero"` // or nil
|
||||
_ structs.Incomparable
|
||||
value any
|
||||
err *ErrorText
|
||||
origin *Origin // or nil
|
||||
}
|
||||
|
||||
// RawItemOf returns a [RawItem] with the specified value.
|
||||
@@ -39,20 +30,20 @@ func RawItemOf(value any) RawItem {
|
||||
|
||||
// RawItemWith returns a [RawItem] with the specified value, error and origin.
|
||||
func RawItemWith(value any, err *ErrorText, origin *Origin) RawItem {
|
||||
return RawItem{data: rawItemJSON{Value: RawValue{opt.ValueOf(value)}, Error: err, Origin: origin}}
|
||||
return RawItem{value: value, err: err, origin: origin}
|
||||
}
|
||||
|
||||
// Value returns the value of the policy setting, or nil if the policy setting
|
||||
// is not configured, or an error occurred while reading it.
|
||||
func (i RawItem) Value() any {
|
||||
return i.data.Value.Get()
|
||||
return i.value
|
||||
}
|
||||
|
||||
// Error returns the error that occurred when reading the policy setting,
|
||||
// or nil if no error occurred.
|
||||
func (i RawItem) Error() error {
|
||||
if i.data.Error != nil {
|
||||
return i.data.Error
|
||||
if i.err != nil {
|
||||
return i.err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -60,103 +51,17 @@ func (i RawItem) Error() error {
|
||||
// Origin returns an optional [Origin] indicating where the policy setting is
|
||||
// configured.
|
||||
func (i RawItem) Origin() *Origin {
|
||||
return i.data.Origin
|
||||
return i.origin
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (i RawItem) String() string {
|
||||
var suffix string
|
||||
if i.data.Origin != nil {
|
||||
suffix = fmt.Sprintf(" - {%v}", i.data.Origin)
|
||||
if i.origin != nil {
|
||||
suffix = fmt.Sprintf(" - {%v}", i.origin)
|
||||
}
|
||||
if i.data.Error != nil {
|
||||
return fmt.Sprintf("Error{%q}%s", i.data.Error.Error(), suffix)
|
||||
if i.err != nil {
|
||||
return fmt.Sprintf("Error{%q}%s", i.err.Error(), suffix)
|
||||
}
|
||||
return fmt.Sprintf("%v%s", i.data.Value.Value, suffix)
|
||||
return fmt.Sprintf("%v%s", i.value, suffix)
|
||||
}
|
||||
|
||||
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
|
||||
func (i RawItem) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
|
||||
return jsonv2.MarshalEncode(out, &i.data, opts)
|
||||
}
|
||||
|
||||
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
|
||||
func (i *RawItem) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
|
||||
return jsonv2.UnmarshalDecode(in, &i.data, opts)
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler].
|
||||
func (i RawItem) MarshalJSON() ([]byte, error) {
|
||||
return jsonv2.Marshal(i) // uses MarshalJSONV2
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler].
|
||||
func (i *RawItem) UnmarshalJSON(b []byte) error {
|
||||
return jsonv2.Unmarshal(b, i) // uses UnmarshalJSONV2
|
||||
}
|
||||
|
||||
// RawValue represents a raw policy setting value read from a policy store.
|
||||
// It is JSON-marshallable and facilitates unmarshalling of JSON values
|
||||
// into corresponding policy setting types, with special handling for JSON numbers
|
||||
// (unmarshalled as float64) and JSON string arrays (unmarshalled as []string).
|
||||
// See also [RawValue.UnmarshalJSONV2].
|
||||
type RawValue struct {
|
||||
opt.Value[any]
|
||||
}
|
||||
|
||||
// RawValueType is a constraint that permits raw setting value types.
|
||||
type RawValueType interface {
|
||||
bool | uint64 | string | []string
|
||||
}
|
||||
|
||||
// RawValueOf returns a new [RawValue] holding the specified value.
|
||||
func RawValueOf[T RawValueType](v T) RawValue {
|
||||
return RawValue{opt.ValueOf[any](v)}
|
||||
}
|
||||
|
||||
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
|
||||
func (v RawValue) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
|
||||
return jsonv2.MarshalEncode(out, v.Value, opts)
|
||||
}
|
||||
|
||||
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2] by attempting to unmarshal
|
||||
// a JSON value as one of the supported policy setting value types (bool, string, uint64, or []string),
|
||||
// based on the JSON value type. It fails if the JSON value is an object, if it's a JSON number that
|
||||
// cannot be represented as a uint64, or if a JSON array contains anything other than strings.
|
||||
func (v *RawValue) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
|
||||
var valPtr any
|
||||
switch k := in.PeekKind(); k {
|
||||
case 't', 'f':
|
||||
valPtr = new(bool)
|
||||
case '"':
|
||||
valPtr = new(string)
|
||||
case '0':
|
||||
valPtr = new(uint64) // unmarshal JSON numbers as uint64
|
||||
case '[', 'n':
|
||||
valPtr = new([]string) // unmarshal arrays as string slices
|
||||
case '{':
|
||||
return fmt.Errorf("unexpected token: %v", k)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
if err := jsonv2.UnmarshalDecode(in, valPtr, opts); err != nil {
|
||||
v.Value.Clear()
|
||||
return err
|
||||
}
|
||||
value := reflect.ValueOf(valPtr).Elem().Interface()
|
||||
v.Value = opt.ValueOf(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler].
|
||||
func (v RawValue) MarshalJSON() ([]byte, error) {
|
||||
return jsonv2.Marshal(v) // uses MarshalJSONV2
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler].
|
||||
func (v *RawValue) UnmarshalJSON(b []byte) error {
|
||||
return jsonv2.Unmarshal(b, v) // uses UnmarshalJSONV2
|
||||
}
|
||||
|
||||
// RawValues is a map of keyed setting values that can be read from a JSON.
|
||||
type RawValues map[Key]RawValue
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
)
|
||||
|
||||
func TestMarshalUnmarshalRawValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
json string
|
||||
want RawValue
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Bool/True",
|
||||
json: `true`,
|
||||
want: RawValueOf(true),
|
||||
},
|
||||
{
|
||||
name: "Bool/False",
|
||||
json: `false`,
|
||||
want: RawValueOf(false),
|
||||
},
|
||||
{
|
||||
name: "String/Empty",
|
||||
json: `""`,
|
||||
want: RawValueOf(""),
|
||||
},
|
||||
{
|
||||
name: "String/NonEmpty",
|
||||
json: `"Test"`,
|
||||
want: RawValueOf("Test"),
|
||||
},
|
||||
{
|
||||
name: "StringSlice/Null",
|
||||
json: `null`,
|
||||
want: RawValueOf([]string(nil)),
|
||||
},
|
||||
{
|
||||
name: "StringSlice/Empty",
|
||||
json: `[]`,
|
||||
want: RawValueOf([]string{}),
|
||||
},
|
||||
{
|
||||
name: "StringSlice/NonEmpty",
|
||||
json: `["A", "B", "C"]`,
|
||||
want: RawValueOf([]string{"A", "B", "C"}),
|
||||
},
|
||||
{
|
||||
name: "StringSlice/NonStrings",
|
||||
json: `[1, 2, 3]`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Number/Integer/0",
|
||||
json: `0`,
|
||||
want: RawValueOf(uint64(0)),
|
||||
},
|
||||
{
|
||||
name: "Number/Integer/1",
|
||||
json: `1`,
|
||||
want: RawValueOf(uint64(1)),
|
||||
},
|
||||
{
|
||||
name: "Number/Integer/MaxUInt64",
|
||||
json: strconv.FormatUint(math.MaxUint64, 10),
|
||||
want: RawValueOf(uint64(math.MaxUint64)),
|
||||
},
|
||||
{
|
||||
name: "Number/Integer/Negative",
|
||||
json: `-1`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Object",
|
||||
json: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got RawValue
|
||||
gotErr := jsonv2.Unmarshal([]byte(tt.json), &got)
|
||||
if (gotErr != nil) != tt.wantErr {
|
||||
t.Fatalf("Error: got %v; want %v", gotErr, tt.wantErr)
|
||||
}
|
||||
|
||||
if !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Fatalf("Value: got %v; want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,14 +4,11 @@
|
||||
package setting
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"iter"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
"github.com/go-json-experiment/json/jsontext"
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
@@ -68,9 +65,6 @@ func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) {
|
||||
|
||||
// Equal reports whether s and s2 are equal.
|
||||
func (s *Snapshot) Equal(s2 *Snapshot) bool {
|
||||
if s == s2 {
|
||||
return true
|
||||
}
|
||||
if !s.EqualItems(s2) {
|
||||
return false
|
||||
}
|
||||
@@ -141,45 +135,6 @@ func (s *Snapshot) String() string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// snapshotJSON holds JSON-marshallable data for [Snapshot].
|
||||
type snapshotJSON struct {
|
||||
Summary Summary `json:",omitzero"`
|
||||
Settings map[Key]RawItem `json:",omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
|
||||
func (s *Snapshot) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
|
||||
data := &snapshotJSON{}
|
||||
if s != nil {
|
||||
data.Summary = s.summary
|
||||
data.Settings = s.m
|
||||
}
|
||||
return jsonv2.MarshalEncode(out, data, opts)
|
||||
}
|
||||
|
||||
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
|
||||
func (s *Snapshot) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
|
||||
if s == nil {
|
||||
return errors.New("s must not be nil")
|
||||
}
|
||||
data := &snapshotJSON{}
|
||||
if err := jsonv2.UnmarshalDecode(in, data, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
*s = Snapshot{m: data.Settings, sig: deephash.Hash(&data.Settings), summary: data.Summary}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler].
|
||||
func (s *Snapshot) MarshalJSON() ([]byte, error) {
|
||||
return jsonv2.Marshal(s) // uses MarshalJSONV2
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler].
|
||||
func (s *Snapshot) UnmarshalJSON(b []byte) error {
|
||||
return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
|
||||
}
|
||||
|
||||
// MergeSnapshots returns a [Snapshot] that contains all [RawItem]s
|
||||
// from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope].
|
||||
// If there's a conflict between policy settings in the two snapshots,
|
||||
|
||||
@@ -4,13 +4,8 @@
|
||||
package setting
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
func TestMergeSnapshots(t *testing.T) {
|
||||
@@ -35,134 +30,134 @@ func TestMergeSnapshots(t *testing.T) {
|
||||
name: "first-nil",
|
||||
s1: nil,
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "first-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "second-nil",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
s2: nil,
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "second-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "no-conflicts",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting5": RawItemOf(VisibleByPolicy),
|
||||
"Setting6": RawItemOf(ShowChoiceByPolicy),
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting5": RawItemOf(VisibleByPolicy),
|
||||
"Setting6": RawItemOf(ShowChoiceByPolicy),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "with-conflicts",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(456),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(456),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting1": {value: 456},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "with-scope-first-wins",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, DeviceScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(456),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-second-wins",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(456),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(456),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting1": {value: 456},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
@@ -175,27 +170,28 @@ func TestMergeSnapshots(t *testing.T) {
|
||||
name: "with-scope-first-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true)}, DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true}},
|
||||
DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope, NewNamedOrigin("TestPolicy", DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "with-scope-second-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
}
|
||||
@@ -248,9 +244,9 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
name: "first-nil",
|
||||
s1: nil,
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
@@ -259,9 +255,9 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
name: "first-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
@@ -269,9 +265,9 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
{
|
||||
name: "second-nil",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(true),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
s2: nil,
|
||||
wantEqual: false,
|
||||
@@ -280,9 +276,9 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
{
|
||||
name: "second-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{}),
|
||||
wantEqual: false,
|
||||
@@ -291,14 +287,14 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
{
|
||||
name: "same-items-same-order-no-scope",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
@@ -306,14 +302,14 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
{
|
||||
name: "same-items-same-order-same-scope",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
@@ -321,14 +317,14 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
{
|
||||
name: "same-items-different-order-same-scope",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": {value: false},
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
}, DeviceScope),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
@@ -336,14 +332,14 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
{
|
||||
name: "same-items-same-order-different-scope",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, CurrentUserScope),
|
||||
wantEqual: false,
|
||||
wantEqualItems: true,
|
||||
@@ -351,14 +347,14 @@ func TestSnapshotEqual(t *testing.T) {
|
||||
{
|
||||
name: "different-items-same-scope",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(123),
|
||||
"Setting2": RawItemOf("String"),
|
||||
"Setting3": RawItemOf(false),
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting4": RawItemOf(2 * time.Hour),
|
||||
"Setting5": RawItemOf(VisibleByPolicy),
|
||||
"Setting6": RawItemOf(ShowChoiceByPolicy),
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}, DeviceScope),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
@@ -405,9 +401,9 @@ func TestSnapshotString(t *testing.T) {
|
||||
{
|
||||
name: "non-empty",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemOf(2 * time.Hour),
|
||||
"Setting2": RawItemOf(VisibleByPolicy),
|
||||
"Setting3": RawItemOf(ShowChoiceByPolicy),
|
||||
"Setting1": {value: 2 * time.Hour},
|
||||
"Setting2": {value: VisibleByPolicy},
|
||||
"Setting3": {value: ShowChoiceByPolicy},
|
||||
}, NewNamedOrigin("Test Policy", DeviceScope)),
|
||||
wantString: `{Test Policy (Device)}
|
||||
Setting1 = 2h0m0s
|
||||
@@ -417,14 +413,14 @@ Setting3 = user-decides`,
|
||||
{
|
||||
name: "non-empty-with-item-origin",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemWith(42, nil, NewNamedOrigin("Test Policy", DeviceScope)),
|
||||
"Setting1": {value: 42, origin: NewNamedOrigin("Test Policy", DeviceScope)},
|
||||
}),
|
||||
wantString: `Setting1 = 42 - {Test Policy (Device)}`,
|
||||
},
|
||||
{
|
||||
name: "non-empty-with-item-error",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": RawItemWith(nil, NewErrorText("bang!"), nil),
|
||||
"Setting1": {err: NewErrorText("bang!")},
|
||||
}),
|
||||
wantString: `Setting1 = Error{"bang!"}`,
|
||||
},
|
||||
@@ -437,133 +433,3 @@ Setting3 = user-decides`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalSnapshot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
snapshot *Snapshot
|
||||
wantJSON string
|
||||
wantBack *Snapshot
|
||||
}{
|
||||
{
|
||||
name: "Nil",
|
||||
snapshot: (*Snapshot)(nil),
|
||||
wantJSON: "null",
|
||||
wantBack: NewSnapshot(nil),
|
||||
},
|
||||
{
|
||||
name: "Zero",
|
||||
snapshot: &Snapshot{},
|
||||
wantJSON: "{}",
|
||||
},
|
||||
{
|
||||
name: "Bool/True",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{"BoolPolicy": RawItemOf(true)}),
|
||||
wantJSON: `{"Settings": {"BoolPolicy": {"Value": true}}}`,
|
||||
},
|
||||
{
|
||||
name: "Bool/False",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{"BoolPolicy": RawItemOf(false)}),
|
||||
wantJSON: `{"Settings": {"BoolPolicy": {"Value": false}}}`,
|
||||
},
|
||||
{
|
||||
name: "String/Non-Empty",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{"StringPolicy": RawItemOf("StringValue")}),
|
||||
wantJSON: `{"Settings": {"StringPolicy": {"Value": "StringValue"}}}`,
|
||||
},
|
||||
{
|
||||
name: "String/Empty",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{"StringPolicy": RawItemOf("")}),
|
||||
wantJSON: `{"Settings": {"StringPolicy": {"Value": ""}}}`,
|
||||
},
|
||||
{
|
||||
name: "Integer/NonZero",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{"IntPolicy": RawItemOf(uint64(42))}),
|
||||
wantJSON: `{"Settings": {"IntPolicy": {"Value": 42}}}`,
|
||||
},
|
||||
{
|
||||
name: "Integer/Zero",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{"IntPolicy": RawItemOf(uint64(0))}),
|
||||
wantJSON: `{"Settings": {"IntPolicy": {"Value": 0}}}`,
|
||||
},
|
||||
{
|
||||
name: "String-List",
|
||||
snapshot: NewSnapshot(map[Key]RawItem{"ListPolicy": RawItemOf([]string{"Value1", "Value2"})}),
|
||||
wantJSON: `{"Settings": {"ListPolicy": {"Value": ["Value1", "Value2"]}}}`,
|
||||
},
|
||||
{
|
||||
name: "Empty/With-Summary",
|
||||
snapshot: NewSnapshot(
|
||||
map[Key]RawItem{},
|
||||
SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)),
|
||||
),
|
||||
wantJSON: `{"Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"}}`,
|
||||
},
|
||||
{
|
||||
name: "Setting/With-Summary",
|
||||
snapshot: NewSnapshot(
|
||||
map[Key]RawItem{"PolicySetting": RawItemOf(uint64(42))},
|
||||
SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)),
|
||||
),
|
||||
wantJSON: `{
|
||||
"Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"},
|
||||
"Settings": {"PolicySetting": {"Value": 42}}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Settings/With-Origins",
|
||||
snapshot: NewSnapshot(
|
||||
map[Key]RawItem{
|
||||
"SettingA": RawItemWith(uint64(42), nil, NewNamedOrigin("SourceA", DeviceScope)),
|
||||
"SettingB": RawItemWith("B", nil, NewNamedOrigin("SourceB", CurrentProfileScope)),
|
||||
"SettingC": RawItemWith(true, nil, NewNamedOrigin("SourceC", CurrentUserScope)),
|
||||
},
|
||||
),
|
||||
wantJSON: `{
|
||||
"Settings": {
|
||||
"SettingA": {"Value": 42, "Origin": {"Name": "SourceA", "Scope": "Device"}},
|
||||
"SettingB": {"Value": "B", "Origin": {"Name": "SourceB", "Scope": "Profile"}},
|
||||
"SettingC": {"Value": true, "Origin": {"Name": "SourceC", "Scope": "User"}}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
doTest := func(t *testing.T, useJSONv2 bool) {
|
||||
var gotJSON []byte
|
||||
var err error
|
||||
if useJSONv2 {
|
||||
gotJSON, err = jsonv2.Marshal(tt.snapshot)
|
||||
} else {
|
||||
gotJSON, err = json.Marshal(tt.snapshot)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, want, equal := internal.EqualJSONForTest(t, gotJSON, []byte(tt.wantJSON)); !equal {
|
||||
t.Errorf("JSON: got %s; want %s", got, want)
|
||||
}
|
||||
|
||||
gotBack := &Snapshot{}
|
||||
if useJSONv2 {
|
||||
err = jsonv2.Unmarshal(gotJSON, &gotBack)
|
||||
} else {
|
||||
err = json.Unmarshal(gotJSON, &gotBack)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if wantBack := cmp.Or(tt.wantBack, tt.snapshot); !gotBack.Equal(wantBack) {
|
||||
t.Errorf("Snapshot: got %+v; want %+v", gotBack, wantBack)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) { doTest(t, false) })
|
||||
t.Run("jsonv2", func(t *testing.T) { doTest(t, true) })
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
var lookupEnv = os.LookupEnv // test hook
|
||||
|
||||
var _ Store = (*EnvPolicyStore)(nil)
|
||||
|
||||
// EnvPolicyStore is a [Store] that reads policy settings from environment variables.
|
||||
type EnvPolicyStore struct{}
|
||||
|
||||
// ReadString implements [Store].
|
||||
func (s *EnvPolicyStore) ReadString(key setting.Key) (string, error) {
|
||||
_, str, err := s.lookupSettingVariable(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ReadUInt64 implements [Store].
|
||||
func (s *EnvPolicyStore) ReadUInt64(key setting.Key) (uint64, error) {
|
||||
name, str, err := s.lookupSettingVariable(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if str == "" {
|
||||
return 0, setting.ErrNotConfigured
|
||||
}
|
||||
value, err := strconv.ParseUint(str, 0, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%s: %w: %q is not a valid uint64", name, setting.ErrTypeMismatch, str)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ReadBoolean implements [Store].
|
||||
func (s *EnvPolicyStore) ReadBoolean(key setting.Key) (bool, error) {
|
||||
name, str, err := s.lookupSettingVariable(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if str == "" {
|
||||
return false, setting.ErrNotConfigured
|
||||
}
|
||||
value, err := strconv.ParseBool(str)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%s: %w: %q is not a valid bool", name, setting.ErrTypeMismatch, str)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ReadStringArray implements [Store].
|
||||
func (s *EnvPolicyStore) ReadStringArray(key setting.Key) ([]string, error) {
|
||||
_, str, err := s.lookupSettingVariable(key)
|
||||
if err != nil || str == "" {
|
||||
return nil, err
|
||||
}
|
||||
var dst int
|
||||
res := strings.Split(str, ",")
|
||||
for src := range res {
|
||||
res[dst] = strings.TrimSpace(res[src])
|
||||
if res[dst] != "" {
|
||||
dst++
|
||||
}
|
||||
}
|
||||
return res[0:dst], nil
|
||||
}
|
||||
|
||||
func (s *EnvPolicyStore) lookupSettingVariable(key setting.Key) (name, value string, err error) {
|
||||
name, err = keyToEnvVarName(key)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
value, ok := lookupEnv(name)
|
||||
if !ok {
|
||||
return name, "", setting.ErrNotConfigured
|
||||
}
|
||||
return name, value, nil
|
||||
}
|
||||
|
||||
var (
|
||||
errEmptyKey = errors.New("key must not be empty")
|
||||
errInvalidKey = errors.New("key must consist of alphanumeric characters and slashes")
|
||||
)
|
||||
|
||||
// keyToEnvVarName returns the environment variable name for a given policy
|
||||
// setting key, or an error if the key is invalid. It converts CamelCase keys into
|
||||
// underscore-separated words and prepends the variable name with the TS prefix.
|
||||
// For example: AuthKey => TS_AUTH_KEY, ExitNodeAllowLANAccess => TS_EXIT_NODE_ALLOW_LAN_ACCESS, etc.
|
||||
//
|
||||
// It's fine to use this in [EnvPolicyStore] without caching variable names since it's not a hot path.
|
||||
// [EnvPolicyStore] is not a [Changeable] policy store, so the conversion will only happen once.
|
||||
func keyToEnvVarName(key setting.Key) (string, error) {
|
||||
if len(key) == 0 {
|
||||
return "", errEmptyKey
|
||||
}
|
||||
|
||||
isLower := func(c byte) bool { return 'a' <= c && c <= 'z' }
|
||||
isUpper := func(c byte) bool { return 'A' <= c && c <= 'Z' }
|
||||
isLetter := func(c byte) bool { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') }
|
||||
isDigit := func(c byte) bool { return '0' <= c && c <= '9' }
|
||||
|
||||
words := make([]string, 0, 8)
|
||||
words = append(words, "TS_DEBUGSYSPOLICY")
|
||||
var currentWord strings.Builder
|
||||
for i := 0; i < len(key); i++ {
|
||||
c := key[i]
|
||||
if c >= utf8.RuneSelf {
|
||||
return "", errInvalidKey
|
||||
}
|
||||
|
||||
var split bool
|
||||
switch {
|
||||
case isLower(c):
|
||||
c -= 'a' - 'A' // make upper
|
||||
split = currentWord.Len() > 0 && !isLetter(key[i-1])
|
||||
case isUpper(c):
|
||||
if currentWord.Len() > 0 {
|
||||
prevUpper := isUpper(key[i-1])
|
||||
nextLower := i < len(key)-1 && isLower(key[i+1])
|
||||
split = !prevUpper || nextLower // split on case transition
|
||||
}
|
||||
case isDigit(c):
|
||||
split = currentWord.Len() > 0 && !isDigit(key[i-1])
|
||||
case c == setting.KeyPathSeparator:
|
||||
words = append(words, currentWord.String())
|
||||
currentWord.Reset()
|
||||
continue
|
||||
default:
|
||||
return "", errInvalidKey
|
||||
}
|
||||
|
||||
if split {
|
||||
words = append(words, currentWord.String())
|
||||
currentWord.Reset()
|
||||
}
|
||||
|
||||
currentWord.WriteByte(c)
|
||||
}
|
||||
|
||||
if currentWord.Len() > 0 {
|
||||
words = append(words, currentWord.String())
|
||||
}
|
||||
|
||||
return strings.Join(words, "_"), nil
|
||||
}
|
||||
@@ -1,359 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"math"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestKeyToEnvVarName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key setting.Key
|
||||
want string // suffix after "TS_DEBUGSYSPOLICY_"
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
key: "",
|
||||
wantErr: errEmptyKey,
|
||||
},
|
||||
{
|
||||
name: "lowercase",
|
||||
key: "tailnet",
|
||||
want: "TAILNET",
|
||||
},
|
||||
{
|
||||
name: "CamelCase",
|
||||
key: "AuthKey",
|
||||
want: "AUTH_KEY",
|
||||
},
|
||||
{
|
||||
name: "LongerCamelCase",
|
||||
key: "ManagedByOrganizationName",
|
||||
want: "MANAGED_BY_ORGANIZATION_NAME",
|
||||
},
|
||||
{
|
||||
name: "UPPERCASE",
|
||||
key: "UPPERCASE",
|
||||
want: "UPPERCASE",
|
||||
},
|
||||
{
|
||||
name: "WithAbbrev/Front",
|
||||
key: "DNSServer",
|
||||
want: "DNS_SERVER",
|
||||
},
|
||||
{
|
||||
name: "WithAbbrev/Middle",
|
||||
key: "ExitNodeAllowLANAccess",
|
||||
want: "EXIT_NODE_ALLOW_LAN_ACCESS",
|
||||
},
|
||||
{
|
||||
name: "WithAbbrev/Back",
|
||||
key: "ExitNodeID",
|
||||
want: "EXIT_NODE_ID",
|
||||
},
|
||||
{
|
||||
name: "WithDigits/Single/Front",
|
||||
key: "0TestKey",
|
||||
want: "0_TEST_KEY",
|
||||
},
|
||||
{
|
||||
name: "WithDigits/Multi/Front",
|
||||
key: "64TestKey",
|
||||
want: "64_TEST_KEY",
|
||||
},
|
||||
{
|
||||
name: "WithDigits/Single/Middle",
|
||||
key: "Test0Key",
|
||||
want: "TEST_0_KEY",
|
||||
},
|
||||
{
|
||||
name: "WithDigits/Multi/Middle",
|
||||
key: "Test64Key",
|
||||
want: "TEST_64_KEY",
|
||||
},
|
||||
{
|
||||
name: "WithDigits/Single/Back",
|
||||
key: "TestKey0",
|
||||
want: "TEST_KEY_0",
|
||||
},
|
||||
{
|
||||
name: "WithDigits/Multi/Back",
|
||||
key: "TestKey64",
|
||||
want: "TEST_KEY_64",
|
||||
},
|
||||
{
|
||||
name: "WithDigits/Multi/Back",
|
||||
key: "TestKey64",
|
||||
want: "TEST_KEY_64",
|
||||
},
|
||||
{
|
||||
name: "WithPathSeparators/Single",
|
||||
key: "Key/Subkey",
|
||||
want: "KEY_SUBKEY",
|
||||
},
|
||||
{
|
||||
name: "WithPathSeparators/Multi",
|
||||
key: "Root/Level1/Level2",
|
||||
want: "ROOT_LEVEL_1_LEVEL_2",
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
key: "Network/DNSServer/IPAddress",
|
||||
want: "NETWORK_DNS_SERVER_IP_ADDRESS",
|
||||
},
|
||||
{
|
||||
name: "Non-Alphanumeric/NonASCII/1",
|
||||
key: "ж",
|
||||
wantErr: errInvalidKey,
|
||||
},
|
||||
{
|
||||
name: "Non-Alphanumeric/NonASCII/2",
|
||||
key: "KeyжName",
|
||||
wantErr: errInvalidKey,
|
||||
},
|
||||
{
|
||||
name: "Non-Alphanumeric/Space",
|
||||
key: "Key Name",
|
||||
wantErr: errInvalidKey,
|
||||
},
|
||||
{
|
||||
name: "Non-Alphanumeric/Punct",
|
||||
key: "Key!Name",
|
||||
wantErr: errInvalidKey,
|
||||
},
|
||||
{
|
||||
name: "Non-Alphanumeric/Backslash",
|
||||
key: `Key\Name`,
|
||||
wantErr: errInvalidKey,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) {
|
||||
got, err := keyToEnvVarName(tt.key)
|
||||
checkError(t, err, tt.wantErr, true)
|
||||
|
||||
want := tt.want
|
||||
if want != "" {
|
||||
want = "TS_DEBUGSYSPOLICY_" + want
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("got %q; want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvPolicyStore(t *testing.T) {
|
||||
blankEnv := func(string) (string, bool) { return "", false }
|
||||
makeEnv := func(wantName, value string) func(string) (string, bool) {
|
||||
wantName = "TS_DEBUGSYSPOLICY_" + wantName
|
||||
return func(gotName string) (string, bool) {
|
||||
if gotName != wantName {
|
||||
return "", false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
key setting.Key
|
||||
lookup func(string) (string, bool)
|
||||
want any
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "NotConfigured/String",
|
||||
key: "AuthKey",
|
||||
lookup: blankEnv,
|
||||
wantErr: setting.ErrNotConfigured,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "Configured/String/Empty",
|
||||
key: "AuthKey",
|
||||
lookup: makeEnv("AUTH_KEY", ""),
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "Configured/String/NonEmpty",
|
||||
key: "AuthKey",
|
||||
lookup: makeEnv("AUTH_KEY", "ABC123"),
|
||||
want: "ABC123",
|
||||
},
|
||||
{
|
||||
name: "NotConfigured/UInt64",
|
||||
key: "IntegerSetting",
|
||||
lookup: blankEnv,
|
||||
wantErr: setting.ErrNotConfigured,
|
||||
want: uint64(0),
|
||||
},
|
||||
{
|
||||
name: "Configured/UInt64/Empty",
|
||||
key: "IntegerSetting",
|
||||
lookup: makeEnv("INTEGER_SETTING", ""),
|
||||
wantErr: setting.ErrNotConfigured,
|
||||
want: uint64(0),
|
||||
},
|
||||
{
|
||||
name: "Configured/UInt64/Zero",
|
||||
key: "IntegerSetting",
|
||||
lookup: makeEnv("INTEGER_SETTING", "0"),
|
||||
want: uint64(0),
|
||||
},
|
||||
{
|
||||
name: "Configured/UInt64/NonZero",
|
||||
key: "IntegerSetting",
|
||||
lookup: makeEnv("INTEGER_SETTING", "12345"),
|
||||
want: uint64(12345),
|
||||
},
|
||||
{
|
||||
name: "Configured/UInt64/MaxUInt64",
|
||||
key: "IntegerSetting",
|
||||
lookup: makeEnv("INTEGER_SETTING", strconv.FormatUint(math.MaxUint64, 10)),
|
||||
want: uint64(math.MaxUint64),
|
||||
},
|
||||
{
|
||||
name: "Configured/UInt64/Negative",
|
||||
key: "IntegerSetting",
|
||||
lookup: makeEnv("INTEGER_SETTING", "-1"),
|
||||
wantErr: setting.ErrTypeMismatch,
|
||||
want: uint64(0),
|
||||
},
|
||||
{
|
||||
name: "Configured/UInt64/Hex",
|
||||
key: "IntegerSetting",
|
||||
lookup: makeEnv("INTEGER_SETTING", "0xDEADBEEF"),
|
||||
want: uint64(0xDEADBEEF),
|
||||
},
|
||||
{
|
||||
name: "NotConfigured/Bool",
|
||||
key: "LogSCMInteractions",
|
||||
lookup: blankEnv,
|
||||
wantErr: setting.ErrNotConfigured,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Configured/Bool/Empty",
|
||||
key: "LogSCMInteractions",
|
||||
lookup: makeEnv("LOG_SCM_INTERACTIONS", ""),
|
||||
wantErr: setting.ErrNotConfigured,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Configured/Bool/True",
|
||||
key: "LogSCMInteractions",
|
||||
lookup: makeEnv("LOG_SCM_INTERACTIONS", "true"),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Configured/Bool/False",
|
||||
key: "LogSCMInteractions",
|
||||
lookup: makeEnv("LOG_SCM_INTERACTIONS", "False"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Configured/Bool/1",
|
||||
key: "LogSCMInteractions",
|
||||
lookup: makeEnv("LOG_SCM_INTERACTIONS", "1"),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Configured/Bool/0",
|
||||
key: "LogSCMInteractions",
|
||||
lookup: makeEnv("LOG_SCM_INTERACTIONS", "0"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Configured/Bool/Invalid",
|
||||
key: "IntegerSetting",
|
||||
lookup: makeEnv("INTEGER_SETTING", "NotABool"),
|
||||
wantErr: setting.ErrTypeMismatch,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "NotConfigured/StringArray",
|
||||
key: "AllowedSuggestedExitNodes",
|
||||
lookup: blankEnv,
|
||||
wantErr: setting.ErrNotConfigured,
|
||||
want: []string(nil),
|
||||
},
|
||||
{
|
||||
name: "Configured/StringArray/Empty",
|
||||
key: "AllowedSuggestedExitNodes",
|
||||
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", ""),
|
||||
want: []string(nil),
|
||||
},
|
||||
{
|
||||
name: "Configured/StringArray/Spaces",
|
||||
key: "AllowedSuggestedExitNodes",
|
||||
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", " \t "),
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "Configured/StringArray/Single",
|
||||
key: "AllowedSuggestedExitNodes",
|
||||
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA"),
|
||||
want: []string{"NodeA"},
|
||||
},
|
||||
{
|
||||
name: "Configured/StringArray/Multi",
|
||||
key: "AllowedSuggestedExitNodes",
|
||||
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,NodeB,NodeC"),
|
||||
want: []string{"NodeA", "NodeB", "NodeC"},
|
||||
},
|
||||
{
|
||||
name: "Configured/StringArray/WithBlank",
|
||||
key: "AllowedSuggestedExitNodes",
|
||||
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,\t,, ,NodeB"),
|
||||
want: []string{"NodeA", "NodeB"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) {
|
||||
oldLookupEnv := lookupEnv
|
||||
t.Cleanup(func() { lookupEnv = oldLookupEnv })
|
||||
lookupEnv = tt.lookup
|
||||
|
||||
var got any
|
||||
var err error
|
||||
var store EnvPolicyStore
|
||||
switch tt.want.(type) {
|
||||
case string:
|
||||
got, err = store.ReadString(tt.key)
|
||||
case uint64:
|
||||
got, err = store.ReadUInt64(tt.key)
|
||||
case bool:
|
||||
got, err = store.ReadBoolean(tt.key)
|
||||
case []string:
|
||||
got, err = store.ReadStringArray(tt.key)
|
||||
}
|
||||
checkError(t, err, tt.wantErr, false)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("got %v; want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkError(tb testing.TB, got, want error, fatal bool) {
|
||||
tb.Helper()
|
||||
f := tb.Errorf
|
||||
if fatal {
|
||||
f = tb.Fatalf
|
||||
}
|
||||
if (want == nil && got != nil) ||
|
||||
(want != nil && got == nil) ||
|
||||
(want != nil && got != nil && !errors.Is(got, want) && want.Error() != got.Error()) {
|
||||
f("gotErr: %v; wantErr: %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -319,9 +319,9 @@ func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error
|
||||
// If there are no [setting.KeyPathSeparator]s in the key, the policy setting value
|
||||
// is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale.
|
||||
func splitSettingKey(key setting.Key) (path, valueName string) {
|
||||
if idx := strings.LastIndexByte(string(key), setting.KeyPathSeparator); idx != -1 {
|
||||
path = strings.ReplaceAll(string(key[:idx]), string(setting.KeyPathSeparator), `\`)
|
||||
valueName = string(key[idx+1:])
|
||||
if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 {
|
||||
path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`)
|
||||
valueName = string(key[idx+len(setting.KeyPathSeparator):])
|
||||
return path, valueName
|
||||
}
|
||||
return "", string(key)
|
||||
|
||||
@@ -1,82 +1,51 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package syspolicy facilitates retrieval of the current policy settings
|
||||
// applied to the device or user and receiving notifications when the policy
|
||||
// changes.
|
||||
//
|
||||
// It provides functions that return specific policy settings by their unique
|
||||
// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString],
|
||||
// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration].
|
||||
// Package syspolicy provides functions to retrieve system settings of a device.
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotConfigured is returned when the requested policy setting is not configured.
|
||||
ErrNotConfigured = setting.ErrNotConfigured
|
||||
// ErrTypeMismatch is returned when there's a type mismatch between the actual type
|
||||
// of the setting value and the expected type.
|
||||
ErrTypeMismatch = setting.ErrTypeMismatch
|
||||
// ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting
|
||||
// has been registered with the specified key.
|
||||
//
|
||||
// This error is also returned by a (now deprecated) [Handler] when the specified
|
||||
// key does not have a value set. While the package maintains compatibility with this
|
||||
// usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
|
||||
// [source.Store] implementations.
|
||||
ErrNoSuchKey = setting.ErrNoSuchKey
|
||||
)
|
||||
|
||||
// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
|
||||
//
|
||||
// It is a shorthand for [rsop.RegisterStore].
|
||||
func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*rsop.StoreRegistration, error) {
|
||||
return rsop.RegisterStore(name, scope, store)
|
||||
}
|
||||
|
||||
// MustRegisterStoreForTest is like [rsop.RegisterStoreForTest], but it fails the test if the store could not be registered.
|
||||
func MustRegisterStoreForTest(tb TB, name string, scope setting.PolicyScope, store source.Store) *rsop.StoreRegistration {
|
||||
tb.Helper()
|
||||
reg, err := rsop.RegisterStoreForTest(tb, name, scope, store)
|
||||
if err != nil {
|
||||
tb.Fatalf("Failed to register policy store %q as a %v policy source: %v", name, scope, err)
|
||||
}
|
||||
return reg
|
||||
}
|
||||
|
||||
// GetString returns a string policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetString(key Key, defaultValue string) (string, error) {
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadString(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// GetUint64 returns a numeric policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetUint64(key Key, defaultValue uint64) (uint64, error) {
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadUInt64(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// GetBoolean returns a boolean policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetBoolean(key Key, defaultValue bool) (bool, error) {
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadBoolean(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// GetStringArray returns a multi-string policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetStringArray(key Key, defaultValue []string) ([]string, error) {
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadStringArray(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
// GetPreferenceOption loads a policy from the registry that can be
|
||||
@@ -86,7 +55,13 @@ func GetStringArray(key Key, defaultValue []string) ([]string, error) {
|
||||
// "always" and "never" remove the user's ability to make a selection. If not
|
||||
// present or set to a different value, "user-decides" is the default.
|
||||
func GetPreferenceOption(name Key) (setting.PreferenceOption, error) {
|
||||
return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy)
|
||||
s, err := GetString(name, "user-decides")
|
||||
if err != nil {
|
||||
return setting.ShowChoiceByPolicy, err
|
||||
}
|
||||
var opt setting.PreferenceOption
|
||||
err = opt.UnmarshalText([]byte(s))
|
||||
return opt, err
|
||||
}
|
||||
|
||||
// GetVisibility loads a policy from the registry that can be managed
|
||||
@@ -95,7 +70,13 @@ func GetPreferenceOption(name Key) (setting.PreferenceOption, error) {
|
||||
// true) or "hide" (return true). If not present or set to a different value,
|
||||
// "show" (return false) is the default.
|
||||
func GetVisibility(name Key) (setting.Visibility, error) {
|
||||
return getCurrentPolicySettingValue(name, setting.VisibleByPolicy)
|
||||
s, err := GetString(name, "show")
|
||||
if err != nil {
|
||||
return setting.VisibleByPolicy, err
|
||||
}
|
||||
var visibility setting.Visibility
|
||||
visibility.UnmarshalText([]byte(s))
|
||||
return visibility, nil
|
||||
}
|
||||
|
||||
// GetDuration loads a policy from the registry that can be managed
|
||||
@@ -104,58 +85,15 @@ func GetVisibility(name Key) (setting.Visibility, error) {
|
||||
// understands. If the registry value is "" or can not be processed,
|
||||
// defaultValue is returned instead.
|
||||
func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) {
|
||||
d, err := getCurrentPolicySettingValue(name, defaultValue)
|
||||
if err != nil {
|
||||
return d, err
|
||||
opt, err := GetString(name, "")
|
||||
if opt == "" || err != nil {
|
||||
return defaultValue, err
|
||||
}
|
||||
if d < 0 {
|
||||
v, err := time.ParseDuration(opt)
|
||||
if err != nil || v < 0 {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// RegisterChangeCallback adds a function that will be called whenever the effective policy
|
||||
// for the default scope changes. The returned function can be used to unregister the callback.
|
||||
func RegisterChangeCallback(cb rsop.PolicyChangeCallback) (unregister func(), err error) {
|
||||
effective, err := rsop.PolicyFor(setting.DefaultScope())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return effective.RegisterChangeCallback(cb), nil
|
||||
}
|
||||
|
||||
// getCurrentPolicySettingValue returns the value of the policy setting
|
||||
// specified by its key from the [rsop.Policy] of the [setting.DefaultScope]. It
|
||||
// returns def if the policy setting is not configured, or an error if it has
|
||||
// an error or could not be converted to the specified type T.
|
||||
func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) {
|
||||
effective, err := rsop.PolicyFor(setting.DefaultScope())
|
||||
if err != nil {
|
||||
return def, err
|
||||
}
|
||||
value, err := effective.Get().GetErr(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) {
|
||||
return def, nil
|
||||
}
|
||||
return def, err
|
||||
}
|
||||
if res, ok := value.(T); ok {
|
||||
return res, nil
|
||||
}
|
||||
return convertPolicySettingValueTo(value, def)
|
||||
}
|
||||
|
||||
func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) {
|
||||
// Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string
|
||||
// if someone requests a string instead of the actual setting's value.
|
||||
// TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests.
|
||||
if reflect.TypeFor[T]().Kind() == reflect.String {
|
||||
if str, ok := value.(fmt.Stringer); ok {
|
||||
return any(str.String()).(T), nil
|
||||
}
|
||||
}
|
||||
return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def)
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// SelectControlURL returns the ControlURL to use based on a value in
|
||||
|
||||
@@ -9,15 +9,57 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/internal/metrics"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
// testHandler encompasses all data types returned when testing any of the syspolicy
|
||||
// methods that involve getting a policy value.
|
||||
// For keys and the corresponding values, check policy_keys.go.
|
||||
type testHandler struct {
|
||||
t *testing.T
|
||||
key Key
|
||||
s string
|
||||
u64 uint64
|
||||
b bool
|
||||
sArr []string
|
||||
err error
|
||||
calls int // used for testing reads from cache vs. handler
|
||||
}
|
||||
|
||||
var someOtherError = errors.New("error other than not found")
|
||||
|
||||
func (th *testHandler) ReadString(key string) (string, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadString(%q) want %q", key, th.key)
|
||||
}
|
||||
th.calls++
|
||||
return th.s, th.err
|
||||
}
|
||||
|
||||
func (th *testHandler) ReadUInt64(key string) (uint64, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadUint64(%q) want %q", key, th.key)
|
||||
}
|
||||
th.calls++
|
||||
return th.u64, th.err
|
||||
}
|
||||
|
||||
func (th *testHandler) ReadBoolean(key string) (bool, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadBool(%q) want %q", key, th.key)
|
||||
}
|
||||
th.calls++
|
||||
return th.b, th.err
|
||||
}
|
||||
|
||||
func (th *testHandler) ReadStringArray(key string) ([]string, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadStringArray(%q) want %q", key, th.key)
|
||||
}
|
||||
th.calls++
|
||||
return th.sArr, th.err
|
||||
}
|
||||
|
||||
func TestGetString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -27,28 +69,23 @@ func TestGetString(t *testing.T) {
|
||||
defaultValue string
|
||||
wantValue string
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "hide",
|
||||
wantValue: "hide",
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: EnableServerMode,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non-blank default",
|
||||
key: EnableServerMode,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
defaultValue: "test",
|
||||
wantValue: "test",
|
||||
wantError: nil,
|
||||
@@ -58,43 +95,24 @@ func TestGetString(t *testing.T) {
|
||||
key: NetworkDevicesVisibility,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_NetworkDevices_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
s := source.TestSetting[string]{
|
||||
Key: tt.key,
|
||||
Value: tt.handlerValue,
|
||||
Error: tt.handlerError,
|
||||
}
|
||||
registerSingleSettingStoreForTest(t, s)
|
||||
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
s: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetString(tt.key, tt.defaultValue)
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
if err != tt.wantError {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-09-04, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -111,7 +129,7 @@ func TestGetUint64(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: LogSCMInteractions,
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerValue: 1,
|
||||
wantValue: 1,
|
||||
},
|
||||
@@ -119,14 +137,14 @@ func TestGetUint64(t *testing.T) {
|
||||
name: "read non-existing value",
|
||||
key: LogSCMInteractions,
|
||||
handlerValue: 0,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: 0,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non-zero default",
|
||||
key: LogSCMInteractions,
|
||||
defaultValue: 2,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: 2,
|
||||
},
|
||||
{
|
||||
@@ -139,23 +157,14 @@ func TestGetUint64(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// None of the policy settings tested here are integers.
|
||||
// In fact, we don't have any integer policies as of 2024-10-08.
|
||||
// However, we can register each of them as an integer policy setting
|
||||
// for the duration of the test, providing us with something to test against.
|
||||
if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
|
||||
s := source.TestSetting[uint64]{
|
||||
Key: tt.key,
|
||||
Value: tt.handlerValue,
|
||||
Error: tt.handlerError,
|
||||
}
|
||||
registerSingleSettingStoreForTest(t, s)
|
||||
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
u64: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetUint64(tt.key, tt.defaultValue)
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
if err != tt.wantError {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
@@ -174,69 +183,45 @@ func TestGetBoolean(t *testing.T) {
|
||||
defaultValue bool
|
||||
wantValue bool
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: FlushDNSOnSessionUnlock,
|
||||
handlerValue: true,
|
||||
wantValue: true,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: LogSCMInteractions,
|
||||
handlerValue: false,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: false,
|
||||
},
|
||||
{
|
||||
name: "reading value returns other error",
|
||||
key: FlushDNSOnSessionUnlock,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError, // expect error...
|
||||
wantError: someOtherError,
|
||||
defaultValue: true,
|
||||
wantValue: true, // ...AND default value if the handler fails.
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1},
|
||||
},
|
||||
wantValue: false,
|
||||
},
|
||||
}
|
||||
|
||||
RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
s := source.TestSetting[bool]{
|
||||
Key: tt.key,
|
||||
Value: tt.handlerValue,
|
||||
Error: tt.handlerError,
|
||||
}
|
||||
registerSingleSettingStoreForTest(t, s)
|
||||
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
b: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetBoolean(tt.key, tt.defaultValue)
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
if err != tt.wantError {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-09-04, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -249,42 +234,29 @@ func TestGetPreferenceOption(t *testing.T) {
|
||||
handlerError error
|
||||
wantValue setting.PreferenceOption
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "always by policy",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "always",
|
||||
wantValue: setting.AlwaysByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "never by policy",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "never",
|
||||
wantValue: setting.NeverByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "use default",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "",
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: EnableIncomingConnections,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
},
|
||||
{
|
||||
@@ -293,43 +265,24 @@ func TestGetPreferenceOption(t *testing.T) {
|
||||
handlerError: someOtherError,
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
s := source.TestSetting[string]{
|
||||
Key: tt.key,
|
||||
Value: tt.handlerValue,
|
||||
Error: tt.handlerError,
|
||||
}
|
||||
registerSingleSettingStoreForTest(t, s)
|
||||
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
s: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
option, err := GetPreferenceOption(tt.key)
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
if err != tt.wantError {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if option != tt.wantValue {
|
||||
t.Errorf("option=%v, want %v", option, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-09-04, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -342,33 +295,24 @@ func TestGetVisibility(t *testing.T) {
|
||||
handlerError error
|
||||
wantValue setting.Visibility
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "hidden by policy",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "hide",
|
||||
wantValue: setting.HiddenByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "visibility default",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "show",
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "show",
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
},
|
||||
{
|
||||
@@ -378,43 +322,24 @@ func TestGetVisibility(t *testing.T) {
|
||||
handlerError: someOtherError,
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
s := source.TestSetting[string]{
|
||||
Key: tt.key,
|
||||
Value: tt.handlerValue,
|
||||
Error: tt.handlerError,
|
||||
}
|
||||
registerSingleSettingStoreForTest(t, s)
|
||||
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
s: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
visibility, err := GetVisibility(tt.key)
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
if err != tt.wantError {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if visibility != tt.wantValue {
|
||||
t.Errorf("visibility=%v, want %v", visibility, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-09-04, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -428,7 +353,6 @@ func TestGetDuration(t *testing.T) {
|
||||
defaultValue time.Duration
|
||||
wantValue time.Duration
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
@@ -436,34 +360,25 @@ func TestGetDuration(t *testing.T) {
|
||||
handlerValue: "2h",
|
||||
wantValue: 2 * time.Hour,
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid duration value",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerValue: "-20",
|
||||
wantValue: 24 * time.Hour,
|
||||
wantError: errors.New(`time: missing unit in duration "-20"`),
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: 24 * time.Hour,
|
||||
defaultValue: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value different default",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: 0 * time.Second,
|
||||
defaultValue: 0 * time.Second,
|
||||
},
|
||||
@@ -474,43 +389,24 @@ func TestGetDuration(t *testing.T) {
|
||||
wantValue: 24 * time.Hour,
|
||||
wantError: someOtherError,
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
s := source.TestSetting[string]{
|
||||
Key: tt.key,
|
||||
Value: tt.handlerValue,
|
||||
Error: tt.handlerError,
|
||||
}
|
||||
registerSingleSettingStoreForTest(t, s)
|
||||
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
s: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
duration, err := GetDuration(tt.key, tt.defaultValue)
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
if err != tt.wantError {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if duration != tt.wantValue {
|
||||
t.Errorf("duration=%v, want %v", duration, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-09-04, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -524,28 +420,23 @@ func TestGetStringArray(t *testing.T) {
|
||||
defaultValue []string
|
||||
wantValue []string
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerValue: []string{"foo", "bar"},
|
||||
wantValue: []string{"foo", "bar"},
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non nil default",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: ErrNotConfigured,
|
||||
handlerError: ErrNoSuchKey,
|
||||
defaultValue: []string{"foo", "bar"},
|
||||
wantValue: []string{"foo", "bar"},
|
||||
wantError: nil,
|
||||
@@ -555,68 +446,28 @@ func TestGetStringArray(t *testing.T) {
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RegisterWellKnownSettingsForTest(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
s := source.TestSetting[[]string]{
|
||||
Key: tt.key,
|
||||
Value: tt.handlerValue,
|
||||
Error: tt.handlerError,
|
||||
}
|
||||
registerSingleSettingStoreForTest(t, s)
|
||||
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
sArr: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetStringArray(tt.key, tt.defaultValue)
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
if err != tt.wantError {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if !slices.Equal(tt.wantValue, value) {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-09-04, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func registerSingleSettingStoreForTest[T source.TestValueType](tb TB, s source.TestSetting[T]) {
|
||||
policyStore := source.NewTestStoreOf(tb, s)
|
||||
MustRegisterStoreForTest(tb, "TestStore", setting.DeviceScope, policyStore)
|
||||
}
|
||||
|
||||
func BenchmarkGetString(b *testing.B) {
|
||||
loggerx.SetForTest(b, logger.Discard, logger.Discard)
|
||||
RegisterWellKnownSettingsForTest(b)
|
||||
|
||||
wantControlURL := "https://login.tailscale.com"
|
||||
registerSingleSettingStoreForTest(b, source.TestSettingOf(ControlURL, wantControlURL))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com")
|
||||
if gotControlURL != wantControlURL {
|
||||
b.Fatalf("got %v; want %v", gotControlURL, wantControlURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectControlURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
reg, disk, want string
|
||||
@@ -648,13 +499,3 @@ func TestSelectControlURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func errorsMatchForTest(got, want error) bool {
|
||||
if got == nil && want == nil {
|
||||
return true
|
||||
}
|
||||
if got == nil || want == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(got, want) || got.Error() == want.Error()
|
||||
}
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/user"
|
||||
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// On Windows, we should automatically register the Registry-based policy
|
||||
// store for the device. If we are running in a user's security context
|
||||
// (e.g., we're the GUI), we should also register the Registry policy store for
|
||||
// the user. In the future, we should register (and unregister) user policy
|
||||
// stores whenever a user connects to (or disconnects from) the local backend.
|
||||
// This ensures the backend is aware of the user's policy settings and can send
|
||||
// them to the GUI/CLI/Web clients on demand or whenever they change.
|
||||
//
|
||||
// Other platforms, such as macOS, iOS and Android, should register their
|
||||
// platform-specific policy stores via [RegisterStore]
|
||||
// (or [RegisterHandler] until they implement the [source.Store] interface).
|
||||
//
|
||||
// External code, such as the ipnlocal package, may choose to register
|
||||
// additional policy stores, such as config files and policies received from
|
||||
// the control plane.
|
||||
internal.Init.MustDefer(func() error {
|
||||
// Do not register or use default policy stores during tests.
|
||||
// Each test should set up its own necessary configurations.
|
||||
if testenv.InTest() {
|
||||
return nil
|
||||
}
|
||||
return configureSyspolicy(nil)
|
||||
})
|
||||
}
|
||||
|
||||
// configureSyspolicy configures syspolicy for use on Windows,
|
||||
// either in test or regular builds depending on whether tb has a non-nil value.
|
||||
func configureSyspolicy(tb internal.TB) error {
|
||||
const localSystemSID = "S-1-5-18"
|
||||
// Always create and register a machine policy store that reads
|
||||
// policy settings from the HKEY_LOCAL_MACHINE registry hive.
|
||||
machineStore, err := source.NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create the machine policy store: %v", err)
|
||||
}
|
||||
if tb == nil {
|
||||
_, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore)
|
||||
} else {
|
||||
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check whether the current process is running as Local System or not.
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Uid == localSystemSID {
|
||||
return nil
|
||||
}
|
||||
// If it's not a Local System's process (e.g., it's the GUI rather than the tailscaled service),
|
||||
// we should create and use a policy store for the current user that reads
|
||||
// policy settings from that user's registry hive (HKEY_CURRENT_USER).
|
||||
userStore, err := source.NewUserPlatformPolicyStore(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create the current user's policy store: %v", err)
|
||||
}
|
||||
if tb == nil {
|
||||
_, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore)
|
||||
} else {
|
||||
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// And also set [setting.CurrentUserScope] as the [setting.DefaultScope], so [GetString],
|
||||
// [GetVisibility] and similar functions would be returning a merged result
|
||||
// of the machine's and user's policies.
|
||||
if !setting.SetDefaultScope(setting.CurrentUserScope) {
|
||||
return errors.New("current scope already set")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// This file contains user-facing metrics that are used by multiple packages.
|
||||
// Use it to define more common metrics. Any changes to the registry and
|
||||
// metric types should be in usermetric.go.
|
||||
|
||||
package usermetric
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"tailscale.com/metrics"
|
||||
)
|
||||
|
||||
// Metrics contains user-facing metrics that are used by multiple packages.
|
||||
type Metrics struct {
|
||||
initOnce sync.Once
|
||||
|
||||
droppedPacketsInbound *metrics.MultiLabelMap[DropLabels]
|
||||
droppedPacketsOutbound *metrics.MultiLabelMap[DropLabels]
|
||||
}
|
||||
|
||||
// DropReason is the reason why a packet was dropped.
|
||||
type DropReason string
|
||||
|
||||
const (
|
||||
// ReasonACL means that the packet was not permitted by ACL.
|
||||
ReasonACL DropReason = "acl"
|
||||
|
||||
// ReasonError means that the packet was dropped because of an error.
|
||||
ReasonError DropReason = "error"
|
||||
)
|
||||
|
||||
// DropLabels contains common label(s) for dropped packet counters.
|
||||
type DropLabels struct {
|
||||
Reason DropReason
|
||||
}
|
||||
|
||||
// initOnce initializes the common metrics.
|
||||
func (r *Registry) initOnce() {
|
||||
r.m.initOnce.Do(func() {
|
||||
r.m.droppedPacketsInbound = NewMultiLabelMapWithRegistry[DropLabels](
|
||||
r,
|
||||
"tailscaled_inbound_dropped_packets_total",
|
||||
"counter",
|
||||
"Counts the number of dropped packets received by the node from other peers",
|
||||
)
|
||||
r.m.droppedPacketsOutbound = NewMultiLabelMapWithRegistry[DropLabels](
|
||||
r,
|
||||
"tailscaled_outbound_dropped_packets_total",
|
||||
"counter",
|
||||
"Counts the number of packets dropped while being sent to other peers",
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// DroppedPacketsOutbound returns the outbound dropped packet metric, creating it
|
||||
// if necessary.
|
||||
func (r *Registry) DroppedPacketsOutbound() *metrics.MultiLabelMap[DropLabels] {
|
||||
r.initOnce()
|
||||
return r.m.droppedPacketsOutbound
|
||||
}
|
||||
|
||||
// DroppedPacketsInbound returns the inbound dropped packet metric.
|
||||
func (r *Registry) DroppedPacketsInbound() *metrics.MultiLabelMap[DropLabels] {
|
||||
r.initOnce()
|
||||
return r.m.droppedPacketsInbound
|
||||
}
|
||||
@@ -19,9 +19,6 @@ import (
|
||||
// Registry tracks user-facing metrics of various Tailscale subsystems.
|
||||
type Registry struct {
|
||||
vars expvar.Map
|
||||
|
||||
// m contains common metrics owned by the registry.
|
||||
m Metrics
|
||||
}
|
||||
|
||||
// NewMultiLabelMapWithRegistry creates and register a new
|
||||
|
||||
@@ -158,10 +158,10 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int)
|
||||
} else {
|
||||
connectedToControl = c.health.GetInPollNetMap()
|
||||
}
|
||||
c.mu.Lock()
|
||||
myDerp := c.myDerp
|
||||
c.mu.Unlock()
|
||||
if !connectedToControl {
|
||||
c.mu.Lock()
|
||||
myDerp := c.myDerp
|
||||
c.mu.Unlock()
|
||||
if myDerp != 0 {
|
||||
metricDERPHomeNoChangeNoControl.Add(1)
|
||||
return myDerp
|
||||
@@ -178,11 +178,6 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int)
|
||||
// one.
|
||||
preferredDERP = c.pickDERPFallback()
|
||||
}
|
||||
if preferredDERP != myDerp {
|
||||
c.logf(
|
||||
"magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms]",
|
||||
c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds())
|
||||
}
|
||||
if !c.setNearestDERP(preferredDERP) {
|
||||
preferredDERP = 0
|
||||
}
|
||||
@@ -649,10 +644,9 @@ func (c *Conn) runDerpReader(ctx context.Context, regionID int, dc *derphttp.Cli
|
||||
}
|
||||
|
||||
type derpWriteRequest struct {
|
||||
addr netip.AddrPort
|
||||
pubKey key.NodePublic
|
||||
b []byte // copied; ownership passed to receiver
|
||||
isDisco bool
|
||||
addr netip.AddrPort
|
||||
pubKey key.NodePublic
|
||||
b []byte // copied; ownership passed to receiver
|
||||
}
|
||||
|
||||
// runDerpWriter runs in a goroutine for the life of a DERP
|
||||
@@ -674,10 +668,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan
|
||||
if err != nil {
|
||||
c.logf("magicsock: derp.Send(%v): %v", wr.addr, err)
|
||||
metricSendDERPError.Add(1)
|
||||
if !wr.isDisco {
|
||||
c.metrics.outboundPacketsDroppedErrors.Add(1)
|
||||
}
|
||||
} else if !wr.isDisco {
|
||||
} else {
|
||||
c.metrics.outboundPacketsDERPTotal.Add(1)
|
||||
c.metrics.outboundBytesDERPTotal.Add(int64(len(wr.b)))
|
||||
}
|
||||
@@ -700,6 +691,8 @@ func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint)
|
||||
// No data read occurred. Wait for another packet.
|
||||
continue
|
||||
}
|
||||
c.metrics.inboundPacketsDERPTotal.Add(1)
|
||||
c.metrics.inboundBytesDERPTotal.Add(int64(n))
|
||||
sizes[0] = n
|
||||
eps[0] = ep
|
||||
return 1, nil
|
||||
@@ -739,9 +732,6 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
|
||||
if stats := c.stats.Load(); stats != nil {
|
||||
stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, dm.n)
|
||||
}
|
||||
|
||||
c.metrics.inboundPacketsDERPTotal.Add(1)
|
||||
c.metrics.inboundBytesDERPTotal.Add(int64(n))
|
||||
return n, ep
|
||||
}
|
||||
|
||||
|
||||
@@ -983,8 +983,7 @@ func (de *endpoint) send(buffs [][]byte) error {
|
||||
allOk := true
|
||||
var txBytes int
|
||||
for _, buff := range buffs {
|
||||
const isDisco = false
|
||||
ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff, isDisco)
|
||||
ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff)
|
||||
txBytes += len(buff)
|
||||
if !ok {
|
||||
allOk = false
|
||||
@@ -992,7 +991,7 @@ func (de *endpoint) send(buffs [][]byte) error {
|
||||
}
|
||||
|
||||
if stats := de.c.stats.Load(); stats != nil {
|
||||
stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buffs), txBytes)
|
||||
stats.UpdateTxPhysical(de.nodeAddr, derpAddr, 1, txBytes)
|
||||
}
|
||||
if allOk {
|
||||
return nil
|
||||
|
||||
@@ -127,10 +127,6 @@ type metrics struct {
|
||||
outboundBytesIPv4Total expvar.Int
|
||||
outboundBytesIPv6Total expvar.Int
|
||||
outboundBytesDERPTotal expvar.Int
|
||||
|
||||
// outboundPacketsDroppedErrors is the total number of outbound packets
|
||||
// dropped due to errors.
|
||||
outboundPacketsDroppedErrors expvar.Int
|
||||
}
|
||||
|
||||
// A Conn routes UDP packets and actively manages a list of its endpoints.
|
||||
@@ -609,8 +605,6 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
|
||||
"counter",
|
||||
"Counts the number of bytes sent to other peers",
|
||||
)
|
||||
outboundPacketsDroppedErrors := reg.DroppedPacketsOutbound()
|
||||
|
||||
m := new(metrics)
|
||||
|
||||
// Map clientmetrics to the usermetric counters.
|
||||
@@ -637,8 +631,6 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
|
||||
outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total)
|
||||
outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal)
|
||||
|
||||
outboundPacketsDroppedErrors.Set(usermetric.DropLabels{Reason: usermetric.ReasonError}, &m.outboundPacketsDroppedErrors)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -1210,13 +1202,8 @@ func (c *Conn) networkDown() bool { return !c.networkUp.Load() }
|
||||
// Send implements conn.Bind.
|
||||
//
|
||||
// See https://pkg.go.dev/golang.zx2c4.com/wireguard/conn#Bind.Send
|
||||
func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) (err error) {
|
||||
func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) error {
|
||||
n := int64(len(buffs))
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.metrics.outboundPacketsDroppedErrors.Add(n)
|
||||
}
|
||||
}()
|
||||
metricSendData.Add(n)
|
||||
if c.networkDown() {
|
||||
metricSendDataNetworkDown.Add(n)
|
||||
@@ -1369,7 +1356,7 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error)
|
||||
// An example of when they might be different: sending to an
|
||||
// IPv6 address when the local machine doesn't have IPv6 support
|
||||
// returns (false, nil); it's not an error, but nothing was sent.
|
||||
func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool) (sent bool, err error) {
|
||||
func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) {
|
||||
if addr.Addr() != tailcfg.DerpMagicIPAddr {
|
||||
return c.sendUDP(addr, b)
|
||||
}
|
||||
@@ -1392,7 +1379,7 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, is
|
||||
case <-c.donec:
|
||||
metricSendDERPErrorClosed.Add(1)
|
||||
return false, errConnClosed
|
||||
case ch <- derpWriteRequest{addr, pubKey, pkt, isDisco}:
|
||||
case ch <- derpWriteRequest{addr, pubKey, pkt}:
|
||||
metricSendDERPQueued.Add(1)
|
||||
return true, nil
|
||||
default:
|
||||
@@ -1590,8 +1577,7 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi
|
||||
|
||||
box := di.sharedKey.Seal(m.AppendMarshal(nil))
|
||||
pkt = append(pkt, box...)
|
||||
const isDisco = true
|
||||
sent, err = c.sendAddr(dst, dstKey, pkt, isDisco)
|
||||
sent, err = c.sendAddr(dst, dstKey, pkt)
|
||||
if sent {
|
||||
if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco()) {
|
||||
node := "?"
|
||||
|
||||
@@ -63,7 +63,6 @@ import (
|
||||
"tailscale.com/types/nettype"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/util/cibuild"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/racebuild"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/usermetric"
|
||||
@@ -3084,27 +3083,3 @@ func TestMaybeRebindOnError(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkDownSendErrors(t *testing.T) {
|
||||
netMon := must.Get(netmon.New(t.Logf))
|
||||
defer netMon.Close()
|
||||
|
||||
reg := new(usermetric.Registry)
|
||||
conn := must.Get(NewConn(Options{
|
||||
DisablePortMapper: true,
|
||||
Logf: t.Logf,
|
||||
NetMon: netMon,
|
||||
Metrics: reg,
|
||||
}))
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetNetworkUp(false)
|
||||
if err := conn.Send([][]byte{{00}}, &lazyEndpoint{}); err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
reg.Handler(resp, new(http.Request))
|
||||
if !strings.Contains(resp.Body.String(), `tailscaled_outbound_dropped_packets_total{reason="error"} 1`) {
|
||||
t.Errorf("expected NetworkDown to increment packet dropped metric; got %q", resp.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
"tailscale.com/drive"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/metrics"
|
||||
@@ -173,18 +174,19 @@ type Impl struct {
|
||||
// It can only be set before calling Start.
|
||||
ProcessSubnets bool
|
||||
|
||||
ipstack *stack.Stack
|
||||
linkEP *linkEndpoint
|
||||
tundev *tstun.Wrapper
|
||||
e wgengine.Engine
|
||||
pm *proxymap.Mapper
|
||||
mc *magicsock.Conn
|
||||
logf logger.Logf
|
||||
dialer *tsdial.Dialer
|
||||
ctx context.Context // alive until Close
|
||||
ctxCancel context.CancelFunc // called on Close
|
||||
lb *ipnlocal.LocalBackend // or nil
|
||||
dns *dns.Manager
|
||||
ipstack *stack.Stack
|
||||
linkEP *linkEndpoint
|
||||
tundev *tstun.Wrapper
|
||||
e wgengine.Engine
|
||||
pm *proxymap.Mapper
|
||||
mc *magicsock.Conn
|
||||
logf logger.Logf
|
||||
dialer *tsdial.Dialer
|
||||
ctx context.Context // alive until Close
|
||||
ctxCancel context.CancelFunc // called on Close
|
||||
lb *ipnlocal.LocalBackend // or nil
|
||||
dns *dns.Manager
|
||||
driveForLocal drive.FileSystemForLocal // or nil
|
||||
|
||||
// loopbackPort, if non-nil, will enable Impl to loop back (dnat to
|
||||
// <address-family-loopback>:loopbackPort) TCP & UDP flows originally
|
||||
@@ -286,7 +288,7 @@ func setTCPBufSizes(ipstack *stack.Stack) error {
|
||||
}
|
||||
|
||||
// Create creates and populates a new Impl.
|
||||
func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper) (*Impl, error) {
|
||||
func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper, driveForLocal drive.FileSystemForLocal) (*Impl, error) {
|
||||
if mc == nil {
|
||||
return nil, errors.New("nil magicsock.Conn")
|
||||
}
|
||||
@@ -380,6 +382,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
|
||||
connsInFlightByClient: make(map[netip.Addr]int),
|
||||
packetsInFlight: make(map[stack.TransportEndpointID]struct{}),
|
||||
dns: dns,
|
||||
driveForLocal: driveForLocal,
|
||||
}
|
||||
loopbackPort, ok := envknob.LookupInt("TS_DEBUG_NETSTACK_LOOPBACK_PORT")
|
||||
if ok && loopbackPort >= 0 && loopbackPort <= math.MaxUint16 {
|
||||
|
||||
@@ -65,7 +65,7 @@ func TestInjectInboundLeak(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper())
|
||||
ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl {
|
||||
tb.Cleanup(func() { eng.Close() })
|
||||
sys.Set(eng)
|
||||
|
||||
ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper())
|
||||
ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1236,7 +1236,7 @@ func (e *userspaceEngine) linkChange(delta *netmon.ChangeDelta) {
|
||||
// and Apple platforms.
|
||||
if changed {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android", "ios", "darwin", "openbsd":
|
||||
case "linux", "android", "ios", "darwin":
|
||||
e.wgLock.Lock()
|
||||
dnsCfg := e.lastDNSConfig
|
||||
e.wgLock.Unlock()
|
||||
|
||||
@@ -21,7 +21,6 @@ type Config struct {
|
||||
NodeID tailcfg.StableNodeID
|
||||
PrivateKey key.NodePrivate
|
||||
Addresses []netip.Prefix
|
||||
ListenPort uint16 // not used by Tailscale's conn.Bind implementation
|
||||
MTU uint16
|
||||
DNS []netip.Addr
|
||||
Peers []Peer
|
||||
|
||||
@@ -124,13 +124,7 @@ func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case k.EqualString("listen_port"):
|
||||
port, err := mem.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse listen_port: %w", err)
|
||||
}
|
||||
cfg.ListenPort = uint16(port)
|
||||
case k.EqualString("fwmark"):
|
||||
case k.EqualString("listen_port") || k.EqualString("fwmark"):
|
||||
// ignore
|
||||
default:
|
||||
return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy())
|
||||
|
||||
@@ -39,7 +39,6 @@ var _ConfigCloneNeedsRegeneration = Config(struct {
|
||||
NodeID tailcfg.StableNodeID
|
||||
PrivateKey key.NodePrivate
|
||||
Addresses []netip.Prefix
|
||||
ListenPort uint16
|
||||
MTU uint16
|
||||
DNS []netip.Addr
|
||||
Peers []Peer
|
||||
|
||||
@@ -42,9 +42,6 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
if !prev.PrivateKey.Equal(cfg.PrivateKey) {
|
||||
set("private_key", cfg.PrivateKey.UntypedHexString())
|
||||
}
|
||||
if prev.ListenPort != cfg.ListenPort {
|
||||
setUint16("listen_port", cfg.ListenPort)
|
||||
}
|
||||
|
||||
old := make(map[key.NodePublic]Peer)
|
||||
for _, p := range prev.Peers {
|
||||
@@ -90,9 +87,7 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
|
||||
// See corp issue 3016.
|
||||
logf("[unexpected] endpoint changed from %s to %s", oldPeer.WGEndpoint, p.PublicKey)
|
||||
}
|
||||
if cfg.NodeID != "" {
|
||||
set("endpoint", p.PublicKey.UntypedHexString())
|
||||
}
|
||||
set("endpoint", p.PublicKey.UntypedHexString())
|
||||
}
|
||||
|
||||
// TODO: replace_allowed_ips is expensive.
|
||||
|
||||
Reference in New Issue
Block a user