From b1b78f7cac84bef2bf5b412f92992a794b978c85 Mon Sep 17 00:00:00 2001
From: David Cowden <dcow@smallstep.com>
Date: Wed, 17 Jun 2020 00:10:03 -0700
Subject: [PATCH] server: Initialize handler maps independently

Don't trigger initialization of the whole server when adding channel and
request handlers. We only want to initialize the server right before it
begins serving.
---
 server.go | 34 ++++++++++++++++++----------------
 1 file changed, 18 insertions(+), 16 deletions(-)

diff --git a/server.go b/server.go
index 12ca67c..f3fe9e3 100644
--- a/server.go
+++ b/server.go
@@ -162,10 +162,12 @@ func (srv *Server) init() error {
 		// Setting a DefaultChannelHandler would cause every type of
 		// channel to be accepted. Instead, add a default session
 		// handler.
+		srv.mu.Lock()
 		if srv.ChannelHandlers == nil {
 			srv.ChannelHandlers = make(map[string]ChannelHandler)
 			srv.ChannelHandlers["session"] = DefaultSessionHandler()
 		}
+		srv.mu.Unlock()
 	})
 	return failure
 }
@@ -482,32 +484,32 @@ func defaultRequestFunc(req *ssh.Request) {
 // GlobalRequest registers a handler to be called on incomming global requests
 // of type reqType. Only one handler may be registered for a given reqType. It
 // is an error if this method is called twice with the same reqType.
-func (srv *Server) GlobalRequest(reqType string, h RequestHandler) error {
-	err := srv.init()
-	if err != nil {
-		return err
-	}
+func (srv *Server) GlobalRequest(reqType string, handler RequestHandler) error {
 	srv.mu.Lock()
 	defer srv.mu.Unlock()
-	_, exists := srv.GlobalRequests[reqType]
-	if exists {
-		return fmt.Errorf("sshutil: request name '%s' already registered", reqType)
+	if srv.GlobalRequests == nil {
+		srv.GlobalRequests = make(map[string]RequestHandler)
+	} else {
+		_, exists := srv.GlobalRequests[reqType]
+		if exists {
+			return fmt.Errorf("sshutil: request name '%s' already registered", reqType)
+		}
 	}
-	srv.GlobalRequests[reqType] = h
+	srv.GlobalRequests[reqType] = handler
 	return nil
 }
 
 // Channel registers a handler for incomming channels named name.
 func (srv *Server) Channel(name string, handler ChannelHandler) error {
-	err := srv.init()
-	if err != nil {
-		return err
-	}
 	srv.mu.Lock()
 	defer srv.mu.Unlock()
-	_, exists := srv.ChannelHandlers[name]
-	if exists {
-		return fmt.Errorf("sshutil: channel name '%s' already registered", name)
+	if srv.ChannelHandlers == nil {
+		srv.ChannelHandlers = make(map[string]ChannelHandler)
+	} else {
+		_, exists := srv.ChannelHandlers[name]
+		if exists {
+			return fmt.Errorf("sshutil: channel name '%s' already registered", name)
+		}
 	}
 	srv.ChannelHandlers[name] = handler
 	return nil
-- 
GitLab