diff --git a/server.go b/server.go index 12ca67cb6ecd9f3b398c35e95474de7d54b593b8..f3fe9e30d26cc9d086aa46b7d474ce1a118167c3 100644 --- a/server.go +++ b/server.go @@ -162,10 +162,12 @@ func (srv *Server) init() error { // Setting a DefaultChannelHandler would cause every type of // channel to be accepted. Instead, add a default session // handler. + srv.mu.Lock() if srv.ChannelHandlers == nil { srv.ChannelHandlers = make(map[string]ChannelHandler) srv.ChannelHandlers["session"] = DefaultSessionHandler() } + srv.mu.Unlock() }) return failure } @@ -482,32 +484,32 @@ func defaultRequestFunc(req *ssh.Request) { // GlobalRequest registers a handler to be called on incomming global requests // of type reqType. Only one handler may be registered for a given reqType. It // is an error if this method is called twice with the same reqType. -func (srv *Server) GlobalRequest(reqType string, h RequestHandler) error { - err := srv.init() - if err != nil { - return err - } +func (srv *Server) GlobalRequest(reqType string, handler RequestHandler) error { srv.mu.Lock() defer srv.mu.Unlock() - _, exists := srv.GlobalRequests[reqType] - if exists { - return fmt.Errorf("sshutil: request name '%s' already registered", reqType) + if srv.GlobalRequests == nil { + srv.GlobalRequests = make(map[string]RequestHandler) + } else { + _, exists := srv.GlobalRequests[reqType] + if exists { + return fmt.Errorf("sshutil: request name '%s' already registered", reqType) + } } - srv.GlobalRequests[reqType] = h + srv.GlobalRequests[reqType] = handler return nil } // Channel registers a handler for incomming channels named name. func (srv *Server) Channel(name string, handler ChannelHandler) error { - err := srv.init() - if err != nil { - return err - } srv.mu.Lock() defer srv.mu.Unlock() - _, exists := srv.ChannelHandlers[name] - if exists { - return fmt.Errorf("sshutil: channel name '%s' already registered", name) + if srv.ChannelHandlers == nil { + srv.ChannelHandlers = make(map[string]ChannelHandler) + } else { + _, exists := srv.ChannelHandlers[name] + if exists { + return fmt.Errorf("sshutil: channel name '%s' already registered", name) + } } srv.ChannelHandlers[name] = handler return nil