Compare commits

...

1 Commits

Author SHA1 Message Date
David Crawshaw
8b9e9c0786 ipnlocal, resolver, etc: add peer API DoH 2021-07-29 17:38:37 -07:00
8 changed files with 104 additions and 16 deletions

View File

@@ -2823,3 +2823,8 @@ func (b *LocalBackend) DERPMap() *tailcfg.DERPMap {
}
return b.netMap.DERPMap
}
// DNSManager returns the underlying DNSManager.
func (b *LocalBackend) DNSManager() *dns.Manager {
return b.e.DNSManager()
}

View File

@@ -12,6 +12,7 @@ import (
"html"
"io"
"io/fs"
"io/ioutil"
"net"
"net/http"
"net/url"
@@ -500,6 +501,10 @@ func (h *peerAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handlePeerPut(w, r)
return
}
if r.URL.Path == "/v0/doh" {
h.handleDoH(w, r)
return
}
if r.URL.Path == "/v0/goroutines" {
h.handleServeGoroutines(w, r)
return
@@ -710,3 +715,32 @@ func (h *peerAPIHandler) handleServeGoroutines(w http.ResponseWriter, r *http.Re
}
w.Write(buf)
}
func (h *peerAPIHandler) handleDoH(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "only POST is support for DNS-over-HTTP", http.StatusMethodNotAllowed)
return
}
const dohType = "application/dns-message"
if r.Header.Get("Content-Type") != dohType {
http.Error(w, fmt.Sprintf("Content-Type=%q; want %q", r.Header.Get("Content-Type"), dohType), http.StatusBadRequest)
return
}
r.Body = http.MaxBytesReader(w, r.Body, 1<<20)
bs, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, "read body: "+err.Error(), http.StatusInternalServerError)
return
}
res, err := h.ps.b.DNSManager().Request(r.Context(), bs)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", dohType)
w.Header().Set("Content-Length", strconv.FormatInt(int64(len(res)), 10))
w.WriteHeader(200)
w.Write(res)
}

View File

@@ -6,6 +6,7 @@ package dns
import (
"bufio"
"context"
"runtime"
"time"
@@ -195,6 +196,10 @@ func (m *Manager) NextResponse() ([]byte, netaddr.IPPort, error) {
return m.resolver.NextResponse()
}
func (m *Manager) Request(ctx context.Context, bs []byte) ([]byte, error) {
return m.resolver.Request(ctx, bs)
}
func (m *Manager) Down() error {
if err := m.os.Close(); err != nil {
return err

View File

@@ -529,28 +529,56 @@ type forwardQuery struct {
// forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query packet) error {
domain, err := nameFromQuery(query.bs)
ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
defer cancel()
v, err := f.forwardQuery(ctx, query.bs)
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case f.responses <- packet{v, query.addr}:
return nil
}
}
clampEDNSSize(query.bs, maxResponseBytes)
func (f *forwarder) Forward(ctx context.Context, bs []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()
go func() {
select {
case <-f.ctx.Done():
cancel()
case <-ctx.Done():
}
}()
return f.forwardQuery(ctx, bs)
}
func (f *forwarder) forwardQuery(ctx context.Context, bs []byte) ([]byte, error) {
domain, err := nameFromQuery(bs)
if err != nil {
return nil, err
}
clampEDNSSize(bs, maxResponseBytes)
resolvers := f.resolvers(domain)
if len(resolvers) == 0 {
return errNoUpstreams
return nil, errNoUpstreams
}
fq := &forwardQuery{
txid: getTxID(query.bs),
packet: query.bs,
txid: getTxID(bs),
packet: bs,
closeOnCtxDone: new(closePool),
}
defer fq.closeOnCtxDone.Close()
ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
defer cancel()
resc := make(chan []byte, 1)
var (
mu sync.Mutex
@@ -586,19 +614,14 @@ func (f *forwarder) forward(query packet) error {
select {
case v := <-resc:
select {
case <-ctx.Done():
return ctx.Err()
case f.responses <- packet{v, query.addr}:
return nil
}
return v, nil
case <-ctx.Done():
mu.Lock()
defer mu.Unlock()
if firstErr != nil {
return firstErr
return nil, firstErr
}
return ctx.Err()
return nil, ctx.Err()
}
}

View File

@@ -8,6 +8,7 @@ package resolver
import (
"bufio"
"context"
"encoding/hex"
"errors"
"fmt"
@@ -270,6 +271,15 @@ func (r *Resolver) NextResponse() (packet []byte, to netaddr.IPPort, err error)
}
}
// Request issues a DNS request and returns the result.
func (r *Resolver) Request(ctx context.Context, bs []byte) ([]byte, error) {
out, err := r.respond(bs)
if err == errNotOurName {
return r.forwarder.Forward(ctx, bs)
}
return out, err
}
// resolveLocal returns an IP for the given domain, if domain is in
// the local hosts map and has an IP corresponding to the requested
// typ (A, AAAA, ALL).

View File

@@ -377,6 +377,10 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
return e, nil
}
func (e *userspaceEngine) DNSManager() *dns.Manager {
return e.dns
}
// echoRespondToAll is an inbound post-filter responding to all echo requests.
func echoRespondToAll(p *packet.Parsed, t *tstun.Wrapper) filter.Response {
if p.IsEchoRequest() {

View File

@@ -129,6 +129,10 @@ func (e *watchdogEngine) WhoIsIPPort(ipp netaddr.IPPort) (tsIP netaddr.IP, ok bo
e.watchdog("UnregisterIPPortIdentity", func() { tsIP, ok = e.wrap.WhoIsIPPort(ipp) })
return tsIP, ok
}
func (e *watchdogEngine) DNSManager() (m *dns.Manager) {
e.watchdog("DNSManager", func() { m = e.wrap.DNSManager() })
return m
}
func (e *watchdogEngine) Close() {
e.watchdog("Close", e.wrap.Close)
}

View File

@@ -149,4 +149,7 @@ type Engine interface {
// WhoIsIPPort looks up an IP:port in the temporary registrations,
// and returns a matching Tailscale IP, if it exists.
WhoIsIPPort(netaddr.IPPort) (netaddr.IP, bool)
// DNSManager returns the DNS manager for this engine.
DNSManager() *dns.Manager
}