Sometimes resources (such as database servers) are not publicly accessible. This is critical for security, but it can be pain for writing scripts that need to access these resources for debugging and other ad-hoc tasks.

One solution is to create an SSH tunnel in bash and point your script to it. However:

  1. You may need to write scripts that are too complicated for bash.
  2. It can make your scripts brittle if you need to run multiple tunnels or forget to clean them up for long running processes.
  3. You may not have access to a seperate terminal to run the SSH tunnel such as under some automation script.
  4. Want to use all your existing Go code, but bolt on the tunnel.
  5. Dislike bash.

Well, here you go. The following code supports creating multiple hassle-free SSH tunnels in pure Go and support using a private key or password authentication:

import (
"fmt"
"golang.org/x/crypto/ssh"
"io"
"io/ioutil"
"log"
"net"
"strconv"
"strings"
)

type Endpoint struct {
Host string
Port int
User string
}

func NewEndpoint(s string) *Endpoint {
endpoint := &Endpoint{
Host: s,
}

if parts := strings.Split(endpoint.Host, "@"); len(parts) > 1 {
endpoint.User = parts[0]
endpoint.Host = parts[1]
}

if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 {
endpoint.Host = parts[0]
endpoint.Port, _ = strconv.Atoi(parts[1])
}

return endpoint
}

func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}

type SSHTunnel struct {
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
Log *log.Logger
}

func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) {
if tunnel.Log != nil {
tunnel.Log.Printf(fmt, args...)
}
}

func (tunnel *SSHTunnel) Start() error {
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
return err
}
defer listener.Close()

tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port

for {
conn, err := listener.Accept()
if err != nil {
return err
}

tunnel.logf("accepted connection")
go tunnel.forward(conn)
}
}

func (tunnel *SSHTunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
tunnel.logf("server dial error: %s", err)
return
}

tunnel.logf("connected to %s (1 of 2)\n", tunnel.Server.String())

remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
tunnel.logf("remote dial error: %s", err)
return
}

tunnel.logf("connected to %s (2 of 2)\n", tunnel.Remote.String())

copyConn := func(writer, reader net.Conn) {
_, err := io.Copy(writer, reader)
if err != nil {
tunnel.logf("io.Copy error: %s", err)
}
}

go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}

func PrivateKeyFile(file string) ssh.AuthMethod {
buffer, err := ioutil.ReadFile(file)
if err != nil {
return nil
}

key, err := ssh.ParsePrivateKey(buffer)
if err != nil {
return nil
}

return ssh.PublicKeys(key)
}

func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string) *SSHTunnel {
// A random port will be chosen for us.
localEndpoint := NewEndpoint("localhost:0")

server := NewEndpoint(tunnel)
if server.Port == 0 {
server.Port = 22
}

sshTunnel := &SSHTunnel{
Config: &ssh.ClientConfig{
User: server.User,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// Always accept key.
return nil
},
},
Local: localEndpoint,
Server: server,
Remote: NewEndpoint(destination),
}

return sshTunnel
}

Here is an example of usage:

func main() {
// Setup the tunnel, but do not yet start it yet.
tunnel := NewSSHTunnel(
// User and host of tunnel server, it will default to port 22
// if not specified.
"ec2-user@jumpbox.us-east-1.mydomain.com",

// Pick ONE of the following authentication methods:
PrivateKeyFile("path/to/private/key.pem"), // 1. private key
ssh.Password("password"), // 2. password

// The destination host and port of the actual server.
"dqrsdfdssdfx.us-east-1.redshift.amazonaws.com:5439",
)

// You can provide a logger for debugging, or remove this line to
// make it silent.
tunnel.Log = log.New(os.Stdout, "", log.Ldate | log.Lmicroseconds)

// Start the server in the background. You will need to wait a
// small amount of time for it to bind to the localhost port
// before you can start sending connections.
go tunnel.Start()
time.Sleep(100 * time.Millisecond)

// NewSSHTunnel will bind to a random port so that you can have
// multiple SSH tunnels available. The port is available through:
// tunnel.Local.Port
// You can use any normal Go code to connect to the destination server
// through localhost. You may need to use 127.0.0.1 for some libraries.
//
// Here is an example of connecting to a PostgreSQL server:
conn := fmt.Sprintf("host=127.0.0.1 port=%d username=foo", tunnel.Local.Port)
db, err := sql.Open("postgres", conn)

// ...
}

A big thanks to Svetlin Ralchev who provided a lot of the original bits and pieces.