Compare commits
4 Commits
irbekrm/op
...
andrew/wor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8290d287d0 | ||
|
|
d7a4f9d31c | ||
|
|
0d6e71df70 | ||
|
|
dcb0f189cc |
@@ -46,6 +46,7 @@ var (
|
||||
backendAddr = flag.String("backend-addr", "", "Address of the Grafana server served over HTTP, in host:port format. Typically localhost:nnnn.")
|
||||
tailscaleDir = flag.String("state-dir", "./", "Alternate directory to use for Tailscale state storage. If empty, a default is used.")
|
||||
useHTTPS = flag.Bool("use-https", false, "Serve over HTTPS via your *.ts.net subdomain if enabled in Tailscale admin.")
|
||||
loginServer = flag.String("login-server", "", "URL to alternative control server. If empty, the default Tailscale control is used.")
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -57,8 +58,9 @@ func main() {
|
||||
log.Fatal("missing --backend-addr")
|
||||
}
|
||||
ts := &tsnet.Server{
|
||||
Dir: *tailscaleDir,
|
||||
Hostname: *hostname,
|
||||
Dir: *tailscaleDir,
|
||||
Hostname: *hostname,
|
||||
ControlURL: *loginServer,
|
||||
}
|
||||
|
||||
// TODO(bradfitz,maisem): move this to a method on tsnet.Server probably.
|
||||
|
||||
@@ -89,13 +89,19 @@ func (t timestampSource) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
type result struct {
|
||||
at time.Time
|
||||
// resultKey contains the stable dimensions and their values for a given
|
||||
// timeseries, i.e. not time and not rtt/timeout.
|
||||
type resultKey struct {
|
||||
meta nodeMeta
|
||||
timestampSource timestampSource
|
||||
connStability connStability
|
||||
dstPort int
|
||||
rtt *time.Duration // nil signifies failure, e.g. timeout
|
||||
}
|
||||
|
||||
type result struct {
|
||||
key resultKey
|
||||
at time.Time
|
||||
rtt *time.Duration // nil signifies failure, e.g. timeout
|
||||
}
|
||||
|
||||
func measureRTT(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error) {
|
||||
@@ -149,6 +155,10 @@ type nodeMeta struct {
|
||||
|
||||
type measureFn func(conn io.ReadWriteCloser, dst *net.UDPAddr) (rtt time.Duration, err error)
|
||||
|
||||
// probe measures STUN round trip time for the node described by meta over
|
||||
// conn against dstPort. It may return a nil duration and nil error if the
|
||||
// STUN request timed out. A non-nil error indicates an unrecoverable or
|
||||
// non-temporary error.
|
||||
func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) {
|
||||
ua := &net.UDPAddr{
|
||||
IP: net.IP(meta.addr.AsSlice()),
|
||||
@@ -162,10 +172,15 @@ func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*
|
||||
log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err)
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &rtt, nil
|
||||
}
|
||||
|
||||
// nodeMetaFromDERPMap parses the provided DERP map in order to update nodeMeta
|
||||
// in the provided nodeMetaByAddr. It returns a slice of nodeMeta containing
|
||||
// the nodes that are no longer seen in the DERP map, but were previously held
|
||||
// in nodeMetaByAddr.
|
||||
func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]nodeMeta, ipv6 bool) (stale []nodeMeta, err error) {
|
||||
// Parse the new derp map before making any state changes in nodeMetaByAddr.
|
||||
// If parse fails we just stick with the old state.
|
||||
@@ -271,10 +286,12 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad
|
||||
doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, dstPort int) {
|
||||
defer wg.Done()
|
||||
r := result{
|
||||
at: at,
|
||||
meta: meta,
|
||||
timestampSource: source,
|
||||
dstPort: dstPort,
|
||||
key: resultKey{
|
||||
meta: meta,
|
||||
timestampSource: source,
|
||||
dstPort: dstPort,
|
||||
},
|
||||
at: at,
|
||||
}
|
||||
if conn == nil {
|
||||
var err error
|
||||
@@ -293,7 +310,7 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[netip.Ad
|
||||
}
|
||||
defer conn.Close()
|
||||
} else {
|
||||
r.connStability = stableConn
|
||||
r.key.connStability = stableConn
|
||||
}
|
||||
fn := measureRTT
|
||||
if source == timestampSourceKernel {
|
||||
@@ -373,7 +390,12 @@ const (
|
||||
stableConn connStability = true
|
||||
)
|
||||
|
||||
func timeSeriesLabels(meta nodeMeta, instance string, source timestampSource, stability connStability, dstPort int) []prompb.Label {
|
||||
const (
|
||||
rttMetricName = "stunstamp_derp_stun_rtt_ns"
|
||||
timeoutsMetricName = "stunstamp_derp_stun_timeouts_total"
|
||||
)
|
||||
|
||||
func timeSeriesLabels(metricName string, meta nodeMeta, instance string, source timestampSource, stability connStability, dstPort int) []prompb.Label {
|
||||
addressFamily := "ipv4"
|
||||
if meta.addr.Is6() {
|
||||
addressFamily = "ipv6"
|
||||
@@ -409,7 +431,7 @@ func timeSeriesLabels(meta nodeMeta, instance string, source timestampSource, st
|
||||
})
|
||||
labels = append(labels, prompb.Label{
|
||||
Name: "__name__",
|
||||
Value: "stunstamp_derp_stun_rtt_ns",
|
||||
Value: metricName,
|
||||
})
|
||||
labels = append(labels, prompb.Label{
|
||||
Name: "timestamp_source",
|
||||
@@ -443,20 +465,36 @@ func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, dstPorts []int)
|
||||
},
|
||||
}
|
||||
staleMarkers = append(staleMarkers, prompb.TimeSeries{
|
||||
Labels: timeSeriesLabels(s, instance, timestampSourceUserspace, unstableConn, dstPort),
|
||||
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, dstPort),
|
||||
Samples: samples,
|
||||
})
|
||||
staleMarkers = append(staleMarkers, prompb.TimeSeries{
|
||||
Labels: timeSeriesLabels(s, instance, timestampSourceUserspace, stableConn, dstPort),
|
||||
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, dstPort),
|
||||
Samples: samples,
|
||||
})
|
||||
staleMarkers = append(staleMarkers, prompb.TimeSeries{
|
||||
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, dstPort),
|
||||
Samples: samples,
|
||||
})
|
||||
staleMarkers = append(staleMarkers, prompb.TimeSeries{
|
||||
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, dstPort),
|
||||
Samples: samples,
|
||||
})
|
||||
if supportsKernelTS() {
|
||||
staleMarkers = append(staleMarkers, prompb.TimeSeries{
|
||||
Labels: timeSeriesLabels(s, instance, timestampSourceKernel, unstableConn, dstPort),
|
||||
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, dstPort),
|
||||
Samples: samples,
|
||||
})
|
||||
staleMarkers = append(staleMarkers, prompb.TimeSeries{
|
||||
Labels: timeSeriesLabels(s, instance, timestampSourceKernel, stableConn, dstPort),
|
||||
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, dstPort),
|
||||
Samples: samples,
|
||||
})
|
||||
staleMarkers = append(staleMarkers, prompb.TimeSeries{
|
||||
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, dstPort),
|
||||
Samples: samples,
|
||||
})
|
||||
staleMarkers = append(staleMarkers, prompb.TimeSeries{
|
||||
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, dstPort),
|
||||
Samples: samples,
|
||||
})
|
||||
}
|
||||
@@ -465,25 +503,47 @@ func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, dstPorts []int)
|
||||
return staleMarkers
|
||||
}
|
||||
|
||||
func resultToPromTimeSeries(r result, instance string) prompb.TimeSeries {
|
||||
labels := timeSeriesLabels(r.meta, instance, r.timestampSource, r.connStability, r.dstPort)
|
||||
samples := make([]prompb.Sample, 1)
|
||||
samples[0].Timestamp = r.at.UnixMilli()
|
||||
if r.rtt != nil {
|
||||
samples[0].Value = float64(*r.rtt)
|
||||
} else {
|
||||
samples[0].Value = math.NaN()
|
||||
// TODO: timeout counter
|
||||
// resultsToPromTimeSeries returns a slice of prometheus TimeSeries for the
|
||||
// provided results and instance. timeouts is updated based on results, i.e.
|
||||
// all result.key's are added to timeouts if they do not exist, and removed
|
||||
// from timeouts if they are not present in results.
|
||||
func resultsToPromTimeSeries(results []result, instance string, timeouts map[resultKey]uint64) []prompb.TimeSeries {
|
||||
all := make([]prompb.TimeSeries, 0, len(results)*2)
|
||||
seenKeys := make(map[resultKey]bool)
|
||||
for _, r := range results {
|
||||
timeoutsCount := timeouts[r.key] // a non-existent key will return a zero val
|
||||
seenKeys[r.key] = true
|
||||
rttLabels := timeSeriesLabels(rttMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.dstPort)
|
||||
rttSamples := make([]prompb.Sample, 1)
|
||||
rttSamples[0].Timestamp = r.at.UnixMilli()
|
||||
if r.rtt != nil {
|
||||
rttSamples[0].Value = float64(*r.rtt)
|
||||
} else {
|
||||
rttSamples[0].Value = math.NaN()
|
||||
timeoutsCount++
|
||||
}
|
||||
rttTS := prompb.TimeSeries{
|
||||
Labels: rttLabels,
|
||||
Samples: rttSamples,
|
||||
}
|
||||
all = append(all, rttTS)
|
||||
timeouts[r.key] = timeoutsCount
|
||||
timeoutsLabels := timeSeriesLabels(timeoutsMetricName, r.key.meta, instance, r.key.timestampSource, r.key.connStability, r.key.dstPort)
|
||||
timeoutsSamples := make([]prompb.Sample, 1)
|
||||
timeoutsSamples[0].Timestamp = r.at.UnixMilli()
|
||||
timeoutsSamples[0].Value = float64(timeoutsCount)
|
||||
timeoutsTS := prompb.TimeSeries{
|
||||
Labels: timeoutsLabels,
|
||||
Samples: timeoutsSamples,
|
||||
}
|
||||
all = append(all, timeoutsTS)
|
||||
}
|
||||
ts := prompb.TimeSeries{
|
||||
Labels: labels,
|
||||
Samples: samples,
|
||||
for k := range timeouts {
|
||||
if !seenKeys[k] {
|
||||
delete(timeouts, k)
|
||||
}
|
||||
}
|
||||
slices.SortFunc(ts.Labels, func(a, b prompb.Label) int {
|
||||
// prometheus remote-write spec requires lexicographically sorted label names
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
return ts
|
||||
return all
|
||||
}
|
||||
|
||||
type remoteWriteClient struct {
|
||||
@@ -719,6 +779,10 @@ CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT
|
||||
// comes into play.
|
||||
stableConns := make(map[netip.Addr]map[int][2]io.ReadWriteCloser)
|
||||
|
||||
// timeouts holds counts of timeout events. Values are persisted for the
|
||||
// lifetime of the related node in the DERP map.
|
||||
timeouts := make(map[resultKey]uint64)
|
||||
|
||||
derpMapTicker := time.NewTicker(time.Minute * 5)
|
||||
defer derpMapTicker.Stop()
|
||||
probeTicker := time.NewTicker(*flagInterval)
|
||||
@@ -744,10 +808,7 @@ CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT
|
||||
shutdown()
|
||||
return
|
||||
}
|
||||
ts := make([]prompb.TimeSeries, 0, len(results))
|
||||
for _, r := range results {
|
||||
ts = append(ts, resultToPromTimeSeries(r, *flagInstance))
|
||||
}
|
||||
ts := resultsToPromTimeSeries(results, *flagInstance, timeouts)
|
||||
select {
|
||||
case tsCh <- ts:
|
||||
default:
|
||||
@@ -766,11 +827,11 @@ CREATE TABLE IF NOT EXISTS rtt(at_unix INT, region_id INT, hostname TEXT, af INT
|
||||
}
|
||||
for _, result := range results {
|
||||
af := 4
|
||||
if result.meta.addr.Is6() {
|
||||
if result.key.meta.addr.Is6() {
|
||||
af = 6
|
||||
}
|
||||
_, err = tx.Exec("INSERT INTO rtt(at_unix, region_id, hostname, af, address, timestamp_source, stable_conn, dst_port, rtt_ns) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
result.at.Unix(), result.meta.regionID, result.meta.hostname, af, result.meta.addr.String(), result.timestampSource, result.connStability, result.dstPort, result.rtt)
|
||||
result.at.Unix(), result.key.meta.regionID, result.key.meta.hostname, af, result.key.meta.addr.String(), result.key.timestampSource, result.key.connStability, result.key.dstPort, result.rtt)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
log.Printf("error adding result to tx: %v", err)
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/net/dns/resolver"
|
||||
@@ -122,6 +123,7 @@ func (m *Manager) Set(cfg Config) error {
|
||||
// The returned list is sorted by the first hostname in each entry.
|
||||
func compileHostEntries(cfg Config) (hosts []*HostEntry) {
|
||||
didLabel := make(map[string]bool, len(cfg.Hosts))
|
||||
hostsMap := make(map[netip.Addr]*HostEntry, len(cfg.Hosts))
|
||||
for _, sd := range cfg.SearchDomains {
|
||||
for h, ips := range cfg.Hosts {
|
||||
if !sd.Contains(h) || h.NumLabels() != (sd.NumLabels()+1) {
|
||||
@@ -136,15 +138,23 @@ func compileHostEntries(cfg Config) (hosts []*HostEntry) {
|
||||
if cfg.OnlyIPv6 && ip.Is4() {
|
||||
continue
|
||||
}
|
||||
hosts = append(hosts, &HostEntry{
|
||||
Addr: ip,
|
||||
Hosts: ipHosts,
|
||||
})
|
||||
if e := hostsMap[ip]; e != nil {
|
||||
e.Hosts = append(e.Hosts, ipHosts...)
|
||||
} else {
|
||||
hostsMap[ip] = &HostEntry{
|
||||
Addr: ip,
|
||||
Hosts: ipHosts,
|
||||
}
|
||||
}
|
||||
// Only add IPv4 or IPv6 per host, like we do in the resolver.
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(hostsMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
hosts = xmaps.Values(hostsMap)
|
||||
slices.SortFunc(hosts, func(a, b *HostEntry) int {
|
||||
if len(a.Hosts) == 0 && len(b.Hosts) == 0 {
|
||||
return 0
|
||||
|
||||
@@ -87,8 +87,7 @@ func TestCompileHostEntries(t *testing.T) {
|
||||
{Addr: netip.MustParseAddr("1.1.1.1"), Hosts: []string{"a.foo.ts.net.", "a"}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.2"), Hosts: []string{"b.foo.ts.net.", "b"}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.3"), Hosts: []string{"c.foo.ts.net.", "c"}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.4"), Hosts: []string{"d.foo.beta.tailscale.net."}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.4"), Hosts: []string{"d.foo.ts.net.", "d"}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.4"), Hosts: []string{"d.foo.ts.net.", "d", "d.foo.beta.tailscale.net."}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.5"), Hosts: []string{"e.foo.beta.tailscale.net.", "e"}},
|
||||
},
|
||||
},
|
||||
@@ -103,8 +102,7 @@ func TestCompileHostEntries(t *testing.T) {
|
||||
SearchDomains: []dnsname.FQDN{"foo.ts.net.", "foo.beta.tailscale.net."},
|
||||
},
|
||||
want: []*HostEntry{
|
||||
{Addr: netip.MustParseAddr("1.1.1.5"), Hosts: []string{"e.foo.beta.tailscale.net."}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.5"), Hosts: []string{"e.foo.ts.net.", "e"}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.5"), Hosts: []string{"e.foo.ts.net.", "e", "e.foo.beta.tailscale.net."}},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -120,8 +118,7 @@ func TestCompileHostEntries(t *testing.T) {
|
||||
SearchDomains: []dnsname.FQDN{"foo.ts.net.", "foo.beta.tailscale.net."},
|
||||
},
|
||||
want: []*HostEntry{
|
||||
{Addr: netip.MustParseAddr("1.1.1.4"), Hosts: []string{"d.foo.beta.tailscale.net."}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.4"), Hosts: []string{"d.foo.ts.net.", "d"}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.4"), Hosts: []string{"d.foo.ts.net.", "d", "d.foo.beta.tailscale.net."}},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -139,8 +136,7 @@ func TestCompileHostEntries(t *testing.T) {
|
||||
want: []*HostEntry{
|
||||
{Addr: netip.MustParseAddr("1.1.1.2"), Hosts: []string{"h1.foo.beta.tailscale.net."}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.3"), Hosts: []string{"h1.foo.ts.net.", "h1"}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.1"), Hosts: []string{"h2.foo.beta.tailscale.net."}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.1"), Hosts: []string{"h2.foo.ts.net.", "h2"}},
|
||||
{Addr: netip.MustParseAddr("1.1.1.1"), Hosts: []string{"h2.foo.ts.net.", "h2", "h2.foo.beta.tailscale.net."}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
450
util/workgraph/workgraph.go
Normal file
450
util/workgraph/workgraph.go
Normal file
@@ -0,0 +1,450 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package workgraph contains a "workgraph"; a data structure that allows
|
||||
// defining individual jobs, dependencies between them, and then executing all
|
||||
// jobs to completion.
|
||||
package workgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
// ErrCyclic is returned when there is a cycle in the graph.
|
||||
var ErrCyclic = errors.New("graph is cyclic")
|
||||
|
||||
// Node is the interface that must be implemented by a node in a WorkGraph.
|
||||
type Node interface {
|
||||
// ID should return a unique ID for this node. IDs for each Node in a
|
||||
// WorkGraph must be unique.
|
||||
ID() string
|
||||
|
||||
// Run is called when this node in a WorkGraph is executed; it should
|
||||
// return an error if execution fails, which will cause all dependent
|
||||
// Nodes to fail to execute.
|
||||
Run(context.Context) error
|
||||
}
|
||||
|
||||
type nodeFunc struct {
|
||||
id string
|
||||
run func(context.Context) error
|
||||
}
|
||||
|
||||
func (n *nodeFunc) ID() string { return n.id }
|
||||
func (n *nodeFunc) Run(ctx context.Context) error { return n.run(ctx) }
|
||||
|
||||
// NodeFunc is a helper that returns a Node with the given ID that calls the
|
||||
// given function when Node.Run is called.
|
||||
func NodeFunc(id string, fn func(context.Context) error) Node {
|
||||
return &nodeFunc{id, fn}
|
||||
}
|
||||
|
||||
// WorkGraph is a directed acyclic graph of individual jobs to be executed,
|
||||
// each of which may have dependencies on other jobs. It supports adding a job
|
||||
// as a Node–a combination of a unique ID and the function to execute that
|
||||
// job–and then running all added Nodes while respecting dependencies.
|
||||
type WorkGraph struct {
|
||||
nodes map[string]Node // keyed by Node.ID
|
||||
edges edgeList[string] // keyed by Node.ID
|
||||
|
||||
// Concurrency is the number of concurrent goroutines to use to process
|
||||
// jobs. If zero, runtime.GOMAXPROCS will be used.
|
||||
//
|
||||
// This field must not be modified after Run has been called.
|
||||
Concurrency int
|
||||
}
|
||||
|
||||
// NewWorkGraph creates a new empty WorkGraph.
|
||||
func NewWorkGraph() *WorkGraph {
|
||||
ret := &WorkGraph{
|
||||
nodes: make(map[string]Node),
|
||||
edges: newEdgeList[string](),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddNodeOpts contains options that can be passed to AddNode.
|
||||
type AddNodeOpts struct {
|
||||
// Dependencies are any Node IDs that must be completed before this
|
||||
// Node is started.
|
||||
Dependencies []string
|
||||
}
|
||||
|
||||
// AddNode adds a new Node to the WorkGraph with the provided options. It
|
||||
// returns an error if the given Node.ID was already added to the WorkGraph, or
|
||||
// if one of the options provided was invalid.
|
||||
func (g *WorkGraph) AddNode(n Node, opts *AddNodeOpts) error {
|
||||
id := n.ID()
|
||||
if _, found := g.nodes[id]; found {
|
||||
return fmt.Errorf("node %q already exists", id)
|
||||
}
|
||||
g.nodes[id] = n
|
||||
|
||||
if opts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create an edge from each dependency pointing to this node, forcing
|
||||
// that node to be evaluated first.
|
||||
for _, dep := range opts.Dependencies {
|
||||
if _, found := g.nodes[dep]; !found {
|
||||
return fmt.Errorf("dependency %q not found", dep)
|
||||
}
|
||||
g.edges.Add(dep, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type queueEntry struct {
|
||||
id string
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Run will iterate through all Nodes in this WorkGraph, running them once all
|
||||
// their dependencies have been satisfied, and returning any errors that occur.
|
||||
func (g *WorkGraph) Run(ctx context.Context) error {
|
||||
groups, err := g.topoSortKahn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Create one goroutine that pushes jobs onto our queue...
|
||||
var wg sync.WaitGroup
|
||||
queue := make(chan queueEntry)
|
||||
publishCtx, publishCancel := context.WithCancel(ctx)
|
||||
defer publishCancel()
|
||||
|
||||
wg.Add(1)
|
||||
go g.runPublisher(publishCtx, &wg, queue, groups)
|
||||
|
||||
firstErr := make(chan error, 1)
|
||||
saveErr := func(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Tell the publisher to shut down
|
||||
publishCancel()
|
||||
|
||||
select {
|
||||
case firstErr <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ... and N goroutines that each work on an item from the queue.
|
||||
n := g.Concurrency
|
||||
if n == 0 {
|
||||
n = runtime.GOMAXPROCS(-1)
|
||||
}
|
||||
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go g.runWorker(ctx, &wg, queue, saveErr)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
select {
|
||||
case err := <-firstErr:
|
||||
return err
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *WorkGraph) runPublisher(ctx context.Context, wg *sync.WaitGroup, queue chan queueEntry, groups []set.Set[string]) {
|
||||
defer wg.Done()
|
||||
defer close(queue)
|
||||
|
||||
// For each parallel group...
|
||||
var dones []chan struct{}
|
||||
for _, group := range groups {
|
||||
dones = dones[:0] // re-use existing storage, if any
|
||||
|
||||
// Push all items in this group onto our queue
|
||||
for curr := range group {
|
||||
done := make(chan struct{})
|
||||
dones = append(dones, done)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case queue <- queueEntry{curr, done}:
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we've started everything, wait for them all
|
||||
// to complete.
|
||||
for _, done := range dones {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we've done this entire group, we can
|
||||
// continue with the next one.
|
||||
}
|
||||
}
|
||||
|
||||
func (g *WorkGraph) runWorker(ctx context.Context, wg *sync.WaitGroup, queue chan queueEntry, saveErr func(error)) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case ent, ok := <-queue:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := g.runEntry(ctx, ent); err != nil {
|
||||
saveErr(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *WorkGraph) runEntry(ctx context.Context, ent queueEntry) (retErr error) {
|
||||
defer close(ent.done)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Ensure that we wrap an existing error with %w so errors.Is works
|
||||
switch v := r.(type) {
|
||||
case error:
|
||||
retErr = fmt.Errorf("node %q: caught panic: %w", ent.id, v)
|
||||
default:
|
||||
retErr = fmt.Errorf("node %q: caught panic: %v", ent.id, v)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
node := g.nodes[ent.id]
|
||||
return node.Run(ctx)
|
||||
}
|
||||
|
||||
// Depth-first toplogical sort; used in tests
|
||||
//
|
||||
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
|
||||
func (g *WorkGraph) topoSortDFS() (sorted []string, err error) {
|
||||
const (
|
||||
markTemporary = 1
|
||||
markPermanent = 2
|
||||
)
|
||||
marks := make(map[string]int) // map[node.ID]markType
|
||||
|
||||
var visit func(string) error
|
||||
visit = func(n string) error {
|
||||
// "if n has a permanent mark then"
|
||||
if marks[n] == markPermanent {
|
||||
return nil
|
||||
}
|
||||
// "if n has a temporary mark then"
|
||||
if marks[n] == markTemporary {
|
||||
return ErrCyclic
|
||||
}
|
||||
|
||||
// "mark n with a temporary mark"
|
||||
marks[n] = markTemporary
|
||||
|
||||
// "for each node m with an edge from n to m do"
|
||||
for m := range g.edges.OutgoingNodes(n) {
|
||||
if err := visit(m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// "remove temporary mark from n"
|
||||
// "mark n with a permanent mark"
|
||||
//
|
||||
// NOTE: this is safe because if this node had a temporary
|
||||
// mark, we'd have returned above, and the only thing that adds
|
||||
// a mark to a node is this function.
|
||||
marks[n] = markPermanent
|
||||
|
||||
// "add n to head of L"; note that we append for performance
|
||||
// reasons and reverse later
|
||||
sorted = append(sorted, n)
|
||||
return nil
|
||||
}
|
||||
|
||||
// For all nodes, visit them. From the algorithm description:
|
||||
// while exists nodes without a permanent mark do
|
||||
// select an unmarked node n
|
||||
// visit(n)
|
||||
for nid := range g.nodes {
|
||||
if err := visit(nid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// We appended to the slice for performance reasons; reverse it to get
|
||||
// our final result.
|
||||
slices.Reverse(sorted)
|
||||
return sorted, nil
|
||||
}
|
||||
|
||||
// topoSortKahn runs a variant of Kahn's algorithm for topological sorting,
|
||||
// which not only returns a sort, but provides individual "groups" of nodes
|
||||
// that can be executed concurrently.
|
||||
//
|
||||
// See:
|
||||
// - https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
||||
// - https://stackoverflow.com/a/67267597
|
||||
func (g *WorkGraph) topoSortKahn() (sorted []set.Set[string], err error) {
|
||||
// We mutate the set of edges during this function, so copy it.
|
||||
edges := g.edges.Clone()
|
||||
|
||||
// Create S_0, the set of nodes with no incoming edge
|
||||
s0 := make(set.Set[string])
|
||||
for nid := range g.nodes {
|
||||
if !edges.HasIncoming(nid) {
|
||||
s0.Add(nid)
|
||||
}
|
||||
}
|
||||
|
||||
// Add this set to the returned set of nodes
|
||||
sorted = append(sorted, s0)
|
||||
|
||||
// Repeatedly iterate, starting from the initial set, until we have no
|
||||
// more nodes. The inner loop is essentially Kahn's algorithm.
|
||||
sCurr := s0
|
||||
for {
|
||||
// Initialize the next set
|
||||
sNext := make(set.Set[string])
|
||||
|
||||
// For each node 'n' in the current set...
|
||||
for n := range sCurr {
|
||||
// For each successor 'd' of the current node...
|
||||
for d := range edges.OutgoingNodes(n) {
|
||||
// Remove edge 'n -> d'
|
||||
edges.Remove(n, d)
|
||||
|
||||
// If this node 'd' has no incoming edges, we
|
||||
// can add it to the current set since it can
|
||||
// be processed.
|
||||
if !edges.HasIncoming(d) {
|
||||
sNext.Add(d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the current set is non-empty, then append it to the list
|
||||
// of returned sets, make it the current set, and continue.
|
||||
// Otherwise, we're done.
|
||||
if len(sNext) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
sorted = append(sorted, sNext)
|
||||
sCurr = sNext
|
||||
}
|
||||
|
||||
if edges.Len() > 0 {
|
||||
return nil, ErrCyclic
|
||||
}
|
||||
return sorted, nil
|
||||
}
|
||||
|
||||
// Graphviz prints a basic Graphviz representation of the WorkGraph. This is
|
||||
// primarily useful for debugging.
|
||||
func (g *WorkGraph) Graphviz() string {
|
||||
var buf strings.Builder
|
||||
buf.WriteString("digraph workgraph {\n")
|
||||
for from, edges := range g.edges.outgoing {
|
||||
for to := range edges {
|
||||
fmt.Fprintf(&buf, "\t%s -> %s;\n", from, to)
|
||||
}
|
||||
}
|
||||
buf.WriteString("}")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// edgeList is a helper type that is used to maintain a set of edges, tracking
|
||||
// both incoming and outgoing edges for a given node.
|
||||
type edgeList[K comparable] struct {
|
||||
incoming map[K]set.Set[K] // for edge A -> B, keyed by B
|
||||
outgoing map[K]set.Set[K] // for edge A -> B, keyed by A
|
||||
}
|
||||
|
||||
func newEdgeList[K comparable]() edgeList[K] {
|
||||
return edgeList[K]{
|
||||
incoming: make(map[K]set.Set[K]),
|
||||
outgoing: make(map[K]set.Set[K]),
|
||||
}
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) Clone() edgeList[K] {
|
||||
ret := edgeList[K]{
|
||||
incoming: make(map[K]set.Set[K], len(el.incoming)),
|
||||
outgoing: make(map[K]set.Set[K], len(el.outgoing)),
|
||||
}
|
||||
for k, v := range el.incoming {
|
||||
ret.incoming[k] = v.Clone()
|
||||
}
|
||||
for k, v := range el.outgoing {
|
||||
ret.outgoing[k] = v.Clone()
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) Len() int {
|
||||
i := 0
|
||||
for _, set := range el.incoming {
|
||||
i += set.Len()
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) Add(from, to K) {
|
||||
if _, found := el.incoming[to]; !found {
|
||||
el.incoming[to] = make(set.Set[K])
|
||||
}
|
||||
if _, found := el.outgoing[from]; !found {
|
||||
el.outgoing[from] = make(set.Set[K])
|
||||
}
|
||||
|
||||
el.incoming[to].Add(from)
|
||||
el.outgoing[from].Add(to)
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) Remove(from, to K) {
|
||||
if m, ok := el.incoming[to]; ok {
|
||||
delete(m, from)
|
||||
}
|
||||
if m, ok := el.outgoing[from]; ok {
|
||||
delete(m, to)
|
||||
}
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) HasIncoming(id K) bool {
|
||||
return el.incoming[id].Len() > 0
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) HasOutgoing(id K) bool {
|
||||
return el.outgoing[id].Len() > 0
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) Exists(from, to K) bool {
|
||||
return el.outgoing[from].Contains(to)
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) IncomingNodes(id K) set.Set[K] {
|
||||
return el.incoming[id]
|
||||
}
|
||||
|
||||
func (el *edgeList[K]) OutgoingNodes(id K) set.Set[K] {
|
||||
return el.outgoing[id]
|
||||
}
|
||||
353
util/workgraph/workgraph_test.go
Normal file
353
util/workgraph/workgraph_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package workgraph
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
func debugGraph(tb testing.TB, g *WorkGraph) {
|
||||
before := g.Graphviz()
|
||||
tb.Cleanup(func() {
|
||||
if !tb.Failed() {
|
||||
return
|
||||
}
|
||||
|
||||
after := g.Graphviz()
|
||||
tb.Logf("graphviz at start of test:\n%s", before)
|
||||
tb.Logf("graphviz at end of test:\n%s", after)
|
||||
})
|
||||
}
|
||||
|
||||
func makeTestGraph(tb testing.TB) *WorkGraph {
|
||||
logFunc := func(s string) func(context.Context) error {
|
||||
return func(_ context.Context) error {
|
||||
tb.Log(s)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
makeNode := func(s string) Node {
|
||||
return NodeFunc(s, logFunc(s+" called"))
|
||||
}
|
||||
withDeps := func(ss ...string) *AddNodeOpts {
|
||||
return &AddNodeOpts{Dependencies: ss}
|
||||
}
|
||||
|
||||
g := NewWorkGraph()
|
||||
|
||||
// Ensure we have at least 2 concurrent goroutines
|
||||
g.Concurrency = runtime.GOMAXPROCS(-1)
|
||||
if g.Concurrency < 2 {
|
||||
g.Concurrency = 2
|
||||
}
|
||||
|
||||
n1 := makeNode("one")
|
||||
n2 := makeNode("two")
|
||||
n3 := makeNode("three")
|
||||
n4 := makeNode("four")
|
||||
n5 := makeNode("five")
|
||||
n6 := makeNode("six")
|
||||
|
||||
must.Do(g.AddNode(n1, nil)) // can execute first
|
||||
must.Do(g.AddNode(n2, nil)) // can execute first
|
||||
must.Do(g.AddNode(n3, withDeps("one")))
|
||||
must.Do(g.AddNode(n4, withDeps("one", "two")))
|
||||
must.Do(g.AddNode(n5, withDeps("one")))
|
||||
must.Do(g.AddNode(n6, withDeps("four", "five")))
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
func TestWorkGraph(t *testing.T) {
|
||||
g := makeTestGraph(t)
|
||||
debugGraph(t, g)
|
||||
|
||||
if err := g.Run(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkGroup_Error(t *testing.T) {
|
||||
g := NewWorkGraph()
|
||||
|
||||
terr := errors.New("test error")
|
||||
|
||||
returnsErr := func(_ context.Context) error { return terr }
|
||||
notCalled := func(_ context.Context) error { panic("unused") }
|
||||
|
||||
n1 := NodeFunc("one", returnsErr)
|
||||
n2 := NodeFunc("two", notCalled)
|
||||
n3 := NodeFunc("three", notCalled)
|
||||
|
||||
must.Do(g.AddNode(n1, nil))
|
||||
must.Do(g.AddNode(n2, &AddNodeOpts{Dependencies: []string{"one"}}))
|
||||
must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}}))
|
||||
|
||||
err := g.Run(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("wanted non-nil error")
|
||||
}
|
||||
if !errors.Is(err, terr) {
|
||||
t.Errorf("got %v, want %v", err, terr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkGroup_HandlesPanic(t *testing.T) {
|
||||
g := NewWorkGraph()
|
||||
|
||||
terr := errors.New("test error")
|
||||
n1 := NodeFunc("one", func(_ context.Context) error { panic(terr) })
|
||||
|
||||
must.Do(g.AddNode(n1, nil))
|
||||
err := g.Run(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("wanted non-nil error")
|
||||
}
|
||||
if !errors.Is(err, terr) {
|
||||
t.Errorf("got %v, want %v", err, terr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkGroup_Cancellation(t *testing.T) {
|
||||
g := NewWorkGraph()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var running atomic.Int64
|
||||
blocks := func(ctx context.Context) error {
|
||||
running.Add(1)
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
n1 := NodeFunc("one", blocks)
|
||||
n2 := NodeFunc("two", blocks)
|
||||
n3 := NodeFunc("three", blocks)
|
||||
|
||||
must.Do(g.AddNode(n1, nil))
|
||||
must.Do(g.AddNode(n2, nil))
|
||||
|
||||
// Ensure that we have a node with dependencies that's also waiting
|
||||
// since we want to verify that the queue publisher also properly
|
||||
// handles context cancellation.
|
||||
must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}}))
|
||||
|
||||
// call Run in a goroutine since it blocks
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- g.Run(ctx)
|
||||
}()
|
||||
|
||||
// after all goroutines are running, cancel the context to unblock
|
||||
for running.Load() != 2 {
|
||||
// wait
|
||||
}
|
||||
cancel()
|
||||
err := <-errCh
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("wanted non-nil error")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("got %v, want %v", err, context.Canceled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopoSortDFS(t *testing.T) {
|
||||
g := makeTestGraph(t)
|
||||
debugGraph(t, g)
|
||||
|
||||
sorted, err := g.topoSortDFS()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("DFS topological sort: %v", sorted)
|
||||
|
||||
validateTopologicalSortDFS(t, g, sorted)
|
||||
}
|
||||
|
||||
func validateTopologicalSortDFS(tb testing.TB, g *WorkGraph, order []string) {
|
||||
// A valid ordering is any one where a node ID later in the list does
|
||||
// not depend on a node ID earlier in the list.
|
||||
for i, node := range order {
|
||||
for j := 0; j < i; j++ {
|
||||
if g.edges.Exists(node, order[j]) {
|
||||
tb.Errorf("invalid edge: %v [%d] -> %v [%d]", node, i, order[j], j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopoSortKahn(t *testing.T) {
|
||||
g := makeTestGraph(t)
|
||||
debugGraph(t, g)
|
||||
|
||||
groups, err := g.topoSortKahn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("grouped topological sort: %v", groups)
|
||||
|
||||
validateTopologicalSortKahn(t, g, groups)
|
||||
}
|
||||
|
||||
func validateTopologicalSortKahn(tb testing.TB, g *WorkGraph, groups []set.Set[string]) {
|
||||
// A valid ordering is any one where a node ID later in the list does
|
||||
// not depend on a node ID earlier in the list.
|
||||
prev := make(map[string]bool)
|
||||
for i, group := range groups {
|
||||
for node := range group {
|
||||
for m := range prev {
|
||||
if g.edges.Exists(node, m) {
|
||||
tb.Errorf("group[%d]: invalid edge: %v -> %v", i, node, m)
|
||||
}
|
||||
}
|
||||
prev[node] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that our topologically sorted groups contain all nodes.
|
||||
for nid := range g.nodes {
|
||||
if !prev[nid] {
|
||||
tb.Errorf("topological sort missing node %v", nid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzTopSortKahn(f *testing.F) {
|
||||
// We can't pass a map[string][]string (or similar) into a fuzz
|
||||
// function, so instead let's create test data by using a combination
|
||||
// of 'n' nodes and an adjacency matrix of edges from node to node.
|
||||
//
|
||||
// We then need to filter this adjacency matrix in the Fuzz function,
|
||||
// since the fuzzer doesn't distinguish between "invalid fuzz inputs
|
||||
// due to logic bugs", and "invalid fuzz data that causes a real
|
||||
// error".
|
||||
f.Add(
|
||||
10, // number of nodes
|
||||
[]byte{
|
||||
1, 0, // 1 depends on 0
|
||||
6, 2, // 6 depends on 2
|
||||
9, 8, // 9 depends on 8
|
||||
},
|
||||
)
|
||||
f.Fuzz(func(t *testing.T, numNodes int, edges []byte) {
|
||||
g := createGraphFromFuzzInput(t, numNodes, edges)
|
||||
if g == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// This should not error
|
||||
groups, err := g.topoSortKahn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validateTopologicalSortKahn(t, g, groups)
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzTopSortDFS(f *testing.F) {
|
||||
// We can't pass a map[string][]string (or similar) into a fuzz
|
||||
// function, so instead let's create test data by using a combination
|
||||
// of 'n' nodes and an adjacency matrix of edges from node to node.
|
||||
//
|
||||
// We then need to filter this adjacency matrix in the Fuzz function,
|
||||
// since the fuzzer doesn't distinguish between "invalid fuzz inputs
|
||||
// due to logic bugs", and "invalid fuzz data that causes a real
|
||||
// error".
|
||||
f.Add(
|
||||
10, // number of nodes
|
||||
[]byte{
|
||||
1, 0, // 1 depends on 0
|
||||
6, 2, // 6 depends on 2
|
||||
9, 8, // 9 depends on 8
|
||||
},
|
||||
)
|
||||
f.Fuzz(func(t *testing.T, numNodes int, edges []byte) {
|
||||
g := createGraphFromFuzzInput(t, numNodes, edges)
|
||||
if g == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// This should not error
|
||||
sorted, err := g.topoSortDFS()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validateTopologicalSortDFS(t, g, sorted)
|
||||
})
|
||||
}
|
||||
|
||||
func createGraphFromFuzzInput(tb testing.TB, numNodes int, edges []byte) *WorkGraph {
|
||||
nodeName := func(i int) string {
|
||||
return fmt.Sprintf("node-%d", i)
|
||||
}
|
||||
|
||||
filterAdjacencyMatrix := func(numNodes int, edges []byte) map[string][]string {
|
||||
deps := make(map[string][]string)
|
||||
for i := 0; i < len(edges); i += 2 {
|
||||
node, dep := int(edges[i]), int(edges[i+1])
|
||||
if node >= numNodes || dep >= numNodes {
|
||||
// invalid node
|
||||
continue
|
||||
}
|
||||
if node == dep {
|
||||
// can't depend on self
|
||||
continue
|
||||
}
|
||||
|
||||
// We add nodes in incrementing order (0, 1, 2, etc.),
|
||||
// so an edge can't point 'forward' or it'll fail to be
|
||||
// added.
|
||||
if dep > node {
|
||||
continue
|
||||
}
|
||||
|
||||
nn := nodeName(node)
|
||||
deps[nn] = append(deps[nn], nodeName(dep))
|
||||
}
|
||||
return deps
|
||||
}
|
||||
|
||||
// Constrain the number of nodes
|
||||
if numNodes <= 0 || numNodes > 1000 {
|
||||
return nil
|
||||
}
|
||||
// Must have pairs of edges (from, to)
|
||||
if len(edges)%2 != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert list of edges into list of dependencies
|
||||
deps := filterAdjacencyMatrix(numNodes, edges)
|
||||
if len(deps) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Actually create graph.
|
||||
g := NewWorkGraph()
|
||||
doNothing := func(context.Context) error { return nil }
|
||||
for i := 0; i < numNodes; i++ {
|
||||
nn := nodeName(i)
|
||||
node := NodeFunc(nn, doNothing)
|
||||
if err := g.AddNode(node, &AddNodeOpts{
|
||||
Dependencies: deps[nn],
|
||||
}); err != nil {
|
||||
tb.Error(err) // shouldn't error after we filtered out bad edges above
|
||||
}
|
||||
}
|
||||
if tb.Failed() {
|
||||
return nil
|
||||
}
|
||||
return g
|
||||
}
|
||||
Reference in New Issue
Block a user