From a978de102709f9755ade652fd5f3393bfa30cf2b Mon Sep 17 00:00:00 2001
From: David Cowden <dcow@smallstep.com>
Date: Wed, 17 Jun 2020 00:54:10 -0700
Subject: [PATCH] server: Add ServerConn to the request handler

Rename GlobalRequest to Request since it's global nature is implicit by
the fact that it is a method on Server, and explicit by the fact that it
now receives a reference to the associated connection when it is called.
---
 server.go | 150 +++++++++++++++++++++++++++---------------------------
 1 file changed, 75 insertions(+), 75 deletions(-)

diff --git a/server.go b/server.go
index 7fadf90..4c1eacc 100644
--- a/server.go
+++ b/server.go
@@ -311,6 +311,52 @@ func (srv *Server) Close() error {
 	return nil
 }
 
+// ServerConn is a facade that decorates an embedded ssh.ServerConn with an
+// associated context and a reference to the server instance.
+type ServerConn struct {
+	*ssh.ServerConn
+	Context context.Context
+	Server  *Server
+}
+
+type ctxKey int
+
+const (
+	_ ctxKey = iota
+	// CtxKeyClientVersion retrieves the client version string associated with
+	// the connection.
+	CtxKeyClientVersion
+
+	// CtxKeyLocalAddr retrieves the address on which the incomming connection
+	// was accepted.
+	CtxKeyLocalAddr
+
+	// CtxKeyPermissions retrieves the permissions set during authentication.
+	// The ssh.Permissions type is used to convey information up the stack.
+	CtxKeyPermissions
+
+	// CtxKeyRemoteAddr retrieve the address of the client with which the server
+	// is communicating.
+	CtxKeyRemoteAddr
+
+	// CtxKeyServerVersion retrieves the server version string the Server sent to
+	// the client during the ssh handshake.
+	CtxKeyServerVersion
+
+	// CtxKeySessionID retrieves a string that is unique per-connection.
+	CtxKeySessionID
+)
+
+func connectionContext(ctx context.Context, conn *ssh.ServerConn) context.Context {
+	ctx = context.WithValue(ctx, CtxKeyClientVersion, string(conn.ClientVersion()))
+	ctx = context.WithValue(ctx, CtxKeyLocalAddr, conn.LocalAddr())
+	ctx = context.WithValue(ctx, CtxKeyPermissions, conn.Permissions)
+	ctx = context.WithValue(ctx, CtxKeyRemoteAddr, conn.RemoteAddr())
+	ctx = context.WithValue(ctx, CtxKeyServerVersion, string(conn.ServerVersion()))
+	ctx = context.WithValue(ctx, CtxKeySessionID, conn.SessionID())
+	return ctx
+}
+
 func (srv *Server) handshake(c net.Conn) {
 	// Before use, a handshake must be performed on the incoming net.Conn.
 	ssh, channels, global, err := ssh.NewServerConn(c, srv.Config)
@@ -339,7 +385,7 @@ func (srv *Server) handshake(c net.Conn) {
 	defer srv.DepartureHook(conn)
 
 	// Process the global requests
-	go srv.handleRequests(global)
+	go srv.handleRequests(conn, global)
 
 	// Process the channels
 	srv.ssh(conn, channels)
@@ -356,50 +402,37 @@ func defaultDepartureHook(conn *ServerConn) {
 	conn.Server.L.Printf("Server peer egression '%s'", pk)
 }
 
-// ServerConn is a facade that decorates an embedded ssh.ServerConn with an
-// associated context and a reference to the server instance.
-type ServerConn struct {
-	*ssh.ServerConn
-	Context context.Context
-	Server  *Server
+// RequestHandler is called for incoming global requests.
+type RequestHandler interface {
+	ServeRequest(conn *ServerConn, r *ssh.Request)
 }
 
-type ctxKey int
-
-const (
-	_ ctxKey = iota
-	// CtxKeyClientVersion retrieves the client version string associated with
-	// the connection.
-	CtxKeyClientVersion
-
-	// CtxKeyLocalAddr retrieves the address on which the incomming connection
-	// was accepted.
-	CtxKeyLocalAddr
-
-	// CtxKeyPermissions retrieves the permissions set during authentication.
-	// The ssh.Permissions type is used to convey information up the stack.
-	CtxKeyPermissions
-
-	// CtxKeyRemoteAddr retrieve the address of the client with which the server
-	// is communicating.
-	CtxKeyRemoteAddr
+// RequestHandlerFunc is ye ol' http/handler adapter type.
+// https://golang.org/src/net/http/server.go#L2004
+type RequestHandlerFunc func(c *ServerConn, r *ssh.Request)
 
-	// CtxKeyServerVersion retrieves the server version string the Server sent to
-	// the client during the ssh handshake.
-	CtxKeyServerVersion
+// ServeRequest calls f(r).
+func (f RequestHandlerFunc) ServeRequest(c *ServerConn, r *ssh.Request) {
+	f(c, r)
+}
 
-	// CtxKeySessionID retrieves a string that is unique per-connection.
-	CtxKeySessionID
-)
+func (srv *Server) handleRequests(conn *ServerConn, requests <-chan *ssh.Request) {
+	for req := range requests {
+		srv.mu.RLock()
+		handler, ok := srv.GlobalRequests[req.Type]
+		srv.mu.RUnlock()
+		if !ok {
+			handler = srv.DefaultRequestHandler
+		}
+		handler.ServeRequest(conn, req)
+	}
+}
 
-func connectionContext(ctx context.Context, conn *ssh.ServerConn) context.Context {
-	ctx = context.WithValue(ctx, CtxKeyClientVersion, string(conn.ClientVersion()))
-	ctx = context.WithValue(ctx, CtxKeyLocalAddr, conn.LocalAddr())
-	ctx = context.WithValue(ctx, CtxKeyPermissions, conn.Permissions)
-	ctx = context.WithValue(ctx, CtxKeyRemoteAddr, conn.RemoteAddr())
-	ctx = context.WithValue(ctx, CtxKeyServerVersion, string(conn.ServerVersion()))
-	ctx = context.WithValue(ctx, CtxKeySessionID, conn.SessionID())
-	return ctx
+// discard request
+func defaultRequestFunc(conn *ServerConn, req *ssh.Request) {
+	if req.WantReply {
+		req.Reply(false, nil)
+	}
 }
 
 func (srv *Server) ssh(conn *ServerConn, channels <-chan ssh.NewChannel) {
@@ -463,43 +496,10 @@ func (f ChannelHandlerFunc) ServeChannel(stream Channel, requests <-chan *ssh.Re
 // server to acept all requests.
 //
 
-// RequestHandler is called for incoming global requests.
-type RequestHandler interface {
-	ServeRequest(r *ssh.Request)
-}
-
-// RequestHandlerFunc is ye ol' http/handler adapter type.
-// https://golang.org/src/net/http/server.go#L2004
-type RequestHandlerFunc func(r *ssh.Request)
-
-// ServeRequest calls f(r).
-func (f RequestHandlerFunc) ServeRequest(r *ssh.Request) {
-	f(r)
-}
-
-func (srv *Server) handleRequests(requests <-chan *ssh.Request) {
-	for req := range requests {
-		srv.mu.RLock()
-		handler, ok := srv.GlobalRequests[req.Type]
-		srv.mu.RUnlock()
-		if !ok {
-			handler = srv.DefaultRequestHandler
-		}
-		handler.ServeRequest(req)
-	}
-}
-
-// discard request
-func defaultRequestFunc(req *ssh.Request) {
-	if req.WantReply {
-		req.Reply(false, nil)
-	}
-}
-
-// GlobalRequest registers a handler to be called on incomming global requests
+// Request registers a handler to be called on incomming global (conn) requests
 // of type reqType. Only one handler may be registered for a given reqType. It
 // is an error if this method is called twice with the same reqType.
-func (srv *Server) GlobalRequest(reqType string, handler RequestHandler) error {
+func (srv *Server) Request(reqType string, handler RequestHandler) error {
 	srv.mu.Lock()
 	defer srv.mu.Unlock()
 	if srv.GlobalRequests == nil {
-- 
GitLab