Skip to content
Snippets Groups Projects
Commit a978de10 authored by David Cowden's avatar David Cowden
Browse files

server: Add ServerConn to the request handler

Rename GlobalRequest to Request since it's global nature is implicit by
the fact that it is a method on Server, and explicit by the fact that it
now receives a reference to the associated connection when it is called.
parent ac2d2ff9
No related branches found
No related tags found
No related merge requests found
...@@ -311,6 +311,52 @@ func (srv *Server) Close() error { ...@@ -311,6 +311,52 @@ func (srv *Server) Close() error {
return nil return nil
} }
// ServerConn is a facade that decorates an embedded ssh.ServerConn with an
// associated context and a reference to the server instance.
type ServerConn struct {
*ssh.ServerConn
Context context.Context
Server *Server
}
type ctxKey int
const (
_ ctxKey = iota
// CtxKeyClientVersion retrieves the client version string associated with
// the connection.
CtxKeyClientVersion
// CtxKeyLocalAddr retrieves the address on which the incomming connection
// was accepted.
CtxKeyLocalAddr
// CtxKeyPermissions retrieves the permissions set during authentication.
// The ssh.Permissions type is used to convey information up the stack.
CtxKeyPermissions
// CtxKeyRemoteAddr retrieve the address of the client with which the server
// is communicating.
CtxKeyRemoteAddr
// CtxKeyServerVersion retrieves the server version string the Server sent to
// the client during the ssh handshake.
CtxKeyServerVersion
// CtxKeySessionID retrieves a string that is unique per-connection.
CtxKeySessionID
)
func connectionContext(ctx context.Context, conn *ssh.ServerConn) context.Context {
ctx = context.WithValue(ctx, CtxKeyClientVersion, string(conn.ClientVersion()))
ctx = context.WithValue(ctx, CtxKeyLocalAddr, conn.LocalAddr())
ctx = context.WithValue(ctx, CtxKeyPermissions, conn.Permissions)
ctx = context.WithValue(ctx, CtxKeyRemoteAddr, conn.RemoteAddr())
ctx = context.WithValue(ctx, CtxKeyServerVersion, string(conn.ServerVersion()))
ctx = context.WithValue(ctx, CtxKeySessionID, conn.SessionID())
return ctx
}
func (srv *Server) handshake(c net.Conn) { func (srv *Server) handshake(c net.Conn) {
// Before use, a handshake must be performed on the incoming net.Conn. // Before use, a handshake must be performed on the incoming net.Conn.
ssh, channels, global, err := ssh.NewServerConn(c, srv.Config) ssh, channels, global, err := ssh.NewServerConn(c, srv.Config)
...@@ -339,7 +385,7 @@ func (srv *Server) handshake(c net.Conn) { ...@@ -339,7 +385,7 @@ func (srv *Server) handshake(c net.Conn) {
defer srv.DepartureHook(conn) defer srv.DepartureHook(conn)
// Process the global requests // Process the global requests
go srv.handleRequests(global) go srv.handleRequests(conn, global)
// Process the channels // Process the channels
srv.ssh(conn, channels) srv.ssh(conn, channels)
...@@ -356,50 +402,37 @@ func defaultDepartureHook(conn *ServerConn) { ...@@ -356,50 +402,37 @@ func defaultDepartureHook(conn *ServerConn) {
conn.Server.L.Printf("Server peer egression '%s'", pk) conn.Server.L.Printf("Server peer egression '%s'", pk)
} }
// ServerConn is a facade that decorates an embedded ssh.ServerConn with an // RequestHandler is called for incoming global requests.
// associated context and a reference to the server instance. type RequestHandler interface {
type ServerConn struct { ServeRequest(conn *ServerConn, r *ssh.Request)
*ssh.ServerConn
Context context.Context
Server *Server
} }
type ctxKey int // RequestHandlerFunc is ye ol' http/handler adapter type.
// https://golang.org/src/net/http/server.go#L2004
const ( type RequestHandlerFunc func(c *ServerConn, r *ssh.Request)
_ ctxKey = iota
// CtxKeyClientVersion retrieves the client version string associated with
// the connection.
CtxKeyClientVersion
// CtxKeyLocalAddr retrieves the address on which the incomming connection
// was accepted.
CtxKeyLocalAddr
// CtxKeyPermissions retrieves the permissions set during authentication.
// The ssh.Permissions type is used to convey information up the stack.
CtxKeyPermissions
// CtxKeyRemoteAddr retrieve the address of the client with which the server
// is communicating.
CtxKeyRemoteAddr
// CtxKeyServerVersion retrieves the server version string the Server sent to // ServeRequest calls f(r).
// the client during the ssh handshake. func (f RequestHandlerFunc) ServeRequest(c *ServerConn, r *ssh.Request) {
CtxKeyServerVersion f(c, r)
}
// CtxKeySessionID retrieves a string that is unique per-connection. func (srv *Server) handleRequests(conn *ServerConn, requests <-chan *ssh.Request) {
CtxKeySessionID for req := range requests {
) srv.mu.RLock()
handler, ok := srv.GlobalRequests[req.Type]
srv.mu.RUnlock()
if !ok {
handler = srv.DefaultRequestHandler
}
handler.ServeRequest(conn, req)
}
}
func connectionContext(ctx context.Context, conn *ssh.ServerConn) context.Context { // discard request
ctx = context.WithValue(ctx, CtxKeyClientVersion, string(conn.ClientVersion())) func defaultRequestFunc(conn *ServerConn, req *ssh.Request) {
ctx = context.WithValue(ctx, CtxKeyLocalAddr, conn.LocalAddr()) if req.WantReply {
ctx = context.WithValue(ctx, CtxKeyPermissions, conn.Permissions) req.Reply(false, nil)
ctx = context.WithValue(ctx, CtxKeyRemoteAddr, conn.RemoteAddr()) }
ctx = context.WithValue(ctx, CtxKeyServerVersion, string(conn.ServerVersion()))
ctx = context.WithValue(ctx, CtxKeySessionID, conn.SessionID())
return ctx
} }
func (srv *Server) ssh(conn *ServerConn, channels <-chan ssh.NewChannel) { func (srv *Server) ssh(conn *ServerConn, channels <-chan ssh.NewChannel) {
...@@ -463,43 +496,10 @@ func (f ChannelHandlerFunc) ServeChannel(stream Channel, requests <-chan *ssh.Re ...@@ -463,43 +496,10 @@ func (f ChannelHandlerFunc) ServeChannel(stream Channel, requests <-chan *ssh.Re
// server to acept all requests. // server to acept all requests.
// //
// RequestHandler is called for incoming global requests. // Request registers a handler to be called on incomming global (conn) requests
type RequestHandler interface {
ServeRequest(r *ssh.Request)
}
// RequestHandlerFunc is ye ol' http/handler adapter type.
// https://golang.org/src/net/http/server.go#L2004
type RequestHandlerFunc func(r *ssh.Request)
// ServeRequest calls f(r).
func (f RequestHandlerFunc) ServeRequest(r *ssh.Request) {
f(r)
}
func (srv *Server) handleRequests(requests <-chan *ssh.Request) {
for req := range requests {
srv.mu.RLock()
handler, ok := srv.GlobalRequests[req.Type]
srv.mu.RUnlock()
if !ok {
handler = srv.DefaultRequestHandler
}
handler.ServeRequest(req)
}
}
// discard request
func defaultRequestFunc(req *ssh.Request) {
if req.WantReply {
req.Reply(false, nil)
}
}
// GlobalRequest registers a handler to be called on incomming global requests
// of type reqType. Only one handler may be registered for a given reqType. It // of type reqType. Only one handler may be registered for a given reqType. It
// is an error if this method is called twice with the same reqType. // is an error if this method is called twice with the same reqType.
func (srv *Server) GlobalRequest(reqType string, handler RequestHandler) error { func (srv *Server) Request(reqType string, handler RequestHandler) error {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
if srv.GlobalRequests == nil { if srv.GlobalRequests == nil {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment