Spaces:
Sleeping
Sleeping
Adding basic SFT template
Browse files- app.py +61 -59
- notebooks/eda.json +1 -0
- notebooks/embeddings.json +1 -0
- notebooks/rag.json +1 -0
- notebooks/sft.json +56 -0
- utils/api_utils.py +33 -0
app.py
CHANGED
|
@@ -2,20 +2,22 @@ import gradio as gr
|
|
| 2 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 3 |
import nbformat as nbf
|
| 4 |
from huggingface_hub import HfApi
|
| 5 |
-
from httpx import Client
|
| 6 |
import logging
|
| 7 |
-
import pandas as pd
|
| 8 |
from utils.notebook_utils import (
|
| 9 |
replace_wildcards,
|
| 10 |
load_json_files_from_folder,
|
| 11 |
)
|
|
|
|
| 12 |
from dotenv import load_dotenv
|
| 13 |
import os
|
| 14 |
from nbconvert import HTMLExporter
|
| 15 |
import uuid
|
|
|
|
| 16 |
|
| 17 |
load_dotenv()
|
| 18 |
|
|
|
|
|
|
|
| 19 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 20 |
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
|
| 21 |
|
|
@@ -25,12 +27,6 @@ assert (
|
|
| 25 |
), "You need to set NOTEBOOKS_REPOSITORY in your environment variables"
|
| 26 |
|
| 27 |
|
| 28 |
-
URL = "https://huggingface.co/spaces/asoria/auto-notebook-creator"
|
| 29 |
-
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
| 30 |
-
HEADERS = {"Accept": "application/json", "Content-Type": "application/json"}
|
| 31 |
-
|
| 32 |
-
client = Client(headers=HEADERS)
|
| 33 |
-
|
| 34 |
logging.basicConfig(level=logging.INFO)
|
| 35 |
|
| 36 |
# TODO: Validate notebook templates format
|
|
@@ -39,18 +35,6 @@ notebook_templates = load_json_files_from_folder(folder_path)
|
|
| 39 |
logging.info(f"Available notebooks {notebook_templates.keys()}")
|
| 40 |
|
| 41 |
|
| 42 |
-
def get_compatible_libraries(dataset: str):
|
| 43 |
-
try:
|
| 44 |
-
response = client.get(
|
| 45 |
-
f"{BASE_DATASETS_SERVER_URL}/compatible-libraries?dataset={dataset}"
|
| 46 |
-
)
|
| 47 |
-
response.raise_for_status()
|
| 48 |
-
return response.json()
|
| 49 |
-
except Exception as e:
|
| 50 |
-
logging.error(f"Error fetching compatible libraries: {e}")
|
| 51 |
-
raise
|
| 52 |
-
|
| 53 |
-
|
| 54 |
def create_notebook_file(cells, notebook_name):
|
| 55 |
nb = nbf.v4.new_notebook()
|
| 56 |
nb["cells"] = [
|
|
@@ -72,22 +56,6 @@ def create_notebook_file(cells, notebook_name):
|
|
| 72 |
return html_data
|
| 73 |
|
| 74 |
|
| 75 |
-
def get_first_rows_as_df(dataset: str, config: str, split: str, limit: int):
|
| 76 |
-
try:
|
| 77 |
-
resp = client.get(
|
| 78 |
-
f"{BASE_DATASETS_SERVER_URL}/first-rows?dataset={dataset}&config={config}&split={split}"
|
| 79 |
-
)
|
| 80 |
-
resp.raise_for_status()
|
| 81 |
-
content = resp.json()
|
| 82 |
-
rows = content["rows"]
|
| 83 |
-
rows = [row["row"] for row in rows]
|
| 84 |
-
first_rows_df = pd.DataFrame.from_dict(rows).sample(frac=1).head(limit)
|
| 85 |
-
return first_rows_df
|
| 86 |
-
except Exception as e:
|
| 87 |
-
logging.error(f"Error fetching first rows: {e}")
|
| 88 |
-
raise
|
| 89 |
-
|
| 90 |
-
|
| 91 |
def longest_string_column(df):
|
| 92 |
longest_col = None
|
| 93 |
max_length = 0
|
|
@@ -127,34 +95,62 @@ def generate_cells(dataset_id, notebook_title):
|
|
| 127 |
cells = notebook_templates[notebook_title]["notebook_template"]
|
| 128 |
notebook_type = notebook_templates[notebook_title]["notebook_type"]
|
| 129 |
dataset_types = notebook_templates[notebook_title]["dataset_types"]
|
| 130 |
-
|
| 131 |
try:
|
| 132 |
libraries = get_compatible_libraries(dataset_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
except Exception as err:
|
| 134 |
gr.Error("Unable to retrieve dataset info from HF Hub.")
|
| 135 |
logging.error(f"Failed to fetch compatible libraries: {err}")
|
| 136 |
-
return "", "## ❌ This dataset is not accessible from the Hub ❌"
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
logging.error(f"Dataset not compatible with pandas library - not libraries")
|
| 140 |
-
return "", "## ❌ This dataset is not compatible with pandas library ❌"
|
| 141 |
-
pandas_library = next(
|
| 142 |
-
(lib for lib in libraries.get("libraries", []) if lib["library"] == "pandas"),
|
| 143 |
-
None,
|
| 144 |
-
)
|
| 145 |
-
if not pandas_library:
|
| 146 |
-
logging.error("Dataset not compatible with pandas library - not pandas library")
|
| 147 |
-
return "", "## ❌ This dataset is not compatible with pandas library ❌"
|
| 148 |
-
first_config_loading_code = pandas_library["loading_codes"][0]
|
| 149 |
-
first_code = first_config_loading_code["code"]
|
| 150 |
-
first_config = first_config_loading_code["config_name"]
|
| 151 |
-
first_split = list(first_config_loading_code["arguments"]["splits"].keys())[0]
|
| 152 |
-
df = get_first_rows_as_df(dataset_id, first_config, first_split, 3)
|
| 153 |
|
| 154 |
longest_col = longest_string_column(df)
|
| 155 |
html_code = f"<iframe src='https://huggingface.co/datasets/{dataset_id}/embed/viewer' width='80%' height='560px'></iframe>"
|
| 156 |
-
wildcards = [
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
has_numeric_columns = len(df.select_dtypes(include=["number"]).columns) > 0
|
| 159 |
has_categoric_columns = len(df.select_dtypes(include=["object"]).columns) > 0
|
| 160 |
|
|
@@ -196,8 +192,12 @@ css = """
|
|
| 196 |
|
| 197 |
with gr.Blocks(css=css) as demo:
|
| 198 |
gr.Markdown("# 🤖 Dataset notebook creator 🕵️")
|
| 199 |
-
gr.Markdown(
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
text_input = gr.Textbox(label="Suggested notebook type", visible=False)
|
| 202 |
|
| 203 |
gr.Markdown("## 1. Select and preview a dataset from Huggingface Hub")
|
|
@@ -259,6 +259,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 259 |
outputs=[code_component, go_to_notebook],
|
| 260 |
)
|
| 261 |
|
| 262 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 263 |
|
| 264 |
demo.launch()
|
|
|
|
| 2 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 3 |
import nbformat as nbf
|
| 4 |
from huggingface_hub import HfApi
|
|
|
|
| 5 |
import logging
|
|
|
|
| 6 |
from utils.notebook_utils import (
|
| 7 |
replace_wildcards,
|
| 8 |
load_json_files_from_folder,
|
| 9 |
)
|
| 10 |
+
from utils.api_utils import get_compatible_libraries, get_first_rows, get_splits
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
import os
|
| 13 |
from nbconvert import HTMLExporter
|
| 14 |
import uuid
|
| 15 |
+
import pandas as pd
|
| 16 |
|
| 17 |
load_dotenv()
|
| 18 |
|
| 19 |
+
URL = "https://huggingface.co/spaces/asoria/auto-notebook-creator"
|
| 20 |
+
|
| 21 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 22 |
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
|
| 23 |
|
|
|
|
| 27 |
), "You need to set NOTEBOOKS_REPOSITORY in your environment variables"
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
logging.basicConfig(level=logging.INFO)
|
| 31 |
|
| 32 |
# TODO: Validate notebook templates format
|
|
|
|
| 35 |
logging.info(f"Available notebooks {notebook_templates.keys()}")
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def create_notebook_file(cells, notebook_name):
|
| 39 |
nb = nbf.v4.new_notebook()
|
| 40 |
nb["cells"] = [
|
|
|
|
| 56 |
return html_data
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def longest_string_column(df):
|
| 60 |
longest_col = None
|
| 61 |
max_length = 0
|
|
|
|
| 95 |
cells = notebook_templates[notebook_title]["notebook_template"]
|
| 96 |
notebook_type = notebook_templates[notebook_title]["notebook_type"]
|
| 97 |
dataset_types = notebook_templates[notebook_title]["dataset_types"]
|
| 98 |
+
compatible_library = notebook_templates[notebook_title]["compatible_library"]
|
| 99 |
try:
|
| 100 |
libraries = get_compatible_libraries(dataset_id)
|
| 101 |
+
if not libraries:
|
| 102 |
+
logging.error(
|
| 103 |
+
f"Dataset not compatible with any loading library (pandas/datasets)"
|
| 104 |
+
)
|
| 105 |
+
return (
|
| 106 |
+
"",
|
| 107 |
+
"## ❌ This dataset is not compatible with pandas or datasets libraries ❌",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
library_code = next(
|
| 111 |
+
(
|
| 112 |
+
lib
|
| 113 |
+
for lib in libraries.get("libraries", [])
|
| 114 |
+
if lib["library"] == compatible_library
|
| 115 |
+
),
|
| 116 |
+
None,
|
| 117 |
+
)
|
| 118 |
+
if not library_code:
|
| 119 |
+
logging.error(f"Dataset not compatible with {compatible_library} library")
|
| 120 |
+
return (
|
| 121 |
+
"",
|
| 122 |
+
f"## ❌ This dataset is not compatible with '{compatible_library}' library ❌",
|
| 123 |
+
)
|
| 124 |
+
first_config_loading_code = library_code["loading_codes"][0]
|
| 125 |
+
first_code = first_config_loading_code["code"]
|
| 126 |
+
first_config = first_config_loading_code["config_name"]
|
| 127 |
+
first_split = get_splits(dataset_id, first_config)[0]["split"]
|
| 128 |
+
first_rows = get_first_rows(dataset_id, first_config, first_split)
|
| 129 |
except Exception as err:
|
| 130 |
gr.Error("Unable to retrieve dataset info from HF Hub.")
|
| 131 |
logging.error(f"Failed to fetch compatible libraries: {err}")
|
| 132 |
+
return "", f"## ❌ This dataset is not accessible from the Hub {err}❌"
|
| 133 |
+
|
| 134 |
+
df = pd.DataFrame.from_dict(first_rows).sample(frac=1).head(3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
longest_col = longest_string_column(df)
|
| 137 |
html_code = f"<iframe src='https://huggingface.co/datasets/{dataset_id}/embed/viewer' width='80%' height='560px'></iframe>"
|
| 138 |
+
wildcards = [
|
| 139 |
+
"{dataset_name}",
|
| 140 |
+
"{first_code}",
|
| 141 |
+
"{html_code}",
|
| 142 |
+
"{longest_col}",
|
| 143 |
+
"{first_config}",
|
| 144 |
+
"{first_split}",
|
| 145 |
+
]
|
| 146 |
+
replacements = [
|
| 147 |
+
dataset_id,
|
| 148 |
+
first_code,
|
| 149 |
+
html_code,
|
| 150 |
+
longest_col,
|
| 151 |
+
first_config,
|
| 152 |
+
first_split,
|
| 153 |
+
]
|
| 154 |
has_numeric_columns = len(df.select_dtypes(include=["number"]).columns) > 0
|
| 155 |
has_categoric_columns = len(df.select_dtypes(include=["object"]).columns) > 0
|
| 156 |
|
|
|
|
| 192 |
|
| 193 |
with gr.Blocks(css=css) as demo:
|
| 194 |
gr.Markdown("# 🤖 Dataset notebook creator 🕵️")
|
| 195 |
+
gr.Markdown(
|
| 196 |
+
f"[}-blue.svg)]({URL}/tree/main/notebooks)"
|
| 197 |
+
)
|
| 198 |
+
gr.Markdown(
|
| 199 |
+
f"[]({URL}/blob/main/CONTRIBUTING.md)"
|
| 200 |
+
)
|
| 201 |
text_input = gr.Textbox(label="Suggested notebook type", visible=False)
|
| 202 |
|
| 203 |
gr.Markdown("## 1. Select and preview a dataset from Huggingface Hub")
|
|
|
|
| 259 |
outputs=[code_component, go_to_notebook],
|
| 260 |
)
|
| 261 |
|
| 262 |
+
gr.Markdown(
|
| 263 |
+
"🚧 Note: Some code may not be compatible with datasets that contain binary data or complex structures. 🚧"
|
| 264 |
+
)
|
| 265 |
|
| 266 |
demo.launch()
|
notebooks/eda.json
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
"notebook_title": "Exploratory data analysis (EDA)",
|
| 3 |
"notebook_type": "eda",
|
| 4 |
"dataset_types": ["numeric", "text"],
|
|
|
|
| 5 |
"notebook_template": [
|
| 6 |
{
|
| 7 |
"cell_type": "markdown",
|
|
|
|
| 2 |
"notebook_title": "Exploratory data analysis (EDA)",
|
| 3 |
"notebook_type": "eda",
|
| 4 |
"dataset_types": ["numeric", "text"],
|
| 5 |
+
"compatible_library": "pandas",
|
| 6 |
"notebook_template": [
|
| 7 |
{
|
| 8 |
"cell_type": "markdown",
|
notebooks/embeddings.json
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
"notebook_title": "Text Embeddings",
|
| 3 |
"notebook_type": "embeddings",
|
| 4 |
"dataset_types": ["text"],
|
|
|
|
| 5 |
"notebook_template": [
|
| 6 |
{
|
| 7 |
"cell_type": "markdown",
|
|
|
|
| 2 |
"notebook_title": "Text Embeddings",
|
| 3 |
"notebook_type": "embeddings",
|
| 4 |
"dataset_types": ["text"],
|
| 5 |
+
"compatible_library": "pandas",
|
| 6 |
"notebook_template": [
|
| 7 |
{
|
| 8 |
"cell_type": "markdown",
|
notebooks/rag.json
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
"notebook_title": "Retrieval-augmented generation (RAG)",
|
| 3 |
"notebook_type": "rag",
|
| 4 |
"dataset_types": ["text"],
|
|
|
|
| 5 |
"notebook_template": [
|
| 6 |
{
|
| 7 |
"cell_type": "markdown",
|
|
|
|
| 2 |
"notebook_title": "Retrieval-augmented generation (RAG)",
|
| 3 |
"notebook_type": "rag",
|
| 4 |
"dataset_types": ["text"],
|
| 5 |
+
"compatible_library": "pandas",
|
| 6 |
"notebook_template": [
|
| 7 |
{
|
| 8 |
"cell_type": "markdown",
|
notebooks/sft.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"notebook_title": "Supervised fine-tuning (SFT)",
|
| 3 |
+
"notebook_type": "sft",
|
| 4 |
+
"dataset_types": ["text"],
|
| 5 |
+
"compatible_library": "datasets",
|
| 6 |
+
"notebook_template": [
|
| 7 |
+
{
|
| 8 |
+
"cell_type": "markdown",
|
| 9 |
+
"source": "---\n# **Supervised fine-tuning Notebook for {dataset_name} dataset**\n---"
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "markdown",
|
| 13 |
+
"source": "## 1. Setup necessary libraries and load the dataset"
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"source": "# Install and import necessary libraries.\n!pip install trl datasets transformers bitsandbytes"
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"source": "from datasets import load_dataset\nfrom trl import SFTTrainer\nfrom transformers import TrainingArguments"
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"source": "# Load the dataset\ndataset = load_dataset('{dataset_name}', name='{first_config}', split='{first_split}')\ndataset"
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"source": "# Specify the column name that will be used for training\ndataset_text_field = '{longest_col}'"
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "markdown",
|
| 33 |
+
"source": "## 2. Configure SFT trainer"
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "code",
|
| 37 |
+
"source": "model_name = 'facebook/opt-350m'\noutput_model_name = f'{model_name}-{dataset_name}'.replace('/', '-')\n\ntrainer = SFTTrainer(\n model = model_name,\n train_dataset=dataset,\n dataset_text_field=dataset_text_field,\n max_seq_length=512,\n args=TrainingArguments(\n per_device_train_batch_size = 1, #Batch size per GPU for training\n gradient_accumulation_steps = 4,\n max_steps = 100, #Total number of training steps.(Overrides epochs)\n learning_rate = 2e-4,\n fp16 = True,\n logging_steps=20,\n output_dir = output_model_name,\n optim = 'paged_adamw_8bit' #Optimizer to use\n )\n)"
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"source": "# Start training\ntrainer.train()"
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "markdown",
|
| 45 |
+
"source": "## 3. Push model to hub"
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"source": "# Authenticate to the Hugging Face Hub\nfrom huggingface_hub import notebook_login\nnotebook_login()"
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"source": "# Push the model to Hugging Face Hub\ntrainer.push_to_hub()"
|
| 54 |
+
}
|
| 55 |
+
]
|
| 56 |
+
}
|
utils/api_utils.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from httpx import Client
|
| 2 |
+
|
| 3 |
+
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
| 4 |
+
HEADERS = {"Accept": "application/json", "Content-Type": "application/json"}
|
| 5 |
+
|
| 6 |
+
client = Client(headers=HEADERS)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_compatible_libraries(dataset: str):
|
| 10 |
+
response = client.get(
|
| 11 |
+
f"{BASE_DATASETS_SERVER_URL}/compatible-libraries?dataset={dataset}"
|
| 12 |
+
)
|
| 13 |
+
response.raise_for_status()
|
| 14 |
+
return response.json()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_first_rows(dataset: str, config: str, split: str):
|
| 18 |
+
resp = client.get(
|
| 19 |
+
f"{BASE_DATASETS_SERVER_URL}/first-rows?dataset={dataset}&config={config}&split={split}"
|
| 20 |
+
)
|
| 21 |
+
resp.raise_for_status()
|
| 22 |
+
content = resp.json()
|
| 23 |
+
rows = content["rows"]
|
| 24 |
+
return [row["row"] for row in rows]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_splits(dataset: str, config: str):
|
| 28 |
+
resp = client.get(
|
| 29 |
+
f"{BASE_DATASETS_SERVER_URL}/splits?dataset={dataset}&config={config}"
|
| 30 |
+
)
|
| 31 |
+
resp.raise_for_status()
|
| 32 |
+
content = resp.json()
|
| 33 |
+
return content["splits"]
|