Compare commits

..

1 Commits

Author SHA1 Message Date
Andrew Dunham
fc4048014e control/keyfallback: add baked-in fallback for control key
Similar to how we bake in the DERPMap to ensure that we can reach the
DERP servers if DNS isn't working, also bake in the control key for the
default control server that we use if the control server is down.

Updates #13890

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I18ef0381e266bd3db10063685993bc3cb76b2f42
2024-10-24 11:35:48 -05:00
98 changed files with 1677 additions and 4567 deletions

View File

@@ -40,7 +40,6 @@ import (
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
"tailscale.com/types/tkatype"
"tailscale.com/util/syspolicy/setting"
)
// defaultLocalClient is the default LocalClient when using the legacy
@@ -815,33 +814,6 @@ func (lc *LocalClient) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn
return decodeJSON[*ipn.Prefs](body)
}
// GetEffectivePolicy returns the effective policy for the specified scope.
func (lc *LocalClient) GetEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) {
scopeID, err := scope.MarshalText()
if err != nil {
return nil, err
}
body, err := lc.get200(ctx, "/localapi/v0/policy/"+string(scopeID))
if err != nil {
return nil, err
}
return decodeJSON[*setting.Snapshot](body)
}
// ReloadEffectivePolicy reloads the effective policy for the specified scope
// by reading and merging policy settings from all applicable policy sources.
func (lc *LocalClient) ReloadEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) {
scopeID, err := scope.MarshalText()
if err != nil {
return nil, err
}
body, err := lc.send(ctx, "POST", "/localapi/v0/policy/"+string(scopeID), 200, http.NoBody)
if err != nil {
return nil, err
}
return decodeJSON[*setting.Snapshot](body)
}
// GetDNSOSConfig returns the system DNS configuration for the current device.
// That is, it returns the DNS configuration that the system would use if Tailscale weren't being used.
func (lc *LocalClient) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig, error) {

View File

@@ -164,16 +164,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
tailscale.com/util/slicesx from tailscale.com/cmd/derper+
tailscale.com/util/syspolicy from tailscale.com/ipn
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
tailscale.com/util/testenv from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy
tailscale.com/util/usermetric from tailscale.com/health
tailscale.com/util/vizerror from tailscale.com/tailcfg+
W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
tailscale.com/version from tailscale.com/derp+
tailscale.com/version/distro from tailscale.com/envknob+
@@ -194,7 +189,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
golang.org/x/crypto/sha3 from crypto/internal/mlkem768+
W golang.org/x/exp/constraints from tailscale.com/util/winutil
golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting+
golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting
L golang.org/x/net/bpf from github.com/mdlayher/netlink+
golang.org/x/net/dns/dnsmessage from net+
golang.org/x/net/http/httpguts from net/http
@@ -255,7 +250,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
encoding/pem from crypto/tls+
errors from bufio+
expvar from github.com/prometheus/client_golang/prometheus+
flag from tailscale.com/cmd/derper+
flag from tailscale.com/cmd/derper
fmt from compress/flate+
go/token from google.golang.org/protobuf/internal/strs
hash from crypto+
@@ -289,7 +284,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
os from crypto/rand+
os/exec from github.com/coreos/go-iptables/iptables+
os/signal from tailscale.com/cmd/derper
W os/user from tailscale.com/util/winutil+
W os/user from tailscale.com/util/winutil
path from github.com/prometheus/client_golang/prometheus/internal+
path/filepath from crypto/x509+
reflect from crypto/x509+

View File

@@ -659,6 +659,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/control/controlclient from tailscale.com/ipn/ipnlocal+
tailscale.com/control/controlhttp from tailscale.com/control/controlclient
tailscale.com/control/controlknobs from tailscale.com/control/controlclient+
tailscale.com/control/keyfallback from tailscale.com/control/controlclient
tailscale.com/derp from tailscale.com/derp/derphttp+
tailscale.com/derp/derphttp from tailscale.com/ipn/localapi+
tailscale.com/disco from tailscale.com/derp+
@@ -812,11 +813,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/util/slicesx from tailscale.com/appc+
tailscale.com/util/syspolicy from tailscale.com/control/controlclient+
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
tailscale.com/util/systemd from tailscale.com/control/controlclient+
tailscale.com/util/testenv from tailscale.com/control/controlclient+
@@ -826,7 +824,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/util/vizerror from tailscale.com/tailcfg+
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+

View File

@@ -1896,182 +1896,6 @@ spec:
Value is the taint value the toleration matches to.
If the operator is Exists, the value should be empty, otherwise just a regular string.
type: string
topologySpreadConstraints:
description: |-
Proxy Pod's topology spread constraints.
By default Tailscale Kubernetes operator does not apply any topology spread constraints.
https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/
type: array
items:
description: TopologySpreadConstraint specifies how to spread matching pods among the given topology.
type: object
required:
- maxSkew
- topologyKey
- whenUnsatisfiable
properties:
labelSelector:
description: |-
LabelSelector is used to find matching pods.
Pods that match this label selector are counted to determine the number of pods
in their corresponding topology domain.
type: object
properties:
matchExpressions:
description: matchExpressions is a list of label selector requirements. The requirements are ANDed.
type: array
items:
description: |-
A label selector requirement is a selector that contains values, a key, and an operator that
relates the key and values.
type: object
required:
- key
- operator
properties:
key:
description: key is the label key that the selector applies to.
type: string
operator:
description: |-
operator represents a key's relationship to a set of values.
Valid operators are In, NotIn, Exists and DoesNotExist.
type: string
values:
description: |-
values is an array of string values. If the operator is In or NotIn,
the values array must be non-empty. If the operator is Exists or DoesNotExist,
the values array must be empty. This array is replaced during a strategic
merge patch.
type: array
items:
type: string
x-kubernetes-list-type: atomic
x-kubernetes-list-type: atomic
matchLabels:
description: |-
matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels
map is equivalent to an element of matchExpressions, whose key field is "key", the
operator is "In", and the values array contains only "value". The requirements are ANDed.
type: object
additionalProperties:
type: string
x-kubernetes-map-type: atomic
matchLabelKeys:
description: |-
MatchLabelKeys is a set of pod label keys to select the pods over which
spreading will be calculated. The keys are used to lookup values from the
incoming pod labels, those key-value labels are ANDed with labelSelector
to select the group of existing pods over which spreading will be calculated
for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector.
MatchLabelKeys cannot be set when LabelSelector isn't set.
Keys that don't exist in the incoming pod labels will
be ignored. A null or empty list means only match against labelSelector.
This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default).
type: array
items:
type: string
x-kubernetes-list-type: atomic
maxSkew:
description: |-
MaxSkew describes the degree to which pods may be unevenly distributed.
When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference
between the number of matching pods in the target topology and the global minimum.
The global minimum is the minimum number of matching pods in an eligible domain
or zero if the number of eligible domains is less than MinDomains.
For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same
labelSelector spread as 2/2/1:
In this case, the global minimum is 1.
| zone1 | zone2 | zone3 |
| P P | P P | P |
- if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2;
scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2)
violate MaxSkew(1).
- if MaxSkew is 2, incoming pod can be scheduled onto any zone.
When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence
to topologies that satisfy it.
It's a required field. Default value is 1 and 0 is not allowed.
type: integer
format: int32
minDomains:
description: |-
MinDomains indicates a minimum number of eligible domains.
When the number of eligible domains with matching topology keys is less than minDomains,
Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed.
And when the number of eligible domains with matching topology keys equals or greater than minDomains,
this value has no effect on scheduling.
As a result, when the number of eligible domains is less than minDomains,
scheduler won't schedule more than maxSkew Pods to those domains.
If value is nil, the constraint behaves as if MinDomains is equal to 1.
Valid values are integers greater than 0.
When value is not nil, WhenUnsatisfiable must be DoNotSchedule.
For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same
labelSelector spread as 2/2/2:
| zone1 | zone2 | zone3 |
| P P | P P | P P |
The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0.
In this situation, new pod with the same labelSelector cannot be scheduled,
because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones,
it will violate MaxSkew.
type: integer
format: int32
nodeAffinityPolicy:
description: |-
NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector
when calculating pod topology spread skew. Options are:
- Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations.
- Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations.
If this value is nil, the behavior is equivalent to the Honor policy.
This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag.
type: string
nodeTaintsPolicy:
description: |-
NodeTaintsPolicy indicates how we will treat node taints when calculating
pod topology spread skew. Options are:
- Honor: nodes without taints, along with tainted nodes for which the incoming pod
has a toleration, are included.
- Ignore: node taints are ignored. All nodes are included.
If this value is nil, the behavior is equivalent to the Ignore policy.
This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag.
type: string
topologyKey:
description: |-
TopologyKey is the key of node labels. Nodes that have a label with this key
and identical values are considered to be in the same topology.
We consider each <key, value> as a "bucket", and try to put balanced number
of pods into each bucket.
We define a domain as a particular instance of a topology.
Also, we define an eligible domain as a domain whose nodes meet the requirements of
nodeAffinityPolicy and nodeTaintsPolicy.
e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology.
And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology.
It's a required field.
type: string
whenUnsatisfiable:
description: |-
WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy
the spread constraint.
- DoNotSchedule (default) tells the scheduler not to schedule it.
- ScheduleAnyway tells the scheduler to schedule the pod in any location,
but giving higher precedence to topologies that would help reduce the
skew.
A constraint is considered "Unsatisfiable" for an incoming pod
if and only if every possible node assignment for that pod would violate
"MaxSkew" on some topology.
For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same
labelSelector spread as 3/1/1:
| zone1 | zone2 | zone3 |
| P P P | P | P |
If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled
to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies
MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler
won't make it *more* imbalanced.
It's a required field.
type: string
tailscale:
description: |-
TailscaleConfig contains options to configure the tailscale-specific

View File

@@ -2323,182 +2323,6 @@ spec:
type: string
type: object
type: array
topologySpreadConstraints:
description: |-
Proxy Pod's topology spread constraints.
By default Tailscale Kubernetes operator does not apply any topology spread constraints.
https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/
items:
description: TopologySpreadConstraint specifies how to spread matching pods among the given topology.
properties:
labelSelector:
description: |-
LabelSelector is used to find matching pods.
Pods that match this label selector are counted to determine the number of pods
in their corresponding topology domain.
properties:
matchExpressions:
description: matchExpressions is a list of label selector requirements. The requirements are ANDed.
items:
description: |-
A label selector requirement is a selector that contains values, a key, and an operator that
relates the key and values.
properties:
key:
description: key is the label key that the selector applies to.
type: string
operator:
description: |-
operator represents a key's relationship to a set of values.
Valid operators are In, NotIn, Exists and DoesNotExist.
type: string
values:
description: |-
values is an array of string values. If the operator is In or NotIn,
the values array must be non-empty. If the operator is Exists or DoesNotExist,
the values array must be empty. This array is replaced during a strategic
merge patch.
items:
type: string
type: array
x-kubernetes-list-type: atomic
required:
- key
- operator
type: object
type: array
x-kubernetes-list-type: atomic
matchLabels:
additionalProperties:
type: string
description: |-
matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels
map is equivalent to an element of matchExpressions, whose key field is "key", the
operator is "In", and the values array contains only "value". The requirements are ANDed.
type: object
type: object
x-kubernetes-map-type: atomic
matchLabelKeys:
description: |-
MatchLabelKeys is a set of pod label keys to select the pods over which
spreading will be calculated. The keys are used to lookup values from the
incoming pod labels, those key-value labels are ANDed with labelSelector
to select the group of existing pods over which spreading will be calculated
for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector.
MatchLabelKeys cannot be set when LabelSelector isn't set.
Keys that don't exist in the incoming pod labels will
be ignored. A null or empty list means only match against labelSelector.
This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default).
items:
type: string
type: array
x-kubernetes-list-type: atomic
maxSkew:
description: |-
MaxSkew describes the degree to which pods may be unevenly distributed.
When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference
between the number of matching pods in the target topology and the global minimum.
The global minimum is the minimum number of matching pods in an eligible domain
or zero if the number of eligible domains is less than MinDomains.
For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same
labelSelector spread as 2/2/1:
In this case, the global minimum is 1.
| zone1 | zone2 | zone3 |
| P P | P P | P |
- if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2;
scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2)
violate MaxSkew(1).
- if MaxSkew is 2, incoming pod can be scheduled onto any zone.
When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence
to topologies that satisfy it.
It's a required field. Default value is 1 and 0 is not allowed.
format: int32
type: integer
minDomains:
description: |-
MinDomains indicates a minimum number of eligible domains.
When the number of eligible domains with matching topology keys is less than minDomains,
Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed.
And when the number of eligible domains with matching topology keys equals or greater than minDomains,
this value has no effect on scheduling.
As a result, when the number of eligible domains is less than minDomains,
scheduler won't schedule more than maxSkew Pods to those domains.
If value is nil, the constraint behaves as if MinDomains is equal to 1.
Valid values are integers greater than 0.
When value is not nil, WhenUnsatisfiable must be DoNotSchedule.
For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same
labelSelector spread as 2/2/2:
| zone1 | zone2 | zone3 |
| P P | P P | P P |
The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0.
In this situation, new pod with the same labelSelector cannot be scheduled,
because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones,
it will violate MaxSkew.
format: int32
type: integer
nodeAffinityPolicy:
description: |-
NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector
when calculating pod topology spread skew. Options are:
- Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations.
- Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations.
If this value is nil, the behavior is equivalent to the Honor policy.
This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag.
type: string
nodeTaintsPolicy:
description: |-
NodeTaintsPolicy indicates how we will treat node taints when calculating
pod topology spread skew. Options are:
- Honor: nodes without taints, along with tainted nodes for which the incoming pod
has a toleration, are included.
- Ignore: node taints are ignored. All nodes are included.
If this value is nil, the behavior is equivalent to the Ignore policy.
This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag.
type: string
topologyKey:
description: |-
TopologyKey is the key of node labels. Nodes that have a label with this key
and identical values are considered to be in the same topology.
We consider each <key, value> as a "bucket", and try to put balanced number
of pods into each bucket.
We define a domain as a particular instance of a topology.
Also, we define an eligible domain as a domain whose nodes meet the requirements of
nodeAffinityPolicy and nodeTaintsPolicy.
e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology.
And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology.
It's a required field.
type: string
whenUnsatisfiable:
description: |-
WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy
the spread constraint.
- DoNotSchedule (default) tells the scheduler not to schedule it.
- ScheduleAnyway tells the scheduler to schedule the pod in any location,
but giving higher precedence to topologies that would help reduce the
skew.
A constraint is considered "Unsatisfiable" for an incoming pod
if and only if every possible node assignment for that pod would violate
"MaxSkew" on some topology.
For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same
labelSelector spread as 3/1/1:
| zone1 | zone2 | zone3 |
| P P P | P | P |
If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled
to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies
MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler
won't make it *more* imbalanced.
It's a required field.
type: string
required:
- maxSkew
- topologyKey
- whenUnsatisfiable
type: object
type: array
type: object
type: object
tailscale:

View File

