Skip to content
Snippets Groups Projects
Unverified Commit 881272ef authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Adds Llama 3.2 example on Modal with a fun experiment (#706)

parents 03c61ae6 c7ee7353
No related branches found
No related tags found
No related merge requests found
# Many-Llamas Human-Eval
In this directory, we run an experiment answering the question:
*If we run enough Llama models in parallel, can they outperform GPT-4o on HumanEval?*
It seeks to increase model performance not through scaling parameters, but by scaling compute time.
### Technical Blog
This experiment built by the team at [Modal](https://modal.com), and is described in the following blog post:
[Beat GPT-4o at Python by searching with 100 small Llamas](https://modal.com/blog/llama-human-eval)
The experiment has since been upgraded to use the [Llama 3.2 3B Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) model, and runnable end-to-end using the Modal serverless platform.
## Run it yourself
### Install the Modal CLI
From within your virtual environment, run:
```bash
pip install modal
```
And if you're new to Modal, authenticate with:
```bash
modal setup
# or if that doesn't work, try
# python -m modal setup
```
That's all!
This CLI will execute your modal apps, which build and run containers on the cloud, on your GPU of choice.
### HuggingFace Pull Access
To download the model, you'll first need to accept the [Llama 3.2 License](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) on HuggingFace and be approved for access.
Then, create a [modal secret](https://modal.com/secrets) named `huggingface`, to which you'll add your HF_TOKEN as an environment variable.
### Run The Experiment
This command will run every step for you:
```bash
bash run_e2e.sh
```
Or if you prefer to run it manually, you can step through each of the modal commands in [the script](./run_e2e.sh).
This will execute:
1. Downloading the Llama 3.2 3B Instruct model to a cloud volume
2. Deploying a vLLM inference server to GPUs
3. Running hundreds of parallel generations on the HumanEval test set
4. Running the evaluation script to compute pass@k and fail@k
5. Generating graphs of pass@k and fail@k
### Results
The resulting plots of the evals will be saved locally to:
- `/tmp/plot-pass-k.jpeg`
- `/tmp/plot-fail-k.jpeg`
`/tmp/plot-pass-k.jpeg` shows pass@k for the Llama 3.2 3B Instruct model vs pass@1 for GPT-4o.
![plot-pass-k](https://github.com/user-attachments/assets/11e9dc6e-4322-4d44-b928-4ed7c4ce8262)
You'll see that at 100 generations, the Llama model is able to perform on-par with GPT-4o. At higher scale, the Llama model will outperform GPT-4o.
`/tmp/plot-fail-k.jpeg` shows fail@k across a log-scale, showing smooth scaling of this method.
![plot-fail-k](https://github.com/user-attachments/assets/7286e4ff-5090-4288-bd62-8a078c6dc5a1)
# ## Downloading Llama 3.2 3B Instruct Model
# This script uses a Modal Function to download the model into a cloud Volume.
#
# Run it with:
# modal run download
import modal
MODELS_DIR = "/llamas"
DEFAULT_NAME = "meta-llama/Llama-3.2-3B-Instruct"
MINUTES = 60
HOURS = 60 * MINUTES
# Create a modal Volume to store the model
volume = modal.Volume.from_name("llamas", create_if_missing=True)
# This defines the image to use for the modal function
image = (
modal.Image.debian_slim(python_version="3.10")
.pip_install(
[
"huggingface_hub", # download models from the Hugging Face Hub
"hf-transfer", # download models faster with Rust
]
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
# We run the function from a modal App, which will have our HF_SECRET env var set.
# Add your HuggingFace secret access token here: https://modal.com/secrets
# secret name: huggingface
# env var name: HF_TOKEN
app = modal.App(image=image, secrets=[modal.Secret.from_name("huggingface")])
# This function will be ran in the cloud, with the volume mounted.
@app.function(volumes={MODELS_DIR: volume}, timeout=4 * HOURS)
def download_model(model_name, force_download=False):
from huggingface_hub import snapshot_download
volume.reload()
snapshot_download(
model_name,
local_dir=MODELS_DIR + "/" + model_name,
ignore_patterns=[
"*.pt",
"*.bin",
"*.pth",
"original/*",
], # Ensure safetensors
force_download=force_download,
)
volume.commit()
print("Model successfully downloaded")
@app.local_entrypoint()
def main(
model_name: str = DEFAULT_NAME,
force_download: bool = False,
):
download_model.remote(model_name, force_download)
# ## Evaluating HumanEval Results using Modal Sandboxes
# This script will take generated results and evaluate them.
# We use Modal Sandboxes to safely evaluate LLM-generated results.
#
# Run it with:
# modal run eval
from pathlib import Path
import modal
app = modal.App("many-llamas-human-eval")
volume = modal.Volume.from_name("humaneval", create_if_missing=True)
sandbox_image = (
modal.Image.debian_slim()
.apt_install("git")
.run_commands(
"git clone https://github.com/modal-labs/human-eval.git",
"pip install -e human-eval",
)
)
MINUTES = 60
@app.function(volumes={"/humaneval": volume}, timeout=10 * MINUTES)
def eval_single_task(sample_file_path: str, problem_file_path: str):
with modal.Volume.ephemeral() as vol:
with vol.batch_upload() as batch:
batch.put_file(sample_file_path, "samples.jsonl")
batch.put_file(problem_file_path, "problems.jsonl")
print(f"Starting sandbox for {sample_file_path}")
sandbox = modal.Sandbox.create(
"bash",
"-c",
"evaluate_functional_correctness vol/samples.jsonl --problem_file=vol/problems.jsonl --n_workers=32",
image=sandbox_image,
volumes={"/vol": vol},
timeout=10 * MINUTES,
cpu=32,
)
try:
sandbox.wait()
print(f"Finished sandbox for {sample_file_path}")
except FunctionTimeoutError:
print("Sandbox timed out")
if sandbox.returncode == 0:
print(sandbox.stdout.read())
data = b""
for chunk in vol.read_file("samples.jsonl_results.jsonl"):
data += chunk
with open(f"{sample_file_path}_results.jsonl", "wb") as f:
f.write(data)
else:
print(f"Tests failed with code {sandbox.returncode}")
print(sandbox.stderr.read())
@app.function(volumes={"/humaneval": volume}, timeout=10 * MINUTES)
def eval_all_tasks():
import os
volume.reload()
# Find all files matching /humaneval/{env}/{run}/{id}.jsonl
envs = [element for element in Path("/humaneval").iterdir() if element.is_dir()]
for env in envs:
print(f"looking in {env}")
problem_file = env / "data.jsonl"
pattern = "*/*.jsonl"
handles = []
for file_path in env.glob(pattern):
# Skip files that end with _results.jsonl
if str(file_path).endswith("_results.jsonl"):
continue
print(f"Checking {file_path}")
# Check if the corresponding results file exists
results_file = f"{file_path}_results.jsonl"
if not os.path.exists(results_file):
# If it doesn't exist, run do_eval
print("Spawning on", file_path, problem_file)
handles.append(eval_single_task.spawn(file_path, problem_file))
for handle in handles:
handle.get()
@app.local_entrypoint()
def main():
eval_all_tasks.remote()
# ## Generating HumanEval Results with our Llama 3.2 3B Instruct Model
# This app starts many parallel clients to send requests to the vLLM server.
#
# For each of the tasks in the HumanEval test set, we'll run a client to request 1000 completions.
# Results are saved to our mounted volume.
#
# Run it with:
# modal run generate --data-dir test --no-dry-run --n 1000 --subsample 100
from datetime import datetime
import json
from pathlib import Path
from dataclasses import dataclass, asdict
import modal
# This defines the image to use for running openai clients in parallel
image = modal.Image.debian_slim(python_version="3.11").pip_install(
"openai==1.38.0", "datasets==2.20.0"
)
app = modal.App("many-llamas-human-eval", image=image)
volume = modal.Volume.from_name("humaneval", create_if_missing=True)
DATA_DIR = Path("/mnt/humaneval")
default_system_prompt = "Write the body for the Python function provided in the prompt below. Do not write anything else. Your output will be directly concatenated with the prompt and the resulting function executed against tests."
MINUTES = 60 # seconds
HOURS = 60 * MINUTES
@dataclass
class CompletionParams:
model: str = None
max_tokens: int = 1024
temperature: float = 0.7
top_p: float = 0.9
frequency_penalty: float = 0
presence_penalty: float = 0
n: int = 1
stop: str = None
seed: int = None
@dataclass
class ClientParams:
app_name: str = "many-llamas-human-eval"
workspace: str = None
api_key: str = "super-secret-token" # match the secret in inference.py
@property
def url(self):
return f"https://{self.workspace}--{self.app_name}-serve.modal.run/v1"
@app.local_entrypoint()
def main(
app_name: str = "many-llamas-human-eval",
workspace: str = None,
api_key: str = "super-secret-token",
model: str = None,
max_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.9,
frequency_penalty: float = 0,
presence_penalty: float = 0,
n: int = 1,
stop: str = None,
seed: int = None,
data_dir: str = "dev-llm",
subsample: int = 1, # percent of the test split to read
system_prompt: str = default_system_prompt,
dry_run: bool = True,
):
if workspace is None:
workspace = modal.config._profile
client_params = ClientParams(app_name, workspace, api_key)
completion_params = CompletionParams(
model=model,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
n=n,
stop=stop,
seed=seed,
)
# Run a remote download function to save the HumanEval dataset in the cloud volume
save_dataset.remote(path=data_dir, subsample=subsample)
# Run a remote generation function
results = run_human_eval.remote(
client_params=client_params,
completion_params=completion_params,
system_prompt=system_prompt,
data_dir=data_dir,
dry_run=dry_run,
)
if results:
with open("/tmp/results.jsonl", "w") as f:
f.writelines(json.dumps(result) + "\n" for result in results)
print(f"results saved locally to {f.name}")
# This is the parent function that spawns a client for each eval task
@app.function(volumes={DATA_DIR: volume}, timeout=1 * HOURS)
def run_human_eval(
client_params: ClientParams,
completion_params: CompletionParams,
data_dir="dev-llm",
system_prompt: str = default_system_prompt,
dry_run=True,
):
dataset = load_dataset(data_dir)
timestamp = datetime.utcnow().isoformat() + "Z"
output_dir = Path(DATA_DIR) / data_dir / f"run-{timestamp}"
output_dir.mkdir(parents=True, exist_ok=True)
handles = []
print(f"Eval set contains {len(dataset)} items")
# For each eval item in the dataset, spawn a parallel openAI client worker that generates n completions each
print(Colors.BOLD, f"Spawning clients for each eval item. You may notice a brief wait while the inference server(s) boot.", Colors.END, sep="")
for i, item in enumerate(dataset):
handles.append(
run_item.spawn(
item,
client_params,
completion_params,
system_prompt,
output_dir,
dry_run,
)
)
for handle in handles:
result = handle.get()
if not dry_run:
return result
# This function is responsible for generating n completions for a single eval item
# It calls into our deployed vLLM server and saves results to the cloud volume
@app.function(volumes={DATA_DIR: volume}, timeout=1 * HOURS)
def run_item(
item: dict,
client_params: ClientParams,
completion_params: CompletionParams,
system_prompt: str,
output_dir: Path,
dry_run: bool,
):
client = create_client(client_params)
if not completion_params.model:
model = client.models.list().data[0]
model = model.id
completion_params.model = model
prompt = item["prompt"]
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
per_request = 250
ct, completions = completion_params.n, []
if not dry_run:
while ct > 0:
response = get_completion(
client,
messages=messages,
**asdict(completion_params) | dict(n=min(ct, per_request)),
)
if response:
completions += [
{
"task_id": item["task_id"],
"completion": choice.message.content,
}
for choice in response.choices
]
ct -= per_request
index = item["task_id"].split("/")[-1]
output_path = output_dir / f"{index}.jsonl"
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
f.writelines(json.dumps(completion) + "\n" for completion in completions)
print(Colors.GREEN + f"Completions saved to {output_path}" + Colors.END)
class Colors:
"""ANSI color codes"""
GREEN = "\033[0;32m"
RED = "\033[0;31m"
BLUE = "\033[0;34m"
GRAY = "\033[0;90m"
BOLD = "\033[1m"
END = "\033[0m"
def get_completion(client, **kwargs):
try:
response = client.chat.completions.create(**kwargs)
return response
except Exception as e:
print(Colors.RED, f"Error during API call: {e}", Colors.END, sep="")
return None
def create_client(client_params: ClientParams):
from openai import OpenAI
client = OpenAI(api_key=client_params.api_key)
client.base_url = client_params.url
return client
# This function downloads the HumanEval dataset
@app.function(volumes={DATA_DIR: volume})
def save_dataset(path="dev-llm", subsample: int = 1):
import datasets
path = DATA_DIR / path
ds = datasets.load_dataset(
"openai/openai_humaneval",
# reads 0% to subsample% of the test split
split=datasets.ReadInstruction("test", to=subsample, unit="%"),
)
ds.to_json(path / "data.jsonl")
volume.commit()
def load_dataset(path="dev-llm"):
import datasets
path = DATA_DIR / path
ds = datasets.load_dataset(path=str(path), data_files="data.jsonl")
return ds["train"]
# ## Serving Llama 3.2 3B Instruct Model With vLLM
# This app runs a vLLM server on an A100 GPU.
#
# Run it with:
# modal deploy inference
import modal
# This defines the image to use for the vLLM server container
vllm_image = modal.Image.debian_slim(python_version="3.10").pip_install(
"vllm==0.5.3post1"
)
MODELS_DIR = "/llamas"
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
# Ensure the model is downloaded and the volume exists
try:
volume = modal.Volume.lookup("llamas", create_if_missing=False)
except modal.exception.NotFoundError:
raise Exception("Download models first with modal run download")
app = modal.App("many-llamas-human-eval")
N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count
TOKEN = (
"super-secret-token" # auth token. for production use, replace with a modal.Secret
)
MINUTES = 60 # seconds
HOURS = 60 * MINUTES
@app.function(
image=vllm_image,
gpu=modal.gpu.A100(count=N_GPU, size="40GB"),
container_idle_timeout=5 * MINUTES,
timeout=24 * HOURS,
allow_concurrent_inputs=20, # VLLM will batch requests so many can be received at once
volumes={MODELS_DIR: volume},
concurrency_limit=10, # max 10 GPUs
)
@modal.asgi_app()
def serve():
import fastapi
import vllm.entrypoints.openai.api_server as api_server
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import (
OpenAIServingCompletion,
)
from vllm.usage.usage_lib import UsageContext
volume.reload() # ensure we have the latest version of the weights
# create a fastAPI app that uses vLLM's OpenAI-compatible router
web_app = fastapi.FastAPI(
title=f"OpenAI-compatible {MODEL_NAME} server",
description="Run an OpenAI-compatible LLM server with vLLM on modal.com",
version="0.0.1",
docs_url="/docs",
)
# security: CORS middleware for external requests
http_bearer = fastapi.security.HTTPBearer(
scheme_name="Bearer Token",
description="See code for authentication details.",
)
web_app.add_middleware(
fastapi.middleware.cors.CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# security: inject dependency on authed routes
async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
if api_key.credentials != TOKEN:
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
return {"username": "authenticated_user"}
router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])
# wrap vllm's router in auth router
router.include_router(api_server.router)
# add authed vllm to our fastAPI app
web_app.include_router(router)
engine_args = AsyncEngineArgs(
model=MODELS_DIR + "/" + MODEL_NAME,
tensor_parallel_size=N_GPU,
gpu_memory_utilization=0.90,
max_model_len=2048,
enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s)
)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER
)
model_config = get_model_config(engine)
request_logger = RequestLogger(max_log_len=2048)
api_server.openai_serving_chat = OpenAIServingChat(
engine,
model_config=model_config,
served_model_names=[MODEL_NAME],
chat_template=None,
response_role="assistant",
lora_modules=[],
prompt_adapters=[],
request_logger=request_logger,
)
api_server.openai_serving_completion = OpenAIServingCompletion(
engine,
model_config=model_config,
served_model_names=[MODEL_NAME],
lora_modules=[],
prompt_adapters=[],
request_logger=request_logger,
)
return web_app
def get_model_config(engine):
import asyncio
try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config = event_loop.run_until_complete(engine.get_model_config())
else:
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
return model_config
# ## Plotting HumanEval Results
# This script will calculate pass@k and fail@k for our experiment and plot them.
#
# Run it with:
# modal run plot
import io
import json
from pathlib import Path
from typing import List, Union
import itertools
import modal
try:
volume = modal.Volume.lookup("humaneval", create_if_missing=False)
except modal.exception.NotFoundError:
raise Exception("Generate results first with modal run generate --data-dir test --no-dry-run --n 1000 --subsample 100")
image = modal.Image.debian_slim(python_version="3.11").pip_install(
"numpy==1.26.4",
"pandas==2.2.3",
"matplotlib==3.9.2",
"seaborn==0.13.2",
)
app = modal.App("many-llamas-human-eval", image=image)
DATA_DIR = Path("/mnt/humaneval")
with image.imports():
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
@app.function(volumes={DATA_DIR: volume})
def render_plots():
run_dirs = list(sorted((DATA_DIR / "test").glob("run-*")))
for run_dir in reversed(run_dirs):
if len(list(run_dir.iterdir())) < 150:
print(f"skipping incomplete run {run_dir}")
else:
break
all_result_paths = list(run_dir.glob("*.jsonl_results.jsonl"))
data = []
for path in all_result_paths:
data += [json.loads(line) for line in path.read_text(encoding='utf-8').splitlines()]
for element in data:
del element["completion"]
df = pd.DataFrame.from_records(data)
gb = df.groupby("task_id")
passes = gb["passed"].sum()
def estimate_pass_at_k(
num_samples: Union[int, List[int], np.ndarray],
num_correct: Union[List[int], np.ndarray],
k: int
) -> np.ndarray:
"""
Estimates pass@k of each problem and returns them in an array.
"""
def estimator(n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
if isinstance(num_samples, int):
num_samples_it = itertools.repeat(num_samples, len(num_correct))
else:
assert len(num_samples) == len(num_correct)
num_samples_it = iter(num_samples)
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
pass_at_ks = {}
for k in [1, 10, 100, 1000]:
pass_at_ks[k] = estimate_pass_at_k(1000, passes, k)
pass_at_k = {k: np.mean(v) for k, v in pass_at_ks.items()}
plot_df = pd.DataFrame(
{"k": pass_at_k.keys(),
"pass@k": pass_at_k.values()}
)
plot_df["fail@k"] = 1 - plot_df["pass@k"]
sns.set_theme(style='dark')
plt.style.use("dark_background")
plt.rcParams['font.sans-serif'] = ["Inter", "Arial", "DejaVu Sans", "Liberation Sans", "Bitstream Vera Sans", "sans-serif"]
sns.despine()
sns.set_context("talk", rc={"lines.linewidth": 2.5})
gpt4o_benchmark = 0.902
# First plot
plt.figure(figsize=(10, 6))
fg = sns.lineplot(
x="k",
y="pass@k",
data=plot_df,
color="#7FEE64",
linewidth=6,
alpha=0.9,
label="Llama 3.2 3B Instruct pass@k"
)
initial_lim = fg.axes.get_xlim()
fg.axes.hlines(
gpt4o_benchmark, *initial_lim,
linestyle="--",
alpha=0.6,
zorder=-1,
label="GPT-4o fail@1"
)
fg.axes.set_xlim(*initial_lim)
fg.axes.set_ylabel("")
fg.axes.set_ylim(0, 1)
plt.tight_layout(pad=1.2)
plt.legend()
# Save the first plot as bytes
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='jpeg')
plot_1_img_bytes = img_buffer.getvalue()
plt.close()
# Second plot
plt.figure(figsize=(10, 6))
fg = sns.lineplot(
x="k",
y="fail@k",
data=plot_df,
color="#7FEE64",
linewidth=6,
alpha=0.9,
label="Llama 3.2 3B Instruct fail@k"
)
initial_lim = fg.axes.get_xlim()
fg.axes.hlines(
1 - gpt4o_benchmark, *initial_lim,
linestyle="--",
alpha=0.6,
zorder=-1,
label="GPT-4o fail@1"
)
fg.axes.set_xlim(*initial_lim)
fg.axes.set_ylabel("")
fg.axes.set_yscale("log")
fg.axes.set_xscale("log")
fg.axes.set_xlim(0.5, 2000)
fg.axes.set_ylim(1e-2, 1e0)
plt.tight_layout(pad=1.2)
plt.legend()
# Save the second plot as bytes
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='jpeg')
plot_2_img_bytes = img_buffer.getvalue()
plt.close()
return [plot_1_img_bytes, plot_2_img_bytes]
@app.local_entrypoint()
def main():
plots = render_plots.remote()
assert len(plots) == 2
with open ("/tmp/plot-pass-k.jpeg", "wb") as f:
f.write(plots[0])
with open ("/tmp/plot-fail-k.jpeg", "wb") as f:
f.write(plots[1])
print("Plots saved to:")
print(" /tmp/plot-pass-k.jpeg")
print(" /tmp/plot-fail-k.jpeg")
\ No newline at end of file
#!/bin/bash
set -euo pipefail
IFS=$'\n\t'
command -v modal >/dev/null 2>&1 || { echo >&2 "modal command not found. Install modal first! Aborting."; exit 1; }
echo 'downloading LLaMA 3.2 3B Instruct model'
echo 'make sure to create a Secret called huggingface on Modal and accept the LLaMA 3.2 license'
modal run download.py
echo 'deploying vLLM inference server'
modal deploy inference.py
echo 'running HumanEval generation'
modal run generate.py --data-dir test --no-dry-run --n 1000 --subsample 100
echo 'running HumanEval evaluation'
modal run eval.py
echo 'generating graphs for pass@k and fail@k'
modal run plot.py
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment