Skip to content
Snippets Groups Projects
pixtral.py 1.21 KiB
Newer Older
  • Learn to ignore specific revisions
  • from mistral_inference.transformer import Transformer
    from mistral_inference.generate import generate
    
    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
    from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk, ImageChunk
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    
    from PIL import Image
    from tqdm import tqdm
    
    def generate_response(queries, model_path):
        tokenizer = MistralTokenizer.from_file(f"{model_path}/tekken.json")
        model = Transformer.from_folder(model_path)
        for k in tqdm(queries):
            query = queries[k]['question']
            image = queries[k]["figure_path"]
            image = Image.open(image).convert('RGB')
            completion_request = ChatCompletionRequest(messages=[UserMessage(content=[ImageChunk(image=image), TextChunk(text=query)])])
            encoded = tokenizer.encode_chat_completion(completion_request)
            images = encoded.images
            tokens = encoded.tokens
            out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0., eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
            response = tokenizer.decode(out_tokens[0])
            queries[k]['response'] = response