package sshutil

import (
	"context"
	"errors"
	"fmt"
	"io"
	"log"
	"net"
	"sync"
	"sync/atomic"
	"time"

	"golang.org/x/crypto/ssh"
)

// Server does SSH.
//
// SSH consists of connections, channels and requests. Each connection can have
// an arbitrary number of channels, and has an out-of-band request mechanism.
// Each channel consists of a duplex data stream and its own scoped request
// mechanism. Clients may request channels for arbitrary reasons, and so
// arbitrary application protocols can be hosted atop SSH. The familiar channel
// is a "session", which represents the client's desire to execute a program
// on the Server's host environment, shell or otherwise.
//
// The zero value of Server is a valid Server. If the config is not specified,
// the DefaultServerConfig() is used and an ephemeral host key is generated on
// Server initialization. The server is initialized on the first call to Listen
// or ListenAndServe, after which point the Server is considered to be running.
// Handlers may be modified on a running Server but the ssh config and any hook
// functions must not be changed.
//
type Server struct {
	Addr   string
	Config *ssh.ServerConfig

	ConnectionHook ConnectionHook
	HandshakeHook  HandshakeHook
	DepartureHook  DepartureHook

	GlobalRequests        map[string]RequestHandler
	DefaultRequestHandler RequestHandler

	ChannelHandlers       map[string]ChannelHandler
	DefaultChannelHandler ChannelHandler

	L *log.Logger

	Idle sync.WaitGroup

	mu          sync.RWMutex
	once        sync.Once
	ctx         context.Context
	listeners   map[io.Closer]struct{}
	connections map[string]io.Closer
	shutdown    uint64
	closed      uint64
	cancel      context.CancelFunc
}

// ConnectionHook allows for custom connection logic after a connection is
// established but prior to the SSH handshake. A non-nil error means c will
// be closed and no handshake will be performed.
type ConnectionHook func(c net.Conn) (net.Conn, error)

// HandshakeHook allows execution of code right after a successful handshake
// on a new connection. The full ssh.ServerConn is provided to the hook whereas
// the ssh.AuthLogCallback hook available in the server config only provides
// connection metadata.
type HandshakeHook func(conn *ServerConn) error

// DepartureHook allows execution of code during teardown of an authenticated
// connection (handshake completed). The provided connection may already be
// closed if the peer departed of their own volition.
type DepartureHook func(conn *ServerConn)

// DefaultServerConfig allows public key auth from any client presenting a key
// or certificate.
//
// TODO rename this "InsecureOpenServerConfig"?
//
func DefaultServerConfig() *ssh.ServerConfig {
	return &ssh.ServerConfig{
		ServerVersion:     "SSH-2.0-Go sshutil",
		PublicKeyCallback: allowAllPublicKeys,
	}
}

func allowAllPublicKeys(meta ssh.ConnMetadata, pubkey ssh.PublicKey) (*ssh.Permissions, error) {
	return &ssh.Permissions{
		// Record the public key used for authentication.
		Extensions: map[string]string{
			"pubkey-fp": ssh.FingerprintSHA256(pubkey),
		},
	}, nil
}

// init must be called prior to serving connections. init is safe to be called
// arbitrarily. If init returns an error the server should not proceed with
// serving connections and convey such to the caller.
func (srv *Server) init() error {
	var failure error
	srv.once.Do(func() {
		if srv.L == nil {
			srv.L = logger
		}

		srv.listeners = make(map[io.Closer]struct{})
		srv.connections = make(map[string]io.Closer)
		srv.ctx, srv.cancel = context.WithCancel(context.Background())

		if srv.Config == nil {
			// copy the default config
			config := DefaultServerConfig()

			// Use an ephemeral key.
			signer, err := GenerateKey()
			if err != nil {
				failure = err
				return
			}
			f := ssh.FingerprintSHA256(signer.PublicKey())
			srv.L.Printf("Server ephermeral key '%s'", f)
			config.AddHostKey(signer)
			srv.Config = config
		}

		if srv.ConnectionHook == nil {
			srv.ConnectionHook = defaultConncectionHook
		}
		if srv.HandshakeHook == nil {
			srv.HandshakeHook = defaultHandshakeHook
		}
		if srv.DepartureHook == nil {
			srv.DepartureHook = defaultDepartureHook
		}
		if srv.DefaultRequestHandler == nil {
			srv.DefaultRequestHandler = RequestHandlerFunc(defaultRequestFunc)
		}
		if srv.GlobalRequests == nil {
			srv.GlobalRequests = make(map[string]RequestHandler)
		}
		// Setting a DefaultChannelHandler would cause every type of
		// channel to be accepted. Instead, add a default session
		// handler.
		srv.mu.Lock()
		if srv.ChannelHandlers == nil {
			srv.ChannelHandlers = make(map[string]ChannelHandler)
			srv.ChannelHandlers["session"] = DefaultSessionHandler()
		}
		srv.mu.Unlock()
	})
	return failure
}

