Skip to content
Snippets Groups Projects
Unverified Commit fa5edb8c authored by Hamid Shojanazeri's avatar Hamid Shojanazeri Committed by GitHub
Browse files

Add gradio library for user interface in inference.py (#367)

parents 5edd1a2c e2efc79e
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,9 @@ Disclaimer - The purpose of the code is to provide a configurable setup to measu ...@@ -7,7 +7,9 @@ Disclaimer - The purpose of the code is to provide a configurable setup to measu
# Azure - Getting Started # Azure - Getting Started
To get started, there are certain steps we need to take to deploy the models: To get started, there are certain steps we need to take to deploy the models:
<!-- markdown-link-check-disable -->
* Register for a valid Azure account with subscription [here](https://azure.microsoft.com/en-us/free/search/?ef_id=_k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&OCID=AIDcmm5edswduu_SEM__k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&gad_source=1&gclid=CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE) * Register for a valid Azure account with subscription [here](https://azure.microsoft.com/en-us/free/search/?ef_id=_k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&OCID=AIDcmm5edswduu_SEM__k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&gad_source=1&gclid=CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE)
<!-- markdown-link-check-enable -->
* Take a quick look on what is the [Azure AI Studio](https://learn.microsoft.com/en-us/azure/ai-studio/what-is-ai-studio?tabs=home) and navigate to the website from the link in the article * Take a quick look on what is the [Azure AI Studio](https://learn.microsoft.com/en-us/azure/ai-studio/what-is-ai-studio?tabs=home) and navigate to the website from the link in the article
* Follow the demos in the article to create a project and [resource](https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/manage-resource-groups-portal) group, or you can also follow the guide [here](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=azure-studio) * Follow the demos in the article to create a project and [resource](https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/manage-resource-groups-portal) group, or you can also follow the guide [here](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=azure-studio)
* Select Llama models from Model catalog * Select Llama models from Model catalog
......
...@@ -7,6 +7,7 @@ import fire ...@@ -7,6 +7,7 @@ import fire
import os import os
import sys import sys
import time import time
import gradio as gr
import torch import torch
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
...@@ -39,18 +40,8 @@ def main( ...@@ -39,18 +40,8 @@ def main(
use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs **kwargs
): ):
if prompt_file is not None:
assert os.path.exists(
prompt_file
), f"Provided Prompt file does not exist {prompt_file}"
with open(prompt_file, "r") as f:
user_prompt = "\n".join(f.readlines())
elif not sys.stdin.isatty():
user_prompt = "\n".join(sys.stdin.readlines())
else:
print("No user prompt provided. Exiting.")
sys.exit(1)
def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,):
safety_checker = get_safety_checker(enable_azure_content_safety, safety_checker = get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics, enable_sensitive_topics,
enable_salesforce_content_safety, enable_salesforce_content_safety,
...@@ -126,7 +117,49 @@ def main( ...@@ -126,7 +117,49 @@ def main(
if not is_safe: if not is_safe:
print(method) print(method)
print(report) print(report)
return output_text
if prompt_file is not None:
assert os.path.exists(
prompt_file
), f"Provided Prompt file does not exist {prompt_file}"
with open(prompt_file, "r") as f:
user_prompt = "\n".join(f.readlines())
inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
elif not sys.stdin.isatty():
user_prompt = "\n".join(sys.stdin.readlines())
inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
else:
gr.Interface(
fn=inference,
inputs=[
gr.components.Textbox(
lines=9,
label="User Prompt",
placeholder="none",
),
gr.components.Slider(
minimum=0, maximum=1, value=1.0, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=1.0, label="Top p"
),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=50, label="Top k"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=200, label="Max tokens"
),
],
outputs=[
gr.components.Textbox(
lines=5,
label="Output",
)
],
title="Llama2 Playground",
description="https://github.com/facebookresearch/llama-recipes",
).queue().launch(server_name="0.0.0.0", share=True)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(main) fire.Fire(main)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment