Skip to content
Snippets Groups Projects
Commit b1b78f7c authored by David Cowden's avatar David Cowden
Browse files

server: Initialize handler maps independently

Don't trigger initialization of the whole server when adding channel and
request handlers. We only want to initialize the server right before it
begins serving.
parent 199671d0
Branches
Tags
No related merge requests found
...@@ -162,10 +162,12 @@ func (srv *Server) init() error { ...@@ -162,10 +162,12 @@ func (srv *Server) init() error {
// Setting a DefaultChannelHandler would cause every type of // Setting a DefaultChannelHandler would cause every type of
// channel to be accepted. Instead, add a default session // channel to be accepted. Instead, add a default session
// handler. // handler.
srv.mu.Lock()
if srv.ChannelHandlers == nil { if srv.ChannelHandlers == nil {
srv.ChannelHandlers = make(map[string]ChannelHandler) srv.ChannelHandlers = make(map[string]ChannelHandler)
srv.ChannelHandlers["session"] = DefaultSessionHandler() srv.ChannelHandlers["session"] = DefaultSessionHandler()
} }
srv.mu.Unlock()
}) })
return failure return failure
} }
...@@ -482,32 +484,32 @@ func defaultRequestFunc(req *ssh.Request) { ...@@ -482,32 +484,32 @@ func defaultRequestFunc(req *ssh.Request) {
// GlobalRequest registers a handler to be called on incomming global requests // GlobalRequest registers a handler to be called on incomming global requests
// of type reqType. Only one handler may be registered for a given reqType. It // of type reqType. Only one handler may be registered for a given reqType. It
// is an error if this method is called twice with the same reqType. // is an error if this method is called twice with the same reqType.
func (srv *Server) GlobalRequest(reqType string, h RequestHandler) error { func (srv *Server) GlobalRequest(reqType string, handler RequestHandler) error {
err := srv.init()
if err != nil {
return err
}
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
_, exists := srv.GlobalRequests[reqType] if srv.GlobalRequests == nil {
if exists { srv.GlobalRequests = make(map[string]RequestHandler)
return fmt.Errorf("sshutil: request name '%s' already registered", reqType) } else {
_, exists := srv.GlobalRequests[reqType]
if exists {
return fmt.Errorf("sshutil: request name '%s' already registered", reqType)
}
} }
srv.GlobalRequests[reqType] = h srv.GlobalRequests[reqType] = handler
return nil return nil
} }
// Channel registers a handler for incomming channels named name. // Channel registers a handler for incomming channels named name.
func (srv *Server) Channel(name string, handler ChannelHandler) error { func (srv *Server) Channel(name string, handler ChannelHandler) error {
err := srv.init()
if err != nil {
return err
}
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
_, exists := srv.ChannelHandlers[name] if srv.ChannelHandlers == nil {
if exists { srv.ChannelHandlers = make(map[string]ChannelHandler)
return fmt.Errorf("sshutil: channel name '%s' already registered", name) } else {
_, exists := srv.ChannelHandlers[name]
if exists {
return fmt.Errorf("sshutil: channel name '%s' already registered", name)
}
} }
srv.ChannelHandlers[name] = handler srv.ChannelHandlers[name] = handler
return nil return nil
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment