Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02acaa00ee | ||
|
|
088d78591c | ||
|
|
24929f6b61 | ||
|
|
78c8f7ec58 | ||
|
|
f4d76fb46d | ||
|
|
b6852d5357 | ||
|
|
51fb4ce517 |
@@ -1 +1 @@
|
||||
1.75.0
|
||||
1.76.3
|
||||
|
||||
@@ -487,6 +487,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
|
||||
defer hres.Body.Close()
|
||||
if hres.StatusCode != 200 {
|
||||
metricDNSFwdDoHErrorStatus.Add(1)
|
||||
if hres.StatusCode/100 == 5 {
|
||||
// Translate 5xx HTTP server errors into SERVFAIL DNS responses.
|
||||
return nil, fmt.Errorf("%w: %s", errServerFailure, hres.Status)
|
||||
}
|
||||
return nil, errors.New(hres.Status)
|
||||
}
|
||||
if ct := hres.Header.Get("Content-Type"); ct != dohType {
|
||||
@@ -916,10 +920,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
|
||||
metricDNSFwdDropBonjour.Add(1)
|
||||
res, err := nxDomainResponse(query)
|
||||
if err != nil {
|
||||
f.logf("error parsing bonjour query: %v", err)
|
||||
// Returning an error will cause an internal retry, there is
|
||||
// nothing we can do if parsing failed. Just drop the packet.
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -951,10 +952,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
|
||||
|
||||
res, err := servfailResponse(query)
|
||||
if err != nil {
|
||||
f.logf("building servfail response: %v", err)
|
||||
// Returning an error will cause an internal retry, there is
|
||||
// nothing we can do if parsing failed. Just drop the packet.
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -1053,6 +1051,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
|
||||
if verboseDNSForward() {
|
||||
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -450,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
|
||||
return
|
||||
}
|
||||
|
||||
func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) {
|
||||
func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
|
||||
netMon, err := netmon.New(tb.Logf)
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
@@ -464,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
|
||||
modify(fwd)
|
||||
}
|
||||
|
||||
rr := resolverAndDelay{
|
||||
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
|
||||
resolvers := make([]resolverAndDelay, len(ports))
|
||||
for i, port := range ports {
|
||||
resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}
|
||||
}
|
||||
|
||||
rpkt := packet{
|
||||
@@ -477,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
|
||||
rchan := make(chan packet, 1)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
tb.Cleanup(cancel)
|
||||
err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr)
|
||||
err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...)
|
||||
select {
|
||||
case res := <-rchan:
|
||||
return res.bs, err
|
||||
@@ -486,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
|
||||
}
|
||||
}
|
||||
|
||||
func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte {
|
||||
resp, err := runTestQuery(tb, port, request, modify)
|
||||
// makeTestRequest returns a new TypeA request for the given domain.
|
||||
func makeTestRequest(tb testing.TB, domain string) []byte {
|
||||
tb.Helper()
|
||||
name := dns.MustNewName(domain)
|
||||
builder := dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
request, err := builder.Finish()
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return request
|
||||
}
|
||||
|
||||
// makeTestResponse returns a new Type A response for the given domain,
|
||||
// with the specified status code and zero or more addresses.
|
||||
func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte {
|
||||
tb.Helper()
|
||||
name := dns.MustNewName(domain)
|
||||
builder := dns.NewBuilder(nil, dns.Header{
|
||||
Response: true,
|
||||
Authoritative: true,
|
||||
RCode: code,
|
||||
})
|
||||
builder.StartQuestions()
|
||||
q := dns.Question{
|
||||
Name: name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
builder.Question(q)
|
||||
if len(addrs) > 0 {
|
||||
builder.StartAnswers()
|
||||
for _, addr := range addrs {
|
||||
builder.AResource(dns.ResourceHeader{
|
||||
Name: q.Name,
|
||||
Class: q.Class,
|
||||
TTL: 120,
|
||||
}, dns.AResource{
|
||||
A: addr.As4(),
|
||||
})
|
||||
}
|
||||
}
|
||||
response, err := builder.Finish()
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte {
|
||||
resp, err := runTestQuery(tb, request, modify, ports...)
|
||||
if err != nil {
|
||||
tb.Fatalf("error making request: %v", err)
|
||||
}
|
||||
@@ -516,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
resp := mustRunTestQuery(t, port, request, nil)
|
||||
resp := mustRunTestQuery(t, request, nil, port)
|
||||
if !bytes.Equal(resp, largeResponse) {
|
||||
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
||||
}
|
||||
@@ -554,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
resp := mustRunTestQuery(t, port, request, nil)
|
||||
resp := mustRunTestQuery(t, request, nil, port)
|
||||
if !bytes.Equal(resp, largeResponse) {
|
||||
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
||||
}
|
||||
@@ -585,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) {
|
||||
resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
|
||||
// Disable retries for this test.
|
||||
fwd.controlKnobs = &controlknobs.Knobs{}
|
||||
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
|
||||
})
|
||||
}, port)
|
||||
|
||||
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
|
||||
|
||||
@@ -613,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) {
|
||||
const domain = "error-response.tailscale.com."
|
||||
|
||||
// Our response is a SERVFAIL
|
||||
response := func() []byte {
|
||||
name := dns.MustNewName(domain)
|
||||
|
||||
builder := dns.NewBuilder(nil, dns.Header{
|
||||
Response: true,
|
||||
RCode: dns.RCodeServerFailure,
|
||||
})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
response, err := builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return response
|
||||
}()
|
||||
response := makeTestResponse(t, domain, dns.RCodeServerFailure)
|
||||
|
||||
// Our request is a single A query for the domain in the answer, above.
|
||||
request := func() []byte {
|
||||
builder := dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: dns.MustNewName(domain),
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
request, err := builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return request
|
||||
}()
|
||||
request := makeTestRequest(t, domain)
|
||||
|
||||
var sawRequest atomic.Bool
|
||||
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
|
||||
@@ -657,14 +680,141 @@ func TestForwarderTCPFallbackError(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
_, err := runTestQuery(t, port, request, nil)
|
||||
resp, err := runTestQuery(t, request, nil, port)
|
||||
if !sawRequest.Load() {
|
||||
t.Error("did not see DNS request")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("wanted error, got nil")
|
||||
} else if !errors.Is(err, errServerFailure) {
|
||||
t.Errorf("wanted errServerFailure, got: %v", err)
|
||||
if err != nil {
|
||||
t.Fatalf("wanted nil, got %v", err)
|
||||
}
|
||||
var parser dns.Parser
|
||||
respHeader, err := parser.Start(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("parser.Start() failed: %v", err)
|
||||
}
|
||||
if got, want := respHeader.RCode, dns.RCodeServerFailure; got != want {
|
||||
t.Errorf("wanted %v, got %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Test to ensure that if we have more than one resolver, and at least one of them
|
||||
// returns a successful response, we propagate it.
|
||||
func TestForwarderWithManyResolvers(t *testing.T) {
|
||||
enableDebug(t)
|
||||
|
||||
const domain = "example.com."
|
||||
request := makeTestRequest(t, domain)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responses [][]byte // upstream responses
|
||||
wantResponses [][]byte // we should receive one of these from the forwarder
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
responses: [][]byte{ // All upstream servers returned successful, but different, response.
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
|
||||
},
|
||||
wantResponses: [][]byte{ // We may forward whichever response is received first.
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ServFail",
|
||||
responses: [][]byte{ // All upstream servers returned a SERVFAIL.
|
||||
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
||||
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
||||
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
||||
},
|
||||
wantResponses: [][]byte{
|
||||
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ServFail+Success",
|
||||
responses: [][]byte{ // All upstream servers fail except for one.
|
||||
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
||||
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
||||
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
||||
},
|
||||
wantResponses: [][]byte{ // We should forward the successful response.
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NXDomain",
|
||||
responses: [][]byte{ // All upstream servers returned NXDOMAIN.
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
},
|
||||
wantResponses: [][]byte{
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NXDomain+Success",
|
||||
responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one.
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
||||
},
|
||||
wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response.
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Refused",
|
||||
responses: [][]byte{ // All upstream servers return different failures.
|
||||
makeTestResponse(t, domain, dns.RCodeRefused),
|
||||
makeTestResponse(t, domain, dns.RCodeRefused),
|
||||
makeTestResponse(t, domain, dns.RCodeRefused),
|
||||
makeTestResponse(t, domain, dns.RCodeRefused),
|
||||
makeTestResponse(t, domain, dns.RCodeRefused),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
||||
},
|
||||
wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded.
|
||||
makeTestResponse(t, domain, dns.RCodeRefused),
|
||||
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MixFail",
|
||||
responses: [][]byte{ // All upstream servers return different failures.
|
||||
makeTestResponse(t, domain, dns.RCodeServerFailure),
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
makeTestResponse(t, domain, dns.RCodeRefused),
|
||||
},
|
||||
wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded.
|
||||
makeTestResponse(t, domain, dns.RCodeNameError),
|
||||
makeTestResponse(t, domain, dns.RCodeRefused),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ports := make([]uint16, len(tt.responses))
|
||||
for i := range tt.responses {
|
||||
ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {})
|
||||
}
|
||||
gotResponse, err := runTestQuery(t, request, nil, ports...)
|
||||
if err != nil {
|
||||
t.Fatalf("wanted nil, got %v", err)
|
||||
}
|
||||
responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool {
|
||||
return slices.Equal(gotResponse, wantResponse)
|
||||
})
|
||||
if !responseOk {
|
||||
t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -713,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
|
||||
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
|
||||
})
|
||||
|
||||
res, err := runTestQuery(t, port, request, nil)
|
||||
res, err := runTestQuery(t, request, nil, port)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -321,15 +321,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net
|
||||
defer cancel()
|
||||
err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses)
|
||||
if err != nil {
|
||||
select {
|
||||
// Best effort: use any error response sent by forwardWithDestChan.
|
||||
// This is present in some errors paths, such as when all upstream
|
||||
// DNS servers replied with an error.
|
||||
case resp := <-responses:
|
||||
return resp.bs, err
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return (<-responses).bs, nil
|
||||
}
|
||||
|
||||
@@ -1503,8 +1503,8 @@ func TestServfail(t *testing.T) {
|
||||
r.SetConfig(cfg)
|
||||
|
||||
pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns))
|
||||
if !errors.Is(err, errServerFailure) {
|
||||
t.Errorf("err = %v, want %v", err, errServerFailure)
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
|
||||
wantPkt := []byte{
|
||||
|
||||
@@ -940,7 +940,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
|
||||
}
|
||||
}
|
||||
if len(need) > 0 {
|
||||
if !opts.OnlyTCP443 {
|
||||
if opts == nil || !opts.OnlyTCP443 {
|
||||
// Kick off ICMP in parallel to HTTPS checks; we don't
|
||||
// reuse the same WaitGroup for those probes because we
|
||||
// need to close the underlying Pinger after a timeout
|
||||
|
||||
@@ -210,8 +210,6 @@ type incubatorArgs struct {
|
||||
debugTest bool
|
||||
isSELinuxEnforcing bool
|
||||
encodedEnv string
|
||||
allowListEnvKeys string
|
||||
forwardedEnviron []string
|
||||
}
|
||||
|
||||
func parseIncubatorArgs(args []string) (incubatorArgs, error) {
|
||||
@@ -246,31 +244,35 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) {
|
||||
ia.gids = append(ia.gids, gid)
|
||||
}
|
||||
|
||||
ia.forwardedEnviron = os.Environ()
|
||||
return ia, nil
|
||||
}
|
||||
|
||||
func (ia incubatorArgs) forwadedEnviron() ([]string, string, error) {
|
||||
environ := os.Environ()
|
||||
// pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding
|
||||
ia.allowListEnvKeys = "SSH_AUTH_SOCK"
|
||||
allowListKeys := "SSH_AUTH_SOCK"
|
||||
|
||||
if ia.encodedEnv != "" {
|
||||
unquoted, err := strconv.Unquote(ia.encodedEnv)
|
||||
if err != nil {
|
||||
return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
|
||||
return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
|
||||
}
|
||||
|
||||
var extraEnviron []string
|
||||
|
||||
err = json.Unmarshal([]byte(unquoted), &extraEnviron)
|
||||
if err != nil {
|
||||
return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
|
||||
return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
|
||||
}
|
||||
|
||||
ia.forwardedEnviron = append(ia.forwardedEnviron, extraEnviron...)
|
||||
environ = append(environ, extraEnviron...)
|
||||
|
||||
for _, v := range extraEnviron {
|
||||
ia.allowListEnvKeys = fmt.Sprintf("%s,%s", ia.allowListEnvKeys, strings.Split(v, "=")[0])
|
||||
allowListKeys = fmt.Sprintf("%s,%s", allowListKeys, strings.Split(v, "=")[0])
|
||||
}
|
||||
}
|
||||
|
||||
return ia, nil
|
||||
return environ, allowListKeys, nil
|
||||
}
|
||||
|
||||
// beIncubator is the entrypoint to the `tailscaled be-child ssh` subcommand.
|
||||
@@ -450,8 +452,13 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error {
|
||||
loginArgs := ia.loginArgs(loginCmdPath)
|
||||
dlogf("logging in with %+v", loginArgs)
|
||||
|
||||
environ, _, err := ia.forwadedEnviron()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If Exec works, the Go code will not proceed past this:
|
||||
err = unix.Exec(loginCmdPath, loginArgs, ia.forwardedEnviron)
|
||||
err = unix.Exec(loginCmdPath, loginArgs, environ)
|
||||
|
||||
// If we made it here, Exec failed.
|
||||
return err
|
||||
@@ -484,9 +491,14 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) {
|
||||
defer sessionCloser()
|
||||
}
|
||||
|
||||
environ, allowListEnvKeys, err := ia.forwadedEnviron()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
loginArgs := []string{
|
||||
su,
|
||||
"-w", ia.allowListEnvKeys,
|
||||
"-w", allowListEnvKeys,
|
||||
"-l",
|
||||
ia.localUser,
|
||||
}
|
||||
@@ -498,7 +510,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) {
|
||||
dlogf("logging in with %+v", loginArgs)
|
||||
|
||||
// If Exec works, the Go code will not proceed past this:
|
||||
err = unix.Exec(su, loginArgs, ia.forwardedEnviron)
|
||||
err = unix.Exec(su, loginArgs, environ)
|
||||
|
||||
// If we made it here, Exec failed.
|
||||
return true, err
|
||||
@@ -527,11 +539,16 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
_, allowListEnvKeys, err := ia.forwadedEnviron()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// First try to execute su -w <allow listed env> -l <user> -c true
|
||||
// to make sure su supports the necessary arguments.
|
||||
err = exec.Command(
|
||||
su,
|
||||
"-w", ia.allowListEnvKeys,
|
||||
"-w", allowListEnvKeys,
|
||||
"-l",
|
||||
ia.localUser,
|
||||
"-c", "true",
|
||||
@@ -558,10 +575,15 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error {
|
||||
return err
|
||||
}
|
||||
|
||||
environ, _, err := ia.forwadedEnviron()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
args := shellArgs(ia.isShell, ia.cmd)
|
||||
dlogf("running %s %q", ia.loginShell, args)
|
||||
cmd := newCommand(ia.hasTTY, ia.loginShell, ia.forwardedEnviron, args)
|
||||
err := cmd.Run()
|
||||
cmd := newCommand(ia.hasTTY, ia.loginShell, environ, args)
|
||||
err = cmd.Run()
|
||||
if ee, ok := err.(*exec.ExitError); ok {
|
||||
ps := ee.ProcessState
|
||||
code := ps.ExitCode()
|
||||
|
||||
Reference in New Issue
Block a user