Spaces:
Paused
Paused
| import os | |
| import json | |
| from langchain.schema import SystemMessage, HumanMessage | |
| from langchain.prompts.chat import ( | |
| HumanMessagePromptTemplate, | |
| SystemMessagePromptTemplate, | |
| ChatPromptTemplate | |
| ) | |
| from langchain.prompts.prompt import PromptTemplate | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from langchain_aws import BedrockEmbeddings | |
| from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse | |
| from langchain_cohere import ChatCohere | |
| from langchain_fireworks.chat_models import ChatFireworks | |
| from langchain_fireworks.embeddings import FireworksEmbeddings | |
| from langchain_groq.chat_models import ChatGroq | |
| from langchain_openai import ChatOpenAI | |
| from langchain_openai.embeddings import OpenAIEmbeddings | |
| from langchain_ollama.chat_models import ChatOllama | |
| from langchain_ollama.embeddings import OllamaEmbeddings | |
| from langchain_cohere.embeddings import CohereEmbeddings | |
| from langchain_cohere.chat_models import ChatCohere | |
| from langchain_openai.embeddings import OpenAIEmbeddings | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings | |
| from langchain_community.chat_models import ChatPerplexity | |
| from langchain_together import ChatTogether | |
| from langchain_together.embeddings import TogetherEmbeddings | |
| def split_provider_model(provider_model): | |
| parts = provider_model.split(':', 1) | |
| provider = parts[0] | |
| model = parts[1] if len(parts) > 1 else None | |
| return provider, model | |
| def get_model(provider_model, temperature=0.0): | |
| provider, model = split_provider_model(provider_model) | |
| match provider: | |
| case 'bedrock': | |
| if model is None: | |
| model = "anthropic.claude-3-sonnet-20240229-v1:0" | |
| chat_llm = ChatBedrockConverse(model=model, temperature=temperature) | |
| case 'cohere': | |
| if model is None: | |
| model = 'command-r-plus' | |
| chat_llm = ChatCohere(model=model, temperature=temperature) | |
| case 'fireworks': | |
| if model is None: | |
| model = 'accounts/fireworks/models/llama-v3p1-8b-instruct' | |
| chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000) | |
| case 'googlegenerativeai': | |
| if model is None: | |
| model = "gemini-1.5-flash" | |
| chat_llm = ChatGoogleGenerativeAI(model=model, temperature=temperature, | |
| max_tokens=None, timeout=None, max_retries=2,) | |
| case 'groq': | |
| if model is None: | |
| model = 'llama-3.1-8b-instant' | |
| chat_llm = ChatGroq(model_name=model, temperature=temperature) | |
| case 'ollama': | |
| if model is None: | |
| model = 'llama3.1' | |
| chat_llm = ChatOllama(model=model, temperature=temperature) | |
| case 'openai': | |
| if model is None: | |
| model = "gpt-4o-mini" | |
| chat_llm = ChatOpenAI(model_name=model, temperature=temperature) | |
| case 'perplexity': | |
| if model is None: | |
| model = 'llama-3.1-sonar-small-128k-online' | |
| chat_llm = ChatPerplexity(model=model, temperature=temperature) | |
| case 'together': | |
| if model is None: | |
| model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo' | |
| chat_llm = ChatTogether(model=model, temperature=temperature) | |
| case _: | |
| raise ValueError(f"Unknown LLM provider {provider}") | |
| return chat_llm | |
| def get_embedding_model(provider_model): | |
| provider, model = split_provider_model(provider_model) | |
| match provider: | |
| case 'bedrock': | |
| if model is None: | |
| model = "amazon.titan-embed-text-v2:0" | |
| embedding_model = BedrockEmbeddings(model_id=model) | |
| case 'cohere': | |
| if model is None: | |
| model = "embed-english-light-v3.0" | |
| embedding_model = CohereEmbeddings(model=model) | |
| case 'fireworks': | |
| if model is None: | |
| model = 'nomic-ai/nomic-embed-text-v1.5' | |
| embedding_model = FireworksEmbeddings(model=model) | |
| case 'ollama': | |
| if model is None: | |
| model = 'nomic-embed-text:latest' | |
| embedding_model = OllamaEmbeddings(model=model) | |
| case 'openai': | |
| if model is None: | |
| model = "text-embedding-3-small" | |
| embedding_model = OpenAIEmbeddings(model=model) | |
| case 'googlegenerativeai': | |
| if model is None: | |
| model = "models/embedding-001" | |
| embedding_model = GoogleGenerativeAIEmbeddings(model=model) | |
| case 'groq': | |
| embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") | |
| case 'perplexity': | |
| raise ValueError(f"Cannot use Perplexity for embedding model") | |
| case 'together': | |
| if model is None: | |
| model = 'togethercomputer/m2-bert-80M-2k-retrieval' | |
| embedding_model = TogetherEmbeddings(model=model) | |
| case _: | |
| raise ValueError(f"Unknown LLM provider {provider}") | |
| return embedding_model | |
| import unittest | |
| from unittest.mock import patch | |
| from models import get_embedding_model # Make sure this import is correct | |
| class TestGetEmbeddingModel(unittest.TestCase): | |
| def test_bedrock_embedding(self, mock_bedrock): | |
| result = get_embedding_model('bedrock') | |
| mock_bedrock.assert_called_once_with(model_id='cohere.embed-multilingual-v3') | |
| self.assertEqual(result, mock_bedrock.return_value) | |
| def test_cohere_embedding(self, mock_cohere): | |
| result = get_embedding_model('cohere') | |
| mock_cohere.assert_called_once_with(model='embed-english-light-v3.0') | |
| self.assertEqual(result, mock_cohere.return_value) | |
| def test_fireworks_embedding(self, mock_fireworks): | |
| result = get_embedding_model('fireworks') | |
| mock_fireworks.assert_called_once_with(model='nomic-ai/nomic-embed-text-v1.5') | |
| self.assertEqual(result, mock_fireworks.return_value) | |
| def test_ollama_embedding(self, mock_ollama): | |
| result = get_embedding_model('ollama') | |
| mock_ollama.assert_called_once_with(model='nomic-embed-text:latest') | |
| self.assertEqual(result, mock_ollama.return_value) | |
| def test_openai_embedding(self, mock_openai): | |
| result = get_embedding_model('openai') | |
| mock_openai.assert_called_once_with(model='text-embedding-3-small') | |
| self.assertEqual(result, mock_openai.return_value) | |
| def test_google_embedding(self, mock_google): | |
| result = get_embedding_model('googlegenerativeai') | |
| mock_google.assert_called_once_with(model='models/embedding-001') | |
| self.assertEqual(result, mock_google.return_value) | |
| def test_together_embedding(self, mock_together): | |
| result = get_embedding_model('together') | |
| mock_together.assert_called_once_with(model='BAAI/bge-base-en-v1.5') | |
| self.assertEqual(result, mock_together.return_value) | |
| def test_invalid_provider(self): | |
| with self.assertRaises(ValueError): | |
| get_embedding_model('invalid_provider') | |
| def test_groq_provider(self): | |
| with self.assertRaises(ValueError): | |
| get_embedding_model('groq') | |
| def test_perplexity_provider(self): | |
| with self.assertRaises(ValueError): | |
| get_embedding_model('perplexity') | |
| import unittest | |
| from unittest.mock import patch | |
| from models import get_model # Make sure this import is correct | |
| class TestGetModel(unittest.TestCase): | |
| def test_bedrock_model(self, mock_bedrock): | |
| result = get_model('bedrock') | |
| mock_bedrock.assert_called_once_with( | |
| model="anthropic.claude-3-sonnet-20240229-v1:0", | |
| temperature=0.0 | |
| ) | |
| self.assertEqual(result, mock_bedrock.return_value) | |
| def test_cohere_model(self, mock_cohere): | |
| result = get_model('cohere') | |
| mock_cohere.assert_called_once_with(model='command-r-plus', temperature=0.0) | |
| self.assertEqual(result, mock_cohere.return_value) | |
| def test_fireworks_model(self, mock_fireworks): | |
| result = get_model('fireworks') | |
| mock_fireworks.assert_called_once_with( | |
| model_name='accounts/fireworks/models/llama-v3p1-8b-instruct', | |
| temperature=0.0, | |
| max_tokens=120000 | |
| ) | |
| self.assertEqual(result, mock_fireworks.return_value) | |
| def test_google_model(self, mock_google): | |
| result = get_model('googlegenerativeai') | |
| mock_google.assert_called_once_with( | |
| model="gemini-1.5-pro", | |
| temperature=0.0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2 | |
| ) | |
| self.assertEqual(result, mock_google.return_value) | |
| def test_groq_model(self, mock_groq): | |
| result = get_model('groq') | |
| mock_groq.assert_called_once_with(model_name='llama-3.1-8b-instant', temperature=0.0) | |
| self.assertEqual(result, mock_groq.return_value) | |
| def test_ollama_model(self, mock_ollama): | |
| result = get_model('ollama') | |
| mock_ollama.assert_called_once_with(model='llama3.1', temperature=0.0) | |
| self.assertEqual(result, mock_ollama.return_value) | |
| def test_openai_model(self, mock_openai): | |
| result = get_model('openai') | |
| mock_openai.assert_called_once_with(model_name='gpt-4o-mini', temperature=0.0) | |
| self.assertEqual(result, mock_openai.return_value) | |
| def test_perplexity_model(self, mock_perplexity): | |
| result = get_model('perplexity') | |
| mock_perplexity.assert_called_once_with(model='llama-3.1-sonar-small-128k-online', temperature=0.0) | |
| self.assertEqual(result, mock_perplexity.return_value) | |
| def test_together_model(self, mock_together): | |
| result = get_model('together') | |
| mock_together.assert_called_once_with(model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', temperature=0.0) | |
| self.assertEqual(result, mock_together.return_value) | |
| def test_invalid_provider(self): | |
| with self.assertRaises(ValueError): | |
| get_model('invalid_provider') | |
| def test_custom_temperature(self): | |
| with patch('models.ChatOpenAI') as mock_openai: | |
| result = get_model('openai', temperature=0.5) | |
| mock_openai.assert_called_once_with(model_name='gpt-4o-mini', temperature=0.5) | |
| self.assertEqual(result, mock_openai.return_value) | |
| def test_custom_model(self): | |
| with patch('models.ChatOpenAI') as mock_openai: | |
| result = get_model('openai/gpt-4') | |
| mock_openai.assert_called_once_with(model_name='gpt-4', temperature=0.0) | |
| self.assertEqual(result, mock_openai.return_value) | |
| if __name__ == '__main__': | |
| unittest.main() | |