Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
d6be5fa
1
Parent(s):
77b89d7
Download model during app startup
Browse files- app.py +10 -6
- main.py +5 -3
- mods/langchain_chroma.py +1 -1
app.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from main import GetChatModel
|
| 3 |
from graph import BuildGraph
|
| 4 |
from retriever import db_dir
|
| 5 |
-
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
-
from dotenv import load_dotenv
|
| 7 |
-
from main import openai_model, model_id
|
| 8 |
from util import get_sources, get_start_end_months
|
| 9 |
from mods.tool_calling_llm import extract_think
|
|
|
|
|
|
|
|
|
|
| 10 |
import requests
|
| 11 |
import zipfile
|
| 12 |
import shutil
|
|
@@ -19,10 +19,14 @@ import ast
|
|
| 19 |
import os
|
| 20 |
import re
|
| 21 |
|
| 22 |
-
|
| 23 |
# Setup environment variables
|
| 24 |
load_dotenv(dotenv_path=".env", override=True)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Global setting for search type
|
| 27 |
search_type = "hybrid"
|
| 28 |
|
|
@@ -86,7 +90,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
| 86 |
title=f"Model loading...",
|
| 87 |
)
|
| 88 |
# Get the chat model and build the graph
|
| 89 |
-
chat_model = GetChatModel(compute_mode)
|
| 90 |
graph_builder = BuildGraph(
|
| 91 |
chat_model, compute_mode, search_type, think_answer=True
|
| 92 |
)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from main import GetChatModel, openai_model, model_id
|
| 3 |
from graph import BuildGraph
|
| 4 |
from retriever import db_dir
|
|
|
|
|
|
|
|
|
|
| 5 |
from util import get_sources, get_start_end_months
|
| 6 |
from mods.tool_calling_llm import extract_think
|
| 7 |
+
from huggingface_hub import snapshot_download
|
| 8 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
import requests
|
| 11 |
import zipfile
|
| 12 |
import shutil
|
|
|
|
| 19 |
import os
|
| 20 |
import re
|
| 21 |
|
|
|
|
| 22 |
# Setup environment variables
|
| 23 |
load_dotenv(dotenv_path=".env", override=True)
|
| 24 |
|
| 25 |
+
# Download model snapshots from Hugging Face Hub
|
| 26 |
+
print(f"Downloading/loading checkpoints for {model_id}...")
|
| 27 |
+
ckpt_dir = snapshot_download(model_id, local_dir_use_symlinks=False)
|
| 28 |
+
print(f"Using checkpoints from {ckpt_dir}")
|
| 29 |
+
|
| 30 |
# Global setting for search type
|
| 31 |
search_type = "hybrid"
|
| 32 |
|
|
|
|
| 90 |
title=f"Model loading...",
|
| 91 |
)
|
| 92 |
# Get the chat model and build the graph
|
| 93 |
+
chat_model = GetChatModel(compute_mode, ckpt_dir)
|
| 94 |
graph_builder = BuildGraph(
|
| 95 |
chat_model, compute_mode, search_type, think_answer=True
|
| 96 |
)
|
main.py
CHANGED
|
@@ -128,12 +128,13 @@ def ProcessDirectory(path, compute_mode):
|
|
| 128 |
print(f"Chroma: no change for {file_path}")
|
| 129 |
|
| 130 |
|
| 131 |
-
def GetChatModel(compute_mode):
|
| 132 |
"""
|
| 133 |
Get a chat model.
|
| 134 |
|
| 135 |
Args:
|
| 136 |
compute_mode: Compute mode for chat model (remote or local)
|
|
|
|
| 137 |
"""
|
| 138 |
|
| 139 |
if compute_mode == "remote":
|
|
@@ -148,9 +149,10 @@ def GetChatModel(compute_mode):
|
|
| 148 |
|
| 149 |
# Define the pipeline to pass to the HuggingFacePipeline class
|
| 150 |
# https://huggingface.co/blog/langchain
|
| 151 |
-
|
|
|
|
| 152 |
model = AutoModelForCausalLM.from_pretrained(
|
| 153 |
-
|
| 154 |
# We need this to load the model in BF16 instead of fp32 (torch.float)
|
| 155 |
torch_dtype=torch.bfloat16,
|
| 156 |
)
|
|
|
|
| 128 |
print(f"Chroma: no change for {file_path}")
|
| 129 |
|
| 130 |
|
| 131 |
+
def GetChatModel(compute_mode, ckpt_dir=None):
|
| 132 |
"""
|
| 133 |
Get a chat model.
|
| 134 |
|
| 135 |
Args:
|
| 136 |
compute_mode: Compute mode for chat model (remote or local)
|
| 137 |
+
ckpt_dir: Checkpoint directory for model weights (optional)
|
| 138 |
"""
|
| 139 |
|
| 140 |
if compute_mode == "remote":
|
|
|
|
| 149 |
|
| 150 |
# Define the pipeline to pass to the HuggingFacePipeline class
|
| 151 |
# https://huggingface.co/blog/langchain
|
| 152 |
+
id_or_dir = ckpt_dir if ckpt_dir else model_id
|
| 153 |
+
tokenizer = AutoTokenizer.from_pretrained(id_or_dir)
|
| 154 |
model = AutoModelForCausalLM.from_pretrained(
|
| 155 |
+
id_or_dir,
|
| 156 |
# We need this to load the model in BF16 instead of fp32 (torch.float)
|
| 157 |
torch_dtype=torch.bfloat16,
|
| 158 |
)
|
mods/langchain_chroma.py
CHANGED
|
@@ -470,7 +470,7 @@ class Chroma(VectorStore):
|
|
| 470 |
|
| 471 |
See more: https://docs.trychroma.com/reference/py-collection#query
|
| 472 |
"""
|
| 473 |
-
#
|
| 474 |
# https://github.com/langchain-ai/langchain/issues/26884
|
| 475 |
chromadb.api.client.SharedSystemClient.clear_system_cache()
|
| 476 |
return self._collection.query(
|
|
|
|
| 470 |
|
| 471 |
See more: https://docs.trychroma.com/reference/py-collection#query
|
| 472 |
"""
|
| 473 |
+
# Possible fix for ValueError('Could not connect to tenant default_tenant. Are you sure it exists?')
|
| 474 |
# https://github.com/langchain-ai/langchain/issues/26884
|
| 475 |
chromadb.api.client.SharedSystemClient.clear_system_cache()
|
| 476 |
return self._collection.query(
|