Skip to content
Snippets Groups Projects
server.go 15.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • package sshutil
    
    import (
    	"context"
    	"errors"
    	"fmt"
    	"io"
    	"log"
    	"net"
    	"sync"
    	"sync/atomic"
    	"time"
    
    	"golang.org/x/crypto/ssh"
    )
    
    
    // Server does SSH.
    //
    // SSH consists of connections, channels and requests. Each connection can have
    // an arbitrary number of channels, and has an out-of-band request mechanism.
    // Each channel consists of a duplex data stream and its own scoped request
    // mechanism. Clients may request channels for arbitrary reasons, and so
    // arbitrary application protocols can be hosted atop SSH. The familiar channel
    // is a "session", which represents the client's desire to execute a program
    // on the Server's host environment, shell or otherwise.
    
    //
    // The zero value of Server is a valid Server. If the config is not specified,
    // the DefaultServerConfig() is used and an ephemeral host key is generated on
    
    // Server initialization. The server is initialized on the first call to Listen
    // or ListenAndServe, after which point the Server is considered to be running.
    // Handlers may be modified on a running Server but the ssh config and any hook
    // functions must not be changed.
    //
    
    type Server struct {
    	Addr   string
    	Config *ssh.ServerConfig
    
    	ConnectionHook ConnectionHook
    	HandshakeHook  HandshakeHook
    	DepartureHook  DepartureHook
    
    	GlobalRequests        map[string]RequestHandler
    	DefaultRequestHandler RequestHandler
    
    	ChannelHandlers       map[string]ChannelHandler
    	DefaultChannelHandler ChannelHandler
    
    	L *log.Logger
    
    
    	Idle sync.WaitGroup
    
    
    	mu          sync.RWMutex
    	once        sync.Once
    	ctx         context.Context
    	listeners   map[io.Closer]struct{}
    	connections map[string]io.Closer
    	shutdown    uint64
    	closed      uint64
    	cancel      context.CancelFunc
    }
    
    // ConnectionHook allows for custom connection logic after a connection is
    // established but prior to the SSH handshake. A non-nil error means c will
    // be closed and no handshake will be performed.
    type ConnectionHook func(c net.Conn) (net.Conn, error)
    
    // HandshakeHook allows execution of code right after a successful handshake
    // on a new connection. The full ssh.ServerConn is provided to the hook whereas
    // the ssh.AuthLogCallback hook available in the server config only provides
    // connection metadata.
    type HandshakeHook func(conn *ServerConn) error
    
    // DepartureHook allows execution of code during teardown of an authenticated
    // connection (handshake completed). The provided connection may already be
    // closed if the peer departed of their own volition.
    type DepartureHook func(conn *ServerConn)
    
    // DefaultServerConfig allows public key auth from any client presenting a key
    // or certificate.
    
    // TODO rename this "InsecureOpenServerConfig"?
    
    func DefaultServerConfig() *ssh.ServerConfig {
    	return &ssh.ServerConfig{
    
    		ServerVersion:     "SSH-2.0-Go sshutil",
    
    		PublicKeyCallback: allowAllPublicKeys,
    	}
    }
    
    func allowAllPublicKeys(meta ssh.ConnMetadata, pubkey ssh.PublicKey) (*ssh.Permissions, error) {
    	return &ssh.Permissions{
    		// Record the public key used for authentication.
    		Extensions: map[string]string{
    			"pubkey-fp": ssh.FingerprintSHA256(pubkey),
    		},
    	}, nil
    }
    
    // init must be called prior to serving connections. init is safe to be called
    // arbitrarily. If init returns an error the server should not proceed with
    // serving connections and convey such to the caller.
    func (srv *Server) init() error {
    	var failure error
    	srv.once.Do(func() {
    		if srv.L == nil {
    			srv.L = logger
    		}
    
    		srv.listeners = make(map[io.Closer]struct{})
    		srv.connections = make(map[string]io.Closer)
    		srv.ctx, srv.cancel = context.WithCancel(context.Background())
    
    		if srv.Config == nil {
    			// copy the default config
    			config := DefaultServerConfig()
    
    			// Use an ephemeral key.
    			signer, err := GenerateKey()
    			if err != nil {
    				failure = err
    				return
    			}
    			f := ssh.FingerprintSHA256(signer.PublicKey())
    			srv.L.Printf("Server ephermeral key '%s'", f)
    			config.AddHostKey(signer)
    			srv.Config = config
    		}
    
    		if srv.ConnectionHook == nil {
    			srv.ConnectionHook = defaultConncectionHook
    		}
    		if srv.HandshakeHook == nil {
    			srv.HandshakeHook = defaultHandshakeHook
    		}
    		if srv.DepartureHook == nil {
    			srv.DepartureHook = defaultDepartureHook
    		}
    		if srv.DefaultRequestHandler == nil {
    			srv.DefaultRequestHandler = RequestHandlerFunc(defaultRequestFunc)
    		}
    		if srv.GlobalRequests == nil {
    			srv.GlobalRequests = make(map[string]RequestHandler)
    		}
    		// Setting a DefaultChannelHandler would cause every type of
    		// channel to be accepted. Instead, add a default session
    		// handler.
    
    		if srv.ChannelHandlers == nil {
    			srv.ChannelHandlers = make(map[string]ChannelHandler)
    			srv.ChannelHandlers["session"] = DefaultSessionHandler()
    		}
    
    	})
    	return failure
    }
    
    // ErrServerClosed is returned from ListenAndServe and Serve when either
    // method returnd due to a call to Close or Shutdown.
    
    var ErrServerClosed = errors.New("sshutil: Server closed")
    
    
    // ListenAndServe blocks listening on Server.Addr. If Addr is empty, the
    // server listens on localhost:22, the ssh port. The returned error is never
    // nil. After Shutdown or Close, this method returns ErrServerClosed.
    func (srv *Server) ListenAndServe() error {
    	if srv.ShutdownCalled() {
    		return ErrServerClosed
    	}
    	addr := srv.Addr
    	if addr == "" {
    		addr = ":22"
    	}
    	ln, err := net.Listen("tcp", addr)
    	if err != nil {
    		return err
    	}
    	return srv.Serve(ln)
    }
    
    // Serve blocks accepting connections on the provided listener. Serve always
    // returns a non-nil error and closes listener. After Shutdown or Close, the
    // returned error is ErrServerClosed.
    func (srv *Server) Serve(listener net.Listener) error {
    	// Synchonize on an initialized Server
    	if err := srv.init(); err != nil {
    		listener.Close()
    		return err
    	}
    	if err := srv.trackListener(listener); err != nil {
    		listener.Close()
    		return err
    	}
    	defer srv.forgetListener(listener)
    
    	srv.L.Printf("Server commence listening on %s", listener.Addr())
    
    	// Delay inspired by net/http:
    	var delay time.Duration
    	delayedTryAgain := func(err error) {
    		if delay == 0 {
    			delay = 5 * time.Millisecond
    		} else {
    			delay *= 2
    		}
    		if max := 1 * time.Second; delay > max {
    			delay = max
    		}
    
    David Cowden's avatar
    David Cowden committed
    		srv.L.Printf("Server accept error '%v'; retrying in %v ns", err, delay)
    
    		time.Sleep(delay)
    	}
    
    	// Accept loop:
    	for {
    		conn, err := listener.Accept()
    		if err != nil {
    			var ne net.Error
    			if errors.As(err, &ne) && ne.Temporary() {
    				delayedTryAgain(ne)
    				continue
    			}
    			addr := listener.Addr().String()
    			srv.L.Printf("Server finished listening on %s", addr)
    			if srv.ShutdownCalled() {
    				return ErrServerClosed
    			}
    			return err
    		}
    		srv.Idle.Add(1)
    		c, err := srv.ConnectionHook(conn)
    		if err != nil {
    			conn.Close()
    			srv.Idle.Done()
    			continue
    		}
    		go srv.handshake(c)
    	}
    }
    
    func defaultConncectionHook(conn net.Conn) (net.Conn, error) {
    	return conn, nil
    }
    
    // ShutdownCalled reports whether the server is Shutdown or not.
    func (srv *Server) ShutdownCalled() bool {
    	return atomic.LoadUint64(&srv.shutdown) == 1
    }
    
    // Shutdown stops any listen loops from accepting new connections and closes
    // all tracked listeners. Active connections are not closed.
    func (srv *Server) Shutdown() error {
    	if !atomic.CompareAndSwapUint64(&srv.shutdown, 0, 1) {
    		return ErrServerClosed
    	}
    	// First call to Shutdown. Stop listen loops:
    
    	srv.mu.RLock()
    	defer srv.mu.RUnlock()
    
    	for l := range srv.listeners {
    		err := l.Close()
    		if err != nil {
    			return err
    		}
    	}
    	// Notify handlers so they can wind down gracefully.
    	srv.cancel()
    	return nil
    }
    
    
    // Shutdown the server and wait for all connections to drain gracefully. If the
    // provided context is canceled return the context's error.
    func (srv *Server) ShutdownAndWait(ctx context.Context) error {
    	err := srv.Shutdown()
    	if err != nil && err != ErrServerClosed {
    		return err
    	}
    
    
    	srv.L.Println("Server waiting for peer connections to drain...")
    
    	idle := make(chan struct{}, 1)
    	go func() {
    		srv.Idle.Wait()
    		close(idle)
    	}()
    	select {
    	case <-ctx.Done():
    		return ctx.Err()
    	case <-idle:
    		return nil
    	}
    }
    
    
    // CloseCalled reports whether the server is Closed or not. A closed server
    // does not allow new connections to be tracked.
    func (srv *Server) CloseCalled() bool {
    	return atomic.LoadUint64(&srv.closed) == 1
    }
    
    // Close stops the server immediately. Close calls Shutdown and then procedes
    // to close any open connections. ...A close handler allows you inject custom
    // teardown logic on an open ssh stream...
    func (srv *Server) Close() error {
    	err := srv.Shutdown()
    	if err != nil && err != ErrServerClosed {
    		return err
    	}
    	if !atomic.CompareAndSwapUint64(&srv.closed, 0, 1) {
    		return ErrServerClosed
    	}
    	// First call to Close. Drain connections forcibly:
    
    	srv.mu.Lock()
    	defer srv.mu.Unlock()
    
    	for _, c := range srv.connections {
    		err := c.Close()
    		if err != nil {
    			return err
    		}
    	}
    	return nil
    }
    
    
    // ServerConn is a facade that decorates an embedded ssh.ServerConn with an
    // associated context and a reference to the server instance.
    type ServerConn struct {
    	*ssh.ServerConn
    	Context context.Context
    	Server  *Server
    }
    
    type ctxKey int
    
    const (
    	_ ctxKey = iota
    	// CtxKeyClientVersion retrieves the client version string associated with
    	// the connection.
    	CtxKeyClientVersion
    
    	// CtxKeyLocalAddr retrieves the address on which the incomming connection
    	// was accepted.
    	CtxKeyLocalAddr
    
    	// CtxKeyPermissions retrieves the permissions set during authentication.
    	// The ssh.Permissions type is used to convey information up the stack.
    	CtxKeyPermissions
    
    	// CtxKeyRemoteAddr retrieve the address of the client with which the server
    	// is communicating.
    	CtxKeyRemoteAddr
    
    	// CtxKeyServerVersion retrieves the server version string the Server sent to
    	// the client during the ssh handshake.
    	CtxKeyServerVersion
    
    	// CtxKeySessionID retrieves a string that is unique per-connection.
    	CtxKeySessionID
    )
    
    func connectionContext(ctx context.Context, conn *ssh.ServerConn) context.Context {
    	ctx = context.WithValue(ctx, CtxKeyClientVersion, string(conn.ClientVersion()))
    	ctx = context.WithValue(ctx, CtxKeyLocalAddr, conn.LocalAddr())
    	ctx = context.WithValue(ctx, CtxKeyPermissions, conn.Permissions)
    	ctx = context.WithValue(ctx, CtxKeyRemoteAddr, conn.RemoteAddr())
    	ctx = context.WithValue(ctx, CtxKeyServerVersion, string(conn.ServerVersion()))
    	ctx = context.WithValue(ctx, CtxKeySessionID, conn.SessionID())
    	return ctx
    }
    
    
    func (srv *Server) handshake(c net.Conn) {
    	// Before use, a handshake must be performed on the incoming net.Conn.
    	ssh, channels, global, err := ssh.NewServerConn(c, srv.Config)
    	if err != nil {
    		srv.L.Printf("Server handshake failure '%v'", err)
    		c.Close()
    		srv.Idle.Done()
    		return
    	}
    	ctx, cancel := context.WithCancel(srv.ctx)
    	conn := &ServerConn{
    		ServerConn: ssh,
    		Context:    connectionContext(ctx, ssh),
    		Server:     srv,
    	}
    	defer srv.Idle.Done()
    	defer conn.Close()
    	defer cancel()
    	if err := srv.trackConnection(conn); err != nil {
    		return
    	}
    	defer srv.forgetConnection(conn)
    	if err := srv.HandshakeHook(conn); err != nil {
    		return
    	}
    	defer srv.DepartureHook(conn)
    
    	// Process the global requests
    
    	go srv.handleRequests(conn, global)
    
    
    	// Process the channels
    	srv.ssh(conn, channels)
    }
    
    func defaultHandshakeHook(conn *ServerConn) error {
    	pk := conn.Permissions.Extensions["pubkey-fp"]
    	conn.Server.L.Printf("Server peer accession '%s'", pk)
    	return nil
    }
    
    func defaultDepartureHook(conn *ServerConn) {
    	pk := conn.Permissions.Extensions["pubkey-fp"]
    	conn.Server.L.Printf("Server peer egression '%s'", pk)
    }
    
    
    // RequestHandler is called for incoming global requests.
    type RequestHandler interface {
    	ServeRequest(conn *ServerConn, r *ssh.Request)
    
    // RequestHandlerFunc is ye ol' http/handler adapter type.
    // https://golang.org/src/net/http/server.go#L2004
    type RequestHandlerFunc func(c *ServerConn, r *ssh.Request)
    
    // ServeRequest calls f(r).
    func (f RequestHandlerFunc) ServeRequest(c *ServerConn, r *ssh.Request) {
    	f(c, r)
    }
    
    func (srv *Server) handleRequests(conn *ServerConn, requests <-chan *ssh.Request) {
    	for req := range requests {
    		srv.mu.RLock()
    		handler, ok := srv.GlobalRequests[req.Type]
    		srv.mu.RUnlock()
    		if !ok {
    			handler = srv.DefaultRequestHandler
    		}
    		handler.ServeRequest(conn, req)
    	}
    }
    
    // discard request
    func defaultRequestFunc(conn *ServerConn, req *ssh.Request) {
    	if req.WantReply {
    		req.Reply(false, nil)
    	}
    
    }
    
    func (srv *Server) ssh(conn *ServerConn, channels <-chan ssh.NewChannel) {
    	for candidate := range channels {
    		t := candidate.ChannelType()
    		srv.mu.RLock()
    		handler, ok := srv.ChannelHandlers[t]
    		srv.mu.RUnlock()
    		if !ok {
    			handler = srv.DefaultChannelHandler
    		}
    		if handler == nil {
    			unknown := ssh.UnknownChannelType
    			err := candidate.Reject(unknown, "unknown channel type")
    			if err != nil {
    
    David Cowden's avatar
    David Cowden committed
    				log.Printf("Server error rejecting channel '%s': %v", t, err)
    
    			}
    			continue
    		}
    		channel, requests, err := candidate.Accept()
    		if err != nil {
    
    David Cowden's avatar
    David Cowden committed
    			srv.L.Printf("Server error accepting channel '%s': %v", t, err)
    
    			continue
    		}
    		ctx, cancel := context.WithCancel(conn.Context)
    		defer cancel()
    		stream := Channel{
    			Channel: channel,
    			Context: ctx,
    			Conn:    conn,
    		}
    		go handler.ServeChannel(stream, requests)
    	}
    }
    
    // Channel is a facade that decorates an embedded ssh.Channel with a context
    // and a referrence to the server instance.
    type Channel struct {
    	ssh.Channel
    	Context context.Context
    	Conn    *ServerConn
    }
    
    
    // ChannelHandler is called for each new ssh stream. When the provided context
    // is canceled, the ctx.Done chan will have data ready.
    type ChannelHandler interface {
    	ServeChannel(stream Channel, requests <-chan *ssh.Request)
    }
    
    // ChannelHandlerFunc is ye ol' http/handler adapter type.
    // https://golang.org/src/net/http/server.go#L2004
    type ChannelHandlerFunc func(stream Channel, requests <-chan *ssh.Request)
    
    // ServeChannel calls f(stream, requests).
    func (f ChannelHandlerFunc) ServeChannel(stream Channel, requests <-chan *ssh.Request) {
    	f(stream, requests)
    }
    
    //
    // There is no default channel handler because setting one would cause the
    // server to acept all requests.
    //
    
    
    // Request registers a handler to be called on incomming global (conn) requests
    
    // of type reqType. Only one handler may be registered for a given reqType. It
    // is an error if this method is called twice with the same reqType.
    
    func (srv *Server) Request(reqType string, handler RequestHandler) error {
    
    	srv.mu.Lock()
    	defer srv.mu.Unlock()
    
    	if srv.GlobalRequests == nil {
    		srv.GlobalRequests = make(map[string]RequestHandler)
    	} else {
    		_, exists := srv.GlobalRequests[reqType]
    		if exists {
    			return fmt.Errorf("sshutil: request name '%s' already registered", reqType)
    		}
    
    	srv.GlobalRequests[reqType] = handler
    
    	return nil
    }
    
    // Channel registers a handler for incomming channels named name.
    func (srv *Server) Channel(name string, handler ChannelHandler) error {
    	srv.mu.Lock()
    	defer srv.mu.Unlock()
    
    	if srv.ChannelHandlers == nil {
    		srv.ChannelHandlers = make(map[string]ChannelHandler)
    	} else {
    		_, exists := srv.ChannelHandlers[name]
    		if exists {
    			return fmt.Errorf("sshutil: channel name '%s' already registered", name)
    		}
    
    	}
    	srv.ChannelHandlers[name] = handler
    	return nil
    }
    
    func (srv *Server) trackListener(ln net.Listener) error {
    	if srv.ShutdownCalled() {
    		return ErrServerClosed
    	}
    	srv.mu.Lock()
    	defer srv.mu.Unlock()
    	_, exists := srv.listeners[ln]
    	if exists {
    
    		return errors.New("sshutil: listener already registered")
    
    	}
    	srv.listeners[ln] = struct{}{}
    	return nil
    }
    
    func (srv *Server) forgetListener(ln net.Listener) {
    	srv.mu.Lock()
    	defer srv.mu.Unlock()
    	delete(srv.listeners, ln)
    }
    
    func (srv *Server) trackConnection(c *ServerConn) error {
    	if srv.CloseCalled() {
    		return ErrServerClosed
    	}
    	id := string(c.SessionID())
    	srv.mu.Lock()
    	defer srv.mu.Unlock()
    	_, exists := srv.connections[id]
    	if exists {
    
    		return errors.New("sshutil: connection already registered")
    
    	}
    	srv.connections[id] = c
    	return nil
    
    }
    
    func (srv *Server) forgetConnection(c *ServerConn) {
    	id := string(c.SessionID())
    	srv.mu.Lock()
    	defer srv.mu.Unlock()
    	delete(srv.connections, id)
    }