import ctxLengthData from "$lib/data/context_length.json"; import { token } from "$lib/state/token.svelte"; import { isCustomModel, isHFModel, type Conversation, type ConversationMessage, type CustomModel, type Model, } from "$lib/types.js"; import { tryGet } from "$lib/utils/object.js"; import { HfInference, snippets, type InferenceProvider } from "@huggingface/inference"; import type { ChatCompletionInputMessage, InferenceSnippet } from "@huggingface/tasks"; import { type ChatCompletionOutputMessage } from "@huggingface/tasks"; import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import OpenAI from "openai"; type ChatCompletionInputMessageChunk = NonNullable extends string | (infer U)[] ? U : never; function parseMessage(message: ConversationMessage): ChatCompletionInputMessage { if (!message.images) return message; return { ...message, content: [ { type: "text", text: message.content ?? "", }, ...message.images.map(img => { return { type: "image_url", image_url: { url: img }, } satisfies ChatCompletionInputMessageChunk; }), ], }; } type HFCompletionMetadata = { type: "huggingface"; client: HfInference; args: Parameters[0]; }; type OpenAICompletionMetadata = { type: "openai"; client: OpenAI; args: OpenAI.ChatCompletionCreateParams; }; type CompletionMetadata = HFCompletionMetadata | OpenAICompletionMetadata; export function maxAllowedTokens(conversation: Conversation) { const ctxLength = (() => { const { provider, model } = conversation; if (!provider || !isHFModel(model)) return; const idOnProvider = model.inferenceProviderMapping.find(data => data.provider === provider)?.providerId; if (!idOnProvider) return; const models = tryGet(ctxLengthData, provider); if (!models) return; return tryGet(models, idOnProvider) as number | undefined; })(); if (!ctxLength) return customMaxTokens[conversation.model.id] ?? 100000; return ctxLength; } function getCompletionMetadata(conversation: Conversation, signal?: AbortSignal): CompletionMetadata { const { model, systemMessage } = conversation; const messages = [ ...(isSystemPromptSupported(model) && systemMessage.content?.length ? [systemMessage] : []), ...conversation.messages, ]; // Handle OpenAI-compatible models if (isCustomModel(model)) { const openai = new OpenAI({ apiKey: model.accessToken, baseURL: model.endpointUrl, dangerouslyAllowBrowser: true, fetch: (...args: Parameters) => { return fetch(args[0], { ...args[1], signal }); }, }); return { type: "openai", client: openai, args: { messages: messages.map(parseMessage) as OpenAI.ChatCompletionMessageParam[], ...conversation.config, model: model.id, }, }; } // Handle HuggingFace models return { type: "huggingface", client: new HfInference(token.value), args: { model: model.id, messages: messages.map(parseMessage), provider: conversation.provider, ...conversation.config, // max_tokens: maxAllowedTokens(conversation) - currTokens, }, }; } export async function handleStreamingResponse( conversation: Conversation, onChunk: (content: string) => void, abortController: AbortController ): Promise { const metadata = getCompletionMetadata(conversation, abortController.signal); if (metadata.type === "openai") { const stream = await metadata.client.chat.completions.create({ ...metadata.args, stream: true, } as OpenAI.ChatCompletionCreateParamsStreaming); let out = ""; for await (const chunk of stream) { if (chunk.choices[0]?.delta?.content) { out += chunk.choices[0].delta.content; onChunk(out); } } return; } // HuggingFace streaming let out = ""; for await (const chunk of metadata.client.chatCompletionStream(metadata.args, { signal: abortController.signal })) { if (chunk.choices && chunk.choices.length > 0 && chunk.choices[0]?.delta?.content) { out += chunk.choices[0].delta.content; onChunk(out); } } } export async function handleNonStreamingResponse( conversation: Conversation ): Promise<{ message: ChatCompletionOutputMessage; completion_tokens: number }> { const metadata = getCompletionMetadata(conversation); if (metadata.type === "openai") { const response = await metadata.client.chat.completions.create({ ...metadata.args, stream: false, } as OpenAI.ChatCompletionCreateParamsNonStreaming); if (response.choices && response.choices.length > 0 && response.choices[0]?.message) { return { message: { role: "assistant", content: response.choices[0].message.content || "", }, completion_tokens: response.usage?.completion_tokens || 0, }; } throw new Error("No response from the model"); } // HuggingFace non-streaming const response = await metadata.client.chatCompletion(metadata.args); if (response.choices && response.choices.length > 0) { const { message } = response.choices[0]!; const { completion_tokens } = response.usage; return { message, completion_tokens }; } throw new Error("No response from the model"); } export function isSystemPromptSupported(model: Model | CustomModel) { if (isCustomModel(model)) return true; // OpenAI-compatible models support system messages const template = model?.config.tokenizer_config?.chat_template; if (typeof template !== "string") return false; return template.includes("system"); } export const defaultSystemMessage: { [key: string]: string } = { "Qwen/QwQ-32B-Preview": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", } as const; export const customMaxTokens: { [key: string]: number } = { "01-ai/Yi-1.5-34B-Chat": 2048, "HuggingFaceM4/idefics-9b-instruct": 2048, "deepseek-ai/DeepSeek-Coder-V2-Instruct": 16384, "bigcode/starcoder": 8192, "bigcode/starcoderplus": 8192, "HuggingFaceH4/starcoderbase-finetuned-oasst1": 8192, "google/gemma-7b": 8192, "google/gemma-1.1-7b-it": 8192, "google/gemma-2b": 8192, "google/gemma-1.1-2b-it": 8192, "google/gemma-2-27b-it": 8192, "google/gemma-2-9b-it": 4096, "google/gemma-2-2b-it": 8192, "tiiuae/falcon-7b": 8192, "tiiuae/falcon-7b-instruct": 8192, "timdettmers/guanaco-33b-merged": 2048, "mistralai/Mixtral-8x7B-Instruct-v0.1": 32768, "Qwen/Qwen2.5-72B-Instruct": 32768, "Qwen/Qwen2.5-Coder-32B-Instruct": 32768, "meta-llama/Meta-Llama-3-70B-Instruct": 8192, "CohereForAI/c4ai-command-r-plus-08-2024": 32768, "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768, "meta-llama/Llama-2-70b-chat-hf": 8192, "HuggingFaceH4/zephyr-7b-alpha": 17432, "HuggingFaceH4/zephyr-7b-beta": 32768, "mistralai/Mistral-7B-Instruct-v0.1": 32768, "mistralai/Mistral-7B-Instruct-v0.2": 32768, "mistralai/Mistral-7B-Instruct-v0.3": 32768, "mistralai/Mistral-Nemo-Instruct-2407": 32768, "meta-llama/Meta-Llama-3-8B-Instruct": 8192, "mistralai/Mistral-7B-v0.1": 32768, "bigcode/starcoder2-3b": 16384, "bigcode/starcoder2-15b": 16384, "HuggingFaceH4/starchat2-15b-v0.1": 16384, "codellama/CodeLlama-7b-hf": 8192, "codellama/CodeLlama-13b-hf": 8192, "codellama/CodeLlama-34b-Instruct-hf": 8192, "meta-llama/Llama-2-7b-chat-hf": 8192, "meta-llama/Llama-2-13b-chat-hf": 8192, "OpenAssistant/oasst-sft-6-llama-30b": 2048, "TheBloke/vicuna-7B-v1.5-GPTQ": 2048, "HuggingFaceH4/starchat-beta": 8192, "bigcode/octocoder": 8192, "vwxyzjn/starcoderbase-triviaqa": 8192, "lvwerra/starcoderbase-gsm8k": 8192, "NousResearch/Hermes-3-Llama-3.1-8B": 16384, "microsoft/Phi-3.5-mini-instruct": 32768, "meta-llama/Llama-3.1-70B-Instruct": 32768, "meta-llama/Llama-3.1-8B-Instruct": 8192, } as const; // Order of the elements in InferenceModal.svelte is determined by this const export const inferenceSnippetLanguages = ["python", "js", "curl"] as const; export type InferenceSnippetLanguage = (typeof inferenceSnippetLanguages)[number]; const GET_SNIPPET_FN = { curl: snippets.curl.getCurlInferenceSnippet, js: snippets.js.getJsInferenceSnippet, python: snippets.python.getPythonInferenceSnippet, } as const; export type GetInferenceSnippetReturn = (InferenceSnippet & { language: InferenceSnippetLanguage })[]; export function getInferenceSnippet( model: Model, provider: InferenceProvider, language: InferenceSnippetLanguage, accessToken: string, opts?: Record ): GetInferenceSnippetReturn { // If it's a custom model, we don't generate inference snippets if (isCustomModel(model)) { return []; } const providerId = model.inferenceProviderMapping.find(p => p.provider === provider)?.providerId; const snippetsByClient = GET_SNIPPET_FN[language]( { ...model, inference: "" }, accessToken, provider, providerId, opts ); return snippetsByClient.map(snippetByClient => ({ ...snippetByClient, language })); } /** * - If language is defined, the function checks if in an inference snippet is available for that specific language */ export function hasInferenceSnippet( model: Model, provider: InferenceProvider, language: InferenceSnippetLanguage ): boolean { if (isCustomModel(model)) return false; return getInferenceSnippet(model, provider, language, "").length > 0; } const tokenizers = new Map(); export async function getTokenizer(model: Model) { if (tokenizers.has(model.id)) return tokenizers.get(model.id)!; try { const tokenizer = await AutoTokenizer.from_pretrained(model.id); tokenizers.set(model.id, tokenizer); return tokenizer; } catch { tokenizers.set(model.id, null); return null; } } // When you don't have access to a tokenizer, guesstimate export function estimateTokens(conversation: Conversation) { const content = conversation.messages.reduce((acc, curr) => { return acc + (curr?.content ?? ""); }, ""); return content.length / 4; // 1 token ~ 4 characters } export async function getTokens(conversation: Conversation): Promise { const model = conversation.model; if (isCustomModel(model)) return estimateTokens(conversation); const tokenizer = await getTokenizer(model); if (tokenizer === null) return estimateTokens(conversation); // This is a simplified version - you might need to adjust based on your exact needs let formattedText = ""; conversation.messages.forEach((message, index) => { let content = `<|start_header_id|>${message.role}<|end_header_id|>\n\n${message.content?.trim()}<|eot_id|>`; // Add BOS token to the first message if (index === 0) { content = "<|begin_of_text|>" + content; } formattedText += content; }); // Encode the text to get tokens const encodedInput = tokenizer.encode(formattedText); // Return the number of tokens return encodedInput.length; }