Skip to content
Snippets Groups Projects
Commit 83efccb2 authored by Joone Hur's avatar Joone Hur
Browse files

Add gradio library for user interface in inference.py

parent 5edd1a2c
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ import fire ...@@ -7,6 +7,7 @@ import fire
import os import os
import sys import sys
import time import time
import gradio as gr
import torch import torch
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
...@@ -39,18 +40,8 @@ def main( ...@@ -39,18 +40,8 @@ def main(
use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs **kwargs
): ):
if prompt_file is not None:
assert os.path.exists(
prompt_file
), f"Provided Prompt file does not exist {prompt_file}"
with open(prompt_file, "r") as f:
user_prompt = "\n".join(f.readlines())
elif not sys.stdin.isatty():
user_prompt = "\n".join(sys.stdin.readlines())
else:
print("No user prompt provided. Exiting.")
sys.exit(1)
def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,):
safety_checker = get_safety_checker(enable_azure_content_safety, safety_checker = get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics, enable_sensitive_topics,
enable_salesforce_content_safety, enable_salesforce_content_safety,
...@@ -126,7 +117,49 @@ def main( ...@@ -126,7 +117,49 @@ def main(
if not is_safe: if not is_safe:
print(method) print(method)
print(report) print(report)
return output_text
if prompt_file is not None:
assert os.path.exists(
prompt_file
), f"Provided Prompt file does not exist {prompt_file}"
with open(prompt_file, "r") as f:
user_prompt = "\n".join(f.readlines())
inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
elif not sys.stdin.isatty():
user_prompt = "\n".join(sys.stdin.readlines())
inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
else:
gr.Interface(
fn=inference,
inputs=[
gr.components.Textbox(
lines=9,
label="User Prompt",
placeholder="none",
),
gr.components.Slider(
minimum=0, maximum=1, value=1.0, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=1.0, label="Top p"
),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=50, label="Top k"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=200, label="Max tokens"
),
],
outputs=[
gr.components.Textbox(
lines=5,
label="Output",
)
],
title="Llama2 Playground",
description="https://github.com/facebookresearch/llama-recipes",
).queue().launch(server_name="0.0.0.0", share=True)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(main) fire.Fire(main)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment