Skip to content
Snippets Groups Projects
utils.py 3.03 KiB
Newer Older
  • Learn to ignore specific revisions
  • Ming Ding's avatar
    Ming Ding committed
    # coding=utf-8
    # Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    
    import torch
    
    Ming Ding's avatar
    Ming Ding committed
    import math
    
    Ming Ding's avatar
    Ming Ding committed
    
    
    def ensure_divisibility(numerator, denominator):
        """Ensure that numerator is divisible by the denominator."""
        assert numerator % denominator == 0, '{} is not divisible by {}'.format(
            numerator, denominator)
    
    
    def divide(numerator, denominator):
        """Ensure that numerator is divisible by the denominator and return
        the division value."""
        ensure_divisibility(numerator, denominator)
        return numerator // denominator
    
    
    def split_tensor_along_last_dim(tensor, num_partitions,
                                    contiguous_split_chunks=False):
        """Split a tensor along its last dimension.
        Arguments:
            tensor: input tensor.
            num_partitions: number of partitions to split the tensor
            contiguous_split_chunks: If True, make each chunk contiguous
                                     in memory.
        """
        # Get the size and dimension.
        last_dim = tensor.dim() - 1
        last_dim_size = divide(tensor.size()[last_dim], num_partitions)
        # Split.
        tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
        # Note: torch.split does not create contiguous tensors by default.
        if contiguous_split_chunks:
            return tuple(chunk.contiguous() for chunk in tensor_list)
    
        return tensor_list
    
    
    class VocabUtility:
        """Split the vocabulary into `world_size` chunks amd return the
            first and last index of the vocabulary belonging to the `rank`
            partition: Note that indecies in [fist, last)"""
    
        @staticmethod
        def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
                                                      rank, world_size):
            index_f = rank * per_partition_vocab_size
            index_l = index_f + per_partition_vocab_size
            return index_f, index_l
    
        @staticmethod
        def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
            per_partition_vocab_size = divide(global_vocab_size, world_size)
            return VocabUtility.vocab_range_from_per_partition_vocab_size(
                per_partition_vocab_size, rank, world_size)
    
    def split_out_sums(x, BLOCK_SIZE=32, all_ret=False):
        b, L = x.shape[:2]
        rs = x.shape[2:]
        x = x.view(b, L // BLOCK_SIZE, BLOCK_SIZE, *rs)
        oris, sums = x.split([BLOCK_SIZE-1, 1], dim=2)
        if all_ret:
            return oris.reshape(b, -1, *rs), sums.reshape(b, -1, *rs)
        else: 
            return sums.reshape(b, -1, *rs)
    
    Ming Ding's avatar
    Ming Ding committed
    
    def sqrt(x):
        return int(math.sqrt(x) + 1e-4)