Compare commits

...

2 Commits

Author SHA1 Message Date
James Tucker
1621f3aa6c cmd/tswrap,portlist: switch tswrap to portlist, add pid to portlist
Signed-off-by: James Tucker <james@tailscale.com>
2022-11-04 23:01:25 -07:00
David Anderson
2a619d3bcf cmd/tswrap: command to run a child process and make it accessible over Tailscale.
Signed-off-by: David Anderson <danderson@tailscale.com>
2022-11-04 20:13:16 -07:00
6 changed files with 333 additions and 6 deletions

314
cmd/tswrap/main.go Normal file
View File

@@ -0,0 +1,314 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The tswrap binary runs a child process and makes it accessible over
// Tailscale.
package main
import (
"context"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"os/exec"
"os/signal"
"sort"
"strconv"
"syscall"
"time"
"tailscale.com/client/tailscale"
"tailscale.com/ipn/ipnstate"
"tailscale.com/ipn/store/mem"
"tailscale.com/portlist"
"tailscale.com/syncs"
"tailscale.com/tsnet"
"tailscale.com/types/logger"
)
var (
tsDir = flag.String("state-dir", "", "Directory in which to store the Tailscale auth state")
verbose = flag.Bool("verbose", false, "Output tailscaled logs to stderr")
)
func main() {
sigch := make(chan os.Signal, 1)
signal.Notify(sigch, os.Interrupt, syscall.SIGTERM)
flag.Parse()
argv := flag.Args()
if len(argv) < 2 {
log.Fatalf("Usage: %s tailscale-host:port child-cmd...", os.Args[0])
}
p := proxy{
ListenAddr: argv[0],
Command: argv[1:],
AuthKey: os.Getenv("TS_AUTHKEY"),
Dir: *tsDir,
Verbose: *verbose,
}
if err := p.Start(); err != nil {
log.Fatalf("Failed to start tswrap: %v", err)
}
go func() {
<-sigch
p.Stop()
}()
p.Wait()
}
type proxy struct {
ListenAddr string
Command []string
AuthKey string
Dir string
Verbose bool
shutdownCtx context.Context
startShutdown context.CancelFunc
srv *tsnet.Server
client *tailscale.LocalClient
ln net.Listener
cmd *exec.Cmd
ports syncs.AtomicValue[[]int]
}
func (p *proxy) Start() error {
host, port, err := net.SplitHostPort(p.ListenAddr)
if err != nil {
return fmt.Errorf("parsing %q: %v", p.ListenAddr, err)
}
if _, err := strconv.Atoi(port); err != nil {
return fmt.Errorf("parsing port number %q: %v", port, err)
}
if p.Dir == "" && p.AuthKey == "" {
return errors.New("must provide either a TS_AUTHKEY or a state storage dir")
}
p.srv = &tsnet.Server{
Hostname: host,
AuthKey: p.AuthKey,
Logf: logger.Discard,
Dir: p.Dir,
}
if p.Dir == "" {
p.srv.Store = new(mem.Store)
p.srv.Ephemeral = true
}
if p.Verbose {
p.srv.Logf = log.Printf
}
p.shutdownCtx, p.startShutdown = context.WithCancel(context.Background())
p.client, err = p.srv.LocalClient()
if err != nil {
return fmt.Errorf("starting tsnet server failed: %v", err)
}
var (
looped = false
authURLShown = false
status *ipnstate.Status
)
loginLoop:
for {
if looped {
time.Sleep(100 * time.Millisecond)
}
looped = true
status, err = p.client.Status(context.Background())
if err != nil {
return fmt.Errorf("getting tsnet status: %v", err)
}
switch status.BackendState {
case "Running":
if status.Self == nil || status.Self.DNSName == "" {
// No known DNS name yet, keep going
continue
}
break loginLoop
case "NeedsLogin":
if status.AuthURL != "" && p.AuthKey != "" {
return errors.New("failed to auth with provided authkey")
}
if status.AuthURL != "" && !authURLShown {
log.Printf("To log into Tailscale, please visit: %s", status.AuthURL)
authURLShown = true
}
default:
// Just keep trying, eventually we should get into either
// NeedsLogin or Running.
}
}
addr := ":" + port
p.ln, err = p.srv.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("tailscale listen on %q: %v", addr, err)
}
log.Printf("Listening on %s:%s", status.Self.DNSName, port)
p.cmd = exec.Command(p.Command[0], p.Command[1:]...)
p.cmd.Stdin = os.Stdin
p.cmd.Stdout = os.Stdout
p.cmd.Stderr = os.Stderr
if err := p.cmd.Start(); err != nil {
return fmt.Errorf("starting child failed: %v", err)
}
go p.listen()
go p.waitForChildExit()
go p.monitorChildPorts()
return nil
}
func (p *proxy) Stop() {
p.startShutdown()
}
func (p *proxy) Wait() {
<-p.shutdownCtx.Done()
p.cmd.Process.Signal(syscall.SIGTERM)
p.ln.Close()
if p.srv.Ephemeral {
p.client.Logout(context.Background())
}
}
func (p *proxy) listen() {
for {
conn, err := p.ln.Accept()
if errors.Is(err, net.ErrClosed) {
return
} else if err != nil {
log.Printf("accept: %v", err)
p.startShutdown()
return
}
go func() {
if err := p.proxy(conn); err != nil {
log.Printf("proxying %s: %v", conn.RemoteAddr(), err)
}
}()
}
}
func (p *proxy) proxy(conn net.Conn) error {
defer conn.Close()
ports, err := p.getPorts()
if err != nil {
return err
}
if len(ports) > 1 {
log.Printf("warning: multiple listening ports found on child, proxying to lowest one (%d)", ports[0])
}
prox, err := net.Dial("tcp", net.JoinHostPort("localhost", strconv.Itoa(ports[0])))
if err != nil {
return fmt.Errorf("dialing child port %d: %v", ports[0], err)
}
defer prox.Close()
errc := make(chan error, 1)
go proxyCopy(errc, conn, prox)
go proxyCopy(errc, prox, conn)
<-errc
return nil
}
func (p *proxy) getPorts() ([]int, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for ctx.Err() == nil {
if ports := p.ports.Load(); len(ports) > 0 {
return ports, nil
}
time.Sleep(100 * time.Millisecond)
}
return nil, errors.New("timed out waiting for child listening ports")
}
func (p *proxy) waitForChildExit() {
if err := p.cmd.Wait(); err != nil {
log.Printf("child exited with error: %v", err)
} else {
log.Printf("child exited, shutting down")
}
p.startShutdown()
}
func (p *proxy) monitorChildPorts() {
for p.shutdownCtx.Err() == nil {
ports, err := portsOfCmd(p.cmd)
if err == nil {
p.ports.Store(ports)
}
select {
case <-time.After(time.Second):
case <-p.shutdownCtx.Done():
return
}
}
}
func proxyCopy(errc chan<- error, dst, src net.Conn) {
// TODO: still need the unwrap hack from tcpproxy? Or is io.Copy
// smart now?
_, err := io.Copy(dst, src)
if err != nil {
log.Print(err)
}
errc <- err
}
func portsOfCmd(cmd *exec.Cmd) (ports []int, err error) {
if cmd == nil || cmd.Process == nil {
return nil, errors.New("no process")
}
pid := cmd.Process.Pid
poller, err := portlist.NewPoller()
if err != nil {
return nil, fmt.Errorf("creating port poller: %w", err)
}
defer poller.Close()
// TODO(raggi): timeout?
go poller.Run(context.Background())
c := poller.Updates()
for list := range c {
for _, p := range list {
if p.Pid == pid {
ports = append(ports, int(p.Port))
}
}
if len(ports) > 0 {
break
}
}
if len(ports) == 0 {
return nil, errors.New("no listening ports found")
}
sort.Ints(ports)
return ports, nil
}

View File

@@ -54,12 +54,12 @@ udp46 0 0 *.146 *.*
func TestParsePortsNetstat(t *testing.T) {
want := List{
Port{"tcp", 23, ""},
Port{"tcp", 24, ""},
Port{"udp", 104, ""},
Port{"udp", 106, ""},
Port{"udp", 146, ""},
Port{"tcp", 8185, ""}, // but not 8186, 8187, 8188 on localhost
Port{"tcp", 23, "", 0},
Port{"tcp", 24, "", 0},
Port{"udp", 104, "", 0},
Port{"udp", 106, "", 0},
Port{"udp", 146, "", 0},
Port{"tcp", 8185, "", 0}, // but not 8186, 8187, 8188 on localhost
}
pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)))

View File

@@ -19,6 +19,7 @@ type Port struct {
Proto string // "tcp" or "udp"
Port uint16 // port number
Process string // optional process name, if found
Pid int // process id, if known
}
// List is a list of Ports.

View File

@@ -40,6 +40,7 @@ type linuxImpl struct {
type portMeta struct {
port Port
pid int
keep bool
needsProcName bool
}
@@ -326,6 +327,9 @@ func (li *linuxImpl) findProcessNames(need map[string]*portMeta) error {
}
argv := strings.Split(strings.TrimSuffix(string(bs), "\x00"), "\x00")
if p, err := strconv.Atoi(pid); err == nil {
pe.pid = p
}
pe.port.Process = argvSubject(argv...)
pe.needsProcName = false
delete(need, string(targetBuf[:n]))

View File

@@ -12,6 +12,7 @@ import (
"fmt"
"log"
"os/exec"
"strconv"
"strings"
"sync/atomic"
"time"
@@ -162,6 +163,7 @@ func (im *macOSImpl) addProcesses() error {
im.br.Reset(outPipe)
var cmd, proto string
var pid int
for {
line, err := im.br.ReadBytes('\n')
if err != nil {
@@ -176,6 +178,10 @@ func (im *macOSImpl) addProcesses() error {
// starting a new process
cmd = ""
proto = ""
pid = 0
if p, err := strconv.Atoi(string(val)); err == nil {
pid = p
}
case 'c':
cmd = string(val) // TODO(bradfitz): avoid garbage; cache process names between runs?
case 'P':
@@ -194,6 +200,7 @@ func (im *macOSImpl) addProcesses() error {
switch {
case m != nil:
m.port.Process = cmd
m.port.Pid = pid
default:
// ignore: processes and ports come and go
}

View File

@@ -82,6 +82,7 @@ func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) {
Proto: "tcp",
Port: e.Local.Port(),
Process: procNameOfPid(e.Pid),
Pid: e.Pid,
},
}
im.known[fp] = pm