@@ -432,148 +432,6 @@ func TestTailnetTargetIPAnnotation(t *testing.T) {
expectMissing[corev1.Secret](t, fc, "operator-ns", fullName)
}
func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) {
fc := fake.NewFakeClient()
ft := &fakeTSClient{}
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatal(err)
}
clock := tstest.NewClock(tstest.ClockOpts{})
sr := &ServiceReconciler{
Client: fc,
ssr: &tailscaleSTSReconciler{
Client: fc,
tsClient: ft,
defaultTags: []string{"tag:k8s"},
operatorNamespace: "operator-ns",
proxyImage: "tailscale/tailscale",
},
logger: zl.Sugar(),
clock: clock,
recorder: record.NewFakeRecorder(100),
}
tailnetTargetIP := "invalid-ip"
mustCreate(t, fc, &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "default",
UID: types.UID("1234-UID"),
Annotations: map[string]string{
AnnotationTailnetTargetIP: tailnetTargetIP,
},
},
Spec: corev1.ServiceSpec{
ClusterIP: "10.20.30.40",
Type: corev1.ServiceTypeLoadBalancer,
LoadBalancerClass: ptr.To("tailscale"),
},
})
expectReconciled(t, sr, "default", "test")
t0 := conditionTime(clock)
want := &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "default",
UID: types.UID("1234-UID"),
Annotations: map[string]string{
AnnotationTailnetTargetIP: tailnetTargetIP,
},
},
Spec: corev1.ServiceSpec{
ClusterIP: "10.20.30.40",
Type: corev1.ServiceTypeLoadBalancer,
LoadBalancerClass: ptr.To("tailscale"),
},
Status: corev1.ServiceStatus{
Conditions: []metav1.Condition{{
Type: string(tsapi.ProxyReady),
Status: metav1.ConditionFalse,
LastTransitionTime: t0,
Reason: reasonProxyInvalid,
Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "invalid-ip" could not be parsed as a valid IP Address, error: ParseAddr("invalid-ip"): unable to parse IP`,
}},
},
}
expectEqual(t, fc, want, nil)
}
func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) {
fc := fake.NewFakeClient()
ft := &fakeTSClient{}
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatal(err)
}
clock := tstest.NewClock(tstest.ClockOpts{})
sr := &ServiceReconciler{
Client: fc,
ssr: &tailscaleSTSReconciler{
Client: fc,
tsClient: ft,
defaultTags: []string{"tag:k8s"},
operatorNamespace: "operator-ns",
proxyImage: "tailscale/tailscale",
},
logger: zl.Sugar(),
clock: clock,
recorder: record.NewFakeRecorder(100),
}
tailnetTargetIP := "999.999.999.999"
mustCreate(t, fc, &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "default",
UID: types.UID("1234-UID"),
Annotations: map[string]string{
AnnotationTailnetTargetIP: tailnetTargetIP,
},
},
Spec: corev1.ServiceSpec{
ClusterIP: "10.20.30.40",
Type: corev1.ServiceTypeLoadBalancer,
LoadBalancerClass: ptr.To("tailscale"),
},
})
expectReconciled(t, sr, "default", "test")
t0 := conditionTime(clock)
want := &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "default",
UID: types.UID("1234-UID"),
Annotations: map[string]string{
AnnotationTailnetTargetIP: tailnetTargetIP,
},
},
Spec: corev1.ServiceSpec{
ClusterIP: "10.20.30.40",
Type: corev1.ServiceTypeLoadBalancer,
LoadBalancerClass: ptr.To("tailscale"),
},
Status: corev1.ServiceStatus{
Conditions: []metav1.Condition{{
Type: string(tsapi.ProxyReady),
Status: metav1.ConditionFalse,
LastTransitionTime: t0,
Reason: reasonProxyInvalid,
Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "999.999.999.999" could not be parsed as a valid IP Address, error: ParseAddr("999.999.999.999"): IPv4 field has value >255`,
}},
},
}
expectEqual(t, fc, want, nil)
}
func TestAnnotations(t *testing.T) {
fc := fake.NewFakeClient()
ft := &fakeTSClient{}

View File

@@ -718,7 +718,6 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet,
ss.Spec.Template.Spec.NodeSelector = wantsPod.NodeSelector
ss.Spec.Template.Spec.Affinity = wantsPod.Affinity
ss.Spec.Template.Spec.Tolerations = wantsPod.Tolerations
ss.Spec.Template.Spec.TopologySpreadConstraints = wantsPod.TopologySpreadConstraints
// Update containers.
updateContainer := func(overlay *tsapi.Container, base corev1.Container) corev1.Container {

View File

@@ -18,7 +18,6 @@ import (
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/yaml"
tsapi "tailscale.com/k8s-operator/apis/v1alpha1"
"tailscale.com/types/ptr"
@@ -74,16 +73,6 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) {
NodeSelector: map[string]string{"beta.kubernetes.io/os": "linux"},
Affinity: &corev1.Affinity{NodeAffinity: &corev1.NodeAffinity{RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{}}},
Tolerations: []corev1.Toleration{{Key: "", Operator: "Exists"}},
TopologySpreadConstraints: []corev1.TopologySpreadConstraint{
{
WhenUnsatisfiable: "DoNotSchedule",
TopologyKey: "kubernetes.io/hostname",
MaxSkew: 3,
LabelSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{"foo": "bar"},
},
},
},
TailscaleContainer: &tsapi.Container{
SecurityContext: &corev1.SecurityContext{
Privileged: ptr.To(true),
@@ -170,7 +159,6 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) {
wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector
wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity
wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations
wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints
wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext
wantSS.Spec.Template.Spec.InitContainers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleInitContainer.SecurityContext
wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources
@@ -213,7 +201,6 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) {
wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector
wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity
wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations
wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints
wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext
wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources
wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, []corev1.EnvVar{{Name: "foo", Value: "bar"}, {Name: "TS_USERSPACE", Value: "true"}, {Name: "bar"}}...)

View File

@@ -358,14 +358,9 @@ func validateService(svc *corev1.Service) []string {
violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q does not appear to be a valid MagicDNS name", AnnotationTailnetTargetFQDN, fqdn))
}
}
if ipStr := svc.Annotations[AnnotationTailnetTargetIP]; ipStr != "" {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q could not be parsed as a valid IP Address, error: %s", AnnotationTailnetTargetIP, ipStr, err))
} else if !ip.IsValid() {
violations = append(violations, fmt.Sprintf("parsed IP address in annotation %s: %q is not valid", AnnotationTailnetTargetIP, ipStr))
}
}
// TODO(irbekrm): validate that tailscale.com/tailnet-ip annotation is a
// valid IP address (tailscale/tailscale#13671).
svcName := nameForService(svc)
if err := dnsname.ValidLabel(svcName); err != nil {

View File

@@ -1,39 +0,0 @@
# Tailscale LOPOWER
"Little Opinionated Proxy Over Wireguard-encrypted Routes"
**STATUS**: in-development alpha (as of 2024-11-03)
## Background
Some small devices such as ESP32 microcontrollers [support WireGuard](https://github.com/ciniml/WireGuard-ESP32-Arduino) but are too small to run Tailscale.
Tailscale LOPOWER is a proxy that you run nearby that bridges a low-power WireGuard-speaking device on one side to Tailscale on the other side. That way network traffic from the low-powered device never hits the network unencrypted but is still able to communicate to/from other Tailscale devices on your Tailnet.
## Diagram
<img src="./lopower.svg">
## Features
* Runs separate Wireguard server with separate keys (unknown to the Tailscale control plane) that proxy on to Tailscale
* Outputs WireGuard-standard configuration to enrolls devices, including in QR code form.
* embeds `tsnet`, with an identity on which the device(s) behind the proxy appear on your Tailnet
* optional IPv4 support. IPv6 is always enabled, as it never conflicts with anything. But IPv4 (or CGNAT) might already be in use on your client's network.
* includes a DNS server (at `fd7a:115c:a1e0:9909::1` by default and optionally also at `10.90.0.1`) to serve both MagicDNS names as well as forwarding non-Tailscale DNS names onwards
* if IPv4 is disabled, MagicDNS `A` records are filtered out, and only `AAAA` records are served.
## Limitations
* this runs in userspace using gVisor's netstack. That means it's portable (and doesn't require kernel/system configuration), but that does mean it doesn't operate at a packet level but rather it stitches together two separate TCP (or UDP) flows and doesn't support IP protocols such as SCTP or other things that aren't TCP or UDP.
* the standard WireGuard configuration doesn't support specifying DNS search domains, so resolving bare names like the `go` in `http://go/foo` won't work and you need to resolve names using the fully qualified `go.your-tailnet.ts.net` names.
* since it's based on userspace tsnet mode, it doesn't pick up your system DNS configuration (yet?) and instead resolves non-tailnet DNS names using either your "Override DNS" tailnet settings for the global DNS resolver, or else defaults to `8.8.8.8` and `1.1.1.1` (using DoH) if that isn't set.
## TODO
* provisioning more than one low-powered device is possible, but requires manual config file edits. It should be possible to enroll multiple devices (including QR code support) easily.
* incoming connections (from Tailscale to `lopower`) don't yet forward to the low-powered devices. When there's only one low-powered device, the mapping policy is obvious. When there are multiple, it's not as obvious. Maybe the answer is supporting [4via6 subnet routers](https://tailscale.com/kb/1201/4via6-subnets).
## Installing
* git clone this repo, switch to `lp` branch, `go install ./cmd/lopower` and see `lopower --help`.

View File

@@ -1,872 +0,0 @@
// The lopower server is a "Little Opinionated Proxy Over
// Wireguard-Encrypted Route". It bridges a static WireGuard
// client into a Tailscale network.
package main
import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"math/rand/v2"
"net"
"net/http"
"net/netip"
"os"
"os/signal"
"path/filepath"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
qrcode "github.com/skip2/go-qrcode"
"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
"tailscale.com/net/packet"
"tailscale.com/net/tsaddr"
"tailscale.com/syncs"
"tailscale.com/tsnet"
"tailscale.com/types/dnstype"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/must"
"tailscale.com/wgengine/wgcfg"
)
var (
wgListenPort = flag.Int("wg-port", 51820, "port number to listen on for WireGuard from the client")
confDir = flag.String("dir", filepath.Join(os.Getenv("HOME"), ".config/lopower"), "directory to store configuration in")
wgPubHost = flag.String("wg-host", "0.0.0.1", "IP address of lopower's WireGuard server that's accessible from the client")
qrListenAddr = flag.String("qr-listen", "127.0.0.1:8014", "HTTP address to serve a QR code for client's WireGuard configuration, or empty for none")
printConfig = flag.Bool("print-config", true, "print the client's WireGuard configuration to stdout on startup")
includeV4 = flag.Bool("include-v4", true, "include IPv4 (CGNAT) in the WireGuard configuration; incompatible with some carriers. IPv6 is always included.")
verbosePackets = flag.Bool("verbose-packets", false, "log packet contents")
)
type config struct {
PrivKey key.NodePrivate // the proxy server's key
Peers []Peer
// V4 and V6 are the local IPs.
V4 netip.Addr
V6 netip.Addr
// CIDRs are used to allocate IPs to peers.
V4CIDR netip.Prefix
V6CIDR netip.Prefix
}
// IsLocalIP reports whether ip is one of the local IPs.
func (c *config) IsLocalIP(ip netip.Addr) bool {
return ip.IsValid() && (ip == c.V4 || ip == c.V6)
}
type Peer struct {
PrivKey key.NodePrivate // e.g. proxy client's
V4 netip.Addr
V6 netip.Addr
}
func (lp *lpServer) storeConfigLocked() {
path := filepath.Join(lp.dir, "config.json")
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
log.Fatalf("os.MkdirAll(%q): %v", filepath.Dir(path), err)
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
log.Fatalf("os.OpenFile(%q): %v", path, err)
}
defer f.Close()
must.Do(json.NewEncoder(f).Encode(lp.c))
if err := f.Close(); err != nil {
log.Fatalf("f.Close: %v", err)
}
}
func (lp *lpServer) loadConfig() {
path := filepath.Join(lp.dir, "config.json")
f, err := os.Open(path)
if err == nil {
defer f.Close()
var cfg *config
must.Do(json.NewDecoder(f).Decode(&cfg))
if len(cfg.Peers) > 0 { // as early version didn't set this
lp.mu.Lock()
defer lp.mu.Unlock()
lp.c = cfg
}
return
}
if !os.IsNotExist(err) {
log.Fatalf("os.OpenFile(%q): %v", path, err)
}
const defaultV4CIDR = "10.90.0.0/24"
const defaultV6CIDR = "fd7a:115c:a1e0:9909::/64" // 9909 = above QWERTY "LOPO"(wer)
c := &config{
PrivKey: key.NewNode(),
V4CIDR: netip.MustParsePrefix(defaultV4CIDR),
V6CIDR: netip.MustParsePrefix(defaultV6CIDR),
}
c.V4 = c.V4CIDR.Addr().Next()
c.V6 = c.V6CIDR.Addr().Next()
c.Peers = append(c.Peers, Peer{
PrivKey: key.NewNode(),
V4: c.V4.Next(),
V6: c.V6.Next(),
})
lp.mu.Lock()
defer lp.mu.Unlock()
lp.c = c
lp.storeConfigLocked()
return
}
func (lp *lpServer) reconfig() {
lp.mu.Lock()
wc := &wgcfg.Config{
Name: "lopower0",
PrivateKey: lp.c.PrivKey,
ListenPort: uint16(*wgListenPort),
Addresses: []netip.Prefix{
netip.PrefixFrom(lp.c.V4, 32),
netip.PrefixFrom(lp.c.V6, 128),
},
}
for _, p := range lp.c.Peers {
wc.Peers = append(wc.Peers, wgcfg.Peer{
PublicKey: p.PrivKey.Public(),
AllowedIPs: []netip.Prefix{
netip.PrefixFrom(p.V4, 32),
netip.PrefixFrom(p.V6, 128),
},
})
}
lp.mu.Unlock()
must.Do(wgcfg.ReconfigDevice(lp.d, wc, log.Printf))
}
func newLP(ctx context.Context) *lpServer {
logf := log.Printf
deviceLogger := &device.Logger{
Verbosef: logger.Discard,
Errorf: logf,
}
lp := &lpServer{
ctx: ctx,
dir: *confDir,
readCh: make(chan *stack.PacketBuffer, 16),
}
lp.loadConfig()
lp.initNetstack(ctx)
nst := &nsTUN{
lp: lp,
closeCh: make(chan struct{}),
evChan: make(chan tun.Event),
}
wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger)
lp.d = wgdev
must.Do(wgdev.Up())
lp.reconfig()
if *printConfig {
log.Printf("Device Wireguard config is:\n%s", lp.wgConfigForQR())
}
lp.startTSNet(ctx)
return lp
}
type lpServer struct {
dir string
tsnet *tsnet.Server
d *device.Device
ns *stack.Stack
ctx context.Context // canceled on shutdown
linkEP *channel.Endpoint
readCh chan *stack.PacketBuffer // from gvisor/dns server => out to network
// protocolConns tracks the number of active connections for each connection.
// It is used to add and remove protocol addresses from netstack as needed.
protocolConns syncs.Map[tcpip.ProtocolAddress, *atomic.Int32]
mu sync.Mutex // protects following
c *config
}
// MaxPacketSize is the maximum size (in bytes)
// of a packet that can be injected into lpServer.
const MaxPacketSize = device.MaxContentSize
const nicID = 1
func (lp *lpServer) initNetstack(ctx context.Context) error {
ns := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
icmp.NewProtocol4,
udp.NewProtocol,
},
})
lp.ns = ns
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
if tcpipErr := ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt); tcpipErr != nil {
return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr)
}
lp.linkEP = channel.New(512, 1280, "")
if tcpipProblem := ns.CreateNIC(nicID, lp.linkEP); tcpipProblem != nil {
return fmt.Errorf("CreateNIC: %v", tcpipProblem)
}
ns.SetPromiscuousMode(nicID, true)
lp.mu.Lock()
v4, v6 := lp.c.V4, lp.c.V6
lp.mu.Unlock()
prefix := tcpip.AddrFrom4Slice(v4.AsSlice()).WithPrefix()
if *includeV4 {
if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: prefix,
}, stack.AddressProperties{}); tcpProb != nil {
return errors.New(tcpProb.String())
}
}
prefix = tcpip.AddrFrom16Slice(v6.AsSlice()).WithPrefix()
if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: prefix,
}, stack.AddressProperties{}); tcpProb != nil {
return errors.New(tcpProb.String())
}
ipv4Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 4)), tcpip.MaskFromBytes(make([]byte, 4)))
if err != nil {
return fmt.Errorf("could not create IPv4 subnet: %v", err)
}
ipv6Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 16)), tcpip.MaskFromBytes(make([]byte, 16)))
if err != nil {
return fmt.Errorf("could not create IPv6 subnet: %v", err)
}
routes := []tcpip.Route{{
Destination: ipv4Subnet,
NIC: nicID,
}, {
Destination: ipv6Subnet,
NIC: nicID,
}}
if !*includeV4 {
routes = routes[1:]
}
ns.SetRouteTable(routes)
const tcpReceiveBufferSize = 0 // default
const maxInFlightConnectionAttempts = 8192
tcpFwd := tcp.NewForwarder(ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, lp.acceptTCP)
udpFwd := udp.NewForwarder(ns, lp.acceptUDP)
ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) {
return tcpFwd.HandlePacket(tei, pb)
})
ns.SetTransportProtocolHandler(udp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) {
return udpFwd.HandlePacket(tei, pb)
})
go func() {
for {
pkt := lp.linkEP.ReadContext(ctx)
if pkt == nil {
if ctx.Err() != nil {
// Return without logging.
log.Printf("linkEP.ReadContext: %v", ctx.Err())
return
}
continue
}
size := pkt.Size()
if size > MaxPacketSize || size == 0 {
pkt.DecRef()
continue
}
select {
case lp.readCh <- pkt:
case <-ctx.Done():
}
}
}()
return nil
}
func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr {
switch s.Len() {
case 4:
return netip.AddrFrom4(s.As4())
case 16:
return netip.AddrFrom16(s.As16()).Unmap()
}
return netip.Addr{}
}
func (lp *lpServer) trackProtocolAddr(destIP netip.Addr) (untrack func()) {
pa := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddrFromSlice(destIP.AsSlice()).WithPrefix(),
}
if destIP.Is4() {
pa.Protocol = ipv4.ProtocolNumber
} else if destIP.Is6() {
pa.Protocol = ipv6.ProtocolNumber
}
addrConns, _ := lp.protocolConns.LoadOrInit(pa, func() *atomic.Int32 { return new(atomic.Int32) })
if addrConns.Add(1) == 1 {
lp.ns.AddProtocolAddress(nicID, pa, stack.AddressProperties{
PEB: stack.CanBePrimaryEndpoint, // zero value default
ConfigType: stack.AddressConfigStatic, // zero value default
})
}
return func() {
if addrConns.Add(-1) == 0 {
lp.ns.RemoveAddress(nicID, pa.AddressWithPrefix.Address)
}
}
}
func (lp *lpServer) acceptUDP(r *udp.ForwarderRequest) {
log.Printf("acceptUDP: %v", r.ID())
destIP := netaddrIPFromNetstackIP(r.ID().LocalAddress)
untrack := lp.trackProtocolAddr(destIP)
var wq waiter.Queue
ep, udpErr := r.CreateEndpoint(&wq)
if udpErr != nil {
log.Printf("CreateEndpoint: %v", udpErr)
return
}
go func() {
defer untrack()
defer ep.Close()
reqDetails := r.ID()
clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
destPort := reqDetails.LocalPort
if !clientRemoteIP.IsValid() {
log.Printf("acceptUDP: invalid remote IP %v", reqDetails.RemoteAddress)
return
}
randPort := rand.IntN(65536-1024) + 1024
v4, v6 := lp.tsnet.TailscaleIPs()
var listenAddr netip.Addr
if destIP.Is4() {
listenAddr = v4
} else {
listenAddr = v6
}
backendConn, err := lp.tsnet.ListenPacket("udp", fmt.Sprintf("%s:%d", listenAddr, randPort))
if err != nil {
log.Printf("ListenPacket: %v", err)
return
}
defer backendConn.Close()
clientConn := gonet.NewUDPConn(&wq, ep)
defer clientConn.Close()
errCh := make(chan error, 2)
go func() (err error) {
defer func() { errCh <- err }()
var buf [64]byte
for {
n, _, err := backendConn.ReadFrom(buf[:])
if err != nil {
log.Printf("UDP read: %v", err)
return err
}
_, err = clientConn.Write(buf[:n])
if err != nil {
return err
}
}
}()
dstAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", destIP, destPort))
if err != nil {
log.Printf("ResolveUDPAddr: %v", err)
return
}
go func() (err error) {
defer func() { errCh <- err }()
var buf [2048]byte
for {
n, err := clientConn.Read(buf[:])
if err != nil {
log.Printf("UDP read: %v", err)
return err
}
_, err = backendConn.WriteTo(buf[:n], dstAddr)
if err != nil {
return err
}
}
}()
err = <-errCh
if err != nil {
log.Printf("io.Copy: %v", err)
}
}()
}
func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) {
log.Printf("acceptTCP: %v", r.ID())
reqDetails := r.ID()
destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress)
clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
destPort := reqDetails.LocalPort
if !clientRemoteIP.IsValid() {
log.Printf("acceptTCP: invalid remote IP %v", reqDetails.RemoteAddress)
r.Complete(true) // sends a RST
return
}
untrack := lp.trackProtocolAddr(destIP)
defer untrack()
var wq waiter.Queue
ep, tcpErr := r.CreateEndpoint(&wq)
if tcpErr != nil {
log.Printf("CreateEndpoint: %v", tcpErr)
r.Complete(true)
return
}
defer ep.Close()
ep.SocketOptions().SetKeepAlive(true)
if destPort == 53 && lp.c.IsLocalIP(destIP) {
tc := gonet.NewTCPConn(&wq, ep)
defer tc.Close()
r.Complete(false) // accept TCP connection
lp.handleTCPDNSQuery(tc, netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort))
return
}
dialCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
c, err := lp.tsnet.Dial(dialCtx, "tcp", fmt.Sprintf("%s:%d", destIP, destPort))
cancel()
if err != nil {
log.Printf("Dial(%s:%d): %v", destIP, destPort, err)
r.Complete(true) // sends a RST
return
}
defer c.Close()
tc := gonet.NewTCPConn(&wq, ep)
defer tc.Close()
r.Complete(false) // accept TCP connection
errc := make(chan error, 2)
go func() { _, err := io.Copy(tc, c); errc <- err }()
go func() { _, err := io.Copy(c, tc); errc <- err }()
err = <-errc
if err != nil {
log.Printf("io.Copy: %v", err)
}
}
func (lp *lpServer) wgConfigForQR() string {
var b strings.Builder
p := lp.c.Peers[0]
privHex, _ := p.PrivKey.MarshalText()
privHex = bytes.TrimPrefix(privHex, []byte("privkey:"))
priv := make([]byte, 32)
got, err := hex.Decode(priv, privHex)
if err != nil || got != 32 {
log.Printf("marshal text was: %q", privHex)
log.Fatalf("bad private key: %v, % bytes", err, got)
}
privb64 := base64.StdEncoding.EncodeToString(priv)
fmt.Fprintf(&b, "[Interface]\nPrivateKey = %s\n", privb64)
fmt.Fprintf(&b, "Address = %v,%v\n", p.V6, p.V4)
pubBin, _ := lp.c.PrivKey.Public().MarshalBinary()
if len(pubBin) != 34 {
log.Fatalf("bad pubkey length: %d", len(pubBin))
}
pubBin = pubBin[2:] // trim off "np"
pubb64 := base64.StdEncoding.EncodeToString(pubBin)
fmt.Fprintf(&b, "\n[Peer]\nPublicKey = %v\n", pubb64)
if *includeV4 {
fmt.Fprintf(&b, "AllowedIPs = %v/32,%v/128,%v,%v\n", lp.c.V4, lp.c.V6, tsaddr.TailscaleULARange(), tsaddr.CGNATRange())
} else {
fmt.Fprintf(&b, "AllowedIPs = %v/128,%v\n", lp.c.V6, tsaddr.TailscaleULARange())
}
fmt.Fprintf(&b, "Endpoint = %v\n", net.JoinHostPort(*wgPubHost, fmt.Sprint(*wgListenPort)))
return b.String()
}
func (lp *lpServer) serveQR() {
ln, err := net.Listen("tcp", *qrListenAddr)
if err != nil {
log.Fatalf("qr: %v", err)
}
log.Printf("# Serving QR code at http://%s/", ln.Addr())
hs := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "image/png")
conf := lp.wgConfigForQR()
v, err := qrcode.Encode(conf, qrcode.Medium, 512)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(v)
}),
}
if err := hs.Serve(ln); err != nil {
log.Fatalf("qr: %v", err)
}
}
type nsTUN struct {
lp *lpServer
closeCh chan struct{}
evChan chan tun.Event
}
func (t *nsTUN) File() *os.File {
panic("nsTUN.File() called, which makes no sense")
}
func (t *nsTUN) Close() error {
close(t.closeCh)
close(t.evChan)
return nil
}
// Read reads packets from gvisor (or the DNS server) to send out to the network.
func (t *nsTUN) Read(out [][]byte, sizes []int, offset int) (int, error) {
select {
case <-t.closeCh:
return 0, io.EOF
case resPacket := <-t.lp.readCh:
defer resPacket.DecRef()
pkt := out[0][offset:]
n := copy(pkt, resPacket.NetworkHeader().Slice())
n += copy(pkt[n:], resPacket.TransportHeader().Slice())
n += copy(pkt[n:], resPacket.Data().AsRange().ToSlice())
if *verbosePackets {
log.Printf("[v] nsTUN.Read (out): % 02x", pkt[:n])
}
sizes[0] = n
return 1, nil
}
}
// Write accepts incoming packets. The packets begin at buffs[:][offset:],
// like wireguard-go/tun.Device.Write. Write is called per-peer via
// wireguard-go/device.Peer.RoutineSequentialReceiver, so it MUST be
// thread-safe.
func (t *nsTUN) Write(buffs [][]byte, offset int) (int, error) {
var pkt packet.Parsed
for _, buff := range buffs {
raw := buff[offset:]
pkt.Decode(raw)
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(slices.Clone(raw)),
})
if *verbosePackets {
log.Printf("[v] nsTUN.Write (in): % 02x", raw)
}
if pkt.IPProto == ipproto.UDP && pkt.Dst.Port() == 53 && t.lp.c.IsLocalIP(pkt.Dst.Addr()) {
// Handle DNS queries before sending to gvisor.
t.lp.handleDNSUDPQuery(raw)
continue
}
if pkt.IPVersion == 4 {
t.lp.linkEP.InjectInbound(ipv4.ProtocolNumber, packetBuf)
} else if pkt.IPVersion == 6 {
t.lp.linkEP.InjectInbound(ipv6.ProtocolNumber, packetBuf)
}
}
return len(buffs), nil
}
func (t *nsTUN) Flush() error { return nil }
func (t *nsTUN) MTU() (int, error) { return 1500, nil }
func (t *nsTUN) Name() (string, error) { return "nstun", nil }
func (t *nsTUN) Events() <-chan tun.Event { return t.evChan }
func (t *nsTUN) BatchSize() int { return 1 }
func (lp *lpServer) startTSNet(ctx context.Context) {
hostname, err := os.Hostname()
if err != nil {
log.Fatal(err)
}
ts := &tsnet.Server{
Dir: filepath.Join(lp.dir, "tsnet"),
Hostname: hostname,
UserLogf: log.Printf,
Ephemeral: false,
}
lp.tsnet = ts
ts.PreStart = func() error {
dnsMgr := ts.Sys().DNSManager.Get()
dnsMgr.SetForceAAAA(true)
// Force fallback resolvers to Google and Cloudflare as an ultimate
// fallback in case the Tailnet DNS servers are not set/forced. Normally
// tailscaled would resort to using the OS DNS resolvers, but
// tsnet/userspace binaries don't do that (yet?), so this is the
// "Opionated" part of the "LOPOWER" name. The opinion is just using
// big providers known to work. (Normally stock tailscaled never
// makes such opinions and never defaults to any big provider, unless
// you're already running on that big provider's network so have
// already indicated you're fine with them.))
dnsMgr.SetForceFallbackResolvers([]*dnstype.Resolver{
{Addr: "8.8.8.8"},
{Addr: "1.1.1.1"},
})
return nil
}
if _, err := ts.Up(ctx); err != nil {
log.Fatal(err)
}
}
// filteredDNSQuery wraps the MagicDNS server response but filters out A record responses
// for *.ts.net if IPv4 is not enabled. This is so the e.g. a phone on a CGNAT-using
// network doesn't prefer the "A" record over AAAA when dialing and dial into the
// the carrier's CGNAT range into of the AAAA record into the Tailscale IPv6 ULA range.
func (lp *lpServer) filteredDNSQuery(ctx context.Context, q []byte, family string, from netip.AddrPort) ([]byte, error) {
m, ok := lp.tsnet.Sys().DNSManager.GetOK()
if !ok {
return nil, errors.New("DNSManager not ready")
}
origRes, err := m.Query(ctx, q, family, from)
if err != nil {
return nil, err
}
if *includeV4 {
return origRes, nil
}
// Filter out *.ts.net A records.
var msg dnsmessage.Message
if err := msg.Unpack(origRes); err != nil {
return nil, err
}
newAnswers := msg.Answers[:0]
for _, a := range msg.Answers {
name := a.Header.Name.String()
if a.Header.Type == dnsmessage.TypeA && strings.HasSuffix(name, ".ts.net.") {
// Drop.
continue
}
newAnswers = append(newAnswers, a)
}
if len(newAnswers) == len(msg.Answers) {
// Nothing was filtered. No need to reencode it.
return origRes, nil
}
msg.Answers = newAnswers
return msg.Pack()
}
func (lp *lpServer) handleTCPDNSQuery(c net.Conn, src netip.AddrPort) {
defer c.Close()
var lenBuf [2]byte
for {
c.SetReadDeadline(time.Now().Add(30 * time.Second))
_, err := io.ReadFull(c, lenBuf[:])
if err != nil {
return
}
n := binary.BigEndian.Uint16(lenBuf[:])
buf := make([]byte, n)
c.SetReadDeadline(time.Now().Add(30 * time.Second))
_, err = io.ReadFull(c, buf[:])
if err != nil {
return
}
res, err := lp.filteredDNSQuery(context.Background(), buf, "tcp", src)
if err != nil {
log.Printf("TCP DNS query error: %v", err)
return
}
binary.BigEndian.PutUint16(lenBuf[:], uint16(len(res)))
c.SetWriteDeadline(time.Now().Add(30 * time.Second))
_, err = c.Write(lenBuf[:])
if err != nil {
return
}
c.SetWriteDeadline(time.Now().Add(30 * time.Second))
_, err = c.Write(res)
if err != nil {
return
}
}
}
// caller owns the raw memory.
func (lp *lpServer) handleDNSUDPQuery(raw []byte) {
var pkt packet.Parsed
pkt.Decode(raw)
if pkt.IPProto != ipproto.UDP || pkt.Dst.Port() != 53 || !lp.c.IsLocalIP(pkt.Dst.Addr()) {
panic("caller error")
}
dnsRes, err := lp.filteredDNSQuery(context.Background(), pkt.Payload(), "udp", pkt.Src)
if err != nil {
log.Printf("DNS query error: %v", err)
return
}
ipLayer := mkIPLayer(layers.IPProtocolUDP, pkt.Dst.Addr(), pkt.Src.Addr())
udpLayer := &layers.UDP{
SrcPort: 53,
DstPort: layers.UDPPort(pkt.Src.Port()),
}
resPkt, err := mkPacket(ipLayer, udpLayer, gopacket.Payload(dnsRes))
if err != nil {
log.Printf("mkPacket: %v", err)
return
}
pktBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(resPkt),
})
select {
case lp.readCh <- pktBuf:
case <-lp.ctx.Done():
}
}
type serializableNetworkLayer interface {
gopacket.SerializableLayer
gopacket.NetworkLayer
}
func mkIPLayer(proto layers.IPProtocol, src, dst netip.Addr) serializableNetworkLayer {
if src.Is4() {
return &layers.IPv4{
Protocol: proto,
SrcIP: src.AsSlice(),
DstIP: dst.AsSlice(),
}
}
if src.Is6() {
return &layers.IPv6{
NextHeader: proto,
SrcIP: src.AsSlice(),
DstIP: dst.AsSlice(),
}
}
panic("invalid src IP")
}
// mkPacket is a serializes a number of layers into a packet.
//
// It's a convenience wrapper around gopacket.SerializeLayers
// that does some things automatically:
//
// * layers.IPv4/IPv6 Version is set to 4/6 if not already set
// * layers.IPv4/IPv6 TTL/HopLimit is set to 64 if not already set
// * the TCP/UDP/ICMPv6 checksum is set based on the network layer
//
// The provided layers in ll must be sorted from lowest (e.g. *layers.Ethernet)
// to highest. (Depending on the need, the first layer will be either *layers.Ethernet
// or *layers.IPv4/IPv6).
func mkPacket(ll ...gopacket.SerializableLayer) ([]byte, error) {
var nl gopacket.NetworkLayer
for _, la := range ll {
switch la := la.(type) {
case *layers.IPv4:
nl = la
if la.Version == 0 {
la.Version = 4
}
if la.TTL == 0 {
la.TTL = 64
}
case *layers.IPv6:
nl = la
if la.Version == 0 {
la.Version = 6
}
if la.HopLimit == 0 {
la.HopLimit = 64
}
}
}
for _, la := range ll {
switch la := la.(type) {
case *layers.TCP:
la.SetNetworkLayerForChecksum(nl)
case *layers.UDP:
la.SetNetworkLayerForChecksum(nl)
case *layers.ICMPv6:
la.SetNetworkLayerForChecksum(nl)
}
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true}
if err := gopacket.SerializeLayers(buf, opts, ll...); err != nil {
return nil, fmt.Errorf("serializing packet: %v", err)
}
return buf.Bytes(), nil
}
func main() {
flag.Parse()
log.Printf("lopower starting")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
lp := newLP(ctx)
if *qrListenAddr != "" {
go lp.serveQR()
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, unix.SIGTERM, os.Interrupt)
<-sigCh
}

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 114 KiB

View File

@@ -185,12 +185,10 @@ change in the future.
logoutCmd,
switchCmd,
configureCmd,
syspolicyCmd,
netcheckCmd,
ipCmd,
dnsCmd,
statusCmd,
metricsCmd,
pingCmd,
ncCmd,
sshCmd,

View File

@@ -1,88 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package cli
import (
"context"
"errors"
"fmt"
"strings"
"github.com/peterbourgon/ff/v3/ffcli"
"tailscale.com/atomicfile"
)
var metricsCmd = &ffcli.Command{
Name: "metrics",
ShortHelp: "Show Tailscale metrics",
LongHelp: strings.TrimSpace(`
The 'tailscale metrics' command shows Tailscale user-facing metrics (as opposed
to internal metrics printed by 'tailscale debug metrics').
For more information about Tailscale metrics, refer to
https://tailscale.com/s/client-metrics
`),
ShortUsage: "tailscale metrics <subcommand> [flags]",
UsageFunc: usageFuncNoDefaultValues,
Exec: runMetricsNoSubcommand,
Subcommands: []*ffcli.Command{
{
Name: "print",
ShortUsage: "tailscale metrics print",
Exec: runMetricsPrint,
ShortHelp: "Prints current metric values in the Prometheus text exposition format",
},
{
Name: "write",
ShortUsage: "tailscale metrics write <path>",
Exec: runMetricsWrite,
ShortHelp: "Writes metric values to a file",
LongHelp: strings.TrimSpace(`
The 'tailscale metrics write' command writes metric values to a text file provided as its
only argument. It's meant to be used alongside Prometheus node exporter, allowing Tailscale
metrics to be consumed and exported by the textfile collector.
As an example, to export Tailscale metrics on an Ubuntu system running node exporter, you
can regularly run 'tailscale metrics write /var/lib/prometheus/node-exporter/tailscaled.prom'
using cron or a systemd timer.
`),
},
},
}
// runMetricsNoSubcommand prints metric values if no subcommand is specified.
func runMetricsNoSubcommand(ctx context.Context, args []string) error {
if len(args) > 0 {
return fmt.Errorf("tailscale metrics: unknown subcommand: %s", args[0])
}
return runMetricsPrint(ctx, args)
}
// runMetricsPrint prints metric values to stdout.
func runMetricsPrint(ctx context.Context, args []string) error {
out, err := localClient.UserMetrics(ctx)
if err != nil {
return err
}
Stdout.Write(out)
return nil
}
// runMetricsWrite writes metric values to a file.
func runMetricsWrite(ctx context.Context, args []string) error {
if len(args) != 1 {
return errors.New("usage: tailscale metrics write <path>")
}
path := args[0]
out, err := localClient.UserMetrics(ctx)
if err != nil {
return err
}
return atomicfile.WriteFile(path, out, 0644)
}

View File

@@ -1,110 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package cli
import (
"context"
"encoding/json"
"flag"
"fmt"
"os"
"slices"
"text/tabwriter"
"github.com/peterbourgon/ff/v3/ffcli"
"tailscale.com/util/syspolicy/setting"
)
var syspolicyArgs struct {
json bool // JSON output mode
}
var syspolicyCmd = &ffcli.Command{
Name: "syspolicy",
ShortHelp: "Diagnose the MDM and system policy configuration",
LongHelp: "The 'tailscale syspolicy' command provides tools for diagnosing the MDM and system policy configuration.",
ShortUsage: "tailscale syspolicy <subcommand>",
UsageFunc: usageFuncNoDefaultValues,
Subcommands: []*ffcli.Command{
{
Name: "list",
ShortUsage: "tailscale syspolicy list",
Exec: runSysPolicyList,
ShortHelp: "Prints effective policy settings",
LongHelp: "The 'tailscale syspolicy list' subcommand displays the effective policy settings and their sources (e.g., MDM or environment variables).",
FlagSet: (func() *flag.FlagSet {
fs := newFlagSet("syspolicy list")
fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format")
return fs
})(),
},
{
Name: "reload",
ShortUsage: "tailscale syspolicy reload",
Exec: runSysPolicyReload,
ShortHelp: "Forces a reload of policy settings, even if no changes are detected, and prints the result",
LongHelp: "The 'tailscale syspolicy reload' subcommand forces a reload of policy settings, even if no changes are detected, and prints the result.",
FlagSet: (func() *flag.FlagSet {
fs := newFlagSet("syspolicy reload")
fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format")
return fs
})(),
},
},
}
func runSysPolicyList(ctx context.Context, args []string) error {
policy, err := localClient.GetEffectivePolicy(ctx, setting.DefaultScope())
if err != nil {
return err
}
printPolicySettings(policy)
return nil
}
func runSysPolicyReload(ctx context.Context, args []string) error {
policy, err := localClient.ReloadEffectivePolicy(ctx, setting.DefaultScope())
if err != nil {
return err
}
printPolicySettings(policy)
return nil
}
func printPolicySettings(policy *setting.Snapshot) {
if syspolicyArgs.json {
json, err := json.MarshalIndent(policy, "", "\t")
if err != nil {
errf("syspolicy marshalling error: %v", err)
} else {
outln(string(json))
}
return
}
if policy.Len() == 0 {
outln("No policy settings")
return
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "Name\tOrigin\tValue\tError")
fmt.Fprintln(w, "----\t------\t-----\t-----")
for _, k := range slices.Sorted(policy.Keys()) {
setting, _ := policy.GetSetting(k)
var origin string
if o := setting.Origin(); o != nil {
origin = o.String()
}
if err := setting.Error(); err != nil {
fmt.Fprintf(w, "%s\t%s\t\t{%s}\n", k, origin, err)
} else {
fmt.Fprintf(w, "%s\t%s\t%s\t\n", k, origin, setting.Value())
}
}
w.Flush()
fmt.Println()
return
}

View File

@@ -174,18 +174,14 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
tailscale.com/util/syspolicy from tailscale.com/ipn
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli
tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli
tailscale.com/util/usermetric from tailscale.com/health
tailscale.com/util/vizerror from tailscale.com/tailcfg+
W 💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
tailscale.com/version from tailscale.com/client/web+
tailscale.com/version/distro from tailscale.com/client/web+

View File

@@ -250,6 +250,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/control/controlclient from tailscale.com/cmd/tailscaled+
tailscale.com/control/controlhttp from tailscale.com/control/controlclient
tailscale.com/control/controlknobs from tailscale.com/control/controlclient+
tailscale.com/control/keyfallback from tailscale.com/control/controlclient
tailscale.com/derp from tailscale.com/derp/derphttp+
tailscale.com/derp/derphttp from tailscale.com/cmd/tailscaled+
tailscale.com/disco from tailscale.com/derp+
@@ -401,11 +402,8 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
tailscale.com/util/systemd from tailscale.com/control/controlclient+
tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+
@@ -415,7 +413,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/util/vizerror from tailscale.com/tailcfg+
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+

View File

@@ -788,6 +788,7 @@ func runDebugServer(mux *http.ServeMux, addr string) {
}
func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) {
tfs, _ := sys.DriveForLocal.GetOK()
ret, err := netstack.Create(logf,
sys.Tun.Get(),
sys.Engine.Get(),
@@ -795,6 +796,7 @@ func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) {
sys.Dialer.Get(),
sys.DNSManager.Get(),
sys.ProxyMapper(),
tfs,
)
if err != nil {
return nil, err

View File

@@ -42,7 +42,6 @@ type testAttempt struct {
testName string // "TestFoo"
outcome string // "pass", "fail", "skip"
logs bytes.Buffer
start, end time.Time
isMarkedFlaky bool // set if the test is marked as flaky
issueURL string // set if the test is marked as flaky
@@ -133,17 +132,11 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te
}
pkg := goOutput.Package
pkgTests := resultMap[pkg]
if pkgTests == nil {
pkgTests = make(map[string]*testAttempt)
resultMap[pkg] = pkgTests
}
if goOutput.Test == "" {
switch goOutput.Action {
case "start":
pkgTests[""] = &testAttempt{start: goOutput.Time}
case "fail", "pass", "skip":
for _, test := range pkgTests {
if test.testName != "" && test.outcome == "" {
if test.outcome == "" {
test.outcome = "fail"
ch <- test
}
@@ -151,13 +144,15 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te
ch <- &testAttempt{
pkg: goOutput.Package,
outcome: goOutput.Action,
start: pkgTests[""].start,
end: goOutput.Time,
pkgFinished: true,
}
}
continue
}
if pkgTests == nil {
pkgTests = make(map[string]*testAttempt)
resultMap[pkg] = pkgTests
}
testName := goOutput.Test
if test, _, isSubtest := strings.Cut(goOutput.Test, "/"); isSubtest {
testName = test
@@ -173,10 +168,8 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te
pkgTests[testName] = &testAttempt{
pkg: pkg,
testName: testName,
start: goOutput.Time,
}
case "skip", "pass", "fail":
pkgTests[testName].end = goOutput.Time
pkgTests[testName].outcome = goOutput.Action
ch <- pkgTests[testName]
case "output":
@@ -220,7 +213,7 @@ func main() {
firstRun.tests = append(firstRun.tests, &packageTests{Pattern: pkg})
}
toRun := []*nextRun{firstRun}
printPkgOutcome := func(pkg, outcome string, attempt int, runtime time.Duration) {
printPkgOutcome := func(pkg, outcome string, attempt int) {
if outcome == "skip" {
fmt.Printf("?\t%s [skipped/no tests] \n", pkg)
return
@@ -232,10 +225,10 @@ func main() {
outcome = "FAIL"
}
if attempt > 1 {
fmt.Printf("%s\t%s\t%.3fs\t[attempt=%d]\n", outcome, pkg, runtime.Seconds(), attempt)
fmt.Printf("%s\t%s [attempt=%d]\n", outcome, pkg, attempt)
return
}
fmt.Printf("%s\t%s\t%.3fs\n", outcome, pkg, runtime.Seconds())
fmt.Printf("%s\t%s\n", outcome, pkg)
}
// Check for -coverprofile argument and filter it out
@@ -314,7 +307,7 @@ func main() {
// when a package times out.
failed = true
}
printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt, tr.end.Sub(tr.start))
printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt)
continue
}
if testingVerbose || tr.outcome == "fail" {

View File

@@ -10,7 +10,6 @@ import (
"os"
"os/exec"
"path/filepath"
"regexp"
"sync"
"testing"
)
@@ -77,10 +76,7 @@ func TestFlakeRun(t *testing.T) {
t.Fatalf("go run . %s: %s with output:\n%s", testfile, err, out)
}
// Replace the unpredictable timestamp with "0.00s".
out = regexp.MustCompile(`\t\d+\.\d\d\ds\t`).ReplaceAll(out, []byte("\t0.00s\t"))
want := []byte("ok\t" + testfile + "\t0.00s\t[attempt=2]")
want := []byte("ok\t" + testfile + " [attempt=2]")
if !bytes.Contains(out, want) {
t.Fatalf("wanted output containing %q but got:\n%s", want, out)
}

View File

@@ -150,7 +150,6 @@ func runEsbuildServe(buildOptions esbuild.BuildOptions) {
log.Fatalf("Cannot start esbuild server: %v", err)
}
log.Printf("Listening on http://%s:%d\n", result.Host, result.Port)
select {}
}
func runEsbuild(buildOptions esbuild.BuildOptions) esbuild.BuildResult {

View File

@@ -115,7 +115,7 @@ func newIPN(jsConfig js.Value) map[string]any {
}
sys.Set(eng)
ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper())
ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
if err != nil {
log.Fatalf("netstack.Create: %v", err)
}

View File

@@ -29,9 +29,11 @@ import (
"go4.org/mem"
"tailscale.com/control/controlknobs"
"tailscale.com/control/keyfallback"
"tailscale.com/envknob"
"tailscale.com/health"
"tailscale.com/hostinfo"
"tailscale.com/ipn"
"tailscale.com/ipn/ipnstate"
"tailscale.com/logtail"
"tailscale.com/net/dnscache"
@@ -87,9 +89,10 @@ type Direct struct {
dialPlan ControlDialPlanner // can be nil
mu sync.Mutex // mutex guards the following fields
serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now
serverNoiseKey key.MachinePublic
mu sync.Mutex // mutex guards the following fields
serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now
serverNoiseKey key.MachinePublic
usedFallbackNoiseKey bool // true if we used the baked-in fallback key
sfGroup singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation.
noiseClient *NoiseClient
@@ -498,6 +501,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
tryingNewKey := c.tryingNewKey
serverKey := c.serverLegacyKey
serverNoiseKey := c.serverNoiseKey
usedFallback := c.usedFallbackNoiseKey
authKey, isWrapped, wrappedSig, wrappedKey := tka.DecodeWrappedAuthkey(c.authKey, c.logf)
hi := c.hostInfoLocked()
backendLogID := hi.BackendLogID
@@ -528,7 +532,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
}
c.logf("doLogin(regen=%v, hasUrl=%v)", regen, opt.URL != "")
if serverKey.IsZero() {
if serverKey.IsZero() || usedFallback {
keys, err := loadServerPubKeys(ctx, c.httpc, c.serverURL)
if err != nil && c.interceptedDial != nil && c.interceptedDial.Load() {
c.health.SetUnhealthy(macOSScreenTime, nil)
@@ -536,13 +540,21 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
c.health.SetHealthy(macOSScreenTime)
}
if err != nil {
return regen, opt.URL, nil, err
if k2, err := c.getFallbackServerPubKeys(); err == nil {
keys = k2
usedFallback = true
} else {
return regen, opt.URL, nil, err
}
} else {
usedFallback = false
c.logf("control server key from %s: ts2021=%s", c.serverURL, keys.PublicKey.ShortString())
}
c.logf("control server key from %s: ts2021=%s, legacy=%v", c.serverURL, keys.PublicKey.ShortString(), keys.LegacyPublicKey.ShortString())
c.mu.Lock()
c.serverLegacyKey = keys.LegacyPublicKey
c.serverNoiseKey = keys.PublicKey
c.usedFallbackNoiseKey = usedFallback
c.mu.Unlock()
serverKey = keys.LegacyPublicKey
serverNoiseKey = keys.PublicKey
@@ -751,6 +763,22 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
return false, resp.AuthURL, nil, nil
}
func (c *Direct) getFallbackServerPubKeys() (*tailcfg.OverTLSPublicKeyResponse, error) {
// If we saw an error, try to use the fallback key if
// we're dialing the default control server.
if ipn.IsLoginServerSynonym(c.serverURL) {
return nil, errors.New("not using default control server")
}
kf, err := keyfallback.Get()
if err != nil {
return nil, err
}
c.logf("using fallback server key: ts2021=%s", kf.PublicKey.ShortString())
return kf, nil
}
// newEndpoints acquires c.mu and sets the local port and endpoints and reports
// whether they've changed.
//

View File

@@ -0,0 +1,4 @@
{
"legacyPublicKey": "mkey:9e5156a4c65121306dd2d8ed8f92cb8d738e2533011344b522c5d28409bc4970",
"publicKey": "mkey:7d2792f9c98d753d2042471536801949104c247f95eac770f8fb321595e2173b"
}

View File

@@ -0,0 +1,32 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package keyfallback contains a fallback mechanism for starting up Tailscale
// when the control server cannot be reached to obtain the primary Noise key.
//
// The data is backed by a JSON file `control-key.json` that is updated by
// `update.go`:
//
// (cd control/keyfallback; go run update.go)
package keyfallback
import (
_ "embed"
"encoding/json"
"tailscale.com/tailcfg"
)
// Get returns the fallback control server public key that was baked into the
// binary at compile time. It is only valid for the main Tailscale control
// server instance.
func Get() (*tailcfg.OverTLSPublicKeyResponse, error) {
out := &tailcfg.OverTLSPublicKeyResponse{}
if err := json.Unmarshal(controlKeyJSON, out); err != nil {
return nil, err
}
return out, nil
}
//go:embed control-key.json
var controlKeyJSON []byte

View File

@@ -0,0 +1,77 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package keyfallback
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"reflect"
"testing"
"time"
"tailscale.com/ipn"
"tailscale.com/tailcfg"
"tailscale.com/tstest/nettest"
"tailscale.com/util/must"
)
func TestHasValidControlKey(t *testing.T) {
t.Parallel()
keys, err := Get()
if err != nil {
t.Fatalf("Get: %v", err)
}
if keys.PublicKey.IsZero() {
t.Fatalf("zero key")
}
}
// TestKeyIsUpToDate fetches the control key from the control server and
// compares it to the baked-in key, to verify that it's up-to-date. If the
// control server is unreachable, the test is skipped.
func TestKeyIsUpToDate(t *testing.T) {
nettest.SkipIfNoNetwork(t)
// Optimistically fetch the control key and check if it's up to date,
// but ignore if we don't have network access (e.g. running tests on an
// airplane).
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
keyURL := fmt.Sprintf("%v/key?v=%d", ipn.DefaultControlURL, tailcfg.CurrentCapabilityVersion)
req := must.Get(http.NewRequestWithContext(ctx, "GET", keyURL, nil))
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Logf("fetch control key: %v", err)
return
}
defer res.Body.Close()
if res.StatusCode != 200 {
t.Fatalf("fetch control key: bad status; got %v, want 200", res.Status)
}
b, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("read control key: %v", err)
}
// Verify that the key is up to date and matches the baked-in key.
out := &tailcfg.OverTLSPublicKeyResponse{}
if err := json.Unmarshal(b, out); err != nil {
t.Fatalf("unmarshal control key: %v", err)
}
keys, err := Get()
if err != nil {
t.Fatalf("Get: %v", err)
}
if !reflect.DeepEqual(keys, out) {
t.Errorf("control key is out of date")
t.Logf("old key: %v", keys)
t.Logf("new key: %v", out)
}
}

View File

@@ -0,0 +1,47 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build ignore
package main
import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"tailscale.com/ipn"
"tailscale.com/tailcfg"
)
func main() {
keyURL := fmt.Sprintf("%v/key?v=%d", ipn.DefaultControlURL, tailcfg.CurrentCapabilityVersion)
res, err := http.Get(keyURL)
if err != nil {
log.Fatalf("fetch control key: %v", err)
}
defer res.Body.Close()
b, err := io.ReadAll(io.LimitReader(res.Body, 64<<10))
if err != nil {
log.Fatalf("read control key: %v", err)
}
if res.StatusCode != 200 {
log.Fatalf("fetch control key: bad status; got %v, want 200", res.Status)
}
// Unmarshal to make sure it's valid.
var out tailcfg.OverTLSPublicKeyResponse
if err := json.Unmarshal(b, &out); err != nil {
log.Fatalf("unmarshal control key: %v", err)
}
if out.PublicKey.IsZero() {
log.Fatalf("control key is zero")
}
if err := os.WriteFile("control-key.json", b, 0644); err != nil {
log.Fatalf("write control key: %v", err)
}
}

View File

@@ -32,8 +32,6 @@ type ConfigVAlpha struct {
AdvertiseRoutes []netip.Prefix `json:",omitempty"`
DisableSNAT opt.Bool `json:",omitempty"`
AppConnector *AppConnectorPrefs `json:",omitempty"` // advertise app connector; defaults to false (if nil or explicitly set to false)
NetfilterMode *string `json:",omitempty"` // "on", "off", "nodivert"
NoStatefulFiltering opt.Bool `json:",omitempty"`
@@ -139,9 +137,5 @@ func (c *ConfigVAlpha) ToPrefs() (MaskedPrefs, error) {
mp.AutoUpdate = *c.AutoUpdate
mp.AutoUpdateSet = AutoUpdatePrefsMask{ApplySet: true, CheckSet: true}
}
if c.AppConnector != nil {
mp.AppConnector = *c.AppConnector
mp.AppConnectorSet = true
}
return mp, nil
}

View File

@@ -332,10 +332,12 @@ func handleC2NPostureIdentityGet(b *LocalBackend, w http.ResponseWriter, r *http
}
if choice.ShouldEnable(b.Prefs().PostureChecking()) {
res.SerialNumbers, err = posture.GetSerialNumbers(b.logf)
sns, err := posture.GetSerialNumbers(b.logf)
if err != nil {
b.logf("c2n: GetSerialNumbers returned error: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
res.SerialNumbers = sns
// TODO(tailscale/corp#21371, 2024-07-10): once this has landed in a stable release
// and looks good in client metrics, remove this parameter and always report MAC

View File

@@ -399,6 +399,11 @@ type metrics struct {
// approvedRoutes is a metric that reports the number of network routes served by the local node and approved
// by the control server.
approvedRoutes *usermetric.Gauge
// primaryRoutes is a metric that reports the number of primary network routes served by the local node.
// A route being a primary route implies that the route is currently served by this node, and not by another
// subnet router in a high availability configuration.
primaryRoutes *usermetric.Gauge
}
// clientGen is a func that creates a control plane client.
@@ -449,6 +454,8 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
"tailscaled_advertised_routes", "Number of advertised network routes (e.g. by a subnet router)"),
approvedRoutes: sys.UserMetricsRegistry().NewGauge(
"tailscaled_approved_routes", "Number of approved network routes (e.g. by a subnet router)"),
primaryRoutes: sys.UserMetricsRegistry().NewGauge(
"tailscaled_primary_routes", "Number of network routes for which this node is a primary router (in high availability configuration)"),
}
b := &LocalBackend{
@@ -479,7 +486,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
mConn.SetNetInfoCallback(b.setNetInfo)
if sys.InitialConfig != nil {
if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil {
if err := b.setConfigLocked(sys.InitialConfig); err != nil {
return nil, err
}
}
@@ -712,8 +719,8 @@ func (b *LocalBackend) SetDirectFileRoot(dir string) {
// It returns (false, nil) if not running in declarative mode, (true, nil) on
// success, or (false, error) on failure.
func (b *LocalBackend) ReloadConfig() (ok bool, err error) {
unlock := b.lockAndGetUnlock()
defer unlock()
b.mu.Lock()
defer b.mu.Unlock()
if b.conf == nil {
return false, nil
}
@@ -721,21 +728,18 @@ func (b *LocalBackend) ReloadConfig() (ok bool, err error) {
if err != nil {
return false, err
}
if err := b.setConfigLockedOnEntry(conf, unlock); err != nil {
if err := b.setConfigLocked(conf); err != nil {
return false, fmt.Errorf("error setting config: %w", err)
}
return true, nil
}
// initPrefsFromConfig initializes the backend's prefs from the provided config.
// This should only be called once, at startup. For updates at runtime, use
// [LocalBackend.setConfigLocked].
func (b *LocalBackend) initPrefsFromConfig(conf *conffile.Config) error {
// TODO(maisem,bradfitz): combine this with setConfigLocked. This is called
// before anything is running, so there's no need to lock and we don't
// update any subsystems. At runtime, we both need to lock and update
// subsystems with the new prefs.
func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error {
// TODO(irbekrm): notify the relevant components to consume any prefs
// updates. Currently only initial configfile settings are applied
// immediately.
p := b.pm.CurrentPrefs().AsStruct()
mp, err := conf.Parsed.ToPrefs()
if err != nil {
@@ -745,14 +749,13 @@ func (b *LocalBackend) initPrefsFromConfig(conf *conffile.Config) error {
if err := b.pm.SetPrefs(p.View(), ipn.NetworkProfile{}); err != nil {
return err
}
b.setStaticEndpointsFromConfigLocked(conf)
b.conf = conf
return nil
}
func (b *LocalBackend) setStaticEndpointsFromConfigLocked(conf *conffile.Config) {
defer func() {
b.conf = conf
}()
if conf.Parsed.StaticEndpoints == nil && (b.conf == nil || b.conf.Parsed.StaticEndpoints == nil) {
return
return nil
}
// Ensure that magicsock conn has the up to date static wireguard
@@ -766,22 +769,6 @@ func (b *LocalBackend) setStaticEndpointsFromConfigLocked(conf *conffile.Config)
ms.SetStaticEndpoints(views.SliceOf(conf.Parsed.StaticEndpoints))
}
}
}
// setConfigLockedOnEntry uses the provided config to update the backend's prefs
// and other state.
func (b *LocalBackend) setConfigLockedOnEntry(conf *conffile.Config, unlock unlockOnce) error {
defer unlock()
p := b.pm.CurrentPrefs().AsStruct()
mp, err := conf.Parsed.ToPrefs()
if err != nil {
return fmt.Errorf("error parsing config to prefs: %w", err)
}
p.ApplyEdits(&mp)
b.setStaticEndpointsFromConfigLocked(conf)
b.setPrefsLockedOnEntry(p, unlock)
b.conf = conf
return nil
}
@@ -4194,11 +4181,7 @@ func (b *LocalBackend) authReconfig() {
disableSubnetsIfPAC := nm.HasCap(tailcfg.NodeAttrDisableSubnetsIfPAC)
userDialUseRoutes := nm.HasCap(tailcfg.NodeAttrUserDialUseRoutes)
dohURL, dohURLOK := exitNodeCanProxyDNS(nm, b.peers, prefs.ExitNodeID())
var forceAAAA bool
if dm, ok := b.sys.DNSManager.GetOK(); ok {
forceAAAA = dm.GetForceAAAA()
}
dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, forceAAAA, b.logf, version.OS())
dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, b.logf, version.OS())
// If the current node is an app connector, ensure the app connector machine is started
b.reconfigAppConnectorLocked(nm, prefs)
b.mu.Unlock()
@@ -4298,7 +4281,7 @@ func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs,
//
// The versionOS is a Tailscale-style version ("iOS", "macOS") and not
// a runtime.GOOS.
func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired, forceAAAA bool, logf logger.Logf, versionOS string) *dns.Config {
func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config {
if nm == nil {
return nil
}
@@ -4361,7 +4344,7 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.
// https://github.com/tailscale/tailscale/issues/1152
// tracks adding the right capability reporting to
// enable AAAA in MagicDNS.
if addr.Addr().Is6() && have4 && !forceAAAA {
if addr.Addr().Is6() && have4 {
continue
}
ips = append(ips, addr.Addr())
@@ -5494,6 +5477,7 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) {
// If there is no netmap, the client is going into a "turned off"
// state so reset the metrics.
b.metrics.approvedRoutes.Set(0)
b.metrics.primaryRoutes.Set(0)
return
}
@@ -5522,6 +5506,7 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) {
}
}
b.metrics.approvedRoutes.Set(approved)
b.metrics.primaryRoutes.Set(float64(tsaddr.WithoutExitRoute(nm.SelfNode.PrimaryRoutes()).Len()))
}
for _, p := range nm.Peers {
addNode(p)

View File

@@ -13,7 +13,6 @@ import (
"net/http"
"net/netip"
"os"
"path/filepath"
"reflect"
"slices"
"strings"
@@ -33,7 +32,6 @@ import (
"tailscale.com/health"
"tailscale.com/hostinfo"
"tailscale.com/ipn"
"tailscale.com/ipn/conffile"
"tailscale.com/ipn/ipnauth"
"tailscale.com/ipn/store/mem"
"tailscale.com/net/netcheck"
@@ -56,8 +54,6 @@ import (
"tailscale.com/util/must"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/wgcfg"
@@ -434,25 +430,16 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) {
}
func newTestLocalBackend(t testing.TB) *LocalBackend {
return newTestLocalBackendWithSys(t, new(tsd.System))
}
// newTestLocalBackendWithSys creates a new LocalBackend with the given tsd.System.
// If the state store or engine are not set in sys, they will be set to a new
// in-memory store and fake userspace engine, respectively.
func newTestLocalBackendWithSys(t testing.TB, sys *tsd.System) *LocalBackend {
var logf logger.Logf = logger.Discard
if _, ok := sys.StateStore.GetOK(); !ok {
sys.Set(new(mem.Store))
}
if _, ok := sys.Engine.GetOK(); !ok {
eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry())
if err != nil {
t.Fatalf("NewFakeUserspaceEngine: %v", err)
}
t.Cleanup(eng.Close)
sys.Set(eng)
sys := new(tsd.System)
store := new(mem.Store)
sys.Set(store)
eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry())
if err != nil {
t.Fatalf("NewFakeUserspaceEngine: %v", err)
}
t.Cleanup(eng.Close)
sys.Set(eng)
lb, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0)
if err != nil {
t.Fatalf("NewLocalBackend: %v", err)
@@ -1298,7 +1285,7 @@ func TestDNSConfigForNetmapForExitNodeConfigs(t *testing.T) {
}
prefs := &ipn.Prefs{ExitNodeID: tc.exitNode, CorpDNS: true}
got := dnsConfigForNetmap(nm, peersMap(tc.peers), prefs.View(), false, false, t.Logf, "")
got := dnsConfigForNetmap(nm, peersMap(tc.peers), prefs.View(), false, t.Logf, "")
if !resolversEqual(t, got.DefaultResolvers, tc.wantDefaultResolvers) {
t.Errorf("DefaultResolvers: got %#v, want %#v", got.DefaultResolvers, tc.wantDefaultResolvers)
}
@@ -1572,6 +1559,94 @@ func dnsResponse(domain, address string) []byte {
return must.Get(b.Finish())
}
type errorSyspolicyHandler struct {
t *testing.T
err error
key syspolicy.Key
allowKeys map[syspolicy.Key]*string
}
func (h *errorSyspolicyHandler) ReadString(key string) (string, error) {
sk := syspolicy.Key(key)
if _, ok := h.allowKeys[sk]; !ok {
h.t.Errorf("ReadString: %q is not in list of permitted keys", h.key)
}
if sk == h.key {
return "", h.err
}
return "", syspolicy.ErrNoSuchKey
}
func (h *errorSyspolicyHandler) ReadUInt64(key string) (uint64, error) {
h.t.Errorf("ReadUInt64(%q) unexpectedly called", key)
return 0, syspolicy.ErrNoSuchKey
}
func (h *errorSyspolicyHandler) ReadBoolean(key string) (bool, error) {
h.t.Errorf("ReadBoolean(%q) unexpectedly called", key)
return false, syspolicy.ErrNoSuchKey
}
func (h *errorSyspolicyHandler) ReadStringArray(key string) ([]string, error) {
h.t.Errorf("ReadStringArray(%q) unexpectedly called", key)
return nil, syspolicy.ErrNoSuchKey
}
type mockSyspolicyHandler struct {
t *testing.T
// stringPolicies is the collection of policies that we expect to see
// queried by the current test. If the policy is expected but unset, then
// use nil, otherwise use a string equal to the policy's desired value.
stringPolicies map[syspolicy.Key]*string
// stringArrayPolicies is the collection of policies that we expected to see
// queries by the current test, that return policy string arrays.
stringArrayPolicies map[syspolicy.Key][]string
// failUnknownPolicies is set if policies other than those in stringPolicies
// (uint64 or bool policies are not supported by mockSyspolicyHandler yet)
// should be considered a test failure if they are queried.
failUnknownPolicies bool
}
func (h *mockSyspolicyHandler) ReadString(key string) (string, error) {
if s, ok := h.stringPolicies[syspolicy.Key(key)]; ok {
if s == nil {
return "", syspolicy.ErrNoSuchKey
}
return *s, nil
}
if h.failUnknownPolicies {
h.t.Errorf("ReadString(%q) unexpectedly called", key)
}
return "", syspolicy.ErrNoSuchKey
}
func (h *mockSyspolicyHandler) ReadUInt64(key string) (uint64, error) {
if h.failUnknownPolicies {
h.t.Errorf("ReadUInt64(%q) unexpectedly called", key)
}
return 0, syspolicy.ErrNoSuchKey
}
func (h *mockSyspolicyHandler) ReadBoolean(key string) (bool, error) {
if h.failUnknownPolicies {
h.t.Errorf("ReadBoolean(%q) unexpectedly called", key)
}
return false, syspolicy.ErrNoSuchKey
}
func (h *mockSyspolicyHandler) ReadStringArray(key string) ([]string, error) {
if h.failUnknownPolicies {
h.t.Errorf("ReadStringArray(%q) unexpectedly called", key)
}
if s, ok := h.stringArrayPolicies[syspolicy.Key(key)]; ok {
if s == nil {
return []string{}, syspolicy.ErrNoSuchKey
}
return s, nil
}
return nil, syspolicy.ErrNoSuchKey
}
func TestSetExitNodeIDPolicy(t *testing.T) {
pfx := netip.MustParsePrefix
tests := []struct {
@@ -1781,18 +1856,23 @@ func TestSetExitNodeIDPolicy(t *testing.T) {
},
}
syspolicy.RegisterWellKnownSettingsForTest(t)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b := newTestBackend(t)
policyStore := source.NewTestStoreOf(t,
source.TestSettingOf(syspolicy.ExitNodeID, test.exitNodeID),
source.TestSettingOf(syspolicy.ExitNodeIP, test.exitNodeIP),
)
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
msh := &mockSyspolicyHandler{
t: t,
stringPolicies: map[syspolicy.Key]*string{
syspolicy.ExitNodeID: nil,
syspolicy.ExitNodeIP: nil,
},
}
if test.exitNodeIDKey {
msh.stringPolicies[syspolicy.ExitNodeID] = &test.exitNodeID
}
if test.exitNodeIPKey {
msh.stringPolicies[syspolicy.ExitNodeIP] = &test.exitNodeIP
}
syspolicy.SetHandlerForTest(t, msh)
if test.nm == nil {
test.nm = new(netmap.NetworkMap)
}
@@ -1914,13 +1994,13 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) {
report: report,
},
}
syspolicy.RegisterWellKnownSettingsForTest(t)
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
syspolicy.ExitNodeID, "auto:any",
))
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
msh := &mockSyspolicyHandler{
t: t,
stringPolicies: map[syspolicy.Key]*string{
syspolicy.ExitNodeID: ptr.To("auto:any"),
},
}
syspolicy.SetHandlerForTest(t, msh)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := newTestLocalBackend(t)
@@ -1969,11 +2049,13 @@ func TestAutoExitNodeSetNetInfoCallback(t *testing.T) {
}
cc = newClient(t, opts)
b.cc = cc
syspolicy.RegisterWellKnownSettingsForTest(t)
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
syspolicy.ExitNodeID, "auto:any",
))
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
msh := &mockSyspolicyHandler{
t: t,
stringPolicies: map[syspolicy.Key]*string{
syspolicy.ExitNodeID: ptr.To("auto:any"),
},
}
syspolicy.SetHandlerForTest(t, msh)
peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes())
peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes())
selfNode := tailcfg.Node{
@@ -2078,11 +2160,13 @@ func TestSetControlClientStatusAutoExitNode(t *testing.T) {
DERPMap: derpMap,
}
b := newTestLocalBackend(t)
syspolicy.RegisterWellKnownSettingsForTest(t)
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
syspolicy.ExitNodeID, "auto:any",
))
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
msh := &mockSyspolicyHandler{
t: t,
stringPolicies: map[syspolicy.Key]*string{
syspolicy.ExitNodeID: ptr.To("auto:any"),
},
}
syspolicy.SetHandlerForTest(t, msh)
b.netMap = nm
b.lastSuggestedExitNode = peer1.StableID()
b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, report)
@@ -2316,16 +2400,17 @@ func TestApplySysPolicy(t *testing.T) {
},
}
syspolicy.RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
settings := make([]source.TestSetting[string], 0, len(tt.stringPolicies))
for p, v := range tt.stringPolicies {
settings = append(settings, source.TestSettingOf(p, v))
msh := &mockSyspolicyHandler{
t: t,
stringPolicies: make(map[syspolicy.Key]*string, len(tt.stringPolicies)),
}
policyStore := source.NewTestStoreOf(t, settings...)
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
for p, v := range tt.stringPolicies {
v := v // construct a unique pointer for each policy value
msh.stringPolicies[p] = &v
}
syspolicy.SetHandlerForTest(t, msh)
t.Run("unit", func(t *testing.T) {
prefs := tt.prefs.Clone()
@@ -2461,19 +2546,35 @@ func TestPreferencePolicyInfo(t *testing.T) {
},
}
syspolicy.RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for _, pp := range preferencePolicies {
t.Run(string(pp.key), func(t *testing.T) {
s := source.TestSetting[string]{
Key: pp.key,
Error: tt.policyError,
Value: tt.policyValue,
var h syspolicy.Handler
allPolicies := make(map[syspolicy.Key]*string, len(preferencePolicies)+1)
allPolicies[syspolicy.ControlURL] = nil
for _, pp := range preferencePolicies {
allPolicies[pp.key] = nil
}
policyStore := source.NewTestStoreOf(t, s)
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
if tt.policyError != nil {
h = &errorSyspolicyHandler{
t: t,
err: tt.policyError,
key: pp.key,
allowKeys: allPolicies,
}
} else {
msh := &mockSyspolicyHandler{
t: t,
stringPolicies: allPolicies,
failUnknownPolicies: true,
}
msh.stringPolicies[pp.key] = &tt.policyValue
h = msh
}
syspolicy.SetHandlerForTest(t, h)
prefs := defaultPrefs.AsStruct()
pp.set(prefs, tt.initialValue)
@@ -3724,16 +3825,15 @@ func TestShouldAutoExitNode(t *testing.T) {
expectedBool: false,
},
}
syspolicy.RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
syspolicy.ExitNodeID, tt.exitNodeIDPolicyValue,
))
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
msh := &mockSyspolicyHandler{
t: t,
stringPolicies: map[syspolicy.Key]*string{
syspolicy.ExitNodeID: ptr.To(tt.exitNodeIDPolicyValue),
},
}
syspolicy.SetHandlerForTest(t, msh)
got := shouldAutoExitNode()
if got != tt.expectedBool {
t.Fatalf("expected %v got %v for %v policy value", tt.expectedBool, got, tt.exitNodeIDPolicyValue)
@@ -3871,13 +3971,17 @@ func TestFillAllowedSuggestions(t *testing.T) {
want: []tailcfg.StableNodeID{"ABC", "def", "gHiJ"},
},
}
syspolicy.RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
syspolicy.AllowedSuggestedExitNodes, tt.allowPolicy,
))
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
mh := mockSyspolicyHandler{
t: t,
}
if tt.allowPolicy != nil {
mh.stringArrayPolicies = map[syspolicy.Key][]string{
syspolicy.AllowedSuggestedExitNodes: tt.allowPolicy,
}
}
syspolicy.SetHandlerForTest(t, &mh)
got := fillAllowedSuggestions()
if got == nil {
@@ -4434,35 +4538,3 @@ func TestLoginNotifications(t *testing.T) {
})
}
}
// TestConfigFileReload tests that the LocalBackend reloads its configuration
// when the configuration file changes.
func TestConfigFileReload(t *testing.T) {
cfg1 := `{"Hostname": "foo", "Version": "alpha0"}`
f := filepath.Join(t.TempDir(), "cfg")
must.Do(os.WriteFile(f, []byte(cfg1), 0600))
sys := new(tsd.System)
sys.InitialConfig = must.Get(conffile.Load(f))
lb := newTestLocalBackendWithSys(t, sys)
must.Do(lb.Start(ipn.Options{}))
lb.mu.Lock()
hn := lb.hostinfo.Hostname
lb.mu.Unlock()
if hn != "foo" {
t.Fatalf("got %q; want %q", hn, "foo")
}
cfg2 := `{"Hostname": "bar", "Version": "alpha0"}`
must.Do(os.WriteFile(f, []byte(cfg2), 0600))
if !must.Get(lb.ReloadConfig()) {
t.Fatal("reload failed")
}
lb.mu.Lock()
hn = lb.hostinfo.Hostname
lb.mu.Unlock()
if hn != "bar" {
t.Fatalf("got %q; want %q", hn, "bar")
}
}

