Skip to content
Snippets Groups Projects
Commit c8d4f38d authored by lchu's avatar lchu
Browse files

replace init_empty_weights with torch.device(meta)

parent d8a81bb5
No related branches found
No related tags found
No related merge requests found
......@@ -42,8 +42,6 @@ from utils.train_utils import (
get_policies
)
from accelerate import init_empty_weights
from utils.dataset_utils import get_preprocessed_dataset
from utils.config_utils import (
......@@ -107,7 +105,7 @@ def main(**kwargs):
)
else:
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
with init_empty_weights():
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment