net/udprelay: start of UDP relay server implementation
Updates tailscale/corp#27101 Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
129
disco/disco.go
129
disco/disco.go
@@ -41,9 +41,12 @@ const NonceLen = 24
|
||||
type MessageType byte
|
||||
|
||||
const (
|
||||
TypePing = MessageType(0x01)
|
||||
TypePong = MessageType(0x02)
|
||||
TypeCallMeMaybe = MessageType(0x03)
|
||||
TypePing = MessageType(0x01)
|
||||
TypePong = MessageType(0x02)
|
||||
TypeCallMeMaybe = MessageType(0x03)
|
||||
TypeBindUDPRelayEndpoint = MessageType(0x04)
|
||||
TypeBindUDPRelayEndpointChallenge = MessageType(0x05)
|
||||
TypeBindUDPRelayEndpointAnswer = MessageType(0x06)
|
||||
)
|
||||
|
||||
const v0 = byte(0)
|
||||
@@ -83,6 +86,12 @@ func Parse(p []byte) (Message, error) {
|
||||
return parsePong(ver, p)
|
||||
case TypeCallMeMaybe:
|
||||
return parseCallMeMaybe(ver, p)
|
||||
case TypeBindUDPRelayEndpoint:
|
||||
return parseBindUDPRelayEndpoint(ver, p)
|
||||
case TypeBindUDPRelayEndpointChallenge:
|
||||
return parseBindUDPRelayEndpointChallenge(ver, p)
|
||||
case TypeBindUDPRelayEndpointAnswer:
|
||||
return parseBindUDPRelayEndpointAnswer(ver, p)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type 0x%02x", byte(t))
|
||||
}
|
||||
@@ -266,7 +275,121 @@ func MessageSummary(m Message) string {
|
||||
return fmt.Sprintf("pong tx=%x", m.TxID[:6])
|
||||
case *CallMeMaybe:
|
||||
return "call-me-maybe"
|
||||
case *BindUDPRelayEndpoint:
|
||||
return "bind-udp-relay-endpoint"
|
||||
case *BindUDPRelayEndpointChallenge:
|
||||
return "bind-udp-relay-endpoint-challenge"
|
||||
case *BindUDPRelayEndpointAnswer:
|
||||
return "bind-udp-relay-endpoint-answer"
|
||||
default:
|
||||
return fmt.Sprintf("%#v", m)
|
||||
}
|
||||
}
|
||||
|
||||
// BindUDPRelayHandshakeState represents the state of the 3-way bind handshake
|
||||
// between UDP relay client and UDP relay server. Its potential values include
|
||||
// those for both participants, UDP relay client and UDP relay server. A UDP
|
||||
// relay server implementation can be found in net/udprelay. This is currently
|
||||
// considered experimental.
|
||||
type BindUDPRelayHandshakeState int
|
||||
|
||||
const (
|
||||
// BindUDPRelayHandshakeStateInit represents the initial state prior to any
|
||||
// message being transmitted.
|
||||
BindUDPRelayHandshakeStateInit BindUDPRelayHandshakeState = iota
|
||||
// BindUDPRelayHandshakeStateBindSent is a potential UDP relay client state
|
||||
// once it has transmitted a BindUDPRelayEndpoint message towards a UDP
|
||||
// relay server.
|
||||
BindUDPRelayHandshakeStateBindSent
|
||||
// BindUDPRelayHandshakeStateChallengeSent is a potential UDP relay server
|
||||
// state once it has transmitted a BindUDPRelayEndpointChallenge message
|
||||
// towards a UDP relay client in response to a BindUDPRelayEndpoint message.
|
||||
BindUDPRelayHandshakeStateChallengeSent
|
||||
// BindUDPRelayHandshakeStateAnswerSent is a potential UDP relay client
|
||||
// state once it has transmitted a BindUDPRelayEndpointAnswer message
|
||||
// towards a UDP relay server in response to a BindUDPRelayEndpointChallenge
|
||||
// message.
|
||||
BindUDPRelayHandshakeStateAnswerSent
|
||||
// BindUDPRelayHandshakeStateAnswerReceived is a potential UDP relay server
|
||||
// state once it has received a valid/correct BindUDPRelayEndpointAnswer
|
||||
// message from a UDP relay client in response to a
|
||||
// BindUDPRelayEndpointChallenge message.
|
||||
BindUDPRelayHandshakeStateAnswerReceived
|
||||
)
|
||||
|
||||
// bindUDPRelayEndpointLen is the length of a marshalled BindUDPRelayEndpoint
|
||||
// message, without the message header.
|
||||
const bindUDPRelayEndpointLen = BindUDPRelayEndpointChallengeLen
|
||||
|
||||
// BindUDPRelayEndpoint is the first messaged transmitted from UDP relay client
|
||||
// towards UDP relay server as part of the 3-way bind handshake. It is padded to
|
||||
// match the length of BindUDPRelayEndpointChallenge. This message type is
|
||||
// currently considered experimental and is not yet tied to a
|
||||
// tailcfg.CapabilityVersion.
|
||||
type BindUDPRelayEndpoint struct {
|
||||
padding [bindUDPRelayEndpointLen]byte
|
||||
}
|
||||
|
||||
func (m *BindUDPRelayEndpoint) AppendMarshal(b []byte) []byte {
|
||||
ret, _ := appendMsgHeader(b, TypeBindUDPRelayEndpoint, v0, 0)
|
||||
return ret
|
||||
}
|
||||
|
||||
func parseBindUDPRelayEndpoint(ver uint8, p []byte) (m *BindUDPRelayEndpoint, err error) {
|
||||
m = new(BindUDPRelayEndpoint)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// BindUDPRelayEndpointChallengeLen is the length of a marshalled
|
||||
// BindUDPRelayEndpointChallenge message, without the message header.
|
||||
const BindUDPRelayEndpointChallengeLen = 32
|
||||
|
||||
// BindUDPRelayEndpointChallenge is transmitted from UDP relay server towards
|
||||
// UDP relay client in response to a BindUDPRelayEndpoint message as part of the
|
||||
// 3-way bind handshake. This message type is currently considered experimental
|
||||
// and is not yet tied to a tailcfg.CapabilityVersion.
|
||||
type BindUDPRelayEndpointChallenge struct {
|
||||
Challenge [BindUDPRelayEndpointChallengeLen]byte
|
||||
}
|
||||
|
||||
func (m *BindUDPRelayEndpointChallenge) AppendMarshal(b []byte) []byte {
|
||||
ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointChallenge, v0, BindUDPRelayEndpointChallengeLen)
|
||||
copy(d, m.Challenge[:])
|
||||
return ret
|
||||
}
|
||||
|
||||
func parseBindUDPRelayEndpointChallenge(ver uint8, p []byte) (m *BindUDPRelayEndpointChallenge, err error) {
|
||||
if len(p) < BindUDPRelayEndpointChallengeLen {
|
||||
return nil, errShort
|
||||
}
|
||||
m = new(BindUDPRelayEndpointChallenge)
|
||||
copy(m.Challenge[:], p[:])
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// bindUDPRelayEndpointAnswerLen is the length of a marshalled
|
||||
// BindUDPRelayEndpointAnswer message, without the message header.
|
||||
const bindUDPRelayEndpointAnswerLen = BindUDPRelayEndpointChallengeLen
|
||||
|
||||
// BindUDPRelayEndpointAnswer is transmitted from UDP relay client to UDP relay
|
||||
// server in response to a BindUDPRelayEndpointChallenge message. This message
|
||||
// type is currently considered experimental and is not yet tied to a
|
||||
// tailcfg.CapabilityVersion.
|
||||
type BindUDPRelayEndpointAnswer struct {
|
||||
Answer [bindUDPRelayEndpointAnswerLen]byte
|
||||
}
|
||||
|
||||
func (m *BindUDPRelayEndpointAnswer) AppendMarshal(b []byte) []byte {
|
||||
ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointAnswer, v0, bindUDPRelayEndpointAnswerLen)
|
||||
copy(d, m.Answer[:])
|
||||
return ret
|
||||
}
|
||||
|
||||
func parseBindUDPRelayEndpointAnswer(ver uint8, p []byte) (m *BindUDPRelayEndpointAnswer, err error) {
|
||||
if len(p) < bindUDPRelayEndpointAnswerLen {
|
||||
return nil, errShort
|
||||
}
|
||||
m = new(BindUDPRelayEndpointAnswer)
|
||||
copy(m.Answer[:], p[:])
|
||||
return m, nil
|
||||
}
|
||||
|
||||
510
net/udprelay/server.go
Normal file
510
net/udprelay/server.go
Normal file
@@ -0,0 +1,510 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package udprelay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBindLifetime = time.Second * 5
|
||||
defaultSteadyStateLifetime = time.Minute * 5
|
||||
)
|
||||
|
||||
// Server implements an experimental UDP relay server.
|
||||
type Server struct {
|
||||
// disco keypair used as part of 3-way bind handshake
|
||||
disco key.DiscoPrivate
|
||||
discoPublic key.DiscoPublic
|
||||
|
||||
bindLifetime time.Duration
|
||||
steadyStateLifetime time.Duration
|
||||
|
||||
// addrPorts contains the ip:port pairs returned as candidate server
|
||||
// endpoints in response to an allocation request.
|
||||
addrPorts []netip.AddrPort
|
||||
|
||||
uc *net.UDPConn
|
||||
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
closeCh chan struct{}
|
||||
closed bool
|
||||
|
||||
mu sync.Mutex // guards the following fields
|
||||
lamportID uint64
|
||||
vniPool []uint32 // the pool of available VNIs
|
||||
byVNI map[uint32]*serverEndpoint
|
||||
byDisco map[pairOfDiscoPubKeys]*serverEndpoint
|
||||
}
|
||||
|
||||
// pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via
|
||||
// newPairOfDiscoPubKeys to ensure lexicographical ordering.
|
||||
type pairOfDiscoPubKeys [2]key.DiscoPublic
|
||||
|
||||
func (p pairOfDiscoPubKeys) String() string {
|
||||
return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString())
|
||||
}
|
||||
|
||||
func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys {
|
||||
var pair pairOfDiscoPubKeys
|
||||
cmp := discoA.Compare(discoB)
|
||||
if cmp == 1 {
|
||||
pair[0] = discoB
|
||||
pair[1] = discoA
|
||||
} else {
|
||||
pair[0] = discoA
|
||||
pair[1] = discoB
|
||||
}
|
||||
return pair
|
||||
}
|
||||
|
||||
// ServerEndpoint contains the Server's endpoint details.
|
||||
type ServerEndpoint struct {
|
||||
// ServerDisco is the Server's Disco public key used as part of the 3-way
|
||||
// bind handshake. Server will use the same ServerDisco for its lifetime.
|
||||
// ServerDisco value in combination with LamportID value represents a
|
||||
// unique ServerEndpoint allocation.
|
||||
ServerDisco key.DiscoPublic
|
||||
|
||||
// LamportID is unique and monotonically increasing across ServerEndpoint
|
||||
// allocations. It enables clients to dedup and resolve allocation event
|
||||
// order. Clients may race to allocate on the same Server, and signal
|
||||
// ServerEndpoint details via alternative channels, e.g. DERP. Additionally,
|
||||
// Server.AllocateEndpoint() requests may not result in a new allocation
|
||||
// depending on existing server-side endpoint state. Therefore, where
|
||||
// clients have local, existing state that contains ServerDisco and
|
||||
// LamportID values matching a newly learned endpoint, these can be
|
||||
// considered one and the same. If ServerDisco is equal, but LamportID is
|
||||
// unequal, LamportID comparison determines which ServerEndpoint was
|
||||
// allocated most recently.
|
||||
LamportID uint64
|
||||
|
||||
// AddrPorts are the IP:Port candidate pairs the Server may be reachable
|
||||
// over.
|
||||
AddrPorts []netip.AddrPort
|
||||
|
||||
// VNI (Virtual Network Identifier) is the Geneve header VNI the Server
|
||||
// will use for transmitted packets, and expects for received packets
|
||||
// associated with this endpoint.
|
||||
VNI uint32
|
||||
|
||||
// BindLifetime is amount of time post-allocation the Server will consider
|
||||
// the endpoint active while it has yet to be bound via 3-way bind handshake
|
||||
// from both client parties.
|
||||
BindLifetime time.Duration
|
||||
|
||||
// SteadyStateLifetime is the amount of time post 3-way bind handshake from
|
||||
// both client parties the Server will consider the endpoint active lacking
|
||||
// bidirectional data flow.
|
||||
SteadyStateLifetime time.Duration
|
||||
}
|
||||
|
||||
type serverEndpoint struct {
|
||||
discoPubKeys pairOfDiscoPubKeys
|
||||
discoSharedSecrets [2]key.DiscoShared
|
||||
handeshakeState [2]disco.BindUDPRelayHandshakeState
|
||||
addrPorts [2]netip.AddrPort
|
||||
lastSeen [2]time.Time
|
||||
challenge [2][disco.BindUDPRelayEndpointChallengeLen]byte
|
||||
lamportID uint64
|
||||
vni uint32
|
||||
allocatedAt time.Time
|
||||
}
|
||||
|
||||
func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, uw udpWriter, serverDisco key.DiscoPublic) {
|
||||
handshakeState := e.handeshakeState[senderIndex]
|
||||
if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived {
|
||||
// this sender is already bound
|
||||
return
|
||||
}
|
||||
switch discoMsg := discoMsg.(type) {
|
||||
case *disco.BindUDPRelayEndpoint:
|
||||
switch handshakeState {
|
||||
case disco.BindUDPRelayHandshakeStateInit:
|
||||
// set sender addr
|
||||
e.addrPorts[senderIndex] = from
|
||||
fallthrough
|
||||
case disco.BindUDPRelayHandshakeStateChallengeSent:
|
||||
if from != e.addrPorts[senderIndex] {
|
||||
// this is a later arriving bind from a different source, or
|
||||
// a retransmit and the sender's source has changed, discard
|
||||
return
|
||||
}
|
||||
m := new(disco.BindUDPRelayEndpointChallenge)
|
||||
copy(m.Challenge[:], e.challenge[senderIndex][:])
|
||||
reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
|
||||
gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
|
||||
err := gh.Encode(reply)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reply = append(reply, disco.Magic...)
|
||||
reply = serverDisco.AppendTo(reply)
|
||||
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
|
||||
reply = append(reply, box...)
|
||||
uw.WriteMsgUDPAddrPort(reply, nil, from)
|
||||
// set new state
|
||||
e.handeshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent
|
||||
return
|
||||
default:
|
||||
// disco.BindUDPRelayEndpoint is unexpected in all other handshake states
|
||||
return
|
||||
}
|
||||
case *disco.BindUDPRelayEndpointAnswer:
|
||||
switch handshakeState {
|
||||
case disco.BindUDPRelayHandshakeStateChallengeSent:
|
||||
if from != e.addrPorts[senderIndex] {
|
||||
// sender source has changed
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) {
|
||||
// bad answer
|
||||
return
|
||||
}
|
||||
// sender is now bound
|
||||
// TODO: Consider installing a fast path via netfilter or similar to
|
||||
// relay (NAT) data packets for this serverEndpoint.
|
||||
e.handeshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived
|
||||
// record last seen as bound time
|
||||
e.lastSeen[senderIndex] = time.Now()
|
||||
return
|
||||
default:
|
||||
// disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake
|
||||
// states, or we've already handled it
|
||||
return
|
||||
}
|
||||
default:
|
||||
// unexpected Disco message type
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, uw udpWriter, serverDisco key.DiscoPublic) {
|
||||
senderRaw, isDiscoMsg := disco.Source(b)
|
||||
if !isDiscoMsg {
|
||||
// Not a Disco message
|
||||
return
|
||||
}
|
||||
sender := key.DiscoPublicFromRaw32(mem.B(senderRaw))
|
||||
senderIndex := -1
|
||||
switch {
|
||||
case sender.Compare(e.discoPubKeys[0]) == 0:
|
||||
senderIndex = 0
|
||||
case sender.Compare(e.discoPubKeys[1]) == 0:
|
||||
senderIndex = 1
|
||||
default:
|
||||
// unknown Disco public key
|
||||
return
|
||||
}
|
||||
|
||||
const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
|
||||
discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:])
|
||||
if !ok {
|
||||
// unable to decrypt the Disco payload
|
||||
return
|
||||
}
|
||||
|
||||
discoMsg, err := disco.Parse(discoPayload)
|
||||
if err != nil {
|
||||
// unable to parse the Disco payload
|
||||
return
|
||||
}
|
||||
|
||||
e.handleDiscoControlMsg(from, senderIndex, discoMsg, uw, serverDisco)
|
||||
}
|
||||
|
||||
type udpWriter interface {
|
||||
WriteMsgUDPAddrPort(b []byte, oob []byte, addr netip.AddrPort) (n, oobn int, err error)
|
||||
}
|
||||
|
||||
func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, uw udpWriter, serverDisco key.DiscoPublic) {
|
||||
if !gh.Control {
|
||||
if !e.isBound() {
|
||||
// not a control packet, but serverEndpoint isn't bound
|
||||
return
|
||||
}
|
||||
var to netip.AddrPort
|
||||
switch {
|
||||
case from == e.addrPorts[0]:
|
||||
e.lastSeen[0] = time.Now()
|
||||
to = e.addrPorts[1]
|
||||
case from == e.addrPorts[1]:
|
||||
e.lastSeen[1] = time.Now()
|
||||
to = e.addrPorts[0]
|
||||
default:
|
||||
// unrecognized source
|
||||
return
|
||||
}
|
||||
// relay packet
|
||||
uw.WriteMsgUDPAddrPort(b, nil, to)
|
||||
return
|
||||
}
|
||||
|
||||
if e.isBound() {
|
||||
// control packet, but serverEndpoint is already bound
|
||||
return
|
||||
}
|
||||
|
||||
if gh.Protocol != packet.GeneveProtocolDisco {
|
||||
// control packet, but not Disco
|
||||
return
|
||||
}
|
||||
|
||||
msg := b[packet.GeneveFixedHeaderLength:]
|
||||
e.handleSealedDiscoControlMsg(from, msg, uw, serverDisco)
|
||||
}
|
||||
|
||||
func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
|
||||
if !e.isBound() {
|
||||
if now.Sub(e.allocatedAt) > bindLifetime {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isBound returns true if both clients have completed their 3-way handshake,
|
||||
// otherwise false.
|
||||
func (e *serverEndpoint) isBound() bool {
|
||||
return e.handeshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived &&
|
||||
e.handeshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived
|
||||
}
|
||||
|
||||
// NewServer constructs a Server listening on 0.0.0.0:'port'. IPv6 is not yet
|
||||
// supported. Port may be 0, and what ultimately gets bound is returned as
|
||||
// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as
|
||||
// ServerEndpoint.AddrPorts in response to Server.AllocateEndpoint() requests.
|
||||
//
|
||||
// TODO: IPv6 support
|
||||
// TODO: dynamic addrs:port discovery
|
||||
func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) {
|
||||
s = &Server{
|
||||
disco: key.NewDisco(),
|
||||
bindLifetime: defaultBindLifetime,
|
||||
steadyStateLifetime: defaultSteadyStateLifetime,
|
||||
closeCh: make(chan struct{}),
|
||||
byDisco: make(map[pairOfDiscoPubKeys]*serverEndpoint),
|
||||
byVNI: make(map[uint32]*serverEndpoint),
|
||||
}
|
||||
s.discoPublic = s.disco.Public()
|
||||
// TODO: instead of allocating 10s of MBs for the full pool, allocate
|
||||
// smaller chunks and increase as needed
|
||||
s.vniPool = make([]uint32, 0, 1<<24-1)
|
||||
for i := 1; i < 1<<24; i++ {
|
||||
s.vniPool = append(s.vniPool, uint32(i))
|
||||
}
|
||||
boundPort, err = s.listenOn(port)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
addrPorts := make([]netip.AddrPort, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort)))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
addrPorts = append(addrPorts, addrPort)
|
||||
}
|
||||
s.addrPorts = addrPorts
|
||||
s.wg.Add(2)
|
||||
go s.packetReadLoop()
|
||||
go s.endpointGCLoop()
|
||||
return s, boundPort, nil
|
||||
}
|
||||
|
||||
func (s *Server) listenOn(port int) (int, error) {
|
||||
uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// TODO: set IP_PKTINFO sockopt
|
||||
_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
|
||||
if err != nil {
|
||||
s.uc.Close()
|
||||
return 0, err
|
||||
}
|
||||
boundPort, err := strconv.Atoi(boundPortStr)
|
||||
if err != nil {
|
||||
s.uc.Close()
|
||||
return 0, err
|
||||
}
|
||||
s.uc = uc
|
||||
return boundPort, nil
|
||||
}
|
||||
|
||||
// Close closes the server.
|
||||
func (s *Server) Close() error {
|
||||
s.closeOnce.Do(func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.uc.Close()
|
||||
close(s.closeCh)
|
||||
s.wg.Wait()
|
||||
clear(s.byVNI)
|
||||
clear(s.byDisco)
|
||||
s.vniPool = nil
|
||||
s.closed = true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) endpointGCLoop() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.bindLifetime)
|
||||
defer ticker.Stop()
|
||||
|
||||
gc := func() {
|
||||
now := time.Now()
|
||||
// TODO: consider performance implications of scanning all endpoints and
|
||||
// holding s.mu for the duration. Keep it simple (and slow) for now.
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, v := range s.byDisco {
|
||||
if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
|
||||
delete(s.byDisco, k)
|
||||
delete(s.byVNI, v.vni)
|
||||
s.vniPool = append(s.vniPool, v.vni)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
gc()
|
||||
case <-s.closeCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
|
||||
gh := packet.GeneveHeader{}
|
||||
err := gh.Decode(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// TODO: consider performance implications of holding s.mu for the remainder
|
||||
// of this method, which does a bunch of disco/crypto work depending. Keep
|
||||
// it simple (and slow) for now.
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
e, ok := s.byVNI[gh.VNI]
|
||||
if !ok {
|
||||
// unknown VNI
|
||||
return
|
||||
}
|
||||
|
||||
e.handlePacket(from, gh, b, uw, s.discoPublic)
|
||||
}
|
||||
|
||||
func (s *Server) packetReadLoop() {
|
||||
defer func() {
|
||||
s.wg.Done()
|
||||
s.Close()
|
||||
}()
|
||||
b := make([]byte, 1<<16-1)
|
||||
for {
|
||||
// TODO: extract laddr from IP_PKTINFO for use in reply
|
||||
n, from, err := s.uc.ReadFromUDPAddrPort(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handlePacket(from, b[:n], s.uc)
|
||||
}
|
||||
}
|
||||
|
||||
var ErrServerClosed = errors.New("server closed")
|
||||
|
||||
// AllocateEndpoint allocates a ServerEndpoint for the provided pair of
|
||||
// key.DiscoPublic's. It returns an error (ErrServerClosed) if the server has
|
||||
// been closed.
|
||||
func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (ServerEndpoint, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return ServerEndpoint{}, ErrServerClosed
|
||||
}
|
||||
|
||||
pair := newPairOfDiscoPubKeys(discoA, discoB)
|
||||
e, ok := s.byDisco[pair]
|
||||
if ok {
|
||||
if !e.isBound() {
|
||||
// If the endpoint is not yet bound this is likely an allocation
|
||||
// race between two clients utilizing the same relay. Instead of
|
||||
// re-allocating we return the existing allocation. We do not reset
|
||||
// e.allocatedAt in case a client is "stuck" in an allocation
|
||||
// loop and will not be able to complete a handshake, for whatever
|
||||
// reason. Once the endpoint expires a new endpoint will be
|
||||
// allocated. Clients can resolve duplicate ServerEndpoint details
|
||||
// via ServerEndpoint.LamportID.
|
||||
//
|
||||
// TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt
|
||||
// to give the client a more accurate picture of the bind window.
|
||||
// Or, some threshold to trigger re-allocation if too much time has
|
||||
// already passed since it was originally allocated.
|
||||
return ServerEndpoint{
|
||||
ServerDisco: s.discoPublic,
|
||||
AddrPorts: s.addrPorts,
|
||||
VNI: e.vni,
|
||||
LamportID: e.lamportID,
|
||||
BindLifetime: s.bindLifetime,
|
||||
SteadyStateLifetime: s.steadyStateLifetime,
|
||||
}, nil
|
||||
}
|
||||
// If an endpoint exists for the pair of key.DiscoPublic's, and is
|
||||
// already bound, delete it. We will re-allocate a new endpoint. Chances
|
||||
// are clients cannot make use of the existing, bound allocation if
|
||||
// they are requesting a new one.
|
||||
delete(s.byDisco, pair)
|
||||
delete(s.byVNI, e.vni)
|
||||
s.vniPool = append(s.vniPool, e.vni)
|
||||
}
|
||||
|
||||
if len(s.vniPool) == 0 {
|
||||
return ServerEndpoint{}, errors.New("VNI pool exhausted")
|
||||
}
|
||||
|
||||
s.lamportID++
|
||||
e = &serverEndpoint{
|
||||
discoPubKeys: pair,
|
||||
lamportID: s.lamportID,
|
||||
allocatedAt: time.Now(),
|
||||
}
|
||||
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0])
|
||||
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1])
|
||||
e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
|
||||
rand.Read(e.challenge[0][:])
|
||||
rand.Read(e.challenge[1][:])
|
||||
|
||||
s.byDisco[pair] = e
|
||||
s.byVNI[e.vni] = e
|
||||
|
||||
return ServerEndpoint{
|
||||
ServerDisco: s.discoPublic,
|
||||
AddrPorts: s.addrPorts,
|
||||
VNI: e.vni,
|
||||
LamportID: e.lamportID,
|
||||
BindLifetime: defaultBindLifetime,
|
||||
SteadyStateLifetime: defaultSteadyStateLifetime,
|
||||
}, nil
|
||||
}
|
||||
201
net/udprelay/server_test.go
Normal file
201
net/udprelay/server_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package udprelay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
type testClient struct {
|
||||
vni uint32
|
||||
local key.DiscoPrivate
|
||||
server key.DiscoPublic
|
||||
uc *net.UDPConn
|
||||
}
|
||||
|
||||
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient {
|
||||
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
|
||||
uc, err := net.DialUDP("udp4", nil, rAddr)
|
||||
if err != nil {
|
||||
t.Fatal(t)
|
||||
}
|
||||
return &testClient{
|
||||
vni: vni,
|
||||
local: local,
|
||||
server: server,
|
||||
uc: uc,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *testClient) write(t *testing.T, b []byte) {
|
||||
_, err := c.uc.Write(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *testClient) read(t *testing.T) []byte {
|
||||
c.uc.SetReadDeadline(time.Now().Add(time.Second))
|
||||
b := make([]byte, 1<<16-1)
|
||||
n, err := c.uc.Read(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return b[:n]
|
||||
}
|
||||
|
||||
func (c *testClient) writeDataPkt(t *testing.T, b []byte) {
|
||||
pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b))
|
||||
gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard}
|
||||
err := gh.Encode(pkt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pkt = append(pkt, b...)
|
||||
c.write(t, pkt)
|
||||
}
|
||||
|
||||
func (c *testClient) readDataPkt(t *testing.T) []byte {
|
||||
b := c.read(t)
|
||||
gh := packet.GeneveHeader{}
|
||||
err := gh.Decode(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if gh.Protocol != packet.GeneveProtocolWireGuard {
|
||||
t.Fatal("unexpected geneve protocol")
|
||||
}
|
||||
if gh.Control {
|
||||
t.Fatal("unexpected control")
|
||||
}
|
||||
if gh.VNI != c.vni {
|
||||
t.Fatal("unexpected vni")
|
||||
}
|
||||
return b[packet.GeneveFixedHeaderLength:]
|
||||
}
|
||||
|
||||
func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) {
|
||||
pkt := make([]byte, packet.GeneveFixedHeaderLength, 512)
|
||||
gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco}
|
||||
err := gh.Encode(pkt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pkt = append(pkt, disco.Magic...)
|
||||
pkt = c.local.Public().AppendTo(pkt)
|
||||
box := c.local.Shared(c.server).Seal(msg.AppendMarshal(nil))
|
||||
pkt = append(pkt, box...)
|
||||
c.write(t, pkt)
|
||||
}
|
||||
|
||||
func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message {
|
||||
b := c.read(t)
|
||||
gh := packet.GeneveHeader{}
|
||||
err := gh.Decode(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if gh.Protocol != packet.GeneveProtocolDisco {
|
||||
t.Fatal("unexpected geneve protocol")
|
||||
}
|
||||
if !gh.Control {
|
||||
t.Fatal("unexpected non-control")
|
||||
}
|
||||
if gh.VNI != c.vni {
|
||||
t.Fatal("unexpected vni")
|
||||
}
|
||||
b = b[packet.GeneveFixedHeaderLength:]
|
||||
headerLen := len(disco.Magic) + key.DiscoPublicRawLen
|
||||
if len(b) < headerLen {
|
||||
t.Fatal("disco message too short")
|
||||
}
|
||||
sender := key.DiscoPublicFromRaw32(mem.B(b[len(disco.Magic):headerLen]))
|
||||
if sender.Compare(c.server) != 0 {
|
||||
t.Fatal("unknown disco public key")
|
||||
}
|
||||
payload, ok := c.local.Shared(c.server).Open(b[headerLen:])
|
||||
if !ok {
|
||||
t.Fatal("failed to open sealed disco msg")
|
||||
}
|
||||
msg, err := disco.Parse(payload)
|
||||
if err != nil {
|
||||
t.Fatal("failed to parse disco payload")
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func (c *testClient) handshake(t *testing.T) {
|
||||
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{})
|
||||
msg := c.readControlDiscoMsg(t)
|
||||
challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge)
|
||||
if !ok {
|
||||
t.Fatal("unexepcted disco message type")
|
||||
}
|
||||
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge})
|
||||
}
|
||||
|
||||
func (c *testClient) close() {
|
||||
c.uc.Close()
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
discoA := key.NewDisco()
|
||||
discoB := key.NewDisco()
|
||||
|
||||
ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1")
|
||||
|
||||
server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We expect the same endpoint details as the 3-way bind handshake has not
|
||||
// yet been completed for both relay client parties.
|
||||
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
|
||||
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
|
||||
}
|
||||
|
||||
if len(endpoint.AddrPorts) != 1 {
|
||||
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
|
||||
}
|
||||
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco)
|
||||
defer tcA.close()
|
||||
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco)
|
||||
defer tcB.close()
|
||||
|
||||
tcA.handshake(t)
|
||||
tcB.handshake(t)
|
||||
|
||||
txToB := []byte{1, 2, 3}
|
||||
tcA.writeDataPkt(t, txToB)
|
||||
rxFromA := tcB.readDataPkt(t)
|
||||
if !bytes.Equal(txToB, rxFromA) {
|
||||
t.Fatal("unexpected msg A->B")
|
||||
}
|
||||
|
||||
txToA := []byte{4, 5, 6}
|
||||
tcB.writeDataPkt(t, txToA)
|
||||
rxFromB := tcA.readDataPkt(t)
|
||||
if !bytes.Equal(txToA, rxFromB) {
|
||||
t.Fatal("unexpected msg B->A")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user