Spaces:
Running
Running
| import * as webllm from "https://esm.run/@mlc-ai/web-llm"; | |
| const messages = [ | |
| { | |
| content: "You are a helpful AI agent assisting users. And your name is 'Kanha'", | |
| role: "system", | |
| }, | |
| ]; | |
| const modelLibURLPrefix = "https://huggingface.co/Kanha-AI/"; | |
| const modelVersion = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC"; | |
| const appConfig = { | |
| model_list: [ | |
| { | |
| model: "https://huggingface.co/Kanha-AI/llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC", | |
| model_id: "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC", | |
| model_lib: | |
| webllm.modelLibURLPrefix + webllm.modelVersion + "/Llama-3.2-1B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm", | |
| vram_required_MB: 3672.07, | |
| low_resource_required: false, | |
| overrides: { | |
| context_window_size: 4096, | |
| }, | |
| }, | |
| ], | |
| }; | |
| let selectedModel = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC"; | |
| let engine = null; | |
| let isInitializing = false; | |
| async function createEngine() { | |
| if (engine) return engine; | |
| engine = await webllm.CreateMLCEngine(selectedModel, { appConfig: appConfig }); | |
| engine.setInitProgressCallback(updateEngineInitProgressCallback); | |
| return engine; | |
| } | |
| function updateEngineInitProgressCallback(report) { | |
| console.log("initialize", report.progress); | |
| const statusElement = document.getElementById("download-status"); | |
| statusElement.textContent = report.text; | |
| statusElement.classList.remove("hidden"); | |
| const progressBar = document.getElementById("progress-bar"); | |
| progressBar.style.width = `${report.progress * 100}%`; | |
| } | |
| function getPlatform() { | |
| const userAgent = navigator.userAgent || navigator.vendor || window.opera; | |
| if (/android/i.test(userAgent)) { | |
| return "Android"; | |
| } | |
| if (/iPad|iPhone|iPod/.test(userAgent) && !window.MSStream) { | |
| return "iOS"; | |
| } | |
| return "Other"; | |
| } | |
| function getWebGPUInstructions(platform) { | |
| switch (platform) { | |
| case "Android": | |
| return "To enable WebGPU on Android:\n1. Go to chrome://flags\n2. Enable 'WebGPU Developer Features' and 'Unsafe WebGPU Support'\n3. Restart your browser"; | |
| case "iOS": | |
| return "To enable WebGPU on iOS:\n1. Open Settings\n2. Tap Safari\n3. Tap Advanced\n4. Tap Feature Flags\n5. Turn on WebGPU"; | |
| default: | |
| return "WebGPU might not be supported on your device. Please check if your browser is up to date."; | |
| } | |
| } | |
| async function checkWebGPUSupport() { | |
| if (!navigator.gpu) { | |
| const platform = getPlatform(); | |
| const instructions = getWebGPUInstructions(platform); | |
| throw new Error(`WebGPU is not supported in this browser. ${instructions}`); | |
| } | |
| const adapter = await navigator.gpu.requestAdapter(); | |
| if (!adapter) { | |
| throw new Error("Couldn't request WebGPU adapter. Please make sure WebGPU is enabled on your device."); | |
| } | |
| const device = await adapter.requestDevice(); | |
| if (!device) { | |
| throw new Error("Couldn't request WebGPU device. Please make sure WebGPU is enabled on your device."); | |
| } | |
| return true; | |
| } | |
| async function checkRAM() { | |
| if (!navigator.deviceMemory) { | |
| console.warn("Device memory information is not available."); | |
| return true; // Assume it's okay if we can't check | |
| } | |
| const ramGB = navigator.deviceMemory; | |
| if (ramGB < 3) { | |
| throw new Error(`Insufficient RAM. Required: 2GB Free RAM, Available: ${ramGB}GB`); | |
| } | |
| return true; | |
| } | |
| async function initializeWebLLMEngine() { | |
| if (isInitializing) return; | |
| isInitializing = true; | |
| const progressContainer = document.getElementById("progress-container"); | |
| const statusElement = document.getElementById("download-status"); | |
| try { | |
| // Check system requirements | |
| await checkRAM(); | |
| await checkWebGPUSupport(); | |
| progressContainer.classList.remove("hidden"); | |
| statusElement.classList.remove("hidden"); | |
| selectedModel = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC"; // Using the default model | |
| const config = { | |
| temperature: 1.0, | |
| top_p: 1, | |
| }; | |
| const engine = await createEngine(); | |
| await engine.reload(selectedModel, config); | |
| statusElement.textContent = "Model initialized successfully!"; | |
| } catch (error) { | |
| console.error("Error initializing WebLLM engine:", error); | |
| statusElement.textContent = `Error initializing: ${error.message}\n\nFor more information and troubleshooting, please visit kanha.ai/faq`; | |
| statusElement.classList.remove("hidden"); | |
| throw error; // Re-throw the error to be caught in onMessageSend | |
| } finally { | |
| progressContainer.classList.add("hidden"); | |
| isInitializing = false; | |
| } | |
| } | |
| async function streamingGenerating(messages, onUpdate, onFinish, onError) { | |
| try { | |
| let curMessage = ""; | |
| let usage; | |
| const engine = await createEngine(); | |
| const completion = await engine.chat.completions.create({ | |
| stream: true, | |
| messages, | |
| stream_options: { include_usage: true }, | |
| }); | |
| for await (const chunk of completion) { | |
| const curDelta = chunk.choices[0]?.delta.content; | |
| if (curDelta) { | |
| curMessage += curDelta; | |
| } | |
| if (chunk.usage) { | |
| usage = chunk.usage; | |
| } | |
| onUpdate(curMessage); | |
| } | |
| const finalMessage = await engine.getMessage(); | |
| onFinish(finalMessage, usage); | |
| } catch (err) { | |
| onError(err); | |
| } | |
| } | |
| async function onMessageSend() { | |
| const input = document.getElementById("user-input"); | |
| const sendButton = document.getElementById("send"); | |
| const message = { | |
| content: input.value.trim(), | |
| role: "user", | |
| }; | |
| if (message.content.length === 0) { | |
| return; | |
| } | |
| sendButton.disabled = true; | |
| sendButton.innerHTML = '<i class="fas fa-spinner fa-spin"></i>'; | |
| messages.push(message); | |
| appendMessage(message); | |
| input.value = ""; | |
| input.setAttribute("placeholder", "AI is thinking..."); | |
| const aiMessage = { | |
| content: "typing...", | |
| role: "assistant", | |
| }; | |
| appendMessage(aiMessage); | |
| try { | |
| if (!engine) { | |
| await initializeWebLLMEngine(); | |
| } | |
| const onFinishGenerating = (finalMessage, usage) => { | |
| updateLastMessage(finalMessage); | |
| sendButton.disabled = false; | |
| sendButton.innerHTML = '<i class="fas fa-paper-plane"></i>'; | |
| input.setAttribute("placeholder", "Type your message here..."); | |
| if (usage) { | |
| const usageText = | |
| `Prompt tokens: ${usage.prompt_tokens}, ` + | |
| `Completion tokens: ${usage.completion_tokens}, ` + | |
| `Prefill: ${usage.extra.prefill_tokens_per_s.toFixed(2)} tokens/sec, ` + | |
| `Decoding: ${usage.extra.decode_tokens_per_s.toFixed(2)} tokens/sec`; | |
| document.getElementById("chat-stats").classList.remove("hidden"); | |
| document.getElementById("chat-stats").textContent = usageText; | |
| } | |
| }; | |
| await streamingGenerating( | |
| messages, | |
| updateLastMessage, | |
| onFinishGenerating, | |
| onError | |
| ); | |
| } catch (error) { | |
| onError(error); | |
| // Update the AI message to show the error | |
| updateLastMessage("I'm sorry, but I encountered an error: " + error.message); | |
| } | |
| } | |
| function appendMessage(message) { | |
| const chatBox = document.getElementById("chat-box"); | |
| const messageElement = document.createElement("div"); | |
| messageElement.classList.add("message"); | |
| if (message.role === "user") { | |
| messageElement.classList.add("user-message"); | |
| messageElement.textContent = message.content; | |
| } else { | |
| messageElement.classList.add("assistant-message"); | |
| if (message.content === "typing...") { | |
| messageElement.classList.add("typing"); | |
| messageElement.textContent = message.content; | |
| } else { | |
| messageElement.innerHTML = marked.parse(message.content); | |
| } | |
| } | |
| chatBox.appendChild(messageElement); | |
| chatBox.scrollTop = chatBox.scrollHeight; | |
| } | |
| function updateLastMessage(content) { | |
| const chatBox = document.getElementById("chat-box"); | |
| const messages = chatBox.getElementsByClassName("message"); | |
| const lastMessage = messages[messages.length - 1]; | |
| lastMessage.innerHTML = marked.parse(content); | |
| lastMessage.classList.remove("typing"); | |
| } | |
| function onError(err) { | |
| console.error(err); | |
| const statusElement = document.getElementById("download-status"); | |
| statusElement.textContent = `Error: ${err.message}\n\nFor more information and troubleshooting, please visit kanha.ai/faq`; | |
| statusElement.classList.remove("hidden"); | |
| document.getElementById("send").disabled = false; | |
| document.getElementById("send").innerHTML = '<i class="fas fa-paper-plane"></i>'; | |
| } | |
| // UI binding | |
| document.getElementById("send").addEventListener("click", onMessageSend); | |
| document.getElementById("user-input").addEventListener("keypress", function(event) { | |
| if (event.key === "Enter") { | |
| event.preventDefault(); | |
| onMessageSend(); | |
| } | |
| }); |