// ErrServerClosed is returned from ListenAndServe and Serve when either
// method returnd due to a call to Close or Shutdown.
var ErrServerClosed = errors.New("sshutil: Server closed")

// ListenAndServe blocks listening on Server.Addr. If Addr is empty, the
// server listens on localhost:22, the ssh port. The returned error is never
// nil. After Shutdown or Close, this method returns ErrServerClosed.
func (srv *Server) ListenAndServe() error {
	if srv.ShutdownCalled() {
		return ErrServerClosed
	}
	addr := srv.Addr
	if addr == "" {
		addr = ":22"
	}
	ln, err := net.Listen("tcp", addr)
	if err != nil {
		return err
	}
	return srv.Serve(ln)
}

// Serve blocks accepting connections on the provided listener. Serve always
// returns a non-nil error and closes listener. After Shutdown or Close, the
// returned error is ErrServerClosed.
func (srv *Server) Serve(listener net.Listener) error {
	// Synchonize on an initialized Server
	if err := srv.init(); err != nil {
		listener.Close()
		return err
	}
	if err := srv.trackListener(listener); err != nil {
		listener.Close()
		return err
	}
	defer srv.forgetListener(listener)

	srv.L.Printf("Server commence listening on %s", listener.Addr())

	// Delay inspired by net/http:
	var delay time.Duration
	delayedTryAgain := func(err error) {
		if delay == 0 {
			delay = 5 * time.Millisecond
		} else {
			delay *= 2
		}
		if max := 1 * time.Second; delay > max {
			delay = max
		}
		srv.L.Printf("Server accept error '%v'; retrying in %v ns", err, delay)
		time.Sleep(delay)
	}

	// Accept loop:
	for {
		conn, err := listener.Accept()
		if err != nil {
			var ne net.Error
			if errors.As(err, &ne) && ne.Temporary() {
				delayedTryAgain(ne)
				continue
			}
			addr := listener.Addr().String()
			srv.L.Printf("Server finished listening on %s", addr)
			if srv.ShutdownCalled() {
				return ErrServerClosed
			}
			return err
		}
		srv.Idle.Add(1)
		c, err := srv.ConnectionHook(conn)
		if err != nil {
			conn.Close()
			srv.Idle.Done()
			continue
		}
		go srv.handshake(c)
	}
}

func defaultConncectionHook(conn net.Conn) (net.Conn, error) {
	return conn, nil
}

// ShutdownCalled reports whether the server is Shutdown or not.
func (srv *Server) ShutdownCalled() bool {
	return atomic.LoadUint64(&srv.shutdown) == 1
}

// Shutdown stops any listen loops from accepting new connections and closes
// all tracked listeners. Active connections are not closed.
func (srv *Server) Shutdown() error {
	if !atomic.CompareAndSwapUint64(&srv.shutdown, 0, 1) {
		return ErrServerClosed
	}
	// First call to Shutdown. Stop listen loops:
	srv.mu.RLock()
	defer srv.mu.RUnlock()
	for l := range srv.listeners {
		err := l.Close()
		if err != nil {
			return err
		}
	}
	// Notify handlers so they can wind down gracefully.
	srv.cancel()
	return nil
}

// Shutdown the server and wait for all connections to drain gracefully. If the
// provided context is canceled return the context's error.
func (srv *Server) ShutdownAndWait(ctx context.Context) error {
	err := srv.Shutdown()
	if err != nil && err != ErrServerClosed {
		return err
	}

	srv.L.Println("Server waiting for peer connections to drain...")
	idle := make(chan struct{}, 1)
	go func() {
		srv.Idle.Wait()
		close(idle)
	}()
	select {
	case <-ctx.Done():
		return ctx.Err()
	case <-idle:
		return nil
	}
}

// CloseCalled reports whether the server is Closed or not. A closed server
// does not allow new connections to be tracked.
func (srv *Server) CloseCalled() bool {
	return atomic.LoadUint64(&srv.closed) == 1
}

// Close stops the server immediately. Close calls Shutdown and then procedes
// to close any open connections. ...A close handler allows you inject custom
// teardown logic on an open ssh stream...
func (srv *Server) Close() error {
	err := srv.Shutdown()
	if err != nil && err != ErrServerClosed {
		return err
	}
	if !atomic.CompareAndSwapUint64(&srv.closed, 0, 1) {
		return ErrServerClosed
	}
	// First call to Close. Drain connections forcibly:
	srv.mu.Lock()
	defer srv.mu.Unlock()
	for _, c := range srv.connections {
		err := c.Close()
		if err != nil {
			return err
		}
	}
	return nil
}

// ServerConn is a facade that decorates an embedded ssh.ServerConn with an
// associated context and a reference to the server instance.
type ServerConn struct {
	*ssh.ServerConn
	Context context.Context
	Server  *Server
}

type ctxKey int

const (
	_ ctxKey = iota
	// CtxKeyClientVersion retrieves the client version string associated with
	// the connection.
	CtxKeyClientVersion

	// CtxKeyLocalAddr retrieves the address on which the incomming connection
	// was accepted.
	CtxKeyLocalAddr

	// CtxKeyPermissions retrieves the permissions set during authentication.
	// The ssh.Permissions type is used to convey information up the stack.
	CtxKeyPermissions

	// CtxKeyRemoteAddr retrieve the address of the client with which the server
	// is communicating.
	CtxKeyRemoteAddr

	// CtxKeyServerVersion retrieves the server version string the Server sent to
	// the client during the ssh handshake.
	CtxKeyServerVersion

	// CtxKeySessionID retrieves a string that is unique per-connection.
	CtxKeySessionID
)

func connectionContext(ctx context.Context, conn *ssh.ServerConn) context.Context {
	ctx = context.WithValue(ctx, CtxKeyClientVersion, string(conn.ClientVersion()))
	ctx = context.WithValue(ctx, CtxKeyLocalAddr, conn.LocalAddr())
	ctx = context.WithValue(ctx, CtxKeyPermissions, conn.Permissions)
	ctx = context.WithValue(ctx, CtxKeyRemoteAddr, conn.RemoteAddr())
	ctx = context.WithValue(ctx, CtxKeyServerVersion, string(conn.ServerVersion()))
	ctx = context.WithValue(ctx, CtxKeySessionID, conn.SessionID())
	return ctx
}

func (srv *Server) handshake(c net.Conn) {
	// Before use, a handshake must be performed on the incoming net.Conn.
	ssh, channels, global, err := ssh.NewServerConn(c, srv.Config)
	if err != nil {
		srv.L.Printf("Server handshake failure '%v'", err)
		c.Close()
		srv.Idle.Done()
		return
	}
	ctx, cancel := context.WithCancel(srv.ctx)
	conn := &ServerConn{
		ServerConn: ssh,
		Context:    connectionContext(ctx, ssh),
		Server:     srv,
	}
	defer srv.Idle.Done()
	defer conn.Close()
	defer cancel()
	if err := srv.trackConnection(conn); err != nil {
		return
	}
	defer srv.forgetConnection(conn)
	if err := srv.HandshakeHook(conn); err != nil {
		return
	}
	defer srv.DepartureHook(conn)

	// Process the global requests
	go srv.handleRequests(conn, global)

	// Process the channels
	srv.ssh(conn, channels)
}

func defaultHandshakeHook(conn *ServerConn) error {
	pk := conn.Permissions.Extensions["pubkey-fp"]
	conn.Server.L.Printf("Server peer accession '%s'", pk)
	return nil
}

func defaultDepartureHook(conn *ServerConn) {
	pk := conn.Permissions.Extensions["pubkey-fp"]
	conn.Server.L.Printf("Server peer egression '%s'", pk)
}

// RequestHandler is called for incoming global requests.
type RequestHandler interface {
	ServeRequest(conn *ServerConn, r *ssh.Request)
}

// RequestHandlerFunc is ye ol' http/handler adapter type.
// https://golang.org/src/net/http/server.go#L2004
type RequestHandlerFunc func(c *ServerConn, r *ssh.Request)

// ServeRequest calls f(r).
func (f RequestHandlerFunc) ServeRequest(c *ServerConn, r *ssh.Request) {
	f(c, r)
}

func (srv *Server) handleRequests(conn *ServerConn, requests <-chan *ssh.Request) {
	for req := range requests {
		srv.mu.RLock()
		handler, ok := srv.GlobalRequests[req.Type]
		srv.mu.RUnlock()
		if !ok {
			handler = srv.DefaultRequestHandler
		}
		handler.ServeRequest(conn, req)
	}
}

// discard request
func defaultRequestFunc(conn *ServerConn, req *ssh.Request) {
	if req.WantReply {
		req.Reply(false, nil)
	}
}

func (srv *Server) ssh(conn *ServerConn, channels <-chan ssh.NewChannel) {
	for candidate := range channels {
		t := candidate.ChannelType()
		srv.mu.RLock()
		handler, ok := srv.ChannelHandlers[t]
		srv.mu.RUnlock()
		if !ok {
			handler = srv.DefaultChannelHandler
		}
		if handler == nil {
			unknown := ssh.UnknownChannelType
			err := candidate.Reject(unknown, "unknown channel type")
			if err != nil {
				log.Printf("Server error rejecting channel '%s': %v", t, err)
			}
			continue
		}
		channel, requests, err := candidate.Accept()
		if err != nil {
			srv.L.Printf("Server error accepting channel '%s': %v", t, err)
			continue
		}
		ctx, cancel := context.WithCancel(conn.Context)
		defer cancel()
		stream := Channel{
			Channel: channel,
			Context: ctx,
			Conn:    conn,
		}
		go handler.ServeChannel(stream, requests)
	}
}

// Channel is a facade that decorates an embedded ssh.Channel with a context
// and a referrence to the server instance.
type Channel struct {
	ssh.Channel
	Context context.Context
	Conn    *ServerConn
}

// ChannelHandler is called for each new ssh stream. When the provided context
// is canceled, the ctx.Done chan will have data ready.
type ChannelHandler interface {
	ServeChannel(stream Channel, requests <-chan *ssh.Request)
}

// ChannelHandlerFunc is ye ol' http/handler adapter type.
// https://golang.org/src/net/http/server.go#L2004
type ChannelHandlerFunc func(stream Channel, requests <-chan *ssh.Request)

// ServeChannel calls f(stream, requests).
func (f ChannelHandlerFunc) ServeChannel(stream Channel, requests <-chan *ssh.Request) {
	f(stream, requests)
}

//
// There is no default channel handler because setting one would cause the
// server to acept all requests.
//

// Request registers a handler to be called on incomming global (conn) requests
// of type reqType. Only one handler may be registered for a given reqType. It
// is an error if this method is called twice with the same reqType.
func (srv *Server) Request(reqType string, handler RequestHandler) error {
	srv.mu.Lock()
	defer srv.mu.Unlock()
	if srv.GlobalRequests == nil {
		srv.GlobalRequests = make(map[string]RequestHandler)
	} else {
		_, exists := srv.GlobalRequests[reqType]
		if exists {
			return fmt.Errorf("sshutil: request name '%s' already registered", reqType)
		}
	}
	srv.GlobalRequests[reqType] = handler
	return nil
}

// Channel registers a handler for incomming channels named name.
func (srv *Server) Channel(name string, handler ChannelHandler) error {
	srv.mu.Lock()
	defer srv.mu.Unlock()
	if srv.ChannelHandlers == nil {
		srv.ChannelHandlers = make(map[string]ChannelHandler)
	} else {
		_, exists := srv.ChannelHandlers[name]
		if exists {
			return fmt.Errorf("sshutil: channel name '%s' already registered", name)
		}
	}
	srv.ChannelHandlers[name] = handler
	return nil
}

func (srv *Server) trackListener(ln net.Listener) error {
	if srv.ShutdownCalled() {
		return ErrServerClosed
	}
	srv.mu.Lock()
	defer srv.mu.Unlock()
	_, exists := srv.listeners[ln]
	if exists {
		return errors.New("sshutil: listener already registered")
	}
	srv.listeners[ln] = struct{}{}
	return nil
}

func (srv *Server) forgetListener(ln net.Listener) {
	srv.mu.Lock()
	defer srv.mu.Unlock()
	delete(srv.listeners, ln)
}

func (srv *Server) trackConnection(c *ServerConn) error {
	if srv.CloseCalled() {
		return ErrServerClosed
	}
	id := string(c.SessionID())
	srv.mu.Lock()
	defer srv.mu.Unlock()
	_, exists := srv.connections[id]
	if exists {
		return errors.New("sshutil: connection already registered")
	}
	srv.connections[id] = c
	return nil

}

func (srv *Server) forgetConnection(c *ServerConn) {
	id := string(c.SessionID())
	srv.mu.Lock()
	defer srv.mu.Unlock()
	delete(srv.connections, id)
}