diff --git a/src/llama_recipes/configs/fsdp.py b/src/llama_recipes/configs/fsdp.py index 78fa1bac20ccae94030543408c0707d3ac03cdeb..ee10262d7554ab900ad439d2d4f9d849ab1f3ffc 100644 --- a/src/llama_recipes/configs/fsdp.py +++ b/src/llama_recipes/configs/fsdp.py @@ -1,8 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from dataclasses import dataclass, field -from typing import ClassVar +from dataclasses import dataclass + from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType diff --git a/src/llama_recipes/configs/peft.py b/src/llama_recipes/configs/peft.py index cb88f146be21c53450b724f9a7d4d79e504b8c0a..6d2c37c3a768fd009794fde33e02d07c995df264 100644 --- a/src/llama_recipes/configs/peft.py +++ b/src/llama_recipes/configs/peft.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import ClassVar, List @dataclass diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index c0617887d2af8408a5324acba828ce2a70b8ad51..8062fda5990e5e454e272f868b553274edcb0b03 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + from dataclasses import dataclass -from typing import ClassVar @dataclass diff --git a/src/llama_recipes/datasets/alpaca_dataset.py b/src/llama_recipes/datasets/alpaca_dataset.py index 77cbb27eac56e73f9f9eb0f415398d2bd9b17921..091aef9e3b517da981dcca1d8bcd4a2b4517d6a6 100644 --- a/src/llama_recipes/datasets/alpaca_dataset.py +++ b/src/llama_recipes/datasets/alpaca_dataset.py @@ -5,12 +5,10 @@ import copy import json -import os -import torch -from sentencepiece import SentencePieceProcessor +import torch from torch.utils.data import Dataset -from typing import List + PROMPT_DICT = { "prompt_input": ( diff --git a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py b/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py index cd2e74cd64009c99feeec90644647adffac1fdea..3ef00523db5b09077076bffdb058bac183e80ba3 100644 --- a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py +++ b/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py @@ -4,29 +4,13 @@ # For dataset details visit: https://huggingface.co/datasets/jfleg # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb -import argparse -import csv -import glob -import os -import json -import time -import logging -import random -import re -from itertools import chain -from string import punctuation - - -import pandas as pd -import numpy as np -import torch -from torch.utils.data import Dataset from datasets import load_dataset from pathlib import Path -from ft_datasets.utils import ConcatDataset +from torch.utils.data import Dataset +from ..utils import ConcatDataset class grammar(Dataset): diff --git a/src/llama_recipes/datasets/samsum_dataset.py b/src/llama_recipes/datasets/samsum_dataset.py index a178e06d9b0bc1591678d551f906f1c49d033074..64b577d4a0c3758d8427cb208111d68bf32dcb9d 100644 --- a/src/llama_recipes/datasets/samsum_dataset.py +++ b/src/llama_recipes/datasets/samsum_dataset.py @@ -4,6 +4,7 @@ # For dataset details visit: https://huggingface.co/datasets/samsum import datasets + from .utils import Concatenator def get_preprocessed_samsum(dataset_config, tokenizer, split): diff --git a/src/llama_recipes/datasets/utils.py b/src/llama_recipes/datasets/utils.py index 3263d806afcebcff0d72754ecf70a14e7bf67a8c..4c6956d865c5cc6352b85ca7f626b05c6371e819 100644 --- a/src/llama_recipes/datasets/utils.py +++ b/src/llama_recipes/datasets/utils.py @@ -3,6 +3,7 @@ from tqdm import tqdm from itertools import chain + from torch.utils.data import Dataset class Concatenator(object): diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 5a2f86267cb2423127e3d28786409f50db52f09b..7715396cd19a07b0f36ba2f11f8ee738952cd5bf 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -2,13 +2,13 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os +from pkg_resources import packaging import fire import torch import torch.distributed as dist import torch.optim as optim from peft import get_peft_model, prepare_model_for_int8_training -from pkg_resources import packaging from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ) @@ -22,19 +22,18 @@ from transformers import ( ) from transformers.models.llama.modeling_llama import LlamaDecoderLayer -import policies -from configs import fsdp_config, train_config -from policies import AnyPrecisionAdamW +from .configs import fsdp_config, train_config +from .policies import AnyPrecisionAdamW, apply_fsdp_checkpointing -from utils import fsdp_auto_wrap_policy -from utils.config_utils import ( +from .utils import fsdp_auto_wrap_policy +from .utils.config_utils import ( update_config, generate_peft_config, generate_dataset_config, ) -from utils.dataset_utils import get_preprocessed_dataset +from .utils.dataset_utils import get_preprocessed_dataset -from utils.train_utils import ( +from .utils.train_utils import ( train, freeze_transformer_layers, setup, @@ -153,7 +152,7 @@ def main(**kwargs): if train_config.low_cpu_fsdp and rank != 0 else None, ) if fsdp_config.fsdp_activation_checkpointing: - policies.apply_fsdp_checkpointing(model) + apply_fsdp_checkpointing(model) elif not train_config.quantization and not train_config.enable_fsdp: model.to("cuda") diff --git a/src/llama_recipes/inference/chat_completion.py b/src/llama_recipes/inference/chat_completion.py index d5c8378bdbbdc5f2db3fc33ef032dc47e23a07a4..4ff89c1ba8390bd9a25bfc594e4df5628f2a636a 100644 --- a/src/llama_recipes/inference/chat_completion.py +++ b/src/llama_recipes/inference/chat_completion.py @@ -2,18 +2,18 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. # from accelerate import init_empty_weights, load_checkpoint_and_dispatch + import fire -import torch import os import sys -import warnings from typing import List -from peft import PeftModel, PeftConfig -from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM -from safety_utils import get_safety_checker +import torch from model_utils import load_model, load_peft_model -from chat_utils import read_dialogs_from_file, format_tokens +from transformers import LlamaTokenizer +from safety_utils import get_safety_checker + +from .chat_utils import read_dialogs_from_file, format_tokens def main( model_name, diff --git a/src/llama_recipes/inference/chat_utils.py b/src/llama_recipes/inference/chat_utils.py index 89911dd09ef32876d2d6622965b348e9f1b92c77..8d781e31a66ff08d4c7ef1fd06cf08fd0654743f 100644 --- a/src/llama_recipes/inference/chat_utils.py +++ b/src/llama_recipes/inference/chat_utils.py @@ -1,8 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from typing import List, Literal, Optional, Tuple, TypedDict, Union import json +from typing import List, Literal, TypedDict + Role = Literal["user", "assistant"] diff --git a/src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py b/src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py index 582a871e57059df0ec99f2ef3884755415a65b05..dd52dc1882c07a68080cf72e6fa9a48829608598 100644 --- a/src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py +++ b/src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py @@ -4,12 +4,14 @@ # from accelerate import init_empty_weights, load_checkpoint_and_dispatch import fire -import torch import os import sys import yaml + from transformers import LlamaTokenizer -from model_utils import load_llama_from_config + +from .model_utils import load_llama_from_config + # Get the current file's directory current_directory = os.path.dirname(os.path.abspath(__file__)) diff --git a/src/llama_recipes/inference/inference.py b/src/llama_recipes/inference/inference.py index 8f502b1783c96ef4158e164393dfb5538aa7dd64..3aee4c9a4dd60f7cbdea34ebe76a94243c8e368f 100644 --- a/src/llama_recipes/inference/inference.py +++ b/src/llama_recipes/inference/inference.py @@ -4,15 +4,16 @@ # from accelerate import init_empty_weights, load_checkpoint_and_dispatch import fire -import torch import os import sys import time -from typing import List +import torch from transformers import LlamaTokenizer -from safety_utils import get_safety_checker -from model_utils import load_model, load_peft_model, load_llama_from_config + +from .safety_utils import get_safety_checker +from .model_utils import load_model, load_peft_model + def main( model_name, diff --git a/src/llama_recipes/inference/safety_utils.py b/src/llama_recipes/inference/safety_utils.py index bc321eb929df52b2c87e1ed0b52f7061d468094c..38a44d42c954bdffb0e41bc03012195fcd1a193c 100644 --- a/src/llama_recipes/inference/safety_utils.py +++ b/src/llama_recipes/inference/safety_utils.py @@ -5,8 +5,6 @@ import os import torch import warnings -from peft import PeftConfig -from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM # Class for performing safety checks using AuditNLG library class AuditNLGSensitiveTopics(object): diff --git a/src/llama_recipes/inference/vLLM_inference.py b/src/llama_recipes/inference/vLLM_inference.py index 63c64414821d82adbb68dae3b01b36851d8eeb3b..e587bc038ceca65359c0573a9b3a719943d26843 100644 --- a/src/llama_recipes/inference/vLLM_inference.py +++ b/src/llama_recipes/inference/vLLM_inference.py @@ -1,20 +1,13 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from accelerate import init_empty_weights, load_checkpoint_and_dispatch import fire + import torch -import os -import sys -from peft import PeftModel, PeftConfig -from transformers import ( - LlamaConfig, - LlamaTokenizer, - LlamaForCausalLM -) from vllm import LLM from vllm import LLM, SamplingParams + torch.cuda.manual_seed(42) torch.manual_seed(42) diff --git a/src/llama_recipes/policies/activation_checkpointing_functions.py b/src/llama_recipes/policies/activation_checkpointing_functions.py index 379bc6bfabc03f34745d995c3257181aa8ae030f..818b7daced3dc807ae43ee9b74afceab7c083ab8 100644 --- a/src/llama_recipes/policies/activation_checkpointing_functions.py +++ b/src/llama_recipes/policies/activation_checkpointing_functions.py @@ -1,18 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -import torch -import os -import torch.distributed as dist +from functools import partial + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, apply_activation_checkpointing, ) - -from transformers.models.t5.modeling_t5 import T5Block from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from functools import partial non_reentrant_wrapper = partial( checkpoint_wrapper, diff --git a/src/llama_recipes/policies/mixed_precision.py b/src/llama_recipes/policies/mixed_precision.py index 410ee392edf846da59318bdc80fdd9ab3951cf0f..11df7edf60d09f1016049026636fca176440043b 100644 --- a/src/llama_recipes/policies/mixed_precision.py +++ b/src/llama_recipes/policies/mixed_precision.py @@ -4,11 +4,7 @@ import torch from torch.distributed.fsdp import ( - # FullyShardedDataParallel as FSDP, - # CPUOffload, MixedPrecision, - # BackwardPrefetch, - # ShardingStrategy, ) # requires grad scaler in main loop diff --git a/src/llama_recipes/policies/wrapping.py b/src/llama_recipes/policies/wrapping.py index d9fadc3347add4974ab57b858288c489e23463d3..da7981cac211efb880ebf36b5879a9202ac668a6 100644 --- a/src/llama_recipes/policies/wrapping.py +++ b/src/llama_recipes/policies/wrapping.py @@ -1,28 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -import torch.distributed as dist -import torch.nn as nn -import torch +import functools from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel as FSDP, - CPUOffload, - BackwardPrefetch, - MixedPrecision, -) from torch.distributed.fsdp.wrap import ( transformer_auto_wrap_policy, size_based_auto_wrap_policy, - enable_wrap, - wrap, ) -import functools -from typing import Type - def get_size_policy(min_params=1e8): num_wrap_policy = functools.partial( diff --git a/src/llama_recipes/utils/config_utils.py b/src/llama_recipes/utils/config_utils.py index 1ba10419d264783d6faebab1d49e74ffd75803f7..cda692f591b5ffe5e1d37b765c59a8b0a215507b 100644 --- a/src/llama_recipes/utils/config_utils.py +++ b/src/llama_recipes/utils/config_utils.py @@ -3,14 +3,14 @@ import inspect from dataclasses import fields + from peft import ( LoraConfig, AdaptionPromptConfig, PrefixTuningConfig, ) -import configs.datasets as datasets -from configs import lora_config, llama_adapter_config, prefix_config, train_config +from ..configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config from .dataset_utils import DATASET_PREPROC diff --git a/src/llama_recipes/utils/dataset_utils.py b/src/llama_recipes/utils/dataset_utils.py index 9f2c0223d561e47623b25525dc5d3c9d425c5f56..1fe46017e263193880f61b2a1d21821af70818ae 100644 --- a/src/llama_recipes/utils/dataset_utils.py +++ b/src/llama_recipes/utils/dataset_utils.py @@ -1,16 +1,15 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -import torch - from functools import partial -from ft_datasets import ( +import torch + +from ..datasets import ( get_grammar_dataset, get_alpaca_dataset, get_samsum_dataset, ) -from typing import Optional DATASET_PREPROC = { diff --git a/src/llama_recipes/utils/fsdp_utils.py b/src/llama_recipes/utils/fsdp_utils.py index e7ed13d2a3f7614ee12e03ff585d0ac91d17a824..e2cd8d9108cb3dc994b47150efae0ff76478aaec 100644 --- a/src/llama_recipes/utils/fsdp_utils.py +++ b/src/llama_recipes/utils/fsdp_utils.py @@ -3,10 +3,7 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name): import functools - import os - from accelerate import FullyShardedDataParallelPlugin - from transformers.models.t5.modeling_t5 import T5Block from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder diff --git a/src/llama_recipes/utils/memory_utils.py b/src/llama_recipes/utils/memory_utils.py index ee134d286074ba7f95e17aec795bc347e17d32b1..725f2b0d5f7a5ec9551108f7bcd7e5c166af7af1 100644 --- a/src/llama_recipes/utils/memory_utils.py +++ b/src/llama_recipes/utils/memory_utils.py @@ -1,12 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + import gc -import os -import sys +import psutil import threading -import numpy as np -import psutil import torch def byte2gb(x): diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index d4eab21a2574c390b8ed117aa6b415d97cb3c757..fc65553f0887f7791df1a9c23a5ff977d489b730 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -2,40 +2,25 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os -import sys -from typing import List -import yaml import time +import yaml +from pathlib import Path +from pkg_resources import packaging + -import fire import torch -import transformers -from datasets import load_dataset -from tqdm import tqdm -""" -Unused imports: -import torch.nn as nn -import bitsandbytes as bnb -""" -from torch.nn import functional as F -from peft import ( - LoraConfig, - get_peft_model, - get_peft_model_state_dict, - prepare_model_for_int8_training, - set_peft_model_state_dict, -) -from transformers import LlamaForCausalLM, LlamaTokenizer -from torch.distributed.fsdp import StateDictType -import torch.distributed as dist -from pkg_resources import packaging -from .memory_utils import MemoryTrace -import model_checkpointing import torch.cuda.nccl as nccl +import torch.distributed as dist +from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from pathlib import Path -sys.path.append(str(Path(__file__).resolve().parent.parent)) -from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper +from tqdm import tqdm +from transformers import LlamaTokenizer + + +from .memory_utils import MemoryTrace +from ..model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint +from ..policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper + def set_tokenizer_params(tokenizer: LlamaTokenizer): tokenizer.pad_token_id = 0 @@ -162,21 +147,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche else: if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - model_checkpointing.save_model_checkpoint( + save_model_checkpoint( model, optimizer, rank, train_config, epoch=epoch ) elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") print("=====================================================") - model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config) + save_model_and_optimizer_sharded(model, rank, train_config) if train_config.save_optimizer: - model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) + save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") print("=====================================================") if not train_config.use_peft and train_config.save_optimizer: - model_checkpointing.save_optimizer_checkpoint( + save_optimizer_checkpoint( model, optimizer, rank, train_config, epoch=epoch ) print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")