From b1b78f7cac84bef2bf5b412f92992a794b978c85 Mon Sep 17 00:00:00 2001 From: David Cowden <dcow@smallstep.com> Date: Wed, 17 Jun 2020 00:10:03 -0700 Subject: [PATCH] server: Initialize handler maps independently Don't trigger initialization of the whole server when adding channel and request handlers. We only want to initialize the server right before it begins serving. --- server.go | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/server.go b/server.go index 12ca67c..f3fe9e3 100644 --- a/server.go +++ b/server.go @@ -162,10 +162,12 @@ func (srv *Server) init() error { // Setting a DefaultChannelHandler would cause every type of // channel to be accepted. Instead, add a default session // handler. + srv.mu.Lock() if srv.ChannelHandlers == nil { srv.ChannelHandlers = make(map[string]ChannelHandler) srv.ChannelHandlers["session"] = DefaultSessionHandler() } + srv.mu.Unlock() }) return failure } @@ -482,32 +484,32 @@ func defaultRequestFunc(req *ssh.Request) { // GlobalRequest registers a handler to be called on incomming global requests // of type reqType. Only one handler may be registered for a given reqType. It // is an error if this method is called twice with the same reqType. -func (srv *Server) GlobalRequest(reqType string, h RequestHandler) error { - err := srv.init() - if err != nil { - return err - } +func (srv *Server) GlobalRequest(reqType string, handler RequestHandler) error { srv.mu.Lock() defer srv.mu.Unlock() - _, exists := srv.GlobalRequests[reqType] - if exists { - return fmt.Errorf("sshutil: request name '%s' already registered", reqType) + if srv.GlobalRequests == nil { + srv.GlobalRequests = make(map[string]RequestHandler) + } else { + _, exists := srv.GlobalRequests[reqType] + if exists { + return fmt.Errorf("sshutil: request name '%s' already registered", reqType) + } } - srv.GlobalRequests[reqType] = h + srv.GlobalRequests[reqType] = handler return nil } // Channel registers a handler for incomming channels named name. func (srv *Server) Channel(name string, handler ChannelHandler) error { - err := srv.init() - if err != nil { - return err - } srv.mu.Lock() defer srv.mu.Unlock() - _, exists := srv.ChannelHandlers[name] - if exists { - return fmt.Errorf("sshutil: channel name '%s' already registered", name) + if srv.ChannelHandlers == nil { + srv.ChannelHandlers = make(map[string]ChannelHandler) + } else { + _, exists := srv.ChannelHandlers[name] + if exists { + return fmt.Errorf("sshutil: channel name '%s' already registered", name) + } } srv.ChannelHandlers[name] = handler return nil -- GitLab