Skip to content
Snippets Groups Projects
Commit cf678b9b authored by Matthias Reso's avatar Matthias Reso
Browse files

Adjust imports to package structure + cleaned up imports

parent 02428c99
No related branches found
No related tags found
No related merge requests found
Showing
with 45 additions and 93 deletions
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from dataclasses import dataclass, field
from typing import ClassVar
from dataclasses import dataclass
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import ClassVar, List
@dataclass
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from dataclasses import dataclass
from typing import ClassVar
@dataclass
......
......@@ -5,12 +5,10 @@
import copy
import json
import os
import torch
from sentencepiece import SentencePieceProcessor
import torch
from torch.utils.data import Dataset
from typing import List
PROMPT_DICT = {
"prompt_input": (
......
......@@ -4,29 +4,13 @@
# For dataset details visit: https://huggingface.co/datasets/jfleg
# For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
import argparse
import csv
import glob
import os
import json
import time
import logging
import random
import re
from itertools import chain
from string import punctuation
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from pathlib import Path
from ft_datasets.utils import ConcatDataset
from torch.utils.data import Dataset
from ..utils import ConcatDataset
class grammar(Dataset):
......
......@@ -4,6 +4,7 @@
# For dataset details visit: https://huggingface.co/datasets/samsum
import datasets
from .utils import Concatenator
def get_preprocessed_samsum(dataset_config, tokenizer, split):
......
......@@ -3,6 +3,7 @@
from tqdm import tqdm
from itertools import chain
from torch.utils.data import Dataset
class Concatenator(object):
......
......@@ -2,13 +2,13 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
from pkg_resources import packaging
import fire
import torch
import torch.distributed as dist
import torch.optim as optim
from peft import get_peft_model, prepare_model_for_int8_training
from pkg_resources import packaging
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
......@@ -22,19 +22,18 @@ from transformers import (
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
import policies
from configs import fsdp_config, train_config
from policies import AnyPrecisionAdamW
from .configs import fsdp_config, train_config
from .policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
from utils import fsdp_auto_wrap_policy
from utils.config_utils import (
from .utils import fsdp_auto_wrap_policy
from .utils.config_utils import (
update_config,
generate_peft_config,
generate_dataset_config,
)
from utils.dataset_utils import get_preprocessed_dataset
from .utils.dataset_utils import get_preprocessed_dataset
from utils.train_utils import (
from .utils.train_utils import (
train,
freeze_transformer_layers,
setup,
......@@ -153,7 +152,7 @@ def main(**kwargs):
if train_config.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
model.to("cuda")
......
......@@ -2,18 +2,18 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import fire
import torch
import os
import sys
import warnings
from typing import List
from peft import PeftModel, PeftConfig
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
from safety_utils import get_safety_checker
import torch
from model_utils import load_model, load_peft_model
from chat_utils import read_dialogs_from_file, format_tokens
from transformers import LlamaTokenizer
from safety_utils import get_safety_checker
from .chat_utils import read_dialogs_from_file, format_tokens
def main(
model_name,
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from typing import List, Literal, Optional, Tuple, TypedDict, Union
import json
from typing import List, Literal, TypedDict
Role = Literal["user", "assistant"]
......
......@@ -4,12 +4,14 @@
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import fire
import torch
import os
import sys
import yaml
from transformers import LlamaTokenizer
from model_utils import load_llama_from_config
from .model_utils import load_llama_from_config
# Get the current file's directory
current_directory = os.path.dirname(os.path.abspath(__file__))
......
......@@ -4,15 +4,16 @@
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import fire
import torch
import os
import sys
import time
from typing import List
import torch
from transformers import LlamaTokenizer
from safety_utils import get_safety_checker
from model_utils import load_model, load_peft_model, load_llama_from_config
from .safety_utils import get_safety_checker
from .model_utils import load_model, load_peft_model
def main(
model_name,
......
......@@ -5,8 +5,6 @@ import os
import torch
import warnings
from peft import PeftConfig
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
# Class for performing safety checks using AuditNLG library
class AuditNLGSensitiveTopics(object):
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import fire
import torch
import os
import sys
from peft import PeftModel, PeftConfig
from transformers import (
LlamaConfig,
LlamaTokenizer,
LlamaForCausalLM
)
from vllm import LLM
from vllm import LLM, SamplingParams
torch.cuda.manual_seed(42)
torch.manual_seed(42)
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import torch
import os
import torch.distributed as dist
from functools import partial
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from transformers.models.t5.modeling_t5 import T5Block
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from functools import partial
non_reentrant_wrapper = partial(
checkpoint_wrapper,
......
......@@ -4,11 +4,7 @@
import torch
from torch.distributed.fsdp import (
# FullyShardedDataParallel as FSDP,
# CPUOffload,
MixedPrecision,
# BackwardPrefetch,
# ShardingStrategy,
)
# requires grad scaler in main loop
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import torch.distributed as dist
import torch.nn as nn
import torch
import functools
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
CPUOffload,
BackwardPrefetch,
MixedPrecision,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
import functools
from typing import Type
def get_size_policy(min_params=1e8):
num_wrap_policy = functools.partial(
......
......@@ -3,14 +3,14 @@
import inspect
from dataclasses import fields
from peft import (
LoraConfig,
AdaptionPromptConfig,
PrefixTuningConfig,
)
import configs.datasets as datasets
from configs import lora_config, llama_adapter_config, prefix_config, train_config
from ..configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
from .dataset_utils import DATASET_PREPROC
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import torch
from functools import partial
from ft_datasets import (
import torch
from ..datasets import (
get_grammar_dataset,
get_alpaca_dataset,
get_samsum_dataset,
)
from typing import Optional
DATASET_PREPROC = {
......
......@@ -3,10 +3,7 @@
def fsdp_auto_wrap_policy(model, transformer_layer_name):
import functools
import os
from accelerate import FullyShardedDataParallelPlugin
from transformers.models.t5.modeling_t5 import T5Block
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment