refactor and add test

ssh-host-2
Jesse Duffield 2 years ago
parent afe3d23cbd
commit 16848ef29b

@ -23,6 +23,7 @@ import (
"github.com/docker/docker/api/types"
"github.com/docker/docker/client"
"github.com/imdario/mergo"
"github.com/jesseduffield/lazydocker/pkg/commands/ssh"
"github.com/jesseduffield/lazydocker/pkg/config"
"github.com/jesseduffield/lazydocker/pkg/i18n"
"github.com/jesseduffield/lazydocker/pkg/utils"
@ -54,6 +55,8 @@ type DockerCommand struct {
Closers []io.Closer
}
var _ io.Closer = &DockerCommand{}
// LimitedDockerCommand is a stripped-down DockerCommand with just the methods the container/service/image might need
type LimitedDockerCommand interface {
NewCommandObject(CommandObject) CommandObject
@ -198,7 +201,7 @@ func clientBuilder(c *client.Client) error {
// NewDockerCommand it runs docker commands
func NewDockerCommand(log *logrus.Entry, osCommand *OSCommand, tr *i18n.TranslationSet, config *config.AppConfig, errorChan chan error) (*DockerCommand, error) {
tunnelCloser, err := handleSSHDockerHost()
tunnelCloser, err := ssh.NewSSHHandler().HandleSSHDockerHost()
if err != nil {
ogLog.Fatal(err)
}

@ -0,0 +1,156 @@
package ssh
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/url"
"os"
"os/exec"
"path"
"syscall"
"time"
)
type dependencies struct {
// storing all these dependencies as fields for the sake of testing
dialContext func(ctx context.Context, network, addr string) (io.Closer, error)
startCmd func(*exec.Cmd) error
tempDir func(dir string, pattern string) (name string, err error)
getenv func(key string) string
setenv func(key, value string) error
}
type SSHHandler struct {
deps dependencies
}
func NewSSHHandler() *SSHHandler {
return &SSHHandler{
deps: dependencies{
dialContext: func(ctx context.Context, network, addr string) (io.Closer, error) {
return (&net.Dialer{}).DialContext(ctx, network, addr)
},
startCmd: func(cmd *exec.Cmd) error { return cmd.Start() },
tempDir: ioutil.TempDir,
getenv: os.Getenv,
setenv: os.Setenv,
},
}
}
// HandleSSHDockerHost overrides the DOCKER_HOST environment variable
// to point towards a local unix socket tunneled over SSH to the specified ssh host.
func (self *SSHHandler) HandleSSHDockerHost() (io.Closer, error) {
const key = "DOCKER_HOST"
ctx := context.Background()
u, err := url.Parse(self.deps.getenv(key))
if err != nil {
// if no or an invalid docker host is specified, continue nominally
return noopCloser{}, nil
}
// if the docker host scheme is "ssh", forward the docker socket before creating the client
if u.Scheme == "ssh" {
tunnel, err := self.createDockerHostTunnel(ctx, u.Host)
if err != nil {
return noopCloser{}, fmt.Errorf("tunnel ssh docker host: %w", err)
}
err = self.deps.setenv(key, tunnel.socketPath)
if err != nil {
return noopCloser{}, fmt.Errorf("override DOCKER_HOST to tunneled socket: %w", err)
}
return tunnel, nil
}
return noopCloser{}, nil
}
type noopCloser struct{}
func (noopCloser) Close() error { return nil }
type tunneledDockerHost struct {
socketPath string
cmd *exec.Cmd
}
var _ io.Closer = (*tunneledDockerHost)(nil)
func (t *tunneledDockerHost) Close() error {
return syscall.Kill(-t.cmd.Process.Pid, syscall.SIGKILL)
}
func (self *SSHHandler) createDockerHostTunnel(ctx context.Context, remoteHost string) (*tunneledDockerHost, error) {
socketDir, err := self.deps.tempDir("/tmp", "lazydocker-sshtunnel-")
if err != nil {
return nil, fmt.Errorf("create ssh tunnel tmp file: %w", err)
}
localSocket := path.Join(socketDir, "dockerhost.sock")
cmd, err := self.tunnelSSH(ctx, remoteHost, localSocket)
if err != nil {
return nil, fmt.Errorf("tunnel docker host over ssh: %w", err)
}
// set a reasonable timeout, then wait for the socket to dial successfully
// before attempting to create a new docker client
const socketTunnelTimeout = 8 * time.Second
ctx, cancel := context.WithTimeout(ctx, socketTunnelTimeout)
defer cancel()
err = self.retrySocketDial(ctx, localSocket)
if err != nil {
return nil, fmt.Errorf("ssh tunneled socket never became available: %w", err)
}
// construct the new DOCKER_HOST url with the proper scheme
newDockerHostURL := url.URL{Scheme: "unix", Path: localSocket}
return &tunneledDockerHost{
socketPath: newDockerHostURL.String(),
cmd: cmd,
}, nil
}
// Attempt to dial the socket until it becomes available.
// The retry loop will continue until the parent context is canceled.
func (self *SSHHandler) retrySocketDial(ctx context.Context, socketPath string) error {
t := time.NewTicker(1 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
}
// attempt to dial the socket, exit on success
err := self.tryDial(ctx, socketPath)
if err != nil {
continue
}
return nil
}
}
// Try to dial the specified unix socket, immediately close the connection if successfully created.
func (self *SSHHandler) tryDial(ctx context.Context, socketPath string) error {
conn, err := self.deps.dialContext(ctx, "unix", socketPath)
if err != nil {
return err
}
defer conn.Close()
return nil
}
func (self *SSHHandler) tunnelSSH(ctx context.Context, host, localSocket string) (*exec.Cmd, error) {
cmd := exec.CommandContext(ctx, "ssh", "-L", localSocket+":/var/run/docker.sock", host, "-N")
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
err := self.deps.startCmd(cmd)
if err != nil {
return nil, err
}
return cmd, nil
}

@ -0,0 +1,102 @@
package ssh
import (
"context"
"io"
"os/exec"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSSHHandlerHandleSSHDockerHost(t *testing.T) {
type scenario struct {
testName string
envVarValue string
expectedDialContextCount int
expectedStartCmdCount int
}
scenarios := []scenario{
{
testName: "No env var set",
envVarValue: "",
expectedDialContextCount: 0,
expectedStartCmdCount: 0,
},
{
testName: "Env var set with https scheme",
envVarValue: "https://myhost.com",
expectedStartCmdCount: 0,
expectedDialContextCount: 0,
},
{
testName: "Env var set with ssh scheme",
envVarValue: "ssh://myhost@192.168.5.178",
expectedStartCmdCount: 1,
expectedDialContextCount: 1,
},
}
for _, s := range scenarios {
s := s
t.Run(s.testName, func(t *testing.T) {
getenv := func(key string) string {
if key != "DOCKER_HOST" {
t.Errorf("Expected key to be DOCKER_HOST, got %s", key)
}
return s.envVarValue
}
tempDir := func(dir string, pattern string) (string, error) {
assert.Equal(t, "/tmp", dir)
assert.Equal(t, "lazydocker-sshtunnel-", pattern)
return "/tmp/lazydocker-ssh-tunnel-12345", nil
}
setenv := func(key, value string) error {
assert.Equal(t, "DOCKER_HOST", key)
assert.Equal(t, "unix:///tmp/lazydocker-ssh-tunnel-12345/dockerhost.sock", value)
return nil
}
startCmdCount := 0
startCmd := func(cmd *exec.Cmd) error {
assert.EqualValues(t, []string{"ssh", "-L", "/tmp/lazydocker-ssh-tunnel-12345/dockerhost.sock:/var/run/docker.sock", "192.168.5.178", "-N"}, cmd.Args)
assert.Equal(t, true, cmd.SysProcAttr.Setpgid)
startCmdCount++
return nil
}
dialContextCount := 0
dialContext := func(ctx context.Context, network string, address string) (io.Closer, error) {
assert.Equal(t, "unix", network)
assert.Equal(t, "/tmp/lazydocker-ssh-tunnel-12345/dockerhost.sock", address)
dialContextCount++
return noopCloser{}, nil
}
handler := &SSHHandler{
deps: dependencies{
dialContext: dialContext,
startCmd: startCmd,
tempDir: tempDir,
getenv: getenv,
setenv: setenv,
},
}
_, err := handler.HandleSSHDockerHost()
assert.NoError(t, err)
assert.Equal(t, s.expectedDialContextCount, dialContextCount)
assert.Equal(t, s.expectedStartCmdCount, startCmdCount)
})
}
}

@ -373,4 +373,3 @@ func CloseMany(closers []io.Closer) error {
}
return multiErr(errs)
}

Loading…
Cancel
Save