View File

@@ -62,8 +62,7 @@ import (
"tailscale.com/util/osdiag"
"tailscale.com/util/progresstracking"
"tailscale.com/util/rands"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/testenv"
"tailscale.com/version"
"tailscale.com/wgengine/magicsock"
)
@@ -78,7 +77,6 @@ var handler = map[string]localAPIHandler{
"cert/": (*Handler).serveCert,
"file-put/": (*Handler).serveFilePut,
"files/": (*Handler).serveFiles,
"policy/": (*Handler).servePolicy,
"profiles/": (*Handler).serveProfiles,
// The other /localapi/v0/NAME handlers are exact matches and contain only NAME
@@ -572,9 +570,15 @@ func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) {
clientmetric.WritePrometheusExpositionFormat(w)
}
// serveUserMetrics returns user-facing metrics in Prometheus text
// exposition format.
// TODO(kradalby): Remove this once we have landed on a final set of
// metrics to export to clients and consider the metrics stable.
var debugUsermetricsEndpoint = envknob.RegisterBool("TS_DEBUG_USER_METRICS")
func (h *Handler) serveUserMetrics(w http.ResponseWriter, r *http.Request) {
if !testenv.InTest() && !debugUsermetricsEndpoint() {
http.Error(w, "usermetrics debug flag not enabled", http.StatusForbidden)
return
}
h.b.UserMetricsRegistry().Handler(w, r)
}
@@ -1335,53 +1339,6 @@ func (h *Handler) servePrefs(w http.ResponseWriter, r *http.Request) {
e.Encode(prefs)
}
func (h *Handler) servePolicy(w http.ResponseWriter, r *http.Request) {
if !h.PermitRead {
http.Error(w, "policy access denied", http.StatusForbidden)
return
}
suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/policy/")
if !ok {
http.Error(w, "misconfigured", http.StatusInternalServerError)
return
}
var scope setting.PolicyScope
if suffix == "" {
scope = setting.DefaultScope()
} else if err := scope.UnmarshalText([]byte(suffix)); err != nil {
http.Error(w, fmt.Sprintf("%q is not a valid scope", suffix), http.StatusBadRequest)
return
}
policy, err := rsop.PolicyFor(scope)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var effectivePolicy *setting.Snapshot
switch r.Method {
case "GET":
effectivePolicy = policy.Get()
case "POST":
effectivePolicy, err = policy.Reload()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
default:
http.Error(w, "unsupported method", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w)
e.SetIndent("", "\t")
e.Encode(effectivePolicy)
}
type resJSON struct {
Error string `json:",omitempty"`
}

View File

@@ -13,27 +13,19 @@ import (
"time"
"tailscale.com/ipn"
"tailscale.com/ipn/store/mem"
"tailscale.com/kube/kubeapi"
"tailscale.com/kube/kubeclient"
"tailscale.com/types/logger"
)
// TODO(irbekrm): should we bump this? should we have retries? See tailscale/tailscale#13024
const timeout = 5 * time.Second
// Store is an ipn.StateStore that uses a Kubernetes Secret for persistence.
type Store struct {
client kubeclient.Client
canPatch bool
secretName string
// memory holds the latest tailscale state. Writes write state to a kube Secret and memory, Reads read from
// memory.
memory mem.Store
}
// New returns a new Store that persists to the named Secret.
// New returns a new Store that persists to the named secret.
func New(_ logger.Logf, secretName string) (*Store, error) {
c, err := kubeclient.New()
if err != nil {
@@ -47,16 +39,11 @@ func New(_ logger.Logf, secretName string) (*Store, error) {
if err != nil {
return nil, err
}
s := &Store{
return &Store{
client: c,
canPatch: canPatch,
secretName: secretName,
}
// Load latest state from kube Secret if it already exists.
if err := s.loadState(); err != nil && err != ipn.ErrStateNotExist {
return nil, fmt.Errorf("error loading state from kube Secret: %w", err)
}
return s, nil
}, nil
}
func (s *Store) SetDialer(d func(ctx context.Context, network, address string) (net.Conn, error)) {
@@ -67,17 +54,37 @@ func (s *Store) String() string { return "kube.Store" }
// ReadState implements the StateStore interface.
func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) {
return s.memory.ReadState(ipn.StateKey(sanitizeKey(id)))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
secret, err := s.client.GetSecret(ctx, s.secretName)
if err != nil {
if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 {
return nil, ipn.ErrStateNotExist
}
return nil, err
}
b, ok := secret.Data[sanitizeKey(id)]
if !ok {
return nil, ipn.ErrStateNotExist
}
return b, nil
}
func sanitizeKey(k ipn.StateKey) string {
// The only valid characters in a Kubernetes secret key are alphanumeric, -,
// _, and .
return strings.Map(func(r rune) rune {
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' {
return r
}
return '_'
}, string(k))
}
// WriteState implements the StateStore interface.
func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) {
defer func() {
if err == nil {
s.memory.WriteState(ipn.StateKey(sanitizeKey(id)), bs)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
func (s *Store) WriteState(id ipn.StateKey, bs []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
secret, err := s.client.GetSecret(ctx, s.secretName)
@@ -130,29 +137,3 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) {
}
return err
}
func (s *Store) loadState() error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
secret, err := s.client.GetSecret(ctx, s.secretName)
if err != nil {
if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 {
return ipn.ErrStateNotExist
}
return err
}
s.memory.LoadFromMap(secret.Data)
return nil
}
func sanitizeKey(k ipn.StateKey) string {
// The only valid characters in a Kubernetes secret key are alphanumeric, -,
// _, and .
return strings.Map(func(r rune) rune {
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' {
return r
}
return '_'
}, string(k))
}

View File

@@ -9,10 +9,8 @@ import (
"encoding/json"
"sync"
xmaps "golang.org/x/exp/maps"
"tailscale.com/ipn"
"tailscale.com/types/logger"
"tailscale.com/util/mak"
)
// New returns a new Store.
@@ -30,7 +28,6 @@ type Store struct {
func (s *Store) String() string { return "mem.Store" }
// ReadState implements the StateStore interface.
// It returns ipn.ErrStateNotExist if the state does not exist.
func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -42,7 +39,6 @@ func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) {
}
// WriteState implements the StateStore interface.
// It never returns an error.
func (s *Store) WriteState(id ipn.StateKey, bs []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -53,19 +49,6 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) error {
return nil
}
// LoadFromMap loads the in-memory cache from the provided map.
// Any existing content is cleared, and the provided map is
// copied into the cache.
func (s *Store) LoadFromMap(m map[string][]byte) {
s.mu.Lock()
defer s.mu.Unlock()
xmaps.Clear(s.cache)
for k, v := range m {
mak.Set(&s.cache, ipn.StateKey(k), v)
}
return
}
// LoadFromJSON attempts to unmarshal json content into the
// in-memory cache.
func (s *Store) LoadFromJSON(data []byte) error {

View File

@@ -381,7 +381,6 @@ _Appears in:_
| `nodeName` _string_ | Proxy Pod's node name.<br />https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | |
| `nodeSelector` _object (keys:string, values:string)_ | Proxy Pod's node selector.<br />By default Tailscale Kubernetes operator does not apply any node<br />selector.<br />https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | |
| `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | Proxy Pod's tolerations.<br />By default Tailscale Kubernetes operator does not apply any<br />tolerations.<br />https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | |
| `topologySpreadConstraints` _[TopologySpreadConstraint](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#topologyspreadconstraint-v1-core) array_ | Proxy Pod's topology spread constraints.<br />By default Tailscale Kubernetes operator does not apply any topology spread constraints.<br />https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ | | |
#### ProxyClass

View File

@@ -154,11 +154,7 @@ type Pod struct {
// https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling
// +optional
Tolerations []corev1.Toleration `json:"tolerations,omitempty"`
// Proxy Pod's topology spread constraints.
// By default Tailscale Kubernetes operator does not apply any topology spread constraints.
// https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/
// +optional
TopologySpreadConstraints []corev1.TopologySpreadConstraint `json:"topologySpreadConstraints,omitempty"`
}
type Metrics struct {

View File

@@ -392,13 +392,6 @@ func (in *Pod) DeepCopyInto(out *Pod) {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
if in.TopologySpreadConstraints != nil {
in, out := &in.TopologySpreadConstraints, &out.TopologySpreadConstraints
*out = make([]corev1.TopologySpreadConstraint, len(*in))
for i := range *in {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Pod.

View File

@@ -73,13 +73,13 @@ See also the dependencies in the [Tailscale CLI][].
- [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE))
- [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE))
- [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE))
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.28.0:LICENSE))
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE))
- [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE))
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE))
- [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE))
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.26.0:LICENSE))
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.25.0:LICENSE))
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.19.0:LICENSE))
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE))
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE))
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE))
- [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE))
- [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE))
- [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE))

View File

@@ -65,15 +65,15 @@ Windows][]. See also the dependencies in the [Tailscale CLI][].
- [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE))
- [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE))
- [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE))
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.28.0:LICENSE))
- [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE))
- [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE))
- [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE))
- [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE))
- [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE))
- [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE))
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.26.0:LICENSE))
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.25.0:LICENSE))
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.19.0:LICENSE))
- [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE))
- [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE))
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE))
- [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2))
- [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3))
- [gopkg.in/Knetic/govaluate.v3](https://pkg.go.dev/gopkg.in/Knetic/govaluate.v3) ([MIT](https://github.com/Knetic/govaluate/blob/v3.0.0/LICENSE))

View File

@@ -63,9 +63,7 @@ type Manager struct {
mu sync.Mutex // guards following
// config is the last configuration we successfully compiled or nil if there
// was any failure applying the last configuration.
config *Config
forceAAAA bool // whether client wants MagicDNS AAAA even if unsure of host's IPv6 status
forceFallbackResolvers []*dnstype.Resolver
config *Config
}
// NewManagers created a new manager from the given config.
@@ -130,28 +128,6 @@ func (m *Manager) GetBaseConfig() (OSConfig, error) {
return m.os.GetBaseConfig()
}
func (m *Manager) GetForceAAAA() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.forceAAAA
}
func (m *Manager) SetForceAAAA(v bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.forceAAAA = v
}
// SetForceFallbackResolvers sets the resolvers to use to override
// the fallback resolvers if the control plane doesn't send any.
//
// It takes ownership of the provided slice.
func (m *Manager) SetForceFallbackResolvers(resolvers []*dnstype.Resolver) {
m.mu.Lock()
defer m.mu.Unlock()
m.forceFallbackResolvers = resolvers
}
// setLocked sets the DNS configuration.
//
// m.mu must be held.
@@ -170,10 +146,6 @@ func (m *Manager) setLocked(cfg Config) error {
return err
}
if _, ok := rcfg.Routes["."]; !ok && len(m.forceFallbackResolvers) > 0 {
rcfg.Routes["."] = m.forceFallbackResolvers
}
m.logf("Resolvercfg: %v", logger.ArgWriter(func(w *bufio.Writer) {
rcfg.WriteToBufioWriter(w)
}))

View File

@@ -57,7 +57,6 @@ func (m *resolvdManager) SetDNS(config OSConfig) error {
if len(newSearch) > 1 {
newResolvConf = append(newResolvConf, []byte(strings.Join(newSearch, " "))...)
newResolvConf = append(newResolvConf, '\n')
}
err = m.fs.WriteFile(resolvConf, newResolvConf, 0644)
@@ -124,6 +123,6 @@ func (m resolvdManager) readResolvConf() (config OSConfig, err error) {
}
func removeSearchLines(orig []byte) []byte {
re := regexp.MustCompile(`(?ms)^search\s+.+$`)
re := regexp.MustCompile(`(?m)^search\s+.+$`)
return re.ReplaceAll(orig, []byte(""))
}

View File

@@ -27,7 +27,6 @@ import (
"tailscale.com/health"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/tstest"
"tailscale.com/types/dnstype"
)
@@ -277,8 +276,6 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
tb.Fatal("cannot skip both UDP and TCP servers")
}
logf := tstest.WhileTestRunningLogger(tb)
tcpResponse := make([]byte, len(response)+2)
binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
copy(tcpResponse[2:], response)
@@ -332,13 +329,13 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
// Read the length header, then the buffer
var length uint16
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
logf("error reading length header: %v", err)
tb.Logf("error reading length header: %v", err)
return
}
req := make([]byte, length)
n, err := io.ReadFull(conn, req)
if err != nil {
logf("error reading query: %v", err)
tb.Logf("error reading query: %v", err)
return
}
req = req[:n]
@@ -346,7 +343,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
// Write response
if _, err := conn.Write(tcpResponse); err != nil {
logf("error writing response: %v", err)
tb.Logf("error writing response: %v", err)
return
}
}
@@ -370,7 +367,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
handleUDP := func(addr netip.AddrPort, req []byte) {
onRequest(false, req)
if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil {
logf("error writing response: %v", err)
tb.Logf("error writing response: %v", err)
}
}
@@ -393,7 +390,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
tb.Cleanup(func() {
tcpLn.Close()
udpLn.Close()
logf("waiting for listeners to finish...")
tb.Logf("waiting for listeners to finish...")
wg.Wait()
})
return
@@ -453,8 +450,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
}
func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
logf := tstest.WhileTestRunningLogger(tb)
netMon, err := netmon.New(logf)
netMon, err := netmon.New(tb.Logf)
if err != nil {
tb.Fatal(err)
}
@@ -462,7 +458,7 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports
var dialer tsdial.Dialer
dialer.SetNetMon(netMon)
fwd := newForwarder(logf, netMon, nil, &dialer, new(health.Tracker), nil)
fwd := newForwarder(tb.Logf, netMon, nil, &dialer, new(health.Tracker), nil)
if modify != nil {
modify(fwd)
}

View File

@@ -392,11 +392,10 @@ type probePlan map[string][]probe
// sortRegions returns the regions of dm first sorted
// from fastest to slowest (based on the 'last' report),
// end in regions that have no data.
func sortRegions(dm *tailcfg.DERPMap, last *Report, preferredDERP int) (prev []*tailcfg.DERPRegion) {
func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) {
prev = make([]*tailcfg.DERPRegion, 0, len(dm.Regions))
for _, reg := range dm.Regions {
// include an otherwise avoid region if it is the current preferred region
if reg.Avoid && reg.RegionID != preferredDERP {
if reg.Avoid {
continue
}
prev = append(prev, reg)
@@ -421,19 +420,9 @@ func sortRegions(dm *tailcfg.DERPMap, last *Report, preferredDERP int) (prev []*
// a full report, all regions are scanned.)
const numIncrementalRegions = 3
// makeProbePlan generates the probe plan for a DERPMap, given the most recent
// report and the current home DERP. preferredDERP is passed independently of
// last (report) because last is currently nil'd to indicate a desire for a full
// netcheck.
//
// TODO(raggi,jwhited): refactor the callers and this function to be more clear
// about full vs. incremental netchecks, and remove the need for the history
// hiding. This was avoided in an incremental change due to exactly this kind of
// distant coupling.
// TODO(raggi): change from "preferred DERP" from a historical report to "home
// DERP" as in what DERP is the current home connection, this would further
// reduce flap events.
func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, preferredDERP int) (plan probePlan) {
// makeProbePlan generates the probe plan for a DERPMap, given the most
// recent report and whether IPv6 is configured on an interface.
func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (plan probePlan) {
if last == nil || len(last.RegionLatency) == 0 {
return makeProbePlanInitial(dm, ifState)
}
@@ -444,34 +433,9 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, pre
had4 := len(last.RegionV4Latency) > 0
had6 := len(last.RegionV6Latency) > 0
hadBoth := have6if && had4 && had6
// #13969 ensure that the home region is always probed.
// If a netcheck has unstable latency, such as a user with large amounts of
// bufferbloat or a highly congested connection, there are cases where a full
// netcheck may observe a one-off high latency to the current home DERP. Prior
// to the forced inclusion of the home DERP, this would result in an
// incremental netcheck following such an event to cause a home DERP move, with
// restoration back to the home DERP on the next full netcheck ~5 minutes later
// - which is highly disruptive when it causes shifts in geo routed subnet
// routers. By always including the home DERP in the incremental netcheck, we
// ensure that the home DERP is always probed, even if it observed a recenet
// poor latency sample. This inclusion enables the latency history checks in
// home DERP selection to still take effect.
// planContainsHome indicates whether the home DERP has been added to the probePlan,
// if there is no prior home, then there's no home to additionally include.
planContainsHome := preferredDERP == 0
for ri, reg := range sortRegions(dm, last, preferredDERP) {
regIsHome := reg.RegionID == preferredDERP
if ri >= numIncrementalRegions {
// planned at least numIncrementalRegions regions and that includes the
// last home region (or there was none), plan complete.
if planContainsHome {
break
}
// planned at least numIncrementalRegions regions, but not the home region,
// check if this is the home region, if not, skip it.
if !regIsHome {
continue
}
for ri, reg := range sortRegions(dm, last) {
if ri == numIncrementalRegions {
break
}
var p4, p6 []probe
do4 := have4if
@@ -482,7 +446,7 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, pre
tries := 1
isFastestTwo := ri < 2
if isFastestTwo || regIsHome {
if isFastestTwo {
tries = 2
} else if hadBoth {
// For dual stack machines, make the 3rd & slower nodes alternate
@@ -493,15 +457,14 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, pre
do4, do6 = false, true
}
}
if !regIsHome && !isFastestTwo && !had6 {
if !isFastestTwo && !had6 {
do6 = false
}
if regIsHome {
if reg.RegionID == last.PreferredDERP {
// But if we already had a DERP home, try extra hard to
// make sure it's there so we don't flip flop around.
tries = 4
planContainsHome = true
}
for try := 0; try < tries; try++ {
@@ -826,10 +789,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
c.curState = rs
last := c.last
// Extract preferredDERP from the last report, if available. This will be used
// in captive portal detection and DERP flapping suppression. Ideally this would
// be the current active home DERP rather than the last report preferred DERP,
// but only the latter is presently available.
// Even if we're doing a non-incremental update, we may want to try our
// preferred DERP region for captive portal detection. Save that, if we
// have it.
var preferredDERP int
if last != nil {
preferredDERP = last.PreferredDERP
@@ -886,7 +848,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
var plan probePlan
if opts == nil || !opts.OnlyTCP443 {
plan = makeProbePlan(dm, ifState, last, preferredDERP)
plan = makeProbePlan(dm, ifState, last)
}
// If we're doing a full probe, also check for a captive portal. We

View File

@@ -357,15 +357,6 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) {
wantPrevLen: 3,
wantDERP: 2, // moved to d2 since d1 is gone
},
{
name: "preferred_derp_hysteresis_no_switch_pct",
steps: []step{
{0 * time.Second, report("d1", 34*time.Millisecond, "d2", 35*time.Millisecond)},
{1 * time.Second, report("d1", 34*time.Millisecond, "d2", 23*time.Millisecond)},
},
wantPrevLen: 2,
wantDERP: 1, // diff is 11ms, but d2 is greater than 2/3s of d1
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -599,40 +590,6 @@ func TestMakeProbePlan(t *testing.T) {
"region-3-v4": []probe{p("3a", 4)},
},
},
{
// #13969: ensure that the prior/current home region is always included in
// probe plans, so that we don't flap between regions due to a single major
// netcheck having excluded the home region due to a spuriously high sample.
name: "ensure_home_region_inclusion",
dm: basicMap,
have6if: true,
last: &Report{
RegionLatency: map[int]time.Duration{
1: 50 * time.Millisecond,
2: 20 * time.Millisecond,
3: 30 * time.Millisecond,
4: 40 * time.Millisecond,
},
RegionV4Latency: map[int]time.Duration{
1: 50 * time.Millisecond,
2: 20 * time.Millisecond,
},
RegionV6Latency: map[int]time.Duration{
3: 30 * time.Millisecond,
4: 40 * time.Millisecond,
},
PreferredDERP: 1,
},
want: probePlan{
"region-1-v4": []probe{p("1a", 4), p("1a", 4, 60*ms), p("1a", 4, 220*ms), p("1a", 4, 330*ms)},
"region-1-v6": []probe{p("1a", 6), p("1a", 6, 60*ms), p("1a", 6, 220*ms), p("1a", 6, 330*ms)},
"region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)},
"region-2-v6": []probe{p("2a", 6), p("2b", 6, 24*ms)},
"region-3-v4": []probe{p("3a", 4), p("3b", 4, 36*ms)},
"region-3-v6": []probe{p("3a", 6), p("3b", 6, 36*ms)},
"region-4-v4": []probe{p("4a", 4)},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -640,11 +597,7 @@ func TestMakeProbePlan(t *testing.T) {
HaveV6: tt.have6if,
HaveV4: !tt.no4,
}
preferredDERP := 0
if tt.last != nil {
preferredDERP = tt.last.PreferredDERP
}
got := makeProbePlan(tt.dm, ifState, tt.last, preferredDERP)
got := makeProbePlan(tt.dm, ifState, tt.last)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unexpected plan; got:\n%v\nwant:\n%v\n", got, tt.want)
}
@@ -817,7 +770,7 @@ func TestSortRegions(t *testing.T) {
report.RegionLatency[3] = time.Second * time.Duration(6)
report.RegionLatency[4] = time.Second * time.Duration(0)
report.RegionLatency[5] = time.Second * time.Duration(2)
sortedMap := sortRegions(unsortedMap, report, 0)
sortedMap := sortRegions(unsortedMap, report)
// Sorting by latency this should result in rid: 5, 2, 1, 3
// rid 4 with latency 0 should be at the end

View File

@@ -81,12 +81,6 @@ const (
addrTypeNotSupported replyCode = 8
)
// UDP conn default buffer size and read timeout.
const (
bufferSize = 8 * 1024
readTimeout = 5 * time.Second
)
// Server is a SOCKS5 proxy server.
type Server struct {
// Logf optionally specifies the logger to use.
@@ -149,8 +143,7 @@ type Conn struct {
clientConn net.Conn
request *request
udpClientAddr net.Addr
udpTargetConns map[socksAddr]net.Conn
udpClientAddr net.Addr
}
// Run starts the new connection.
@@ -283,6 +276,15 @@ func (c *Conn) handleUDP() error {
}
defer clientUDPConn.Close()
serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
defer serverUDPConn.Close()
bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
if err != nil {
return err
@@ -303,32 +305,25 @@ func (c *Conn) handleUDP() error {
}
c.clientConn.Write(buf)
return c.transferUDP(c.clientConn, clientUDPConn)
return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
}
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error {
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
const bufferSize = 8 * 1024
const readTimeout = 5 * time.Second
// client -> target
go func() {
defer cancel()
c.udpTargetConns = make(map[socksAddr]net.Conn)
// close all target udp connections when the client connection is closed
defer func() {
for _, conn := range c.udpTargetConns {
_ = conn.Close()
}
}()
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPRequest(ctx, clientConn, buf)
err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
if err != nil {
if isTimeout(err) {
continue
@@ -342,6 +337,29 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
}
}()
// target -> client
go func() {
defer cancel()
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) {
return
}
c.logf("udp transfer: handle udp response fail: %v", err)
}
}
}
}()
// A UDP association terminates when the TCP connection that the UDP
// ASSOCIATE request arrived on terminates. RFC1928
_, err := io.Copy(io.Discard, associatedTCP)
@@ -351,50 +369,11 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
return err
}
func (c *Conn) getOrDialTargetConn(
ctx context.Context,
clientConn net.PacketConn,
targetAddr socksAddr,
) (net.Conn, error) {
conn, exist := c.udpTargetConns[targetAddr]
if exist {
return conn, nil
}
conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort())
if err != nil {
return nil, err
}
c.udpTargetConns[targetAddr] = conn
// target -> client
go func() {
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPResponse(clientConn, targetAddr, conn, buf)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return
}
c.logf("udp transfer: handle udp response fail: %v", err)
}
}
}
}()
return conn, nil
}
func (c *Conn) handleUDPRequest(
ctx context.Context,
clientConn net.PacketConn,
targetConn net.PacketConn,
buf []byte,
readTimeout time.Duration,
) error {
// add a deadline for the read to avoid blocking forever
_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
@@ -407,35 +386,38 @@ func (c *Conn) handleUDPRequest(
if err != nil {
return fmt.Errorf("parse udp request: %w", err)
}
targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr)
targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
if err != nil {
return fmt.Errorf("dial target %s fail: %w", req.addr, err)
c.logf("resolve target addr fail: %v", err)
}
nn, err := targetConn.Write(data)
nn, err := targetConn.WriteTo(data, targetAddr)
if err != nil {
return fmt.Errorf("write to target %s fail: %w", req.addr, err)
return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
}
if nn != len(data) {
return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
}
return nil
}
func (c *Conn) handleUDPResponse(
targetConn net.PacketConn,
clientConn net.PacketConn,
targetAddr socksAddr,
targetConn net.Conn,
buf []byte,
readTimeout time.Duration,
) error {
// add a deadline for the read to avoid blocking forever
_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
n, err := targetConn.Read(buf)
n, addr, err := targetConn.ReadFrom(buf)
if err != nil {
return fmt.Errorf("read from target: %w", err)
}
hdr := udpRequest{addr: targetAddr}
host, port, err := splitHostPort(addr.String())
if err != nil {
return fmt.Errorf("split host port: %w", err)
}
hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
pkt, err := hdr.marshal()
if err != nil {
return fmt.Errorf("marshal udp request: %w", err)
@@ -645,15 +627,10 @@ func (s socksAddr) marshal() ([]byte, error) {
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
return pkt, nil
}
func (s socksAddr) hostPort() string {
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
}
func (s socksAddr) String() string {
return s.hostPort()
}
// response contains the contents of
// a response packet sent from the proxy
// to the client.

View File

@@ -169,25 +169,12 @@ func TestReadPassword(t *testing.T) {
func TestUDP(t *testing.T) {
// backend UDP server which we'll use SOCKS5 to connect to
newUDPEchoServer := func() net.PacketConn {
listener, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatal(err)
}
go udpEchoServer(listener)
return listener
listener, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatal(err)
}
const echoServerNumber = 3
echoServerListener := make([]net.PacketConn, echoServerNumber)
for i := 0; i < echoServerNumber; i++ {
echoServerListener[i] = newUDPEchoServer()
}
defer func() {
for i := 0; i < echoServerNumber; i++ {
_ = echoServerListener[i].Close()
}
}()
backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
go udpEchoServer(listener)
// SOCKS5 server
socks5, err := net.Listen("tcp", ":0")
@@ -197,93 +184,84 @@ func TestUDP(t *testing.T) {
socks5Port := socks5.Addr().(*net.TCPAddr).Port
go socks5Server(socks5)
// make a socks5 udpAssociate conn
newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) {
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
if err != nil {
t.Fatal(err)
}
_, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := conn.Read(buf) // server hello
if err != nil {
t.Fatal(err)
}
if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired {
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
}
targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
targetAddrPkt, err := targetAddr.marshal()
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust
if err != nil {
t.Fatal(err)
}
n, err = conn.Read(buf) // server response
if err != nil {
t.Fatal(err)
}
if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) {
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
}
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
if err != nil {
t.Fatal(err)
}
return conn, udpProxySocksAddr
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
if err != nil {
t.Fatal(err)
}
_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := conn.Read(buf) // server hello
if err != nil {
t.Fatal(err)
}
if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
}
conn, udpProxySocksAddr := newUdpAssociateConn()
defer conn.Close()
targetAddr := socksAddr{
addrType: domainName,
addr: "localhost",
port: uint16(backendServerPort),
}
targetAddrPkt, err := targetAddr.marshal()
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
if err != nil {
t.Fatal(err)
}
sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) {
udpPayload, err := (&udpRequest{addr: addr}).marshal()
if err != nil {
t.Fatal(err)
}
udpPayload = append(udpPayload, body...)
_, err = socks5UDPConn.Write(udpPayload)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := socks5UDPConn.Read(buf)
if err != nil {
t.Fatal(err)
}
_, responseBody, err = parseUDPRequest(buf[:n])
if err != nil {
t.Fatal(err)
}
return responseBody
n, err = conn.Read(buf) // server response
if err != nil {
t.Fatal(err)
}
if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
}
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
if err != nil {
t.Fatal(err)
}
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
if err != nil {
t.Fatal(err)
}
socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr)
udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
if err != nil {
t.Fatal(err)
}
defer socks5UDPConn.Close()
for i := 0; i < echoServerNumber; i++ {
port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port
addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)}
requestBody := []byte(fmt.Sprintf("Test %d", i))
responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody)
if !bytes.Equal(requestBody, responseBody) {
t.Fatalf("got: %q want: %q", responseBody, requestBody)
}
udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
if err != nil {
t.Fatal(err)
}
udpPayload = append(udpPayload, []byte("Test")...)
_, err = udpConn.Write(udpPayload) // send udp package
if err != nil {
t.Fatal(err)
}
n, _, err = udpConn.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
if err != nil {
t.Fatal(err)
}
if string(responseBody) != "Test" {
t.Fatalf("got: %q want: Test", responseBody)
}
err = udpConn.Close()
if err != nil {
t.Fatal(err)
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -279,13 +279,7 @@ func setNetMon(netMon *netmon.Monitor) {
if ifName == "" {
return
}
// DefaultRouteInterface and Interface are gathered at different points in time.
// Check for existence first, to avoid a nil pointer dereference.
iface, ok := state.Interface[ifName]
if !ok {
return
}
ifIndex := iface.Index
ifIndex := state.Interface[ifName].Index
sockStats.mu.Lock()
defer sockStats.mu.Unlock()
// Ignore changes to unknown interfaces -- it would require

View File

@@ -213,14 +213,24 @@ type Wrapper struct {
}
type metrics struct {
inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[usermetric.DropLabels]
outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[usermetric.DropLabels]
inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel]
outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel]
}
func registerMetrics(reg *usermetric.Registry) *metrics {
return &metrics{
inboundDroppedPacketsTotal: reg.DroppedPacketsInbound(),
outboundDroppedPacketsTotal: reg.DroppedPacketsOutbound(),
inboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel](
reg,
"tailscaled_inbound_dropped_packets_total",
"counter",
"Counts the number of dropped packets received by the node from other peers",
),
outboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel](
reg,
"tailscaled_outbound_dropped_packets_total",
"counter",
"Counts the number of packets dropped while being sent to other peers",
),
}
}
@@ -861,22 +871,7 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf
return res, gro
}
}
if resp := t.filtRunOut(p, pc); resp != filter.Accept {
return resp, gro
}
if t.PostFilterPacketOutboundToWireGuard != nil {
if res := t.PostFilterPacketOutboundToWireGuard(p, t); res.IsDrop() {
return res, gro
}
}
return filter.Accept, gro
}
// filtRunOut runs the outbound packet filter on p.
// It uses pc to determine if the packet is to a jailed peer and should be
// filtered with the jailed filter.
func (t *Wrapper) filtRunOut(p *packet.Parsed, pc *peerConfigTable) filter.Response {
// If the outbound packet is to a jailed peer, use our jailed peer
// packet filter.
var filt *filter.Filter
@@ -886,17 +881,23 @@ func (t *Wrapper) filtRunOut(p *packet.Parsed, pc *peerConfigTable) filter.Respo
filt = t.filter.Load()
}
if filt == nil {
return filter.Drop
return filter.Drop, gro
}
if filt.RunOut(p, t.filterFlags) != filter.Accept {
metricPacketOutDropFilter.Add(1)
t.metrics.outboundDroppedPacketsTotal.Add(usermetric.DropLabels{
Reason: usermetric.ReasonACL,
t.metrics.outboundDroppedPacketsTotal.Add(dropPacketLabel{
Reason: DropReasonACL,
}, 1)
return filter.Drop
return filter.Drop, gro
}
return filter.Accept
if t.PostFilterPacketOutboundToWireGuard != nil {
if res := t.PostFilterPacketOutboundToWireGuard(p, t); res.IsDrop() {
return res, gro
}
}
return filter.Accept, gro
}
// noteActivity records that there was a read or write at the current time.
@@ -1060,11 +1061,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p)
p.Decode(pkt)
response, _ := t.filterPacketOutboundToWireGuard(p, pc, nil)
if response != filter.Accept {
metricPacketOutDrop.Add(1)
return
}
invertGSOChecksum(pkt, gso)
pc.snat(p)
@@ -1162,8 +1158,8 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
if outcome != filter.Accept {
metricPacketInDropFilter.Add(1)
t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{
Reason: usermetric.ReasonACL,
t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{
Reason: DropReasonACL,
}, 1)
// Tell them, via TSMP, we're dropping them due to the ACL.
@@ -1243,8 +1239,8 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
t.noteActivity()
_, err := t.tdevWrite(buffs, offset)
if err != nil {
t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{
Reason: usermetric.ReasonError,
t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{
Reason: DropReasonError,
}, int64(len(buffs)))
}
return len(buffs), err
@@ -1486,6 +1482,20 @@ var (
metricPacketOutDropSelfDisco = clientmetric.NewCounter("tstun_out_to_wg_drop_self_disco")
)
type DropReason string
const (
DropReasonACL DropReason = "acl"
DropReasonError DropReason = "error"
)
type dropPacketLabel struct {
// Reason indicates what we have done with the packet, and has the following values:
// - acl (rejected packets because of ACL)
// - error (rejected packets because of an error)
Reason DropReason
}
func (t *Wrapper) InstallCaptureHook(cb capture.Callback) {
t.captureHook.Store(cb)
}

View File

@@ -441,13 +441,13 @@ func TestFilter(t *testing.T) {
}
var metricInboundDroppedPacketsACL, metricInboundDroppedPacketsErr, metricOutboundDroppedPacketsACL int64
if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok {
if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok {
metricInboundDroppedPacketsACL = m.Value()
}
if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonError}).(*expvar.Int); ok {
if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonError}).(*expvar.Int); ok {
metricInboundDroppedPacketsErr = m.Value()
}
if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok {
if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok {
metricOutboundDroppedPacketsACL = m.Value()
}

View File

@@ -74,74 +74,25 @@ import (
crand "crypto/rand"
"fmt"
"log"
"maps"
"net"
"net/http"
"net/url"
"path"
"slices"
"strings"
"github.com/gorilla/csrf"
)
// CSP is the value of a Content-Security-Policy header. Keys are CSP
// directives (like "default-src") and values are source expressions (like
// "'self'" or "https://tailscale.com"). A nil slice value is allowed for some
// directives like "upgrade-insecure-requests" that don't expect a list of
// source definitions.
type CSP map[string][]string
// DefaultCSP is the recommended CSP to use when not loading resources from
// other domains and not embedding the current website. If you need to tweak
// the CSP, it is recommended to extend DefaultCSP instead of writing your own
// from scratch.
func DefaultCSP() CSP {
return CSP{
"default-src": {"self"}, // origin is the only valid source for all content types
"frame-ancestors": {"none"}, // disallow framing of the page
"form-action": {"self"}, // disallow form submissions to other origins
"base-uri": {"self"}, // disallow base URIs from other origins
// TODO(awly): consider upgrade-insecure-requests in SecureContext
// instead, as this is deprecated.
"block-all-mixed-content": nil, // disallow mixed content when serving over HTTPS
}
}
// Set sets the values for a given directive. Empty values are allowed, if the
// directive doesn't expect any (like "upgrade-insecure-requests").
func (csp CSP) Set(directive string, values ...string) {
csp[directive] = values
}
// Add adds a source expression to an existing directive.
func (csp CSP) Add(directive, value string) {
csp[directive] = append(csp[directive], value)
}
// Del deletes a directive and all its values.
func (csp CSP) Del(directive string) {
delete(csp, directive)
}
func (csp CSP) String() string {
keys := slices.Collect(maps.Keys(csp))
slices.Sort(keys)
var s strings.Builder
for _, k := range keys {
s.WriteString(k)
for _, v := range csp[k] {
// Special values like 'self', 'none', 'unsafe-inline', etc., must
// be quoted. Do it implicitly as a convenience here.
if !strings.Contains(v, ".") && len(v) > 1 && v[0] != '\'' && v[len(v)-1] != '\'' {
v = "'" + v + "'"
}
s.WriteString(" " + v)
}
s.WriteString("; ")
}
return strings.TrimSpace(s.String())
}
// The default Content-Security-Policy header.
var defaultCSP = strings.Join([]string{
`default-src 'self'`, // origin is the only valid source for all content types
`script-src 'self'`, // disallow inline javascript
`frame-ancestors 'none'`, // disallow framing of the page
`form-action 'self'`, // disallow form submissions to other origins
`base-uri 'self'`, // disallow base URIs from other origins
`block-all-mixed-content`, // disallow mixed content when serving over HTTPS
`object-src 'self'`, // disallow embedding of resources from other origins
}, "; ")
// The default Strict-Transport-Security header. This header tells the browser
// to exclusively use HTTPS for all requests to the origin for the next year.
@@ -179,9 +130,6 @@ type Config struct {
// startup.
CSRFSecret []byte
// CSP is the Content-Security-Policy header to return with BrowserMux
// responses.
CSP CSP
// CSPAllowInlineStyles specifies whether to include `style-src:
// unsafe-inline` in the Content-Security-Policy header to permit the use of
// inline CSS.
@@ -220,10 +168,6 @@ func (c *Config) setDefaults() error {
}
}
if c.CSP == nil {
c.CSP = DefaultCSP()
}
return nil
}
@@ -255,20 +199,16 @@ func NewServer(config Config) (*Server, error) {
if config.CookiesSameSiteLax {
sameSite = csrf.SameSiteLaxMode
}
if config.CSPAllowInlineStyles {
if _, ok := config.CSP["style-src"]; ok {
config.CSP.Add("style-src", "unsafe-inline")
} else {
config.CSP.Set("style-src", "self", "unsafe-inline")
}
}
s := &Server{
Config: config,
csp: config.CSP.String(),
csp: defaultCSP,
// only set Secure flag on CSRF cookies if we are in a secure context
// as otherwise the browser will reject the cookie
csrfProtect: csrf.Protect(config.CSRFSecret, csrf.Secure(config.SecureContext), csrf.SameSite(sameSite)),
}
if config.CSPAllowInlineStyles {
s.csp = defaultCSP + `; style-src 'self' 'unsafe-inline'`
}
s.h = cmp.Or(config.HTTPServer, &http.Server{})
if s.h.Handler != nil {
return nil, fmt.Errorf("use safeweb.Config.APIMux and safeweb.Config.BrowserMux instead of http.Server.Handler")
@@ -285,27 +225,12 @@ const (
browserHandler
)
func (h handlerType) String() string {
switch h {
case browserHandler:
return "browser"
case apiHandler:
return "api"
default:
return "unknown"
}
}
// checkHandlerType returns either apiHandler or browserHandler, depending on
// whether apiPattern or browserPattern is more specific (i.e. which pattern
// contains more pathname components). If they are equally specific, it returns
// unknownHandler.
func checkHandlerType(apiPattern, browserPattern string) handlerType {
apiPattern, browserPattern = path.Clean(apiPattern), path.Clean(browserPattern)
c := cmp.Compare(strings.Count(apiPattern, "/"), strings.Count(browserPattern, "/"))
if apiPattern == "/" || browserPattern == "/" {
c = cmp.Compare(len(apiPattern), len(browserPattern))
}
c := cmp.Compare(strings.Count(path.Clean(apiPattern), "/"), strings.Count(path.Clean(browserPattern), "/"))
switch {
case c > 0:
return apiHandler

View File

@@ -241,26 +241,18 @@ func TestCSRFProtection(t *testing.T) {
func TestContentSecurityPolicyHeader(t *testing.T) {
tests := []struct {
name string
csp CSP
apiRoute bool
wantCSP string
wantCSP bool
}{
{
name: "default CSP",
wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`,
},
{
name: "custom CSP",
csp: CSP{
"default-src": {"'self'", "https://tailscale.com"},
"upgrade-insecure-requests": nil,
},
wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`,
name: "default routes get CSP headers",
apiRoute: false,
wantCSP: true,
},
{
name: "`/api/*` routes do not get CSP headers",
apiRoute: true,
wantCSP: "",
wantCSP: false,
},
}
@@ -273,9 +265,9 @@ func TestContentSecurityPolicyHeader(t *testing.T) {
var s *Server
var err error
if tt.apiRoute {
s, err = NewServer(Config{APIMux: h, CSP: tt.csp})
s, err = NewServer(Config{APIMux: h})
} else {
s, err = NewServer(Config{BrowserMux: h, CSP: tt.csp})
s, err = NewServer(Config{BrowserMux: h})
}
if err != nil {
t.Fatal(err)
@@ -287,8 +279,8 @@ func TestContentSecurityPolicyHeader(t *testing.T) {
s.h.Handler.ServeHTTP(w, req)
resp := w.Result()
if got := resp.Header.Get("Content-Security-Policy"); got != tt.wantCSP {
t.Fatalf("content security policy want: %q; got: %q", tt.wantCSP, got)
if (resp.Header.Get("Content-Security-Policy") == "") == tt.wantCSP {
t.Fatalf("content security policy want: %v; got: %v", tt.wantCSP, resp.Header.Get("Content-Security-Policy"))
}
})
}
@@ -405,7 +397,7 @@ func TestCSPAllowInlineStyles(t *testing.T) {
csp := resp.Header.Get("Content-Security-Policy")
allowsStyles := strings.Contains(csp, "style-src 'self' 'unsafe-inline'")
if allowsStyles != allow {
t.Fatalf("CSP inline styles want: %v, got: %v in %q", allow, allowsStyles, csp)
t.Fatalf("CSP inline styles want: %v; got: %v", allow, allowsStyles)
}
})
}
@@ -535,13 +527,13 @@ func TestGetMoreSpecificPattern(t *testing.T) {
{
desc: "same prefix",
a: "/foo/bar/quux",
b: "/foo/bar/", // path.Clean will strip the trailing slash.
b: "/foo/bar/",
want: apiHandler,
},
{
desc: "almost same prefix, but not a path component",
a: "/goat/sheep/cheese",
b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash.
b: "/goat/sheepcheese/",
want: apiHandler,
},
{
@@ -562,12 +554,6 @@ func TestGetMoreSpecificPattern(t *testing.T) {
b: "///////",
want: unknownHandler,
},
{
desc: "root-level",
a: "/latest",
b: "/", // path.Clean will NOT strip the trailing slash.
want: apiHandler,
},
} {
t.Run(tt.desc, func(t *testing.T) {
got := checkHandlerType(tt.a, tt.b)

View File

@@ -149,8 +149,7 @@ type CapabilityVersion int
// - 104: 2024-08-03: SelfNodeV6MasqAddrForThisPeer now works
// - 105: 2024-08-05: Fixed SSH behavior on systems that use busybox (issue #12849)
// - 106: 2024-09-03: fix panic regression from cryptokey routing change (65fe0ba7b5)
// - 107: 2024-10-30: add App Connector to conffile (PR #13942)
const CurrentCapabilityVersion CapabilityVersion = 107
const CurrentCapabilityVersion CapabilityVersion = 106
type StableID string

View File

@@ -35,7 +35,7 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ
cc = "cc"
targetOS = cmp.Or(env.Get("GOOS", ""), nativeGOOS)
targetArch = cmp.Or(env.Get("GOARCH", ""), nativeGOARCH)
buildFlags = []string{}
buildFlags = []string{"-trimpath"}
cgoCflags = []string{"-O3", "-std=gnu11", "-g"}
cgoLdflags []string
ldflags []string
@@ -47,10 +47,6 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ
subcommand = argv[1]
}
if subcommand != "test" {
buildFlags = append(buildFlags, "-trimpath")
}
switch subcommand {
case "build", "env", "install", "run", "test", "list":
default:

View File

@@ -163,6 +163,7 @@ GOTOOLCHAIN=local (was <nil>)
TS_LINK_FAIL_REFLECT=0 (was <nil>)`,
wantArgv: []string{
"gocross", "test",
"-trimpath",
"-tags=tailscale_go,osusergo,netgo",
"-ldflags", "-X tailscale.com/version.longStamp=1.2.3-long -X tailscale.com/version.shortStamp=1.2.3 -X tailscale.com/version.gitCommitStamp=abcd -X tailscale.com/version.extraGitCommitStamp=defg '-extldflags=-static'",
"-race",

View File

@@ -121,17 +121,11 @@ type Server struct {
// field at zero unless you know what you are doing.
Port uint16
// PreStart is an optional hook to run just before LocalBackend.Start,
// to reconfigure internals. If it returns an error, Server.Start
// will return that error, wrapper.
PreStart func() error
getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error)
initOnce sync.Once
initErr error
lb *ipnlocal.LocalBackend
sys *tsd.System
netstack *netstack.Impl
netMon *netmon.Monitor
rootPath string // the state directory
@@ -524,7 +518,6 @@ func (s *Server) start() (reterr error) {
}
sys := new(tsd.System)
s.sys = sys
if err := s.startLogger(&closePool, sys.HealthTracker(), tsLogf); err != nil {
return err
}
@@ -553,7 +546,7 @@ func (s *Server) start() (reterr error) {
sys.HealthTracker().SetMetricsRegistry(sys.UserMetricsRegistry())
// TODO(oxtoacart): do we need to support Taildrive on tsnet, and if so, how?
ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper())
ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
if err != nil {
return fmt.Errorf("netstack.Create: %w", err)
}
@@ -621,13 +614,6 @@ func (s *Server) start() (reterr error) {
prefs.ControlURL = s.ControlURL
prefs.RunWebClient = s.RunWebClient
authKey := s.getAuthKey()
if f := s.PreStart; f != nil {
if err := f(); err != nil {
return fmt.Errorf("PreStart: %w", err)
}
}
err = lb.Start(ipn.Options{
UpdatePrefs: prefs,
AuthKey: authKey,
@@ -1241,13 +1227,6 @@ func (s *Server) CapturePcap(ctx context.Context, pcapFile string) error {
return nil
}
// Sys returns a handle to the Tailscale subsystems of this node.
//
// This is not a stable API, nor are the APIs of the returned subsystems.
func (s *Server) Sys() *tsd.System {
return s.sys
}
type listenKey struct {
network string
host netip.Addr // or zero value for unspecified

View File

@@ -1080,6 +1080,13 @@ func TestUserMetrics(t *testing.T) {
t.Errorf("metrics1, tailscaled_health_messages: got %v, want %v", got, want)
}
// The node is the primary subnet router for 2 routes:
// - 192.0.2.0/24
// - 192.0.5.1/32
if got, want := parsedMetrics1["tailscaled_primary_routes"], wantRoutes; got != want {
t.Errorf("metrics1, tailscaled_primary_routes: got %v, want %v", got, want)
}
// Verify that the amount of data recorded in bytes is higher or equal to the
// 10 megabytes sent.
inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`]
@@ -1124,6 +1131,11 @@ func TestUserMetrics(t *testing.T) {
t.Errorf("metrics2, tailscaled_health_messages: got %v, want %v", got, want)
}
// The node is the primary subnet router for 0 routes
if got, want := parsedMetrics2["tailscaled_primary_routes"], 0.0; got != want {
t.Errorf("metrics2, tailscaled_primary_routes: got %v, want %v", got, want)
}
// Verify that the amount of data recorded in bytes is higher or equal than the
// 10 megabytes sent.
outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`]

View File

@@ -10,7 +10,6 @@ import (
"net/netip"
"os"
"slices"
"time"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcapgo"
@@ -280,28 +279,10 @@ type Network struct {
svcs set.Set[NetworkService]
latency time.Duration // latency applied to interface writes
lossRate float64 // chance of packet loss (0.0 to 1.0)
// ...
err error // carried error
}
// SetLatency sets the simulated network latency for this network.
func (n *Network) SetLatency(d time.Duration) {
n.latency = d
}
// SetPacketLoss sets the packet loss rate for this network 0.0 (no loss) to 1.0 (total loss).
func (n *Network) SetPacketLoss(rate float64) {
if rate < 0 {
rate = 0
} else if rate > 1 {
rate = 1
}
n.lossRate = rate
}
// SetBlackholedIPv4 sets whether the network should blackhole all IPv4 traffic
// out to the Internet. (DHCP etc continues to work on the LAN.)
func (n *Network) SetBlackholedIPv4(v bool) {
@@ -380,8 +361,6 @@ func (s *Server) initFromConfig(c *Config) error {
wanIP4: conf.wanIP4,
lanIP4: conf.lanIP4,
breakWAN4: conf.breakWAN4,
latency: conf.latency,
lossRate: conf.lossRate,
nodesByIP4: map[netip.Addr]*node{},
nodesByMAC: map[MAC]*node{},
logf: logger.WithPrefix(s.logf, fmt.Sprintf("[net-%v] ", conf.mac)),

View File

@@ -3,10 +3,7 @@
package vnet
import (
"testing"
"time"
)
import "testing"
func TestConfig(t *testing.T) {
tests := []struct {
@@ -21,16 +18,6 @@ func TestConfig(t *testing.T) {
c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT))
},
},
{
name: "latency-and-loss",
setup: func(c *Config) {
n1 := c.AddNetwork("2.1.1.1", "192.168.1.1/24", EasyNAT, NATPMP)
n1.SetLatency(time.Second)
n1.SetPacketLoss(0.1)
c.AddNode(n1)
c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT))
},
},
{
name: "indirect",
setup: func(c *Config) {

View File

@@ -515,8 +515,6 @@ type network struct {
wanIP4 netip.Addr // router's LAN IPv4, if any
lanIP4 netip.Prefix // router's LAN IP + CIDR (e.g. 192.168.2.1/24)
breakWAN4 bool // break WAN IPv4 connectivity
latency time.Duration // latency applied to interface writes
lossRate float64 // probability of dropping a packet (0.0 to 1.0)
nodesByIP4 map[netip.Addr]*node // by LAN IPv4
nodesByMAC map[MAC]*node
logf func(format string, args ...any)
@@ -979,7 +977,7 @@ func (n *network) writeEth(res []byte) bool {
for mac, nw := range n.writers.All() {
if mac != srcMAC {
num++
n.conditionedWrite(nw, res)
nw.write(res)
}
}
return num > 0
@@ -989,7 +987,7 @@ func (n *network) writeEth(res []byte) bool {
return false
}
if nw, ok := n.writers.Load(dstMAC); ok {
n.conditionedWrite(nw, res)
nw.write(res)
return true
}
@@ -1002,23 +1000,6 @@ func (n *network) writeEth(res []byte) bool {
return false
}
func (n *network) conditionedWrite(nw networkWriter, packet []byte) {
if n.lossRate > 0 && rand.Float64() < n.lossRate {
// packet lost
return
}
if n.latency > 0 {
// copy the packet as there's no guarantee packet is owned long enough.
// TODO(raggi): this could be optimized substantially if necessary,
// a pool of buffers and a cheaper delay mechanism are both obvious improvements.
var pkt = make([]byte, len(packet))
copy(pkt, packet)
time.AfterFunc(n.latency, func() { nw.write(pkt) })
} else {
nw.write(packet)
}
}
var (
macAllNodes = MAC{0: 0x33, 1: 0x33, 5: 0x01}
macAllRouters = MAC{0: 0x33, 1: 0x33, 5: 0x02}

View File

@@ -14,7 +14,6 @@ class Config: Codable {
var mac = "52:cc:cc:cc:cc:01"
var ethermac = "52:cc:cc:cc:ce:01"
var port: UInt32 = 51009
var sharedDir: String?
// The virtual machines ID. Also double as the directory name under which
// we will store configuration, block device, etc.

View File

@@ -141,18 +141,5 @@ struct TailMacConfigHelper {
func createKeyboardConfiguration() -> VZKeyboardConfiguration {
return VZMacKeyboardConfiguration()
}
func createDirectoryShareConfiguration(tag: String) -> VZDirectorySharingDeviceConfiguration? {
guard let dir = config.sharedDir else { return nil }
let sharedDir = VZSharedDirectory(url: URL(fileURLWithPath: dir), readOnly: false)
let share = VZSingleDirectoryShare(directory: sharedDir)
// Create the VZVirtioFileSystemDeviceConfiguration and assign it a unique tag.
let sharingConfiguration = VZVirtioFileSystemDeviceConfiguration(tag: tag)
sharingConfiguration.share = share
return sharingConfiguration
}
}

View File

@@ -19,12 +19,10 @@ var config: Config = Config()
extension HostCli {
struct Run: ParsableCommand {
@Option var id: String
@Option var share: String?
mutating func run() {
print("Running vm with identifier \(id)")
config = Config(id)
config.sharedDir = share
print("Running vm with identifier \(id) and sharedDir \(share ?? "<none>")")
_ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv)
}
}

View File

@@ -95,13 +95,6 @@ class VMController: NSObject, VZVirtualMachineDelegate {
virtualMachineConfiguration.keyboards = [helper.createKeyboardConfiguration()]
virtualMachineConfiguration.socketDevices = [helper.createSocketDeviceConfiguration()]
if let dir = config.sharedDir, let shareConfig = helper.createDirectoryShareConfiguration(tag: "vmshare") {
print("Sharing \(dir) as vmshare. Use: mount_virtiofs vmshare <path> in the guest to mount.")
virtualMachineConfiguration.directorySharingDevices = [shareConfig]
} else {
print("No shared directory created. \(config.sharedDir ?? "none") was requested.")
}
try! virtualMachineConfiguration.validate()
try! virtualMachineConfiguration.validateSaveRestoreSupport()

View File

@@ -95,16 +95,12 @@ extension Tailmac {
extension Tailmac {
struct Run: ParsableCommand {
@Option(help: "The vm identifier") var id: String
@Option(help: "Optional share directory") var share: String?
@Flag(help: "Tail the TailMac log output instead of returning immediatly") var tail
mutating func run() {
let process = Process()
let stdOutPipe = Pipe()
let executablePath = CommandLine.arguments[0]
let executableDirectory = (executablePath as NSString).deletingLastPathComponent
let appPath = executableDirectory + "/Host.app/Contents/MacOS/Host"
let appPath = "./Host.app/Contents/MacOS/Host"
process.executableURL = URL(
fileURLWithPath: appPath,
@@ -113,15 +109,10 @@ extension Tailmac {
)
if !FileManager.default.fileExists(atPath: appPath) {
fatalError("Could not find Host.app at \(appPath). This must be co-located with the tailmac utility")
fatalError("Could not find Host.app. This must be co-located with the tailmac utility")
}
var args = ["run", "--id", id]
if let share {
args.append("--share")
args.append(share)
}
process.arguments = args
process.arguments = ["run", "--id", id]
do {
process.standardOutput = stdOutPipe
@@ -130,18 +121,26 @@ extension Tailmac {
fatalError("Unable to launch the vm process")
}
// This doesn't print until we exit which is not ideal, but at least we
// get the output
if tail != 0 {
// (jonathan)TODO: How do we get the process output in real time?
// The child process only seems to flush to stdout on completion
let outHandle = stdOutPipe.fileHandleForReading
outHandle.readabilityHandler = { handle in
let data = handle.availableData
let queue = OperationQueue()
NotificationCenter.default.addObserver(
forName: NSNotification.Name.NSFileHandleDataAvailable,
object: outHandle, queue: queue)
{
notification -> Void in
let data = outHandle.availableData
if data.count > 0 {
if let str = String(data: data, encoding: String.Encoding.utf8) {
print(str)
}
}
outHandle.waitForDataInBackgroundAndNotify()
}
outHandle.waitForDataInBackgroundAndNotify()
process.waitUntilExit()
}
}

View File

@@ -36,7 +36,7 @@ func ValueOf[T any](v T) Value[T] {
}
// String implements [fmt.Stringer].
func (o Value[T]) String() string {
func (o *Value[T]) String() string {
if !o.set {
return fmt.Sprintf("(empty[%T])", o.value)
}

View File

@@ -0,0 +1,122 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"errors"
"sync"
)
// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
// otherwise the actual error is returned and the next read for that key will retry using the handler.
type CachingHandler struct {
mu sync.Mutex
strings map[string]string
uint64s map[string]uint64
bools map[string]bool
strArrs map[string][]string
notFound map[string]bool
handler Handler
}
// NewCachingHandler creates a CachingHandler given a handler.
func NewCachingHandler(handler Handler) *CachingHandler {
return &CachingHandler{
handler: handler,
strings: make(map[string]string),
uint64s: make(map[string]uint64),
bools: make(map[string]bool),
strArrs: make(map[string][]string),
notFound: make(map[string]bool),
}
}
// ReadString reads the policy settings value string given the key.
// ReadString first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadString(key string) (string, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.strings[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return "", ErrNoSuchKey
}
val, err := ch.handler.ReadString(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return "", err
} else if err != nil {
return "", err
}
ch.strings[key] = val
return val, nil
}
// ReadUInt64 reads the policy settings uint64 value given the key.
// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.uint64s[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return 0, ErrNoSuchKey
}
val, err := ch.handler.ReadUInt64(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return 0, err
} else if err != nil {
return 0, err
}
ch.uint64s[key] = val
return val, nil
}
// ReadBoolean reads the policy settings boolean value given the key.
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.bools[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return false, ErrNoSuchKey
}
val, err := ch.handler.ReadBoolean(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return false, err
} else if err != nil {
return false, err
}
ch.bools[key] = val
return val, nil
}
// ReadBoolean reads the policy settings boolean value given the key.
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.strArrs[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return nil, ErrNoSuchKey
}
val, err := ch.handler.ReadStringArray(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return nil, err
} else if err != nil {
return nil, err
}
ch.strArrs[key] = val
return val, nil
}

View File

@@ -0,0 +1,262 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"testing"
)
func TestHandlerReadString(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue string
handlerError error
preserveHandler bool
wantValue string
wantErr error
strings map[string]string
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
strings: map[string]string{"test": "foo"},
wantValue: "foo",
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: "foo",
wantValue: "foo",
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
s: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.strings != nil {
cache.strings = tt.strings
}
got, err := cache.ReadString(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadString(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}
func TestHandlerReadUint64(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue uint64
handlerError error
preserveHandler bool
wantValue uint64
wantErr error
uint64s map[string]uint64
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
uint64s: map[string]uint64{"test": 1},
wantValue: 1,
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: 1,
wantValue: 1,
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
u64: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.uint64s != nil {
cache.uint64s = tt.uint64s
}
got, err := cache.ReadUInt64(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadUInt64(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}
func TestHandlerReadBool(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue bool
handlerError error
preserveHandler bool
wantValue bool
wantErr error
bools map[string]bool
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
bools: map[string]bool{"test": true},
wantValue: true,
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: true,
wantValue: true,
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
b: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.bools != nil {
cache.bools = tt.bools
}
got, err := cache.ReadBoolean(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadBoolean(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}

View File

@@ -4,17 +4,16 @@
package syspolicy
import (
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
"errors"
"sync/atomic"
)
// TODO(nickkhyl): delete this file once other repos are updated.
var (
handlerUsed atomic.Bool
handler Handler = defaultHandler{}
)
// Handler reads system policies from OS-specific storage.
//
// Deprecated: implementing a [source.Store] should be preferred.
type Handler interface {
// ReadString reads the policy setting's string value for the given key.
// It should return ErrNoSuchKey if the key does not have a value set.
@@ -30,88 +29,55 @@ type Handler interface {
ReadStringArray(key string) ([]string, error)
}
// RegisterHandler wraps and registers the specified handler as the device's
// policy [source.Store] for the program's lifetime.
//
// Deprecated: using [RegisterStore] should be preferred.
// ErrNoSuchKey is returned by a Handler when the specified key does not have a
// value set.
var ErrNoSuchKey = errors.New("no such key")
// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple.
type defaultHandler struct{}
func (defaultHandler) ReadString(_ string) (string, error) {
return "", ErrNoSuchKey
}
func (defaultHandler) ReadUInt64(_ string) (uint64, error) {
return 0, ErrNoSuchKey
}
func (defaultHandler) ReadBoolean(_ string) (bool, error) {
return false, ErrNoSuchKey
}
func (defaultHandler) ReadStringArray(_ string) ([]string, error) {
return nil, ErrNoSuchKey
}
// markHandlerInUse is called before handler methods are called.
func markHandlerInUse() {
handlerUsed.Store(true)
}
// RegisterHandler initializes the policy handler and ensures registration will happen once.
func RegisterHandler(h Handler) {
rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h))
// Technically this assignment is not concurrency safe, but in the
// event that there was any risk of a data race, we will panic due to
// the CompareAndSwap failing.
handler = h
if !handlerUsed.CompareAndSwap(false, true) {
panic("handler was already used before registration")
}
}
// TB is a subset of testing.TB that we use to set up test helpers.
// It's defined here to avoid pulling in the testing package.
type TB = internal.TB
type TB interface {
Helper()
Cleanup(func())
}
// SetHandlerForTest wraps and sets the specified handler as the device's policy
// [source.Store] for the duration of tb.
//
// Deprecated: using [MustRegisterStoreForTest] should be preferred.
func SetHandlerForTest(tb TB, h Handler) {
RegisterWellKnownSettingsForTest(tb)
MustRegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.DefaultScope(), WrapHandler(h))
}
var _ source.Store = (*handlerStore)(nil)
// handlerStore is a [source.Store] that calls the underlying [Handler].
//
// TODO(nickkhyl): remove it when the corp and android repos are updated.
type handlerStore struct {
h Handler
}
// WrapHandler returns a [source.Store] that wraps the specified [Handler].
func WrapHandler(h Handler) source.Store {
return handlerStore{h}
}
// Lock implements [source.Lockable].
func (s handlerStore) Lock() error {
if lockable, ok := s.h.(source.Lockable); ok {
return lockable.Lock()
}
return nil
}
// Unlock implements [source.Lockable].
func (s handlerStore) Unlock() {
if lockable, ok := s.h.(source.Lockable); ok {
lockable.Unlock()
}
}
// RegisterChangeCallback implements [source.Changeable].
func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
if changeable, ok := s.h.(source.Changeable); ok {
return changeable.RegisterChangeCallback(callback)
}
return func() {}, nil
}
// ReadString implements [source.Store].
func (s handlerStore) ReadString(key setting.Key) (string, error) {
return s.h.ReadString(string(key))
}
// ReadUInt64 implements [source.Store].
func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) {
return s.h.ReadUInt64(string(key))
}
// ReadBoolean implements [source.Store].
func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) {
return s.h.ReadBoolean(string(key))
}
// ReadStringArray implements [source.Store].
func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) {
return s.h.ReadStringArray(string(key))
}
// Done implements [source.Expirable].
func (s handlerStore) Done() <-chan struct{} {
if expirable, ok := s.h.(source.Expirable); ok {
return expirable.Done()
}
return nil
tb.Helper()
oldHandler := handler
handler = h
tb.Cleanup(func() { handler = oldHandler })
}

View File

@@ -0,0 +1,19 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import "testing"
func TestDefaultHandlerReadValues(t *testing.T) {
var h defaultHandler
got, err := h.ReadString(string(AdminConsoleVisibility))
if got != "" || err != ErrNoSuchKey {
t.Fatalf("got %v err %v", got, err)
}
result, err := h.ReadUInt64(string(LogSCMInteractions))
if result != 0 || err != ErrNoSuchKey {
t.Fatalf("got %v err %v", result, err)
}
}

View File

@@ -0,0 +1,105 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"errors"
"fmt"
"tailscale.com/util/clientmetric"
"tailscale.com/util/winutil"
)
var (
windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors")
windowsAny = clientmetric.NewGauge("windows_syspolicy_any")
)
type windowsHandler struct{}
func init() {
RegisterHandler(NewCachingHandler(windowsHandler{}))
keyList := []struct {
isSet func(Key) bool
keys []Key
}{
{
isSet: func(k Key) bool {
_, err := handler.ReadString(string(k))
return err == nil
},
keys: stringKeys,
},
{
isSet: func(k Key) bool {
_, err := handler.ReadBoolean(string(k))
return err == nil
},
keys: boolKeys,
},
{
isSet: func(k Key) bool {
_, err := handler.ReadUInt64(string(k))
return err == nil
},
keys: uint64Keys,
},
}
var anySet bool
for _, l := range keyList {
for _, k := range l.keys {
if !l.isSet(k) {
continue
}
clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1)
anySet = true
}
}
if anySet {
windowsAny.Set(1)
}
}
func (windowsHandler) ReadString(key string) (string, error) {
s, err := winutil.GetPolicyString(key)
if errors.Is(err, winutil.ErrNoValue) {
err = ErrNoSuchKey
} else if err != nil {
windowsErrors.Add(1)
}
return s, err
}
func (windowsHandler) ReadUInt64(key string) (uint64, error) {
value, err := winutil.GetPolicyInteger(key)
if errors.Is(err, winutil.ErrNoValue) {
err = ErrNoSuchKey
} else if err != nil {
windowsErrors.Add(1)
}
return value, err
}
func (windowsHandler) ReadBoolean(key string) (bool, error) {
value, err := winutil.GetPolicyInteger(key)
if errors.Is(err, winutil.ErrNoValue) {
err = ErrNoSuchKey
} else if err != nil {
windowsErrors.Add(1)
}
return value != 0, err
}
func (windowsHandler) ReadStringArray(key string) ([]string, error) {
value, err := winutil.GetPolicyStringArray(key)
if errors.Is(err, winutil.ErrNoValue) {
err = ErrNoSuchKey
} else if err != nil {
windowsErrors.Add(1)
}
return value, err
}

View File

@@ -284,7 +284,7 @@ func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) {
}
func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric {
name := strings.ReplaceAll(string(key), string(setting.KeyPathSeparator), "_")
name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_")
return newMetric([]string{name, metricScopeName(scope), suffix}, typ)
}

View File

@@ -3,24 +3,10 @@
package syspolicy
import (
"tailscale.com/types/lazy"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/testenv"
)
import "tailscale.com/util/syspolicy/setting"
// Key is a string that uniquely identifies a policy and must remain unchanged
// once established and documented for a given policy setting. It may contain
// alphanumeric characters and zero or more [KeyPathSeparator]s to group
// individual policy settings into categories.
type Key = setting.Key
// The const block below lists known policy keys.
// When adding a key to this list, remember to add a corresponding
// [setting.Definition] to [implicitDefinitions] below.
// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder.
const (
// Keys with a string value
ControlURL Key = "LoginURL" // default ""; if blank, ipn uses ipn.DefaultControlURL.
@@ -77,9 +63,6 @@ const (
// SuggestedExitNodeVisibility controls the visibility of suggested exit nodes in the client GUI.
// When this system policy is set to 'hide', an exit node suggestion won't be presented to the user as part of the exit nodes picker.
SuggestedExitNodeVisibility Key = "SuggestedExitNode"
// OnboardingFlowVisibility controls the visibility of the onboarding flow in the client GUI.
// When this system policy is set to 'hide', the onboarding flow is never shown to the user.
OnboardingFlowVisibility Key = "OnboardingFlow"
// Keys with a string value formatted for use with time.ParseDuration().
KeyExpirationNoticeTime Key = "KeyExpirationNotice" // default 24 hours
@@ -127,91 +110,3 @@ const (
// AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes.
AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes"
)
// implicitDefinitions is a list of [setting.Definition] that will be registered
// automatically when the policy setting definitions are first used by the syspolicy package hierarchy.
// This includes the first time a policy needs to be read from any source.
var implicitDefinitions = []*setting.Definition{
// Device policy settings (can only be configured on a per-device basis):
setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue),
setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(AuthKey, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue),
setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue),
setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(MachineCertificateSubject, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue),
// User policy settings (can be configured on a user- or device-basis):
setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue),
setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue),
setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue),
setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue),
setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(OnboardingFlowVisibility, setting.UserSetting, setting.VisibilityValue),
}
func init() {
internal.Init.MustDefer(func() error {
// Avoid implicit [setting.Definition] registration during tests.
// Each test should control which policy settings to register.
// Use [setting.SetDefinitionsForTest] to specify necessary definitions,
// or [setWellKnownSettingsForTest] to set implicit definitions for the test duration.
if testenv.InTest() {
return nil
}
for _, d := range implicitDefinitions {
setting.RegisterDefinition(d)
}
return nil
})
}
var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap]
// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key,
// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist
// among implicit policy definitions.
func WellKnownSettingDefinition(k Key) (*setting.Definition, error) {
m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) {
return setting.DefinitionMapOf(implicitDefinitions)
})
if err != nil {
return nil, err
}
if d, ok := m[k]; ok {
return d, nil
}
return nil, ErrNoSuchKey
}
// RegisterWellKnownSettingsForTest registers all implicit setting definitions
// for the duration of the test.
func RegisterWellKnownSettingsForTest(tb TB) {
tb.Helper()
err := setting.SetDefinitionsForTest(tb, implicitDefinitions...)
if err != nil {
tb.Fatalf("Failed to register well-known settings: %v", err)
}
}

View File

@@ -1,95 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"go/types"
"os"
"reflect"
"strconv"
"testing"
"tailscale.com/util/syspolicy/setting"
)
func TestKnownKeysRegistered(t *testing.T) {
keyConsts, err := listStringConsts[Key]("policy_keys.go")
if err != nil {
t.Fatalf("listStringConsts failed: %v", err)
}
m, err := setting.DefinitionMapOf(implicitDefinitions)
if err != nil {
t.Fatalf("definitionMapOf failed: %v", err)
}
for _, key := range keyConsts {
t.Run(string(key), func(t *testing.T) {
d := m[key]
if d == nil {
t.Fatalf("%q was not registered", key)
}
if d.Key() != key {
t.Fatalf("d.Key got: %s, want %s", d.Key(), key)
}
})
}
}
func TestNotAWellKnownSetting(t *testing.T) {
d, err := WellKnownSettingDefinition("TestSettingDoesNotExist")
if d != nil || err == nil {
t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey)
}
}
func listStringConsts[T ~string](filename string) (map[string]T, error) {
fset := token.NewFileSet()
src, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
f, err := parser.ParseFile(fset, filename, src, 0)
if err != nil {
return nil, err
}
consts := make(map[string]T)
typeName := reflect.TypeFor[T]().Name()
for _, d := range f.Decls {
g, ok := d.(*ast.GenDecl)
if !ok || g.Tok != token.CONST {
continue
}
for _, s := range g.Specs {
vs, ok := s.(*ast.ValueSpec)
if !ok || len(vs.Names) != len(vs.Values) {
continue
}
if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName {
continue
}
for i, n := range vs.Names {
lit, ok := vs.Values[i].(*ast.BasicLit)
if !ok {
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i]))
}
val, err := strconv.Unquote(lit.Value)
if err != nil {
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value)
}
consts[n.Name] = T(val)
}
}
}
return consts, nil
}

View File

@@ -0,0 +1,38 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
var stringKeys = []Key{
ControlURL,
LogTarget,
Tailnet,
ExitNodeID,
ExitNodeIP,
EnableIncomingConnections,
EnableServerMode,
ExitNodeAllowLANAccess,
EnableTailscaleDNS,
EnableTailscaleSubnets,
AdminConsoleVisibility,
NetworkDevicesVisibility,
TestMenuVisibility,
UpdateMenuVisibility,
RunExitNodeVisibility,
PreferencesMenuVisibility,
ExitNodeMenuVisibility,
AutoUpdateVisibility,
ResetToDefaultsVisibility,
KeyExpirationNoticeTime,
PostureChecking,
ManagedByOrganizationName,
ManagedByCaption,
ManagedByURL,
}
var boolKeys = []Key{
LogSCMInteractions,
FlushDNSOnSessionUnlock,
}
var uint64Keys = []Key{}

View File

@@ -10,4 +10,4 @@ package setting
type Key string
// KeyPathSeparator allows logical grouping of policy settings into categories.
const KeyPathSeparator = '/'
const KeyPathSeparator = "/"

View File

@@ -5,11 +5,7 @@ package setting
import (
"fmt"
"reflect"
jsonv2 "github.com/go-json-experiment/json"
"github.com/go-json-experiment/json/jsontext"
"tailscale.com/types/opt"
"tailscale.com/types/structs"
)
@@ -21,15 +17,10 @@ import (
// or converted from strings, these setting types predate the typed policy
// hierarchies, and must be supported at this layer.
type RawItem struct {
_ structs.Incomparable
data rawItemJSON
}
// rawItemJSON holds JSON-marshallable data for [RawItem].
type rawItemJSON struct {
Value RawValue `json:",omitzero"`
Error *ErrorText `json:",omitzero"` // or nil
Origin *Origin `json:",omitzero"` // or nil
_ structs.Incomparable
value any
err *ErrorText
origin *Origin // or nil
}
// RawItemOf returns a [RawItem] with the specified value.
@@ -39,20 +30,20 @@ func RawItemOf(value any) RawItem {
// RawItemWith returns a [RawItem] with the specified value, error and origin.
func RawItemWith(value any, err *ErrorText, origin *Origin) RawItem {
return RawItem{data: rawItemJSON{Value: RawValue{opt.ValueOf(value)}, Error: err, Origin: origin}}
return RawItem{value: value, err: err, origin: origin}
}
// Value returns the value of the policy setting, or nil if the policy setting
// is not configured, or an error occurred while reading it.
func (i RawItem) Value() any {
return i.data.Value.Get()
return i.value
}
// Error returns the error that occurred when reading the policy setting,
// or nil if no error occurred.
func (i RawItem) Error() error {
if i.data.Error != nil {
return i.data.Error
if i.err != nil {
return i.err
}
return nil
}
@@ -60,103 +51,17 @@ func (i RawItem) Error() error {
// Origin returns an optional [Origin] indicating where the policy setting is
// configured.
func (i RawItem) Origin() *Origin {
return i.data.Origin
return i.origin
}
// String implements [fmt.Stringer].
func (i RawItem) String() string {
var suffix string
if i.data.Origin != nil {
suffix = fmt.Sprintf(" - {%v}", i.data.Origin)
if i.origin != nil {
suffix = fmt.Sprintf(" - {%v}", i.origin)
}
if i.data.Error != nil {
return fmt.Sprintf("Error{%q}%s", i.data.Error.Error(), suffix)
if i.err != nil {
return fmt.Sprintf("Error{%q}%s", i.err.Error(), suffix)
}
return fmt.Sprintf("%v%s", i.data.Value.Value, suffix)
return fmt.Sprintf("%v%s", i.value, suffix)
}
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
func (i RawItem) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
return jsonv2.MarshalEncode(out, &i.data, opts)
}
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
func (i *RawItem) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
return jsonv2.UnmarshalDecode(in, &i.data, opts)
}
// MarshalJSON implements [json.Marshaler].
func (i RawItem) MarshalJSON() ([]byte, error) {
return jsonv2.Marshal(i) // uses MarshalJSONV2
}
// UnmarshalJSON implements [json.Unmarshaler].
func (i *RawItem) UnmarshalJSON(b []byte) error {
return jsonv2.Unmarshal(b, i) // uses UnmarshalJSONV2
}
// RawValue represents a raw policy setting value read from a policy store.
// It is JSON-marshallable and facilitates unmarshalling of JSON values
// into corresponding policy setting types, with special handling for JSON numbers
// (unmarshalled as float64) and JSON string arrays (unmarshalled as []string).
// See also [RawValue.UnmarshalJSONV2].
type RawValue struct {
opt.Value[any]
}
// RawValueType is a constraint that permits raw setting value types.
type RawValueType interface {
bool | uint64 | string | []string
}
// RawValueOf returns a new [RawValue] holding the specified value.
func RawValueOf[T RawValueType](v T) RawValue {
return RawValue{opt.ValueOf[any](v)}
}
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
func (v RawValue) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
return jsonv2.MarshalEncode(out, v.Value, opts)
}
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2] by attempting to unmarshal
// a JSON value as one of the supported policy setting value types (bool, string, uint64, or []string),
// based on the JSON value type. It fails if the JSON value is an object, if it's a JSON number that
// cannot be represented as a uint64, or if a JSON array contains anything other than strings.
func (v *RawValue) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
var valPtr any
switch k := in.PeekKind(); k {
case 't', 'f':
valPtr = new(bool)
case '"':
valPtr = new(string)
case '0':
valPtr = new(uint64) // unmarshal JSON numbers as uint64
case '[', 'n':
valPtr = new([]string) // unmarshal arrays as string slices
case '{':
return fmt.Errorf("unexpected token: %v", k)
default:
panic("unreachable")
}
if err := jsonv2.UnmarshalDecode(in, valPtr, opts); err != nil {
v.Value.Clear()
return err
}
value := reflect.ValueOf(valPtr).Elem().Interface()
v.Value = opt.ValueOf(value)
return nil
}
// MarshalJSON implements [json.Marshaler].
func (v RawValue) MarshalJSON() ([]byte, error) {
return jsonv2.Marshal(v) // uses MarshalJSONV2
}
// UnmarshalJSON implements [json.Unmarshaler].
func (v *RawValue) UnmarshalJSON(b []byte) error {
return jsonv2.Unmarshal(b, v) // uses UnmarshalJSONV2
}
// RawValues is a map of keyed setting values that can be read from a JSON.
type RawValues map[Key]RawValue

View File

@@ -1,101 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
"math"
"reflect"
"strconv"
"testing"
jsonv2 "github.com/go-json-experiment/json"
)
func TestMarshalUnmarshalRawValue(t *testing.T) {
tests := []struct {
name string
json string
want RawValue
wantErr bool
}{
{
name: "Bool/True",
json: `true`,
want: RawValueOf(true),
},
{
name: "Bool/False",
json: `false`,
want: RawValueOf(false),
},
{
name: "String/Empty",
json: `""`,
want: RawValueOf(""),
},
{
name: "String/NonEmpty",
json: `"Test"`,
want: RawValueOf("Test"),
},
{
name: "StringSlice/Null",
json: `null`,
want: RawValueOf([]string(nil)),
},
{
name: "StringSlice/Empty",
json: `[]`,
want: RawValueOf([]string{}),
},
{
name: "StringSlice/NonEmpty",
json: `["A", "B", "C"]`,
want: RawValueOf([]string{"A", "B", "C"}),
},
{
name: "StringSlice/NonStrings",
json: `[1, 2, 3]`,
wantErr: true,
},
{
name: "Number/Integer/0",
json: `0`,
want: RawValueOf(uint64(0)),
},
{
name: "Number/Integer/1",
json: `1`,
want: RawValueOf(uint64(1)),
},
{
name: "Number/Integer/MaxUInt64",
json: strconv.FormatUint(math.MaxUint64, 10),
want: RawValueOf(uint64(math.MaxUint64)),
},
{
name: "Number/Integer/Negative",
json: `-1`,
wantErr: true,
},
{
name: "Object",
json: `{}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got RawValue
gotErr := jsonv2.Unmarshal([]byte(tt.json), &got)
if (gotErr != nil) != tt.wantErr {
t.Fatalf("Error: got %v; want %v", gotErr, tt.wantErr)
}
if !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
t.Fatalf("Value: got %v; want %v", got, tt.want)
}
})
}
}

View File

@@ -4,14 +4,11 @@
package setting
import (
"errors"
"iter"
"maps"
"slices"
"strings"
jsonv2 "github.com/go-json-experiment/json"
"github.com/go-json-experiment/json/jsontext"
xmaps "golang.org/x/exp/maps"
"tailscale.com/util/deephash"
)
@@ -68,9 +65,6 @@ func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) {
// Equal reports whether s and s2 are equal.
func (s *Snapshot) Equal(s2 *Snapshot) bool {
if s == s2 {
return true
}
if !s.EqualItems(s2) {
return false
}
@@ -141,45 +135,6 @@ func (s *Snapshot) String() string {
return sb.String()
}
// snapshotJSON holds JSON-marshallable data for [Snapshot].
type snapshotJSON struct {
Summary Summary `json:",omitzero"`
Settings map[Key]RawItem `json:",omitempty"`
}
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
func (s *Snapshot) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
data := &snapshotJSON{}
if s != nil {
data.Summary = s.summary
data.Settings = s.m
}
return jsonv2.MarshalEncode(out, data, opts)
}
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
func (s *Snapshot) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
if s == nil {
return errors.New("s must not be nil")
}
data := &snapshotJSON{}
if err := jsonv2.UnmarshalDecode(in, data, opts); err != nil {
return err
}
*s = Snapshot{m: data.Settings, sig: deephash.Hash(&data.Settings), summary: data.Summary}
return nil
}
// MarshalJSON implements [json.Marshaler].
func (s *Snapshot) MarshalJSON() ([]byte, error) {
return jsonv2.Marshal(s) // uses MarshalJSONV2
}
// UnmarshalJSON implements [json.Unmarshaler].
func (s *Snapshot) UnmarshalJSON(b []byte) error {
return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
}
// MergeSnapshots returns a [Snapshot] that contains all [RawItem]s
// from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope].
// If there's a conflict between policy settings in the two snapshots,

View File

@@ -4,13 +4,8 @@
package setting
import (
"cmp"
"encoding/json"
"testing"
"time"
jsonv2 "github.com/go-json-experiment/json"
"tailscale.com/util/syspolicy/internal"
)
func TestMergeSnapshots(t *testing.T) {
@@ -35,134 +30,134 @@ func TestMergeSnapshots(t *testing.T) {
name: "first-nil",
s1: nil,
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
},
{
name: "first-empty",
s1: NewSnapshot(map[Key]RawItem{}),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
},
{
name: "second-nil",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
s2: nil,
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
},
{
name: "second-empty",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
s2: NewSnapshot(map[Key]RawItem{}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
},
{
name: "no-conflicts",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
s2: NewSnapshot(map[Key]RawItem{
"Setting4": RawItemOf(2 * time.Hour),
"Setting5": RawItemOf(VisibleByPolicy),
"Setting6": RawItemOf(ShowChoiceByPolicy),
"Setting4": {value: 2 * time.Hour},
"Setting5": {value: VisibleByPolicy},
"Setting6": {value: ShowChoiceByPolicy},
}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting4": RawItemOf(2 * time.Hour),
"Setting5": RawItemOf(VisibleByPolicy),
"Setting6": RawItemOf(ShowChoiceByPolicy),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
"Setting5": {value: VisibleByPolicy},
"Setting6": {value: ShowChoiceByPolicy},
}),
},
{
name: "with-conflicts",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(456),
"Setting3": RawItemOf(false),
"Setting4": RawItemOf(2 * time.Hour),
"Setting1": {value: 456},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(456),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting4": RawItemOf(2 * time.Hour),
"Setting1": {value: 456},
"Setting2": {value: "String"},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}),
},
{
name: "with-scope-first-wins",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, DeviceScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(456),
"Setting3": RawItemOf(false),
"Setting4": RawItemOf(2 * time.Hour),
"Setting1": {value: 456},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}, CurrentUserScope),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting4": RawItemOf(2 * time.Hour),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
"Setting4": {value: 2 * time.Hour},
}, CurrentUserScope),
},
{
name: "with-scope-second-wins",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, CurrentUserScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(456),
"Setting3": RawItemOf(false),
"Setting4": RawItemOf(2 * time.Hour),
"Setting1": {value: 456},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}, DeviceScope),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(456),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting4": RawItemOf(2 * time.Hour),
"Setting1": {value: 456},
"Setting2": {value: "String"},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}, CurrentUserScope),
},
{
@@ -175,27 +170,28 @@ func TestMergeSnapshots(t *testing.T) {
name: "with-scope-first-empty",
s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true)}, DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true}},
DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, CurrentUserScope, NewNamedOrigin("TestPolicy", DeviceScope)),
},
{
name: "with-scope-second-empty",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, CurrentUserScope),
s2: NewSnapshot(map[Key]RawItem{}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, CurrentUserScope),
},
}
@@ -248,9 +244,9 @@ func TestSnapshotEqual(t *testing.T) {
name: "first-nil",
s1: nil,
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
wantEqual: false,
wantEqualItems: false,
@@ -259,9 +255,9 @@ func TestSnapshotEqual(t *testing.T) {
name: "first-empty",
s1: NewSnapshot(map[Key]RawItem{}),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
wantEqual: false,
wantEqualItems: false,
@@ -269,9 +265,9 @@ func TestSnapshotEqual(t *testing.T) {
{
name: "second-nil",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(true),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
s2: nil,
wantEqual: false,
@@ -280,9 +276,9 @@ func TestSnapshotEqual(t *testing.T) {
{
name: "second-empty",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
s2: NewSnapshot(map[Key]RawItem{}),
wantEqual: false,
@@ -291,14 +287,14 @@ func TestSnapshotEqual(t *testing.T) {
{
name: "same-items-same-order-no-scope",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
wantEqual: true,
wantEqualItems: true,
@@ -306,14 +302,14 @@ func TestSnapshotEqual(t *testing.T) {
{
name: "same-items-same-order-same-scope",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
wantEqual: true,
wantEqualItems: true,
@@ -321,14 +317,14 @@ func TestSnapshotEqual(t *testing.T) {
{
name: "same-items-different-order-same-scope",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting3": RawItemOf(false),
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": {value: false},
"Setting1": {value: 123},
"Setting2": {value: "String"},
}, DeviceScope),
wantEqual: true,
wantEqualItems: true,
@@ -336,14 +332,14 @@ func TestSnapshotEqual(t *testing.T) {
{
name: "same-items-same-order-different-scope",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, CurrentUserScope),
wantEqual: false,
wantEqualItems: true,
@@ -351,14 +347,14 @@ func TestSnapshotEqual(t *testing.T) {
{
name: "different-items-same-scope",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(123),
"Setting2": RawItemOf("String"),
"Setting3": RawItemOf(false),
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting4": RawItemOf(2 * time.Hour),
"Setting5": RawItemOf(VisibleByPolicy),
"Setting6": RawItemOf(ShowChoiceByPolicy),
"Setting4": {value: 2 * time.Hour},
"Setting5": {value: VisibleByPolicy},
"Setting6": {value: ShowChoiceByPolicy},
}, DeviceScope),
wantEqual: false,
wantEqualItems: false,
@@ -405,9 +401,9 @@ func TestSnapshotString(t *testing.T) {
{
name: "non-empty",
snapshot: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemOf(2 * time.Hour),
"Setting2": RawItemOf(VisibleByPolicy),
"Setting3": RawItemOf(ShowChoiceByPolicy),
"Setting1": {value: 2 * time.Hour},
"Setting2": {value: VisibleByPolicy},
"Setting3": {value: ShowChoiceByPolicy},
}, NewNamedOrigin("Test Policy", DeviceScope)),
wantString: `{Test Policy (Device)}
Setting1 = 2h0m0s
@@ -417,14 +413,14 @@ Setting3 = user-decides`,
{
name: "non-empty-with-item-origin",
snapshot: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemWith(42, nil, NewNamedOrigin("Test Policy", DeviceScope)),
"Setting1": {value: 42, origin: NewNamedOrigin("Test Policy", DeviceScope)},
}),
wantString: `Setting1 = 42 - {Test Policy (Device)}`,
},
{
name: "non-empty-with-item-error",
snapshot: NewSnapshot(map[Key]RawItem{
"Setting1": RawItemWith(nil, NewErrorText("bang!"), nil),
"Setting1": {err: NewErrorText("bang!")},
}),
wantString: `Setting1 = Error{"bang!"}`,
},
@@ -437,133 +433,3 @@ Setting3 = user-decides`,
})
}
}
func TestMarshalUnmarshalSnapshot(t *testing.T) {
tests := []struct {
name string
snapshot *Snapshot
wantJSON string
wantBack *Snapshot
}{
{
name: "Nil",
snapshot: (*Snapshot)(nil),
wantJSON: "null",
wantBack: NewSnapshot(nil),
},
{
name: "Zero",
snapshot: &Snapshot{},
wantJSON: "{}",
},
{
name: "Bool/True",
snapshot: NewSnapshot(map[Key]RawItem{"BoolPolicy": RawItemOf(true)}),
wantJSON: `{"Settings": {"BoolPolicy": {"Value": true}}}`,
},
{
name: "Bool/False",
snapshot: NewSnapshot(map[Key]RawItem{"BoolPolicy": RawItemOf(false)}),
wantJSON: `{"Settings": {"BoolPolicy": {"Value": false}}}`,
},
{
name: "String/Non-Empty",
snapshot: NewSnapshot(map[Key]RawItem{"StringPolicy": RawItemOf("StringValue")}),
wantJSON: `{"Settings": {"StringPolicy": {"Value": "StringValue"}}}`,
},
{
name: "String/Empty",
snapshot: NewSnapshot(map[Key]RawItem{"StringPolicy": RawItemOf("")}),
wantJSON: `{"Settings": {"StringPolicy": {"Value": ""}}}`,
},
{
name: "Integer/NonZero",
snapshot: NewSnapshot(map[Key]RawItem{"IntPolicy": RawItemOf(uint64(42))}),
wantJSON: `{"Settings": {"IntPolicy": {"Value": 42}}}`,
},
{
name: "Integer/Zero",
snapshot: NewSnapshot(map[Key]RawItem{"IntPolicy": RawItemOf(uint64(0))}),
wantJSON: `{"Settings": {"IntPolicy": {"Value": 0}}}`,
},
{
name: "String-List",
snapshot: NewSnapshot(map[Key]RawItem{"ListPolicy": RawItemOf([]string{"Value1", "Value2"})}),
wantJSON: `{"Settings": {"ListPolicy": {"Value": ["Value1", "Value2"]}}}`,
},
{
name: "Empty/With-Summary",
snapshot: NewSnapshot(
map[Key]RawItem{},
SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)),
),
wantJSON: `{"Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"}}`,
},
{
name: "Setting/With-Summary",
snapshot: NewSnapshot(
map[Key]RawItem{"PolicySetting": RawItemOf(uint64(42))},
SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)),
),
wantJSON: `{
"Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"},
"Settings": {"PolicySetting": {"Value": 42}}
}`,
},
{
name: "Settings/With-Origins",
snapshot: NewSnapshot(
map[Key]RawItem{
"SettingA": RawItemWith(uint64(42), nil, NewNamedOrigin("SourceA", DeviceScope)),
"SettingB": RawItemWith("B", nil, NewNamedOrigin("SourceB", CurrentProfileScope)),
"SettingC": RawItemWith(true, nil, NewNamedOrigin("SourceC", CurrentUserScope)),
},
),
wantJSON: `{
"Settings": {
"SettingA": {"Value": 42, "Origin": {"Name": "SourceA", "Scope": "Device"}},
"SettingB": {"Value": "B", "Origin": {"Name": "SourceB", "Scope": "Profile"}},
"SettingC": {"Value": true, "Origin": {"Name": "SourceC", "Scope": "User"}}
}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
doTest := func(t *testing.T, useJSONv2 bool) {
var gotJSON []byte
var err error
if useJSONv2 {
gotJSON, err = jsonv2.Marshal(tt.snapshot)
} else {
gotJSON, err = json.Marshal(tt.snapshot)
}
if err != nil {
t.Fatal(err)
}
if got, want, equal := internal.EqualJSONForTest(t, gotJSON, []byte(tt.wantJSON)); !equal {
t.Errorf("JSON: got %s; want %s", got, want)
}
gotBack := &Snapshot{}
if useJSONv2 {
err = jsonv2.Unmarshal(gotJSON, &gotBack)
} else {
err = json.Unmarshal(gotJSON, &gotBack)
}
if err != nil {
t.Fatal(err)
}
if wantBack := cmp.Or(tt.wantBack, tt.snapshot); !gotBack.Equal(wantBack) {
t.Errorf("Snapshot: got %+v; want %+v", gotBack, wantBack)
}
}
t.Run("json", func(t *testing.T) { doTest(t, false) })
t.Run("jsonv2", func(t *testing.T) { doTest(t, true) })
})
}
}

View File

@@ -1,159 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"errors"
"fmt"
"os"
"strconv"
"strings"
"unicode/utf8"
"tailscale.com/util/syspolicy/setting"
)
var lookupEnv = os.LookupEnv // test hook
var _ Store = (*EnvPolicyStore)(nil)
// EnvPolicyStore is a [Store] that reads policy settings from environment variables.
type EnvPolicyStore struct{}
// ReadString implements [Store].
func (s *EnvPolicyStore) ReadString(key setting.Key) (string, error) {
_, str, err := s.lookupSettingVariable(key)
if err != nil {
return "", err
}
return str, nil
}
// ReadUInt64 implements [Store].
func (s *EnvPolicyStore) ReadUInt64(key setting.Key) (uint64, error) {
name, str, err := s.lookupSettingVariable(key)
if err != nil {
return 0, err
}
if str == "" {
return 0, setting.ErrNotConfigured
}
value, err := strconv.ParseUint(str, 0, 64)
if err != nil {
return 0, fmt.Errorf("%s: %w: %q is not a valid uint64", name, setting.ErrTypeMismatch, str)
}
return value, nil
}
// ReadBoolean implements [Store].
func (s *EnvPolicyStore) ReadBoolean(key setting.Key) (bool, error) {
name, str, err := s.lookupSettingVariable(key)
if err != nil {
return false, err
}
if str == "" {
return false, setting.ErrNotConfigured
}
value, err := strconv.ParseBool(str)
if err != nil {
return false, fmt.Errorf("%s: %w: %q is not a valid bool", name, setting.ErrTypeMismatch, str)
}
return value, nil
}
// ReadStringArray implements [Store].
func (s *EnvPolicyStore) ReadStringArray(key setting.Key) ([]string, error) {
_, str, err := s.lookupSettingVariable(key)
if err != nil || str == "" {
return nil, err
}
var dst int
res := strings.Split(str, ",")
for src := range res {
res[dst] = strings.TrimSpace(res[src])
if res[dst] != "" {
dst++
}
}
return res[0:dst], nil
}
func (s *EnvPolicyStore) lookupSettingVariable(key setting.Key) (name, value string, err error) {
name, err = keyToEnvVarName(key)
if err != nil {
return "", "", err
}
value, ok := lookupEnv(name)
if !ok {
return name, "", setting.ErrNotConfigured
}
return name, value, nil
}
var (
errEmptyKey = errors.New("key must not be empty")
errInvalidKey = errors.New("key must consist of alphanumeric characters and slashes")
)
// keyToEnvVarName returns the environment variable name for a given policy
// setting key, or an error if the key is invalid. It converts CamelCase keys into
// underscore-separated words and prepends the variable name with the TS prefix.
// For example: AuthKey => TS_AUTH_KEY, ExitNodeAllowLANAccess => TS_EXIT_NODE_ALLOW_LAN_ACCESS, etc.
//
// It's fine to use this in [EnvPolicyStore] without caching variable names since it's not a hot path.
// [EnvPolicyStore] is not a [Changeable] policy store, so the conversion will only happen once.
func keyToEnvVarName(key setting.Key) (string, error) {
if len(key) == 0 {
return "", errEmptyKey
}
isLower := func(c byte) bool { return 'a' <= c && c <= 'z' }
isUpper := func(c byte) bool { return 'A' <= c && c <= 'Z' }
isLetter := func(c byte) bool { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') }
isDigit := func(c byte) bool { return '0' <= c && c <= '9' }
words := make([]string, 0, 8)
words = append(words, "TS_DEBUGSYSPOLICY")
var currentWord strings.Builder
for i := 0; i < len(key); i++ {
c := key[i]
if c >= utf8.RuneSelf {
return "", errInvalidKey
}
var split bool
switch {
case isLower(c):
c -= 'a' - 'A' // make upper
split = currentWord.Len() > 0 && !isLetter(key[i-1])
case isUpper(c):
if currentWord.Len() > 0 {
prevUpper := isUpper(key[i-1])
nextLower := i < len(key)-1 && isLower(key[i+1])
split = !prevUpper || nextLower // split on case transition
}
case isDigit(c):
split = currentWord.Len() > 0 && !isDigit(key[i-1])
case c == setting.KeyPathSeparator:
words = append(words, currentWord.String())
currentWord.Reset()
continue
default:
return "", errInvalidKey
}
if split {
words = append(words, currentWord.String())
currentWord.Reset()
}
currentWord.WriteByte(c)
}
if currentWord.Len() > 0 {
words = append(words, currentWord.String())
}
return strings.Join(words, "_"), nil
}

View File

@@ -1,359 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"cmp"
"errors"
"math"
"reflect"
"strconv"
"testing"
"tailscale.com/util/syspolicy/setting"
)
func TestKeyToEnvVarName(t *testing.T) {
tests := []struct {
name string
key setting.Key
want string // suffix after "TS_DEBUGSYSPOLICY_"
wantErr error
}{
{
name: "empty",
key: "",
wantErr: errEmptyKey,
},
{
name: "lowercase",
key: "tailnet",
want: "TAILNET",
},
{
name: "CamelCase",
key: "AuthKey",
want: "AUTH_KEY",
},
{
name: "LongerCamelCase",
key: "ManagedByOrganizationName",
want: "MANAGED_BY_ORGANIZATION_NAME",
},
{
name: "UPPERCASE",
key: "UPPERCASE",
want: "UPPERCASE",
},
{
name: "WithAbbrev/Front",
key: "DNSServer",
want: "DNS_SERVER",
},
{
name: "WithAbbrev/Middle",
key: "ExitNodeAllowLANAccess",
want: "EXIT_NODE_ALLOW_LAN_ACCESS",
},
{
name: "WithAbbrev/Back",
key: "ExitNodeID",
want: "EXIT_NODE_ID",
},
{
name: "WithDigits/Single/Front",
key: "0TestKey",
want: "0_TEST_KEY",
},
{
name: "WithDigits/Multi/Front",
key: "64TestKey",
want: "64_TEST_KEY",
},
{
name: "WithDigits/Single/Middle",
key: "Test0Key",
want: "TEST_0_KEY",
},
{
name: "WithDigits/Multi/Middle",
key: "Test64Key",
want: "TEST_64_KEY",
},
{
name: "WithDigits/Single/Back",
key: "TestKey0",
want: "TEST_KEY_0",
},
{
name: "WithDigits/Multi/Back",
key: "TestKey64",
want: "TEST_KEY_64",
},
{
name: "WithDigits/Multi/Back",
key: "TestKey64",
want: "TEST_KEY_64",
},
{
name: "WithPathSeparators/Single",
key: "Key/Subkey",
want: "KEY_SUBKEY",
},
{
name: "WithPathSeparators/Multi",
key: "Root/Level1/Level2",
want: "ROOT_LEVEL_1_LEVEL_2",
},
{
name: "Mixed",
key: "Network/DNSServer/IPAddress",
want: "NETWORK_DNS_SERVER_IP_ADDRESS",
},
{
name: "Non-Alphanumeric/NonASCII/1",
key: "ж",
wantErr: errInvalidKey,
},
{
name: "Non-Alphanumeric/NonASCII/2",
key: "KeyжName",
wantErr: errInvalidKey,
},
{
name: "Non-Alphanumeric/Space",
key: "Key Name",
wantErr: errInvalidKey,
},
{
name: "Non-Alphanumeric/Punct",
key: "Key!Name",
wantErr: errInvalidKey,
},
{
name: "Non-Alphanumeric/Backslash",
key: `Key\Name`,
wantErr: errInvalidKey,
},
}
for _, tt := range tests {
t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) {
got, err := keyToEnvVarName(tt.key)
checkError(t, err, tt.wantErr, true)
want := tt.want
if want != "" {
want = "TS_DEBUGSYSPOLICY_" + want
}
if got != want {
t.Fatalf("got %q; want %q", got, want)
}
})
}
}
func TestEnvPolicyStore(t *testing.T) {
blankEnv := func(string) (string, bool) { return "", false }
makeEnv := func(wantName, value string) func(string) (string, bool) {
wantName = "TS_DEBUGSYSPOLICY_" + wantName
return func(gotName string) (string, bool) {
if gotName != wantName {
return "", false
}
return value, true
}
}
tests := []struct {
name string
key setting.Key
lookup func(string) (string, bool)
want any
wantErr error
}{
{
name: "NotConfigured/String",
key: "AuthKey",
lookup: blankEnv,
wantErr: setting.ErrNotConfigured,
want: "",
},
{
name: "Configured/String/Empty",
key: "AuthKey",
lookup: makeEnv("AUTH_KEY", ""),
want: "",
},
{
name: "Configured/String/NonEmpty",
key: "AuthKey",
lookup: makeEnv("AUTH_KEY", "ABC123"),
want: "ABC123",
},
{
name: "NotConfigured/UInt64",
key: "IntegerSetting",
lookup: blankEnv,
wantErr: setting.ErrNotConfigured,
want: uint64(0),
},
{
name: "Configured/UInt64/Empty",
key: "IntegerSetting",
lookup: makeEnv("INTEGER_SETTING", ""),
wantErr: setting.ErrNotConfigured,
want: uint64(0),
},
{
name: "Configured/UInt64/Zero",
key: "IntegerSetting",
lookup: makeEnv("INTEGER_SETTING", "0"),
want: uint64(0),
},
{
name: "Configured/UInt64/NonZero",
key: "IntegerSetting",
lookup: makeEnv("INTEGER_SETTING", "12345"),
want: uint64(12345),
},
{
name: "Configured/UInt64/MaxUInt64",
key: "IntegerSetting",
lookup: makeEnv("INTEGER_SETTING", strconv.FormatUint(math.MaxUint64, 10)),
want: uint64(math.MaxUint64),
},
{
name: "Configured/UInt64/Negative",
key: "IntegerSetting",
lookup: makeEnv("INTEGER_SETTING", "-1"),
wantErr: setting.ErrTypeMismatch,
want: uint64(0),
},
{
name: "Configured/UInt64/Hex",
key: "IntegerSetting",
lookup: makeEnv("INTEGER_SETTING", "0xDEADBEEF"),
want: uint64(0xDEADBEEF),
},
{
name: "NotConfigured/Bool",
key: "LogSCMInteractions",
lookup: blankEnv,
wantErr: setting.ErrNotConfigured,
want: false,
},
{
name: "Configured/Bool/Empty",
key: "LogSCMInteractions",
lookup: makeEnv("LOG_SCM_INTERACTIONS", ""),
wantErr: setting.ErrNotConfigured,
want: false,
},
{
name: "Configured/Bool/True",
key: "LogSCMInteractions",
lookup: makeEnv("LOG_SCM_INTERACTIONS", "true"),
want: true,
},
{
name: "Configured/Bool/False",
key: "LogSCMInteractions",
lookup: makeEnv("LOG_SCM_INTERACTIONS", "False"),
want: false,
},
{
name: "Configured/Bool/1",
key: "LogSCMInteractions",
lookup: makeEnv("LOG_SCM_INTERACTIONS", "1"),
want: true,
},
{
name: "Configured/Bool/0",
key: "LogSCMInteractions",
lookup: makeEnv("LOG_SCM_INTERACTIONS", "0"),
want: false,
},
{
name: "Configured/Bool/Invalid",
key: "IntegerSetting",
lookup: makeEnv("INTEGER_SETTING", "NotABool"),
wantErr: setting.ErrTypeMismatch,
want: false,
},
{
name: "NotConfigured/StringArray",
key: "AllowedSuggestedExitNodes",
lookup: blankEnv,
wantErr: setting.ErrNotConfigured,
want: []string(nil),
},
{
name: "Configured/StringArray/Empty",
key: "AllowedSuggestedExitNodes",
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", ""),
want: []string(nil),
},
{
name: "Configured/StringArray/Spaces",
key: "AllowedSuggestedExitNodes",
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", " \t "),
want: []string{},
},
{
name: "Configured/StringArray/Single",
key: "AllowedSuggestedExitNodes",
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA"),
want: []string{"NodeA"},
},
{
name: "Configured/StringArray/Multi",
key: "AllowedSuggestedExitNodes",
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,NodeB,NodeC"),
want: []string{"NodeA", "NodeB", "NodeC"},
},
{
name: "Configured/StringArray/WithBlank",
key: "AllowedSuggestedExitNodes",
lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,\t,, ,NodeB"),
want: []string{"NodeA", "NodeB"},
},
}
for _, tt := range tests {
t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) {
oldLookupEnv := lookupEnv
t.Cleanup(func() { lookupEnv = oldLookupEnv })
lookupEnv = tt.lookup
var got any
var err error
var store EnvPolicyStore
switch tt.want.(type) {
case string:
got, err = store.ReadString(tt.key)
case uint64:
got, err = store.ReadUInt64(tt.key)
case bool:
got, err = store.ReadBoolean(tt.key)
case []string:
got, err = store.ReadStringArray(tt.key)
}
checkError(t, err, tt.wantErr, false)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %v; want %v", got, tt.want)
}
})
}
}
func checkError(tb testing.TB, got, want error, fatal bool) {
tb.Helper()
f := tb.Errorf
if fatal {
f = tb.Fatalf
}
if (want == nil && got != nil) ||
(want != nil && got == nil) ||
(want != nil && got != nil && !errors.Is(got, want) && want.Error() != got.Error()) {
f("gotErr: %v; wantErr: %v", got, want)
}
}

View File

@@ -319,9 +319,9 @@ func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error
// If there are no [setting.KeyPathSeparator]s in the key, the policy setting value
// is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale.
func splitSettingKey(key setting.Key) (path, valueName string) {
if idx := strings.LastIndexByte(string(key), setting.KeyPathSeparator); idx != -1 {
path = strings.ReplaceAll(string(key[:idx]), string(setting.KeyPathSeparator), `\`)
valueName = string(key[idx+1:])
if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 {
path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`)
valueName = string(key[idx+len(setting.KeyPathSeparator):])
return path, valueName
}
return "", string(key)

View File

@@ -1,82 +1,51 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package syspolicy facilitates retrieval of the current policy settings
// applied to the device or user and receiving notifications when the policy
// changes.
//
// It provides functions that return specific policy settings by their unique
// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString],
// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration].
// Package syspolicy provides functions to retrieve system settings of a device.
package syspolicy
import (
"errors"
"fmt"
"reflect"
"time"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
var (
// ErrNotConfigured is returned when the requested policy setting is not configured.
ErrNotConfigured = setting.ErrNotConfigured
// ErrTypeMismatch is returned when there's a type mismatch between the actual type
// of the setting value and the expected type.
ErrTypeMismatch = setting.ErrTypeMismatch
// ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting
// has been registered with the specified key.
//
// This error is also returned by a (now deprecated) [Handler] when the specified
// key does not have a value set. While the package maintains compatibility with this
// usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
// [source.Store] implementations.
ErrNoSuchKey = setting.ErrNoSuchKey
)
// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
//
// It is a shorthand for [rsop.RegisterStore].
func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*rsop.StoreRegistration, error) {
return rsop.RegisterStore(name, scope, store)
}
// MustRegisterStoreForTest is like [rsop.RegisterStoreForTest], but it fails the test if the store could not be registered.
func MustRegisterStoreForTest(tb TB, name string, scope setting.PolicyScope, store source.Store) *rsop.StoreRegistration {
tb.Helper()
reg, err := rsop.RegisterStoreForTest(tb, name, scope, store)
if err != nil {
tb.Fatalf("Failed to register policy store %q as a %v policy source: %v", name, scope, err)
}
return reg
}
// GetString returns a string policy setting with the specified key,
// or defaultValue if it does not exist.
func GetString(key Key, defaultValue string) (string, error) {
return getCurrentPolicySettingValue(key, defaultValue)
markHandlerInUse()
v, err := handler.ReadString(string(key))
if errors.Is(err, ErrNoSuchKey) {
return defaultValue, nil
}
return v, err
}
// GetUint64 returns a numeric policy setting with the specified key,
// or defaultValue if it does not exist.
func GetUint64(key Key, defaultValue uint64) (uint64, error) {
return getCurrentPolicySettingValue(key, defaultValue)
markHandlerInUse()
v, err := handler.ReadUInt64(string(key))
if errors.Is(err, ErrNoSuchKey) {
return defaultValue, nil
}
return v, err
}
// GetBoolean returns a boolean policy setting with the specified key,
// or defaultValue if it does not exist.
func GetBoolean(key Key, defaultValue bool) (bool, error) {
return getCurrentPolicySettingValue(key, defaultValue)
markHandlerInUse()
v, err := handler.ReadBoolean(string(key))
if errors.Is(err, ErrNoSuchKey) {
return defaultValue, nil
}
return v, err
}
// GetStringArray returns a multi-string policy setting with the specified key,
// or defaultValue if it does not exist.
func GetStringArray(key Key, defaultValue []string) ([]string, error) {
return getCurrentPolicySettingValue(key, defaultValue)
markHandlerInUse()
v, err := handler.ReadStringArray(string(key))
if errors.Is(err, ErrNoSuchKey) {
return defaultValue, nil
}
return v, err
}
// GetPreferenceOption loads a policy from the registry that can be
@@ -86,7 +55,13 @@ func GetStringArray(key Key, defaultValue []string) ([]string, error) {
// "always" and "never" remove the user's ability to make a selection. If not
// present or set to a different value, "user-decides" is the default.
func GetPreferenceOption(name Key) (setting.PreferenceOption, error) {
return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy)
s, err := GetString(name, "user-decides")
if err != nil {
return setting.ShowChoiceByPolicy, err
}
var opt setting.PreferenceOption
err = opt.UnmarshalText([]byte(s))
return opt, err
}
// GetVisibility loads a policy from the registry that can be managed
@@ -95,7 +70,13 @@ func GetPreferenceOption(name Key) (setting.PreferenceOption, error) {
// true) or "hide" (return true). If not present or set to a different value,
// "show" (return false) is the default.
func GetVisibility(name Key) (setting.Visibility, error) {
return getCurrentPolicySettingValue(name, setting.VisibleByPolicy)
s, err := GetString(name, "show")
if err != nil {
return setting.VisibleByPolicy, err
}
var visibility setting.Visibility
visibility.UnmarshalText([]byte(s))
return visibility, nil
}
// GetDuration loads a policy from the registry that can be managed
@@ -104,58 +85,15 @@ func GetVisibility(name Key) (setting.Visibility, error) {
// understands. If the registry value is "" or can not be processed,
// defaultValue is returned instead.
func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) {
d, err := getCurrentPolicySettingValue(name, defaultValue)
if err != nil {
return d, err
opt, err := GetString(name, "")
if opt == "" || err != nil {
return defaultValue, err
}
if d < 0 {
v, err := time.ParseDuration(opt)
if err != nil || v < 0 {
return defaultValue, nil
}
return d, nil
}
// RegisterChangeCallback adds a function that will be called whenever the effective policy
// for the default scope changes. The returned function can be used to unregister the callback.
func RegisterChangeCallback(cb rsop.PolicyChangeCallback) (unregister func(), err error) {
effective, err := rsop.PolicyFor(setting.DefaultScope())
if err != nil {
return nil, err
}
return effective.RegisterChangeCallback(cb), nil
}
// getCurrentPolicySettingValue returns the value of the policy setting
// specified by its key from the [rsop.Policy] of the [setting.DefaultScope]. It
// returns def if the policy setting is not configured, or an error if it has
// an error or could not be converted to the specified type T.
func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) {
effective, err := rsop.PolicyFor(setting.DefaultScope())
if err != nil {
return def, err
}
value, err := effective.Get().GetErr(key)
if err != nil {
if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) {
return def, nil
}
return def, err
}
if res, ok := value.(T); ok {
return res, nil
}
return convertPolicySettingValueTo(value, def)
}
func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) {
// Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string
// if someone requests a string instead of the actual setting's value.
// TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests.
if reflect.TypeFor[T]().Kind() == reflect.String {
if str, ok := value.(fmt.Stringer); ok {
return any(str.String()).(T), nil
}
}
return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def)
return v, nil
}
// SelectControlURL returns the ControlURL to use based on a value in

