From a978de102709f9755ade652fd5f3393bfa30cf2b Mon Sep 17 00:00:00 2001 From: David Cowden <dcow@smallstep.com> Date: Wed, 17 Jun 2020 00:54:10 -0700 Subject: [PATCH] server: Add ServerConn to the request handler Rename GlobalRequest to Request since it's global nature is implicit by the fact that it is a method on Server, and explicit by the fact that it now receives a reference to the associated connection when it is called. --- server.go | 150 +++++++++++++++++++++++++++--------------------------- 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/server.go b/server.go index 7fadf90..4c1eacc 100644 --- a/server.go +++ b/server.go @@ -311,6 +311,52 @@ func (srv *Server) Close() error { return nil } +// ServerConn is a facade that decorates an embedded ssh.ServerConn with an +// associated context and a reference to the server instance. +type ServerConn struct { + *ssh.ServerConn + Context context.Context + Server *Server +} + +type ctxKey int + +const ( + _ ctxKey = iota + // CtxKeyClientVersion retrieves the client version string associated with + // the connection. + CtxKeyClientVersion + + // CtxKeyLocalAddr retrieves the address on which the incomming connection + // was accepted. + CtxKeyLocalAddr + + // CtxKeyPermissions retrieves the permissions set during authentication. + // The ssh.Permissions type is used to convey information up the stack. + CtxKeyPermissions + + // CtxKeyRemoteAddr retrieve the address of the client with which the server + // is communicating. + CtxKeyRemoteAddr + + // CtxKeyServerVersion retrieves the server version string the Server sent to + // the client during the ssh handshake. + CtxKeyServerVersion + + // CtxKeySessionID retrieves a string that is unique per-connection. + CtxKeySessionID +) + +func connectionContext(ctx context.Context, conn *ssh.ServerConn) context.Context { + ctx = context.WithValue(ctx, CtxKeyClientVersion, string(conn.ClientVersion())) + ctx = context.WithValue(ctx, CtxKeyLocalAddr, conn.LocalAddr()) + ctx = context.WithValue(ctx, CtxKeyPermissions, conn.Permissions) + ctx = context.WithValue(ctx, CtxKeyRemoteAddr, conn.RemoteAddr()) + ctx = context.WithValue(ctx, CtxKeyServerVersion, string(conn.ServerVersion())) + ctx = context.WithValue(ctx, CtxKeySessionID, conn.SessionID()) + return ctx +} + func (srv *Server) handshake(c net.Conn) { // Before use, a handshake must be performed on the incoming net.Conn. ssh, channels, global, err := ssh.NewServerConn(c, srv.Config) @@ -339,7 +385,7 @@ func (srv *Server) handshake(c net.Conn) { defer srv.DepartureHook(conn) // Process the global requests - go srv.handleRequests(global) + go srv.handleRequests(conn, global) // Process the channels srv.ssh(conn, channels) @@ -356,50 +402,37 @@ func defaultDepartureHook(conn *ServerConn) { conn.Server.L.Printf("Server peer egression '%s'", pk) } -// ServerConn is a facade that decorates an embedded ssh.ServerConn with an -// associated context and a reference to the server instance. -type ServerConn struct { - *ssh.ServerConn - Context context.Context - Server *Server +// RequestHandler is called for incoming global requests. +type RequestHandler interface { + ServeRequest(conn *ServerConn, r *ssh.Request) } -type ctxKey int - -const ( - _ ctxKey = iota - // CtxKeyClientVersion retrieves the client version string associated with - // the connection. - CtxKeyClientVersion - - // CtxKeyLocalAddr retrieves the address on which the incomming connection - // was accepted. - CtxKeyLocalAddr - - // CtxKeyPermissions retrieves the permissions set during authentication. - // The ssh.Permissions type is used to convey information up the stack. - CtxKeyPermissions - - // CtxKeyRemoteAddr retrieve the address of the client with which the server - // is communicating. - CtxKeyRemoteAddr +// RequestHandlerFunc is ye ol' http/handler adapter type. +// https://golang.org/src/net/http/server.go#L2004 +type RequestHandlerFunc func(c *ServerConn, r *ssh.Request) - // CtxKeyServerVersion retrieves the server version string the Server sent to - // the client during the ssh handshake. - CtxKeyServerVersion +// ServeRequest calls f(r). +func (f RequestHandlerFunc) ServeRequest(c *ServerConn, r *ssh.Request) { + f(c, r) +} - // CtxKeySessionID retrieves a string that is unique per-connection. - CtxKeySessionID -) +func (srv *Server) handleRequests(conn *ServerConn, requests <-chan *ssh.Request) { + for req := range requests { + srv.mu.RLock() + handler, ok := srv.GlobalRequests[req.Type] + srv.mu.RUnlock() + if !ok { + handler = srv.DefaultRequestHandler + } + handler.ServeRequest(conn, req) + } +} -func connectionContext(ctx context.Context, conn *ssh.ServerConn) context.Context { - ctx = context.WithValue(ctx, CtxKeyClientVersion, string(conn.ClientVersion())) - ctx = context.WithValue(ctx, CtxKeyLocalAddr, conn.LocalAddr()) - ctx = context.WithValue(ctx, CtxKeyPermissions, conn.Permissions) - ctx = context.WithValue(ctx, CtxKeyRemoteAddr, conn.RemoteAddr()) - ctx = context.WithValue(ctx, CtxKeyServerVersion, string(conn.ServerVersion())) - ctx = context.WithValue(ctx, CtxKeySessionID, conn.SessionID()) - return ctx +// discard request +func defaultRequestFunc(conn *ServerConn, req *ssh.Request) { + if req.WantReply { + req.Reply(false, nil) + } } func (srv *Server) ssh(conn *ServerConn, channels <-chan ssh.NewChannel) { @@ -463,43 +496,10 @@ func (f ChannelHandlerFunc) ServeChannel(stream Channel, requests <-chan *ssh.Re // server to acept all requests. // -// RequestHandler is called for incoming global requests. -type RequestHandler interface { - ServeRequest(r *ssh.Request) -} - -// RequestHandlerFunc is ye ol' http/handler adapter type. -// https://golang.org/src/net/http/server.go#L2004 -type RequestHandlerFunc func(r *ssh.Request) - -// ServeRequest calls f(r). -func (f RequestHandlerFunc) ServeRequest(r *ssh.Request) { - f(r) -} - -func (srv *Server) handleRequests(requests <-chan *ssh.Request) { - for req := range requests { - srv.mu.RLock() - handler, ok := srv.GlobalRequests[req.Type] - srv.mu.RUnlock() - if !ok { - handler = srv.DefaultRequestHandler - } - handler.ServeRequest(req) - } -} - -// discard request -func defaultRequestFunc(req *ssh.Request) { - if req.WantReply { - req.Reply(false, nil) - } -} - -// GlobalRequest registers a handler to be called on incomming global requests +// Request registers a handler to be called on incomming global (conn) requests // of type reqType. Only one handler may be registered for a given reqType. It // is an error if this method is called twice with the same reqType. -func (srv *Server) GlobalRequest(reqType string, handler RequestHandler) error { +func (srv *Server) Request(reqType string, handler RequestHandler) error { srv.mu.Lock() defer srv.mu.Unlock() if srv.GlobalRequests == nil { -- GitLab