From d3d7a1656e2181118a77288e49cc83267d94e7fc Mon Sep 17 00:00:00 2001 From: Rohan Varma <rvarm1@gmail.com> Date: Tue, 18 Jul 2023 15:55:10 -0400 Subject: [PATCH] Update llama_finetuning.py --- llama_finetuning.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/llama_finetuning.py b/llama_finetuning.py index 85bf18e3..02aebdec 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -53,10 +53,8 @@ import configs from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, - StateDictType, ) from torch.utils.data import DistributedSampler -from torch.distributed.fsdp._common_utils import _is_fsdp_flattened import policies from policies import AnyPrecisionAdamW from configs import fsdp_config, train_config @@ -66,7 +64,6 @@ from pkg_resources import packaging import torch import torch.cuda.nccl as nccl import torch.distributed as dist -from transformers.models.t5.modeling_t5 import T5Block from transformers.models.llama.modeling_llama import LlamaDecoderLayer @@ -239,4 +236,4 @@ def main(**kwargs): [print(f'Key: {k}, Value: {v}') for k, v in results.items()] if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file + fire.Fire(main) -- GitLab