{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5a088833-ce3e-4570-b030-390ae7af8f69",
   "metadata": {},
   "source": [
    "# Finetuning LLaMa + Text-to-SQL \n",
    "\n",
    "This walkthrough shows you how to fine-tune LLaMa-7B on a Text-to-SQL dataset, and then use it for inference against\n",
    "any database of structured data using LlamaIndex.\n",
    "\n",
    "**NOTE**: This code is taken and adapted from Modal's `doppel-bot` repo: https://github.com/modal-labs/doppel-bot.\n",
    "**NOTE**: A lot of the code is contained in the underlying Python scripts. We definitely encourage you to go and take a look!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f84b13d-cb3b-4857-b163-22957feca599",
   "metadata": {},
   "source": [
    "### Load Training Data for Finetuning LLaMa\n",
    "\n",
    "We load data from `b-mc2/sql-create-context` on Hugging Face: https://huggingface.co/datasets/b-mc2/sql-create-context.\n",
    "\n",
    "This dataset consists of tuples of natural language queries, create table statements, and ground-truth SQL queries. This is the dataset that we use to finetune our SQL model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "521d237d-d985-4675-8fda-12766d5182ce",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from src.load_data_sql import load_data_sql\n",
    "\n",
    "load_data_sql(data_dir=\"data_sql\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8286edb4-9e8a-4be1-8672-c0cec36b5443",
   "metadata": {},
   "source": [
    "### Run Finetuning Script\n",
    "\n",
    "We run our finetuning script on the loaded dataset.\n",
    "The finetuning script contains the following components:\n",
    "- We split the dataset into training and validation splits.\n",
    "- We format each split into input/output tuples of token id's. This means that the labels are the same as inputs (loss signal is measured on full input, not just on the generated portion). \n",
    "- We use `LoraConfig` from `peft` for efficient fine-tuning.\n",
    "- We use `transformers.Trainer` to actually run the training process.\n",
    "- If a valid `WANDB_PROJECT` is specified, along with the relevant secret in Modal, then we will log results to wandb.\n",
    "\n",
    "We use Modal to spin up an A100 to run our finetuning code. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4e98b8e-53c4-490b-a142-8ae27e1bfe20",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.finetune_sql import finetune\n",
    "\n",
    "finetune(data_dir=\"data_sql\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc6405f9-edca-42a1-bc51-c2aa7bd4b7e8",
   "metadata": {},
   "source": [
    "### Integrate Model with LlamaIndex\n",
    "\n",
    "Now that the model is "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "modal_finetune_sql",
   "language": "python",
   "name": "modal_finetune_sql"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}