Compare commits

...

5 Commits

Author SHA1 Message Date
Marwan Sulaiman
fb13df273a Ensure ticker stop 2023-05-25 13:52:41 -04:00
Marwan Sulaiman
bfda034655 Inline init as it's only called in one place 2023-05-25 10:04:16 -04:00
Marwan Sulaiman
9bbb2b0911 Make it even simpler by removing Init 2023-05-24 23:59:44 -04:00
Marwan Sulaiman
4b807ca95e portlist: refactor Poller to simplify external API
This change makes it so that Poller has only two external methods: Run and Close.
The Poller becomes harder to get wrong by external packages because Close is a no-op
if Run is not called and Run must be called to obtain updates and check an init-error.

The only 3rd method is an exported "Init" that preserves an old behavior whereby you can
pre-fetch a list of ports and check for errors _without_ starting Run's ticker.
However, I'd be happy to remove that if we can afford it.

Fixes #8171

Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
2023-05-24 19:37:32 -04:00
Marwan Sulaiman
27ea062078 portlist: remove NewPoller constructor
This is a follow up on PR #8172 and a breaking change that removes NewPoller.
The issue with the previous PR was that NewPoller immediately initializes the underlying os implementation
and therefore setting IncludeLocalhost as an exported field happened too late and cannot happen early enough.
Using the zero value of Poller was also not an option from outside of the package because we need to set initial
private fields

Fixes #8171

Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
2023-05-24 18:17:29 -04:00
3 changed files with 136 additions and 103 deletions

View File

@@ -146,9 +146,8 @@ type LocalBackend struct {
backendLogID logid.PublicID
unregisterNetMon func()
unregisterHealthWatch func()
portpoll *portlist.Poller // may be nil
portpollOnce sync.Once // guards starting readPoller
gotPortPollRes chan struct{} // closed upon first readPoller result
portpollOnce sync.Once // guards starting readPoller
gotPortPollRes chan struct{} // closed upon first readPoller result
newDecompressor func() (controlclient.Decompressor, error)
varRoot string // or empty if SetVarRoot never called
logFlushFunc func() // or nil if SetLogFlusher wasn't called
@@ -292,10 +291,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
osshare.SetFileSharingEnabled(false, logf)
ctx, cancel := context.WithCancel(context.Background())
portpoll, err := portlist.NewPoller()
if err != nil {
logf("skipping portlist: %s", err)
}
b := &LocalBackend{
ctx: ctx,
@@ -310,7 +305,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
pm: pm,
backendLogID: logID,
state: ipn.NoState,
portpoll: portpoll,
em: newExpiryManager(logf),
gotPortPollRes: make(chan struct{}),
loginFlags: loginFlags,
@@ -1375,26 +1369,32 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
b.updateFilterLocked(nil, ipn.PrefsView{})
b.mu.Unlock()
if b.portpoll != nil {
b.portpollOnce.Do(func() {
go b.portpoll.Run(b.ctx)
go b.readPoller()
b.portpollOnce.Do(func() {
var p portlist.Poller
updates, err := p.Run(b.ctx)
if err != nil {
b.logf("skipping portlist: %s", err)
return
}
go func() {
defer p.Close()
b.readPoller(updates)
}()
// Give the poller a second to get results to
// prevent it from restarting our map poll
// HTTP request (via doSetHostinfoFilterServices >
// cli.SetHostinfo). In practice this is very quick.
t0 := time.Now()
timer := time.NewTimer(time.Second)
select {
case <-b.gotPortPollRes:
b.logf("[v1] got initial portlist info in %v", time.Since(t0).Round(time.Millisecond))
timer.Stop()
case <-timer.C:
b.logf("timeout waiting for initial portlist")
}
})
}
// Give the poller a second to get results to
// prevent it from restarting our map poll
// HTTP request (via doSetHostinfoFilterServices >
// cli.SetHostinfo). In practice this is very quick.
t0 := time.Now()
timer := time.NewTimer(time.Second)
select {
case <-b.gotPortPollRes:
b.logf("[v1] got initial portlist info in %v", time.Since(t0).Round(time.Millisecond))
timer.Stop()
case <-timer.C:
b.logf("timeout waiting for initial portlist")
}
})
discoPublic := b.e.DiscoPublicKey()
@@ -1811,15 +1811,15 @@ func dnsMapsEqual(new, old *netmap.NetworkMap) bool {
// readPoller is a goroutine that receives service lists from
// b.portpoll and propagates them into the controlclient's HostInfo.
func (b *LocalBackend) readPoller() {
n := 0
for {
ports, ok := <-b.portpoll.Updates()
if !ok {
return
func (b *LocalBackend) readPoller(updates chan portlist.Update) {
firstResults := true
for update := range updates {
if update.Error != nil {
b.logf("error polling os ports: %v", update.Error)
return // preserve old behavior, though we can just continue and try again?
}
sl := []tailcfg.Service{}
for _, p := range ports {
for _, p := range update.List {
s := tailcfg.Service{
Proto: tailcfg.ServiceProto(p.Proto),
Port: p.Port,
@@ -1840,8 +1840,8 @@ func (b *LocalBackend) readPoller() {
b.doSetHostinfoFilterServices(hi)
n++
if n == 1 {
if firstResults {
firstResults = false
close(b.gotPortPollRes)
}
}

View File

@@ -9,6 +9,7 @@ package portlist
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
@@ -29,7 +30,14 @@ type Poller struct {
// This field should only be changed before calling Run.
IncludeLocalhost bool
c chan List // unbuffered
// Interval sets the polling interval for probing the underlying
// os for port updates.
Interval time.Duration
c chan Update // unbuffered
initOnce sync.Once // guards init of private fields
initErr error
// os, if non-nil, is an OS-specific implementation of the portlist getting
// code. When non-nil, it's responsible for getting the complete list of
@@ -37,8 +45,7 @@ type Poller struct {
// addProcesses is not used.
// A nil values means we don't have code for getting the list on the current
// operating system.
os osImpl
osOnce sync.Once // guards init of os
os osImpl
// closeCtx is the context that's canceled on Close.
closeCtx context.Context
@@ -52,6 +59,19 @@ type Poller struct {
prev List // most recent data, not aliasing scratch
}
// Update is a container for a portlist update event.
// When Poller polls the underlying OS for an update,
// it either returns a new list of open ports,
// or an error that happened in the process.
//
// Note that it is up to the caller to act upon the error,
// such as closing the Poller. Otherwise, the Poller will continue
// to try and get a list for every interval.
type Update struct {
List List
Error error
}
// osImpl is the OS-specific implementation of getting the open listening ports.
type osImpl interface {
Close() error
@@ -71,54 +91,54 @@ var newOSImpl func(includeLocalhost bool) osImpl
var errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS)
// NewPoller returns a new portlist Poller. It returns an error
// if the portlist couldn't be obtained.
func NewPoller() (*Poller, error) {
if debugDisablePortlist() {
return nil, errors.New("portlist disabled by envknob")
}
p := &Poller{
c: make(chan List),
runDone: make(chan struct{}),
}
p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background())
p.osOnce.Do(p.initOSField)
if p.os == nil {
return nil, errUnimplemented
}
// Do one initial poll synchronously so we can return an error
// early.
if pl, err := p.getList(); err != nil {
return nil, err
} else {
p.setPrev(pl)
}
return p, nil
}
func (p *Poller) setPrev(pl List) {
// Make a copy, as the pass in pl slice aliases pl.scratch and we don't want
// that to except to the caller.
p.prev = slices.Clone(pl)
}
func (p *Poller) initOSField() {
if newOSImpl != nil {
p.os = newOSImpl(p.IncludeLocalhost)
}
}
// Updates return the channel that receives port list updates.
// init makes sure the Poller is enabled
// and the undelrying OS implementation is working properly.
//
// The channel is closed when the Poller is closed.
func (p *Poller) Updates() <-chan List { return p.c }
// An error returned from init is non-fatal and means
// that it's been administratively disabled or the underlying
// OS is not implemented.
func (p *Poller) init() error {
if debugDisablePortlist() {
return errors.New("portlist disabled by envknob")
}
if newOSImpl == nil {
return errUnimplemented
}
p.os = newOSImpl(p.IncludeLocalhost)
// Do one initial poll synchronously so we can return an error
// early.
if pl, err := p.getList(); err != nil {
return err
} else {
p.setPrev(pl)
}
if p.Interval == 0 {
p.Interval = pollInterval
}
p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background())
p.c = make(chan Update)
p.runDone = make(chan struct{})
return nil
}
// Close closes the Poller.
// Run will return with a nil error.
func (p *Poller) Close() error {
if p.os == nil {
return nil
}
p.closeCtxCancel()
<-p.runDone
<-p.runDone // if caller of Close never called Run, this can hang.
if p.os != nil {
p.os.Close()
}
@@ -126,14 +146,14 @@ func (p *Poller) Close() error {
}
// send sends pl to p.c and returns whether it was successfully sent.
func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) {
func (p *Poller) send(ctx context.Context, pl List, plErr error) (sent bool) {
select {
case p.c <- pl:
return true, nil
case p.c <- Update{pl, plErr}:
return true
case <-ctx.Done():
return false, ctx.Err()
return false
case <-p.closeCtx.Done():
return false, nil
return false
}
}
@@ -141,19 +161,28 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) {
// is done, or the Close is called.
//
// Run may only be called once.
func (p *Poller) Run(ctx context.Context) error {
tick := time.NewTicker(pollInterval)
defer tick.Stop()
return p.runWithTickChan(ctx, tick.C)
func (p *Poller) Run(ctx context.Context) (chan Update, error) {
p.initOnce.Do(func() {
p.initErr = p.init()
})
if p.initErr != nil {
return nil, fmt.Errorf("error initializing poller: %w", p.initErr)
}
tick := time.NewTicker(p.Interval)
go func() {
defer tick.Stop()
p.runWithTickChan(ctx, tick.C)
}()
return p.c, nil
}
func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) error {
func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) {
defer close(p.runDone)
defer close(p.c)
// Send out the pre-generated initial value.
if sent, err := p.send(ctx, p.prev); !sent {
return err
if sent := p.send(ctx, p.prev, nil); !sent {
return
}
for {
@@ -161,28 +190,27 @@ func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time)
case <-tickChan:
pl, err := p.getList()
if err != nil {
return err
if !p.send(ctx, nil, err) {
return
}
continue
}
if pl.equal(p.prev) {
continue
}
p.setPrev(pl)
if sent, err := p.send(ctx, p.prev); !sent {
return err
if !p.send(ctx, p.prev, nil) {
return
}
case <-ctx.Done():
return ctx.Err()
return
case <-p.closeCtx.Done():
return nil
return
}
}
}
func (p *Poller) getList() (List, error) {
if debugDisablePortlist() {
return nil, nil
}
p.osOnce.Do(p.initOSField)
var err error
p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0])
return p.scratch, err

View File

@@ -17,6 +17,7 @@ func TestGetList(t *testing.T) {
tstest.ResourceCheck(t)
var p Poller
p.os = newOSImpl(false)
pl, err := p.getList()
if err != nil {
t.Fatal(err)
@@ -38,6 +39,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) {
ta := ln.Addr().(*net.TCPAddr)
port := ta.Port
var p Poller
p.os = newOSImpl(false)
pl, err := p.getList()
if err != nil {
t.Fatal(err)
@@ -51,7 +53,7 @@ func TestIgnoreLocallyBoundPorts(t *testing.T) {
func TestChangesOverTime(t *testing.T) {
var p Poller
p.IncludeLocalhost = true
p.os = newOSImpl(true)
get := func(t *testing.T) []Port {
t.Helper()
s, err := p.getList()
@@ -176,7 +178,8 @@ func TestEqualLessThan(t *testing.T) {
}
func TestPoller(t *testing.T) {
p, err := NewPoller()
var p Poller
err := p.init()
if err != nil {
t.Skipf("not running test: %v", err)
}
@@ -189,10 +192,14 @@ func TestPoller(t *testing.T) {
go func() {
defer wg.Done()
for pl := range p.Updates() {
for update := range p.c {
if update.Error != nil {
t.Errorf("error polling ports: %v", err)
return
}
// Look at all the pl slice memory to maximize
// chance of race detector seeing violations.
for _, v := range pl {
for _, v := range update.List {
if v == (Port{}) {
// Force use
panic("empty port")
@@ -208,9 +215,7 @@ func TestPoller(t *testing.T) {
tick := make(chan time.Time, 16)
go func() {
defer wg.Done()
if err := p.runWithTickChan(context.Background(), tick); err != nil {
t.Error("runWithTickChan:", err)
}
p.runWithTickChan(context.Background(), tick)
}()
for i := 0; i < 10; i++ {
ln, err := net.Listen("tcp", ":0")