Spaces:
Paused
Paused
Eddie Pick
commited on
Commit
·
bb1d601
1
Parent(s):
9438062
Updates
Browse files- .gitignore +1 -0
- models.py +107 -63
.gitignore
CHANGED
|
@@ -121,6 +121,7 @@ celerybeat.pid
|
|
| 121 |
|
| 122 |
# Environments
|
| 123 |
.env
|
|
|
|
| 124 |
.venv
|
| 125 |
env/
|
| 126 |
venv/
|
|
|
|
| 121 |
|
| 122 |
# Environments
|
| 123 |
.env
|
| 124 |
+
.env.local
|
| 125 |
.venv
|
| 126 |
env/
|
| 127 |
venv/
|
models.py
CHANGED
|
@@ -1,11 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
-
from langchain.schema import SystemMessage, HumanMessage
|
| 4 |
-
from langchain.prompts.chat import (
|
| 5 |
-
HumanMessagePromptTemplate,
|
| 6 |
-
SystemMessagePromptTemplate,
|
| 7 |
-
ChatPromptTemplate
|
| 8 |
-
)
|
| 9 |
from langchain.prompts.prompt import PromptTemplate
|
| 10 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 11 |
|
|
@@ -17,83 +11,122 @@ from langchain_fireworks.embeddings import FireworksEmbeddings
|
|
| 17 |
from langchain_groq.chat_models import ChatGroq
|
| 18 |
from langchain_openai import ChatOpenAI
|
| 19 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
|
|
|
|
|
|
|
|
| 20 |
from langchain_ollama.chat_models import ChatOllama
|
| 21 |
from langchain_ollama.embeddings import OllamaEmbeddings
|
| 22 |
from langchain_cohere.embeddings import CohereEmbeddings
|
| 23 |
from langchain_cohere.chat_models import ChatCohere
|
|
|
|
|
|
|
| 24 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 25 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 26 |
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
| 27 |
from langchain_community.chat_models import ChatPerplexity
|
| 28 |
from langchain_together import ChatTogether
|
| 29 |
from langchain_together.embeddings import TogetherEmbeddings
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
def split_provider_model(provider_model):
|
| 32 |
-
parts = provider_model.split(
|
| 33 |
provider = parts[0]
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
return provider, model
|
| 36 |
|
| 37 |
-
def get_model(provider_model, temperature=0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
provider, model = split_provider_model(provider_model)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
model
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
model
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
model
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
model
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
model
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
model
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
model =
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
return chat_llm
|
| 85 |
|
| 86 |
|
| 87 |
-
def get_embedding_model(provider_model):
|
| 88 |
provider, model = split_provider_model(provider_model)
|
| 89 |
-
match provider:
|
| 90 |
case 'bedrock':
|
| 91 |
if model is None:
|
| 92 |
model = "amazon.titan-embed-text-v2:0"
|
| 93 |
embedding_model = BedrockEmbeddings(model_id=model)
|
| 94 |
case 'cohere':
|
| 95 |
if model is None:
|
| 96 |
-
model = "embed-
|
| 97 |
embedding_model = CohereEmbeddings(model=model)
|
| 98 |
case 'fireworks':
|
| 99 |
if model is None:
|
|
@@ -113,6 +146,14 @@ def get_embedding_model(provider_model):
|
|
| 113 |
embedding_model = GoogleGenerativeAIEmbeddings(model=model)
|
| 114 |
case 'groq':
|
| 115 |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
case 'perplexity':
|
| 117 |
raise ValueError(f"Cannot use Perplexity for embedding model")
|
| 118 |
case 'together':
|
|
@@ -193,12 +234,15 @@ from models import get_model # Make sure this import is correct
|
|
| 193 |
class TestGetModel(unittest.TestCase):
|
| 194 |
|
| 195 |
@patch('models.ChatBedrockConverse')
|
| 196 |
-
def
|
| 197 |
result = get_model('bedrock')
|
| 198 |
-
mock_bedrock.assert_called_once_with(
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
| 202 |
self.assertEqual(result, mock_bedrock.return_value)
|
| 203 |
|
| 204 |
@patch('models.ChatCohere')
|
|
|
|
| 1 |
import os
|
| 2 |
+
from typing import Tuple, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from langchain.prompts.prompt import PromptTemplate
|
| 4 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 5 |
|
|
|
|
| 11 |
from langchain_groq.chat_models import ChatGroq
|
| 12 |
from langchain_openai import ChatOpenAI
|
| 13 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 14 |
+
from langchain_anthropic.chat_models import ChatAnthropic
|
| 15 |
+
from langchain_mistralai.chat_models import ChatMistralAI
|
| 16 |
+
from langchain_mistralai.embeddings import MistralAIEmbeddings
|
| 17 |
from langchain_ollama.chat_models import ChatOllama
|
| 18 |
from langchain_ollama.embeddings import OllamaEmbeddings
|
| 19 |
from langchain_cohere.embeddings import CohereEmbeddings
|
| 20 |
from langchain_cohere.chat_models import ChatCohere
|
| 21 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
| 22 |
+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
| 23 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 24 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 25 |
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
| 26 |
from langchain_community.chat_models import ChatPerplexity
|
| 27 |
from langchain_together import ChatTogether
|
| 28 |
from langchain_together.embeddings import TogetherEmbeddings
|
| 29 |
+
from langchain.chat_models.base import BaseChatModel
|
| 30 |
+
from langchain.embeddings.base import Embeddings
|
| 31 |
|
| 32 |
+
def split_provider_model(provider_model: str) -> Tuple[str, Optional[str]]:
|
| 33 |
+
parts = provider_model.split(":", 1)
|
| 34 |
provider = parts[0]
|
| 35 |
+
if len(parts) > 1:
|
| 36 |
+
model = parts[1] if parts[1] else None
|
| 37 |
+
else:
|
| 38 |
+
model = None
|
| 39 |
return provider, model
|
| 40 |
|
| 41 |
+
def get_model(provider_model: str, temperature: float = 0.7) -> BaseChatModel:
|
| 42 |
+
"""
|
| 43 |
+
Get a model from a provider and model name.
|
| 44 |
+
returns BaseChatModel
|
| 45 |
+
"""
|
| 46 |
provider, model = split_provider_model(provider_model)
|
| 47 |
+
try:
|
| 48 |
+
match provider.lower():
|
| 49 |
+
case 'anthropic':
|
| 50 |
+
if model is None:
|
| 51 |
+
model = "claude-3-sonnet-20240229"
|
| 52 |
+
chat_llm = ChatAnthropic(model=model, temperature=temperature)
|
| 53 |
+
case 'bedrock':
|
| 54 |
+
if model is None:
|
| 55 |
+
model = "us.anthropic.claude-3-5-haiku-20241022-v1:0"
|
| 56 |
+
chat_llm = ChatBedrockConverse(model=model, temperature=temperature)
|
| 57 |
+
case 'cohere':
|
| 58 |
+
if model is None:
|
| 59 |
+
model = 'command-r-plus'
|
| 60 |
+
chat_llm = ChatCohere(model=model, temperature=temperature)
|
| 61 |
+
case 'fireworks':
|
| 62 |
+
if model is None:
|
| 63 |
+
model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
|
| 64 |
+
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
|
| 65 |
+
case 'googlegenerativeai':
|
| 66 |
+
if model is None:
|
| 67 |
+
model = "gemini-1.5-flash"
|
| 68 |
+
chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature,
|
| 69 |
+
max_tokens=None, timeout=None, max_retries=2,)
|
| 70 |
+
case 'groq':
|
| 71 |
+
if model is None:
|
| 72 |
+
model = 'llama-3.1-8b-instant'
|
| 73 |
+
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
| 74 |
+
case 'huggingface' | 'hf':
|
| 75 |
+
if model is None:
|
| 76 |
+
model = 'mistralai/Mistral-Nemo-Instruct-2407'
|
| 77 |
+
llm = HuggingFaceEndpoint(
|
| 78 |
+
repo_id=model,
|
| 79 |
+
max_length=8192,
|
| 80 |
+
temperature=temperature,
|
| 81 |
+
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY"),
|
| 82 |
+
)
|
| 83 |
+
chat_llm = ChatHuggingFace(llm=llm)
|
| 84 |
+
case 'ollama':
|
| 85 |
+
if model is None:
|
| 86 |
+
model = 'llama3.1'
|
| 87 |
+
chat_llm = ChatOllama(model=model, temperature=temperature)
|
| 88 |
+
case 'openai':
|
| 89 |
+
if model is None:
|
| 90 |
+
model = "gpt-4o-mini"
|
| 91 |
+
chat_llm = ChatOpenAI(model=model, temperature=temperature)
|
| 92 |
+
case 'openrouter':
|
| 93 |
+
if model is None:
|
| 94 |
+
model = "google/gemini-flash-1.5-exp"
|
| 95 |
+
chat_llm = ChatOpenAI(model=model, temperature=temperature, base_url="https://openrouter.ai/api/v1", api_key=os.getenv("OPENROUTER_API_KEY"))
|
| 96 |
+
case 'mistralai' | 'mistral':
|
| 97 |
+
if model is None:
|
| 98 |
+
model = "open-mistral-nemo"
|
| 99 |
+
chat_llm = ChatMistralAI(model=model, temperature=temperature)
|
| 100 |
+
case 'perplexity':
|
| 101 |
+
if model is None:
|
| 102 |
+
model = 'llama-3.1-sonar-small-128k-online'
|
| 103 |
+
chat_llm = ChatPerplexity(model=model, temperature=temperature)
|
| 104 |
+
case 'together':
|
| 105 |
+
if model is None:
|
| 106 |
+
model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
|
| 107 |
+
chat_llm = ChatTogether(model=model, temperature=temperature)
|
| 108 |
+
case 'xai':
|
| 109 |
+
if model is None:
|
| 110 |
+
model = 'grok-beta'
|
| 111 |
+
chat_llm = ChatOpenAI(model=model,api_key=os.getenv("XAI_API_KEY"), base_url="https://api.x.ai/v1", temperature=temperature)
|
| 112 |
+
case _:
|
| 113 |
+
raise ValueError(f"Unknown LLM provider {provider}")
|
| 114 |
+
except Exception as e:
|
| 115 |
+
raise ValueError(f"Unexpected error with {provider}: {str(e)}")
|
| 116 |
|
| 117 |
return chat_llm
|
| 118 |
|
| 119 |
|
| 120 |
+
def get_embedding_model(provider_model: str) -> Embeddings:
|
| 121 |
provider, model = split_provider_model(provider_model)
|
| 122 |
+
match provider.lower():
|
| 123 |
case 'bedrock':
|
| 124 |
if model is None:
|
| 125 |
model = "amazon.titan-embed-text-v2:0"
|
| 126 |
embedding_model = BedrockEmbeddings(model_id=model)
|
| 127 |
case 'cohere':
|
| 128 |
if model is None:
|
| 129 |
+
model = "embed-multilingual-v3"
|
| 130 |
embedding_model = CohereEmbeddings(model=model)
|
| 131 |
case 'fireworks':
|
| 132 |
if model is None:
|
|
|
|
| 146 |
embedding_model = GoogleGenerativeAIEmbeddings(model=model)
|
| 147 |
case 'groq':
|
| 148 |
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
|
| 149 |
+
case 'huggingface' | 'hf':
|
| 150 |
+
if model is None:
|
| 151 |
+
model = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 152 |
+
embedding_model = HuggingFaceInferenceAPIEmbeddings(model_name=model, api_key=os.getenv("HUGGINGFACE_API_KEY"))
|
| 153 |
+
case 'mistral':
|
| 154 |
+
if model is None:
|
| 155 |
+
model = "mistral-embed"
|
| 156 |
+
embedding_model = MistralAIEmbeddings(model=model)
|
| 157 |
case 'perplexity':
|
| 158 |
raise ValueError(f"Cannot use Perplexity for embedding model")
|
| 159 |
case 'together':
|
|
|
|
| 234 |
class TestGetModel(unittest.TestCase):
|
| 235 |
|
| 236 |
@patch('models.ChatBedrockConverse')
|
| 237 |
+
def test_bedrock_model_no_specific_model(self, mock_bedrock):
|
| 238 |
result = get_model('bedrock')
|
| 239 |
+
mock_bedrock.assert_called_once_with(model=None, temperature=0.0)
|
| 240 |
+
self.assertEqual(result, mock_bedrock.return_value)
|
| 241 |
+
|
| 242 |
+
@patch('models.ChatBedrockConverse')
|
| 243 |
+
def test_bedrock_model_with_specific_model(self, mock_bedrock):
|
| 244 |
+
result = get_model('bedrock:specific-model')
|
| 245 |
+
mock_bedrock.assert_called_once_with(model='specific-model', temperature=0.0)
|
| 246 |
self.assertEqual(result, mock_bedrock.return_value)
|
| 247 |
|
| 248 |
@patch('models.ChatCohere')
|