View File

@@ -9,15 +9,57 @@ import (
"testing"
"time"
"tailscale.com/types/logger"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/internal/metrics"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
// testHandler encompasses all data types returned when testing any of the syspolicy
// methods that involve getting a policy value.
// For keys and the corresponding values, check policy_keys.go.
type testHandler struct {
t *testing.T
key Key
s string
u64 uint64
b bool
sArr []string
err error
calls int // used for testing reads from cache vs. handler
}
var someOtherError = errors.New("error other than not found")
func (th *testHandler) ReadString(key string) (string, error) {
if key != string(th.key) {
th.t.Errorf("ReadString(%q) want %q", key, th.key)
}
th.calls++
return th.s, th.err
}
func (th *testHandler) ReadUInt64(key string) (uint64, error) {
if key != string(th.key) {
th.t.Errorf("ReadUint64(%q) want %q", key, th.key)
}
th.calls++
return th.u64, th.err
}
func (th *testHandler) ReadBoolean(key string) (bool, error) {
if key != string(th.key) {
th.t.Errorf("ReadBool(%q) want %q", key, th.key)
}
th.calls++
return th.b, th.err
}
func (th *testHandler) ReadStringArray(key string) ([]string, error) {
if key != string(th.key) {
th.t.Errorf("ReadStringArray(%q) want %q", key, th.key)
}
th.calls++
return th.sArr, th.err
}
func TestGetString(t *testing.T) {
tests := []struct {
name string
@@ -27,28 +69,23 @@ func TestGetString(t *testing.T) {
defaultValue string
wantValue string
wantError error
wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: AdminConsoleVisibility,
handlerValue: "hide",
wantValue: "hide",
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AdminConsole", Value: 1},
},
},
{
name: "read non-existing value",
key: EnableServerMode,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantError: nil,
},
{
name: "read non-existing value, non-blank default",
key: EnableServerMode,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
defaultValue: "test",
wantValue: "test",
wantError: nil,
@@ -58,43 +95,24 @@ func TestGetString(t *testing.T) {
key: NetworkDevicesVisibility,
handlerError: someOtherError,
wantError: someOtherError,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_NetworkDevices_error", Value: 1},
},
},
}
RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
s := source.TestSetting[string]{
Key: tt.key,
Value: tt.handlerValue,
Error: tt.handlerError,
}
registerSingleSettingStoreForTest(t, s)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
s: tt.handlerValue,
err: tt.handlerError,
})
value, err := GetString(tt.key, tt.defaultValue)
if !errorsMatchForTest(err, tt.wantError) {
if err != tt.wantError {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-09-04, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@@ -111,7 +129,7 @@ func TestGetUint64(t *testing.T) {
}{
{
name: "read existing value",
key: LogSCMInteractions,
key: KeyExpirationNoticeTime,
handlerValue: 1,
wantValue: 1,
},
@@ -119,14 +137,14 @@ func TestGetUint64(t *testing.T) {
name: "read non-existing value",
key: LogSCMInteractions,
handlerValue: 0,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantValue: 0,
},
{
name: "read non-existing value, non-zero default",
key: LogSCMInteractions,
defaultValue: 2,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantValue: 2,
},
{
@@ -139,23 +157,14 @@ func TestGetUint64(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// None of the policy settings tested here are integers.
// In fact, we don't have any integer policies as of 2024-10-08.
// However, we can register each of them as an integer policy setting
// for the duration of the test, providing us with something to test against.
if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
s := source.TestSetting[uint64]{
Key: tt.key,
Value: tt.handlerValue,
Error: tt.handlerError,
}
registerSingleSettingStoreForTest(t, s)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
u64: tt.handlerValue,
err: tt.handlerError,
})
value, err := GetUint64(tt.key, tt.defaultValue)
if !errorsMatchForTest(err, tt.wantError) {
if err != tt.wantError {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
@@ -174,69 +183,45 @@ func TestGetBoolean(t *testing.T) {
defaultValue bool
wantValue bool
wantError error
wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: FlushDNSOnSessionUnlock,
handlerValue: true,
wantValue: true,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1},
},
},
{
name: "read non-existing value",
key: LogSCMInteractions,
handlerValue: false,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantValue: false,
},
{
name: "reading value returns other error",
key: FlushDNSOnSessionUnlock,
handlerError: someOtherError,
wantError: someOtherError, // expect error...
wantError: someOtherError,
defaultValue: true,
wantValue: true, // ...AND default value if the handler fails.
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1},
},
wantValue: false,
},
}
RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
s := source.TestSetting[bool]{
Key: tt.key,
Value: tt.handlerValue,
Error: tt.handlerError,
}
registerSingleSettingStoreForTest(t, s)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
b: tt.handlerValue,
err: tt.handlerError,
})
value, err := GetBoolean(tt.key, tt.defaultValue)
if !errorsMatchForTest(err, tt.wantError) {
if err != tt.wantError {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-09-04, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@@ -249,42 +234,29 @@ func TestGetPreferenceOption(t *testing.T) {
handlerError error
wantValue setting.PreferenceOption
wantError error
wantMetrics []metrics.TestState
}{
{
name: "always by policy",
key: EnableIncomingConnections,
handlerValue: "always",
wantValue: setting.AlwaysByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
},
},
{
name: "never by policy",
key: EnableIncomingConnections,
handlerValue: "never",
wantValue: setting.NeverByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
},
},
{
name: "use default",
key: EnableIncomingConnections,
handlerValue: "",
wantValue: setting.ShowChoiceByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
},
},
{
name: "read non-existing value",
key: EnableIncomingConnections,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantValue: setting.ShowChoiceByPolicy,
},
{
@@ -293,43 +265,24 @@ func TestGetPreferenceOption(t *testing.T) {
handlerError: someOtherError,
wantValue: setting.ShowChoiceByPolicy,
wantError: someOtherError,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1},
},
},
}
RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
s := source.TestSetting[string]{
Key: tt.key,
Value: tt.handlerValue,
Error: tt.handlerError,
}
registerSingleSettingStoreForTest(t, s)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
s: tt.handlerValue,
err: tt.handlerError,
})
option, err := GetPreferenceOption(tt.key)
if !errorsMatchForTest(err, tt.wantError) {
if err != tt.wantError {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if option != tt.wantValue {
t.Errorf("option=%v, want %v", option, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-09-04, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@@ -342,33 +295,24 @@ func TestGetVisibility(t *testing.T) {
handlerError error
wantValue setting.Visibility
wantError error
wantMetrics []metrics.TestState
}{
{
name: "hidden by policy",
key: AdminConsoleVisibility,
handlerValue: "hide",
wantValue: setting.HiddenByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AdminConsole", Value: 1},
},
},
{
name: "visibility default",
key: AdminConsoleVisibility,
handlerValue: "show",
wantValue: setting.VisibleByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AdminConsole", Value: 1},
},
},
{
name: "read non-existing value",
key: AdminConsoleVisibility,
handlerValue: "show",
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantValue: setting.VisibleByPolicy,
},
{
@@ -378,43 +322,24 @@ func TestGetVisibility(t *testing.T) {
handlerError: someOtherError,
wantValue: setting.VisibleByPolicy,
wantError: someOtherError,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_AdminConsole_error", Value: 1},
},
},
}
RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
s := source.TestSetting[string]{
Key: tt.key,
Value: tt.handlerValue,
Error: tt.handlerError,
}
registerSingleSettingStoreForTest(t, s)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
s: tt.handlerValue,
err: tt.handlerError,
})
visibility, err := GetVisibility(tt.key)
if !errorsMatchForTest(err, tt.wantError) {
if err != tt.wantError {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if visibility != tt.wantValue {
t.Errorf("visibility=%v, want %v", visibility, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-09-04, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@@ -428,7 +353,6 @@ func TestGetDuration(t *testing.T) {
defaultValue time.Duration
wantValue time.Duration
wantError error
wantMetrics []metrics.TestState
}{
{
name: "read existing value",
@@ -436,34 +360,25 @@ func TestGetDuration(t *testing.T) {
handlerValue: "2h",
wantValue: 2 * time.Hour,
defaultValue: 24 * time.Hour,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_KeyExpirationNotice", Value: 1},
},
},
{
name: "invalid duration value",
key: KeyExpirationNoticeTime,
handlerValue: "-20",
wantValue: 24 * time.Hour,
wantError: errors.New(`time: missing unit in duration "-20"`),
defaultValue: 24 * time.Hour,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
},
},
{
name: "read non-existing value",
key: KeyExpirationNoticeTime,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantValue: 24 * time.Hour,
defaultValue: 24 * time.Hour,
},
{
name: "read non-existing value different default",
key: KeyExpirationNoticeTime,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantValue: 0 * time.Second,
defaultValue: 0 * time.Second,
},
@@ -474,43 +389,24 @@ func TestGetDuration(t *testing.T) {
wantValue: 24 * time.Hour,
wantError: someOtherError,
defaultValue: 24 * time.Hour,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
},
},
}
RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
s := source.TestSetting[string]{
Key: tt.key,
Value: tt.handlerValue,
Error: tt.handlerError,
}
registerSingleSettingStoreForTest(t, s)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
s: tt.handlerValue,
err: tt.handlerError,
})
duration, err := GetDuration(tt.key, tt.defaultValue)
if !errorsMatchForTest(err, tt.wantError) {
if err != tt.wantError {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if duration != tt.wantValue {
t.Errorf("duration=%v, want %v", duration, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-09-04, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@@ -524,28 +420,23 @@ func TestGetStringArray(t *testing.T) {
defaultValue []string
wantValue []string
wantError error
wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: AllowedSuggestedExitNodes,
handlerValue: []string{"foo", "bar"},
wantValue: []string{"foo", "bar"},
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1},
},
},
{
name: "read non-existing value",
key: AllowedSuggestedExitNodes,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
wantError: nil,
},
{
name: "read non-existing value, non nil default",
key: AllowedSuggestedExitNodes,
handlerError: ErrNotConfigured,
handlerError: ErrNoSuchKey,
defaultValue: []string{"foo", "bar"},
wantValue: []string{"foo", "bar"},
wantError: nil,
@@ -555,68 +446,28 @@ func TestGetStringArray(t *testing.T) {
key: AllowedSuggestedExitNodes,
handlerError: someOtherError,
wantError: someOtherError,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1},
},
},
}
RegisterWellKnownSettingsForTest(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
s := source.TestSetting[[]string]{
Key: tt.key,
Value: tt.handlerValue,
Error: tt.handlerError,
}
registerSingleSettingStoreForTest(t, s)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
sArr: tt.handlerValue,
err: tt.handlerError,
})
value, err := GetStringArray(tt.key, tt.defaultValue)
if !errorsMatchForTest(err, tt.wantError) {
if err != tt.wantError {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if !slices.Equal(tt.wantValue, value) {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-09-04, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
func registerSingleSettingStoreForTest[T source.TestValueType](tb TB, s source.TestSetting[T]) {
policyStore := source.NewTestStoreOf(tb, s)
MustRegisterStoreForTest(tb, "TestStore", setting.DeviceScope, policyStore)
}
func BenchmarkGetString(b *testing.B) {
loggerx.SetForTest(b, logger.Discard, logger.Discard)
RegisterWellKnownSettingsForTest(b)
wantControlURL := "https://login.tailscale.com"
registerSingleSettingStoreForTest(b, source.TestSettingOf(ControlURL, wantControlURL))
b.ResetTimer()
for i := 0; i < b.N; i++ {
gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com")
if gotControlURL != wantControlURL {
b.Fatalf("got %v; want %v", gotControlURL, wantControlURL)
}
}
}
func TestSelectControlURL(t *testing.T) {
tests := []struct {
reg, disk, want string
@@ -648,13 +499,3 @@ func TestSelectControlURL(t *testing.T) {
}
}
}
func errorsMatchForTest(got, want error) bool {
if got == nil && want == nil {
return true
}
if got == nil || want == nil {
return false
}
return errors.Is(got, want) || got.Error() == want.Error()
}

View File

@@ -1,92 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"errors"
"fmt"
"os/user"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
"tailscale.com/util/testenv"
)
func init() {
// On Windows, we should automatically register the Registry-based policy
// store for the device. If we are running in a user's security context
// (e.g., we're the GUI), we should also register the Registry policy store for
// the user. In the future, we should register (and unregister) user policy
// stores whenever a user connects to (or disconnects from) the local backend.
// This ensures the backend is aware of the user's policy settings and can send
// them to the GUI/CLI/Web clients on demand or whenever they change.
//
// Other platforms, such as macOS, iOS and Android, should register their
// platform-specific policy stores via [RegisterStore]
// (or [RegisterHandler] until they implement the [source.Store] interface).
//
// External code, such as the ipnlocal package, may choose to register
// additional policy stores, such as config files and policies received from
// the control plane.
internal.Init.MustDefer(func() error {
// Do not register or use default policy stores during tests.
// Each test should set up its own necessary configurations.
if testenv.InTest() {
return nil
}
return configureSyspolicy(nil)
})
}
// configureSyspolicy configures syspolicy for use on Windows,
// either in test or regular builds depending on whether tb has a non-nil value.
func configureSyspolicy(tb internal.TB) error {
const localSystemSID = "S-1-5-18"
// Always create and register a machine policy store that reads
// policy settings from the HKEY_LOCAL_MACHINE registry hive.
machineStore, err := source.NewMachinePlatformPolicyStore()
if err != nil {
return fmt.Errorf("failed to create the machine policy store: %v", err)
}
if tb == nil {
_, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore)
} else {
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore)
}
if err != nil {
return err
}
// Check whether the current process is running as Local System or not.
u, err := user.Current()
if err != nil {
return err
}
if u.Uid == localSystemSID {
return nil
}
// If it's not a Local System's process (e.g., it's the GUI rather than the tailscaled service),
// we should create and use a policy store for the current user that reads
// policy settings from that user's registry hive (HKEY_CURRENT_USER).
userStore, err := source.NewUserPlatformPolicyStore(0)
if err != nil {
return fmt.Errorf("failed to create the current user's policy store: %v", err)
}
if tb == nil {
_, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore)
} else {
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore)
}
if err != nil {
return err
}
// And also set [setting.CurrentUserScope] as the [setting.DefaultScope], so [GetString],
// [GetVisibility] and similar functions would be returning a merged result
// of the machine's and user's policies.
if !setting.SetDefaultScope(setting.CurrentUserScope) {
return errors.New("current scope already set")
}
return nil
}

View File

@@ -1,69 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// This file contains user-facing metrics that are used by multiple packages.
// Use it to define more common metrics. Any changes to the registry and
// metric types should be in usermetric.go.
package usermetric
import (
"sync"
"tailscale.com/metrics"
)
// Metrics contains user-facing metrics that are used by multiple packages.
type Metrics struct {
initOnce sync.Once
droppedPacketsInbound *metrics.MultiLabelMap[DropLabels]
droppedPacketsOutbound *metrics.MultiLabelMap[DropLabels]
}
// DropReason is the reason why a packet was dropped.
type DropReason string
const (
// ReasonACL means that the packet was not permitted by ACL.
ReasonACL DropReason = "acl"
// ReasonError means that the packet was dropped because of an error.
ReasonError DropReason = "error"
)
// DropLabels contains common label(s) for dropped packet counters.
type DropLabels struct {
Reason DropReason
}
// initOnce initializes the common metrics.
func (r *Registry) initOnce() {
r.m.initOnce.Do(func() {
r.m.droppedPacketsInbound = NewMultiLabelMapWithRegistry[DropLabels](
r,
"tailscaled_inbound_dropped_packets_total",
"counter",
"Counts the number of dropped packets received by the node from other peers",
)
r.m.droppedPacketsOutbound = NewMultiLabelMapWithRegistry[DropLabels](
r,
"tailscaled_outbound_dropped_packets_total",
"counter",
"Counts the number of packets dropped while being sent to other peers",
)
})
}
// DroppedPacketsOutbound returns the outbound dropped packet metric, creating it
// if necessary.
func (r *Registry) DroppedPacketsOutbound() *metrics.MultiLabelMap[DropLabels] {
r.initOnce()
return r.m.droppedPacketsOutbound
}
// DroppedPacketsInbound returns the inbound dropped packet metric.
func (r *Registry) DroppedPacketsInbound() *metrics.MultiLabelMap[DropLabels] {
r.initOnce()
return r.m.droppedPacketsInbound
}

View File

@@ -19,9 +19,6 @@ import (
// Registry tracks user-facing metrics of various Tailscale subsystems.
type Registry struct {
vars expvar.Map
// m contains common metrics owned by the registry.
m Metrics
}
// NewMultiLabelMapWithRegistry creates and register a new

View File

@@ -158,10 +158,10 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int)
} else {
connectedToControl = c.health.GetInPollNetMap()
}
c.mu.Lock()
myDerp := c.myDerp
c.mu.Unlock()
if !connectedToControl {
c.mu.Lock()
myDerp := c.myDerp
c.mu.Unlock()
if myDerp != 0 {
metricDERPHomeNoChangeNoControl.Add(1)
return myDerp
@@ -178,11 +178,6 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int)
// one.
preferredDERP = c.pickDERPFallback()
}
if preferredDERP != myDerp {
c.logf(
"magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms]",
c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds())
}
if !c.setNearestDERP(preferredDERP) {
preferredDERP = 0
}
@@ -649,10 +644,9 @@ func (c *Conn) runDerpReader(ctx context.Context, regionID int, dc *derphttp.Cli
}
type derpWriteRequest struct {
addr netip.AddrPort
pubKey key.NodePublic
b []byte // copied; ownership passed to receiver
isDisco bool
addr netip.AddrPort
pubKey key.NodePublic
b []byte // copied; ownership passed to receiver
}
// runDerpWriter runs in a goroutine for the life of a DERP
@@ -674,10 +668,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan
if err != nil {
c.logf("magicsock: derp.Send(%v): %v", wr.addr, err)
metricSendDERPError.Add(1)
if !wr.isDisco {
c.metrics.outboundPacketsDroppedErrors.Add(1)
}
} else if !wr.isDisco {
} else {
c.metrics.outboundPacketsDERPTotal.Add(1)
c.metrics.outboundBytesDERPTotal.Add(int64(len(wr.b)))
}
@@ -700,6 +691,8 @@ func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint)
// No data read occurred. Wait for another packet.
continue
}
c.metrics.inboundPacketsDERPTotal.Add(1)
c.metrics.inboundBytesDERPTotal.Add(int64(n))
sizes[0] = n
eps[0] = ep
return 1, nil
@@ -739,9 +732,6 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
if stats := c.stats.Load(); stats != nil {
stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, dm.n)
}
c.metrics.inboundPacketsDERPTotal.Add(1)
c.metrics.inboundBytesDERPTotal.Add(int64(n))
return n, ep
}

View File

@@ -983,8 +983,7 @@ func (de *endpoint) send(buffs [][]byte) error {
allOk := true
var txBytes int
for _, buff := range buffs {
const isDisco = false
ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff, isDisco)
ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff)
txBytes += len(buff)
if !ok {
allOk = false
@@ -992,7 +991,7 @@ func (de *endpoint) send(buffs [][]byte) error {
}
if stats := de.c.stats.Load(); stats != nil {
stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buffs), txBytes)
stats.UpdateTxPhysical(de.nodeAddr, derpAddr, 1, txBytes)
}
if allOk {
return nil

View File

@@ -127,10 +127,6 @@ type metrics struct {
outboundBytesIPv4Total expvar.Int
outboundBytesIPv6Total expvar.Int
outboundBytesDERPTotal expvar.Int
// outboundPacketsDroppedErrors is the total number of outbound packets
// dropped due to errors.
outboundPacketsDroppedErrors expvar.Int
}
// A Conn routes UDP packets and actively manages a list of its endpoints.
@@ -609,8 +605,6 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
"counter",
"Counts the number of bytes sent to other peers",
)
outboundPacketsDroppedErrors := reg.DroppedPacketsOutbound()
m := new(metrics)
// Map clientmetrics to the usermetric counters.
@@ -637,8 +631,6 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total)
outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal)
outboundPacketsDroppedErrors.Set(usermetric.DropLabels{Reason: usermetric.ReasonError}, &m.outboundPacketsDroppedErrors)
return m
}
@@ -1210,13 +1202,8 @@ func (c *Conn) networkDown() bool { return !c.networkUp.Load() }
// Send implements conn.Bind.
//
// See https://pkg.go.dev/golang.zx2c4.com/wireguard/conn#Bind.Send
func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) (err error) {
func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) error {
n := int64(len(buffs))
defer func() {
if err != nil {
c.metrics.outboundPacketsDroppedErrors.Add(n)
}
}()
metricSendData.Add(n)
if c.networkDown() {
metricSendDataNetworkDown.Add(n)
@@ -1369,7 +1356,7 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error)
// An example of when they might be different: sending to an
// IPv6 address when the local machine doesn't have IPv6 support
// returns (false, nil); it's not an error, but nothing was sent.
func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool) (sent bool, err error) {
func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) {
if addr.Addr() != tailcfg.DerpMagicIPAddr {
return c.sendUDP(addr, b)
}
@@ -1392,7 +1379,7 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, is
case <-c.donec:
metricSendDERPErrorClosed.Add(1)
return false, errConnClosed
case ch <- derpWriteRequest{addr, pubKey, pkt, isDisco}:
case ch <- derpWriteRequest{addr, pubKey, pkt}:
metricSendDERPQueued.Add(1)
return true, nil
default:
@@ -1590,8 +1577,7 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi
box := di.sharedKey.Seal(m.AppendMarshal(nil))
pkt = append(pkt, box...)
const isDisco = true
sent, err = c.sendAddr(dst, dstKey, pkt, isDisco)
sent, err = c.sendAddr(dst, dstKey, pkt)
if sent {
if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco()) {
node := "?"

View File

@@ -63,7 +63,6 @@ import (
"tailscale.com/types/nettype"
"tailscale.com/types/ptr"
"tailscale.com/util/cibuild"
"tailscale.com/util/must"
"tailscale.com/util/racebuild"
"tailscale.com/util/set"
"tailscale.com/util/usermetric"
@@ -3084,27 +3083,3 @@ func TestMaybeRebindOnError(t *testing.T) {
}
})
}
func TestNetworkDownSendErrors(t *testing.T) {
netMon := must.Get(netmon.New(t.Logf))
defer netMon.Close()
reg := new(usermetric.Registry)
conn := must.Get(NewConn(Options{
DisablePortMapper: true,
Logf: t.Logf,
NetMon: netMon,
Metrics: reg,
}))
defer conn.Close()
conn.SetNetworkUp(false)
if err := conn.Send([][]byte{{00}}, &lazyEndpoint{}); err == nil {
t.Error("expected error, got nil")
}
resp := httptest.NewRecorder()
reg.Handler(resp, new(http.Request))
if !strings.Contains(resp.Body.String(), `tailscaled_outbound_dropped_packets_total{reason="error"} 1`) {
t.Errorf("expected NetworkDown to increment packet dropped metric; got %q", resp.Body.String())
}
}

View File

@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
"tailscale.com/drive"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/metrics"
@@ -173,18 +174,19 @@ type Impl struct {
// It can only be set before calling Start.
ProcessSubnets bool
ipstack *stack.Stack
linkEP *linkEndpoint
tundev *tstun.Wrapper
e wgengine.Engine
pm *proxymap.Mapper
mc *magicsock.Conn
logf logger.Logf
dialer *tsdial.Dialer
ctx context.Context // alive until Close
ctxCancel context.CancelFunc // called on Close
lb *ipnlocal.LocalBackend // or nil
dns *dns.Manager
ipstack *stack.Stack
linkEP *linkEndpoint
tundev *tstun.Wrapper
e wgengine.Engine
pm *proxymap.Mapper
mc *magicsock.Conn
logf logger.Logf
dialer *tsdial.Dialer
ctx context.Context // alive until Close
ctxCancel context.CancelFunc // called on Close
lb *ipnlocal.LocalBackend // or nil
dns *dns.Manager
driveForLocal drive.FileSystemForLocal // or nil
// loopbackPort, if non-nil, will enable Impl to loop back (dnat to
// <address-family-loopback>:loopbackPort) TCP & UDP flows originally
@@ -286,7 +288,7 @@ func setTCPBufSizes(ipstack *stack.Stack) error {
}
// Create creates and populates a new Impl.
func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper) (*Impl, error) {
func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper, driveForLocal drive.FileSystemForLocal) (*Impl, error) {
if mc == nil {
return nil, errors.New("nil magicsock.Conn")
}
@@ -380,6 +382,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
connsInFlightByClient: make(map[netip.Addr]int),
packetsInFlight: make(map[stack.TransportEndpointID]struct{}),
dns: dns,
driveForLocal: driveForLocal,
}
loopbackPort, ok := envknob.LookupInt("TS_DEBUG_NETSTACK_LOOPBACK_PORT")
if ok && loopbackPort >= 0 && loopbackPort <= math.MaxUint16 {

View File

@@ -65,7 +65,7 @@ func TestInjectInboundLeak(t *testing.T) {
t.Fatal(err)
}
ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper())
ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
if err != nil {
t.Fatal(err)
}
@@ -116,7 +116,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl {
tb.Cleanup(func() { eng.Close() })
sys.Set(eng)
ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper())
ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil)
if err != nil {
tb.Fatal(err)
}

View File

@@ -1236,7 +1236,7 @@ func (e *userspaceEngine) linkChange(delta *netmon.ChangeDelta) {
// and Apple platforms.
if changed {
switch runtime.GOOS {
case "linux", "android", "ios", "darwin", "openbsd":
case "linux", "android", "ios", "darwin":
e.wgLock.Lock()
dnsCfg := e.lastDNSConfig
e.wgLock.Unlock()

View File

@@ -21,7 +21,6 @@ type Config struct {
NodeID tailcfg.StableNodeID
PrivateKey key.NodePrivate
Addresses []netip.Prefix
ListenPort uint16 // not used by Tailscale's conn.Bind implementation
MTU uint16
DNS []netip.Addr
Peers []Peer

View File

@@ -124,13 +124,7 @@ func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error {
if err != nil {
return err
}
case k.EqualString("listen_port"):
port, err := mem.ParseUint(value, 10, 16)
if err != nil {
return fmt.Errorf("failed to parse listen_port: %w", err)
}
cfg.ListenPort = uint16(port)
case k.EqualString("fwmark"):
case k.EqualString("listen_port") || k.EqualString("fwmark"):
// ignore
default:
return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy())

View File

@@ -39,7 +39,6 @@ var _ConfigCloneNeedsRegeneration = Config(struct {
NodeID tailcfg.StableNodeID
PrivateKey key.NodePrivate
Addresses []netip.Prefix
ListenPort uint16
MTU uint16
DNS []netip.Addr
Peers []Peer

View File

@@ -42,9 +42,6 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
if !prev.PrivateKey.Equal(cfg.PrivateKey) {
set("private_key", cfg.PrivateKey.UntypedHexString())
}
if prev.ListenPort != cfg.ListenPort {
setUint16("listen_port", cfg.ListenPort)
}
old := make(map[key.NodePublic]Peer)
for _, p := range prev.Peers {
@@ -90,9 +87,7 @@ func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error {
// See corp issue 3016.
logf("[unexpected] endpoint changed from %s to %s", oldPeer.WGEndpoint, p.PublicKey)
}
if cfg.NodeID != "" {
set("endpoint", p.PublicKey.UntypedHexString())
}
set("endpoint", p.PublicKey.UntypedHexString())
}
// TODO: replace_allowed_ips is expensive.