Spaces:
Running
Running
| import { | |
| ChatInterface, | |
| ChatModule, | |
| ChatRestModule, | |
| ChatWorkerClient, | |
| } from "@mlc-ai/web-llm"; | |
| function getElementAndCheck(id: string): HTMLElement { | |
| const element = document.getElementById(id); | |
| if (element == null) { | |
| throw Error("Cannot find element " + id); | |
| } | |
| return element; | |
| } | |
| const appConfig = { | |
| model_list: [ | |
| { | |
| model_url: | |
| "https://huggingface.co/hrishioa/wasm-ANIMA-Phi-Neptune-Mistral-7B-q4f32_1/resolve/main/params/", | |
| local_id: "ANIMA-Phi-Neptune-Mistral-7B-q4f32_1", | |
| }, | |
| ], | |
| model_lib_map: { | |
| "ANIMA-Phi-Neptune-Mistral-7B-q4f32_1": | |
| "https://huggingface.co/hrishioa/wasm-ANIMA-Phi-Neptune-Mistral-7B-q4f32_1/resolve/main/ANIMA-Phi-Neptune-Mistral-7B-q4f32_1-webgpu.wasm", | |
| }, | |
| use_web_worker: true, | |
| }; | |
| class ChatUI { | |
| private uiChat: HTMLElement; | |
| private uiChatInput: HTMLInputElement; | |
| private uiChatInfoLabel: HTMLLabelElement; | |
| private chat: ChatInterface; | |
| private localChat: ChatInterface; | |
| private config = appConfig; | |
| private selectedModel: string; | |
| private chatLoaded = false; | |
| private requestInProgress = false; | |
| // We use a request chain to ensure that | |
| // all requests send to chat are sequentialized | |
| private chatRequestChain: Promise<void> = Promise.resolve(); | |
| constructor(chat: ChatInterface, localChat: ChatInterface) { | |
| // use web worker to run chat generation in background | |
| this.chat = chat; | |
| this.localChat = localChat; | |
| // get the elements | |
| this.uiChat = getElementAndCheck("chatui-chat"); | |
| this.uiChatInput = getElementAndCheck("chatui-input") as HTMLInputElement; | |
| this.uiChatInfoLabel = getElementAndCheck( | |
| "chatui-info-label" | |
| ) as HTMLLabelElement; | |
| // register event handlers | |
| getElementAndCheck("chatui-reset-btn").onclick = () => { | |
| this.onReset(); | |
| }; | |
| getElementAndCheck("chatui-send-btn").onclick = () => { | |
| this.onGenerate(); | |
| }; | |
| // TODO: find other alternative triggers | |
| getElementAndCheck("chatui-input").onkeypress = (event) => { | |
| if (event.keyCode === 13) { | |
| this.onGenerate(); | |
| } | |
| }; | |
| const modelSelector = getElementAndCheck( | |
| "chatui-select" | |
| ) as HTMLSelectElement; | |
| for (let i = 0; i < this.config.model_list.length; ++i) { | |
| const item = this.config.model_list[i]; | |
| const opt = document.createElement("option"); | |
| opt.value = item.local_id; | |
| opt.innerHTML = item.local_id; | |
| opt.selected = i == 0; | |
| modelSelector.appendChild(opt); | |
| } | |
| // Append local server option to the model selector | |
| const localServerOpt = document.createElement("option"); | |
| localServerOpt.value = "Local Server"; | |
| localServerOpt.innerHTML = "Local Server"; | |
| modelSelector.append(localServerOpt); | |
| this.selectedModel = modelSelector.value; | |
| modelSelector.onchange = () => { | |
| this.onSelectChange(modelSelector); | |
| }; | |
| } | |
| /** | |
| * Push a task to the execution queue. | |
| * | |
| * @param task The task to be executed; | |
| */ | |
| private pushTask(task: () => Promise<void>) { | |
| const lastEvent = this.chatRequestChain; | |
| this.chatRequestChain = lastEvent.then(task); | |
| } | |
| // Event handlers | |
| // all event handler pushes the tasks to a queue | |
| // that get executed sequentially | |
| // the tasks previous tasks, which causes them to early stop | |
| // can be interrupted by chat.interruptGenerate | |
| private async onGenerate() { | |
| if (this.requestInProgress) { | |
| return; | |
| } | |
| this.pushTask(async () => { | |
| await this.asyncGenerate(); | |
| }); | |
| } | |
| private async onSelectChange(modelSelector: HTMLSelectElement) { | |
| if (this.requestInProgress) { | |
| // interrupt previous generation if any | |
| this.chat.interruptGenerate(); | |
| } | |
| // try reset after previous requests finishes | |
| this.pushTask(async () => { | |
| await this.chat.resetChat(); | |
| this.resetChatHistory(); | |
| await this.unloadChat(); | |
| this.selectedModel = modelSelector.value; | |
| await this.asyncInitChat(); | |
| }); | |
| } | |
| private async onReset() { | |
| if (this.requestInProgress) { | |
| // interrupt previous generation if any | |
| this.chat.interruptGenerate(); | |
| } | |
| // try reset after previous requests finishes | |
| this.pushTask(async () => { | |
| await this.chat.resetChat(); | |
| this.resetChatHistory(); | |
| }); | |
| } | |
| // Internal helper functions | |
| private appendMessage(kind, text) { | |
| if (kind == "init") { | |
| text = "[System Initalize] " + text; | |
| } | |
| if (this.uiChat === undefined) { | |
| throw Error("cannot find ui chat"); | |
| } | |
| const msg = ` | |
| <div class="msg ${kind}-msg"> | |
| <div class="msg-bubble"> | |
| <div class="msg-text">${text}</div> | |
| </div> | |
| </div> | |
| `; | |
| this.uiChat.insertAdjacentHTML("beforeend", msg); | |
| this.uiChat.scrollTo(0, this.uiChat.scrollHeight); | |
| } | |
| private updateLastMessage(kind, text) { | |
| if (kind == "init") { | |
| text = "[System Initalize] " + text; | |
| } | |
| if (this.uiChat === undefined) { | |
| throw Error("cannot find ui chat"); | |
| } | |
| const matches = this.uiChat.getElementsByClassName(`msg ${kind}-msg`); | |
| if (matches.length == 0) throw Error(`${kind} message do not exist`); | |
| const msg = matches[matches.length - 1]; | |
| const msgText = msg.getElementsByClassName("msg-text"); | |
| if (msgText.length != 1) throw Error("Expect msg-text"); | |
| if (msgText[0].innerHTML == text) return; | |
| const list = text.split("\n").map((t) => { | |
| const item = document.createElement("div"); | |
| item.textContent = t; | |
| return item; | |
| }); | |
| msgText[0].innerHTML = ""; | |
| list.forEach((item) => msgText[0].append(item)); | |
| this.uiChat.scrollTo(0, this.uiChat.scrollHeight); | |
| } | |
| private resetChatHistory() { | |
| const clearTags = ["left", "right", "init", "error"]; | |
| for (const tag of clearTags) { | |
| // need to unpack to list so the iterator don't get affected by mutation | |
| const matches = [...this.uiChat.getElementsByClassName(`msg ${tag}-msg`)]; | |
| for (const item of matches) { | |
| this.uiChat.removeChild(item); | |
| } | |
| } | |
| if (this.uiChatInfoLabel !== undefined) { | |
| this.uiChatInfoLabel.innerHTML = ""; | |
| } | |
| } | |
| private async asyncInitChat() { | |
| if (this.chatLoaded) return; | |
| this.requestInProgress = true; | |
| this.appendMessage("init", ""); | |
| const initProgressCallback = (report) => { | |
| this.updateLastMessage("init", report.text); | |
| }; | |
| this.chat.setInitProgressCallback(initProgressCallback); | |
| try { | |
| if (this.selectedModel != "Local Server") { | |
| await this.chat.reload(this.selectedModel, undefined, this.config); | |
| } | |
| } catch (err) { | |
| this.appendMessage("error", "Init error, " + err.toString()); | |
| console.log(err.stack); | |
| this.unloadChat(); | |
| this.requestInProgress = false; | |
| return; | |
| } | |
| this.requestInProgress = false; | |
| this.chatLoaded = true; | |
| } | |
| private async unloadChat() { | |
| await this.chat.unload(); | |
| this.chatLoaded = false; | |
| } | |
| /** | |
| * Run generate | |
| */ | |
| private async asyncGenerate() { | |
| await this.asyncInitChat(); | |
| this.requestInProgress = true; | |
| const prompt = this.uiChatInput.value; | |
| if (prompt == "") { | |
| this.requestInProgress = false; | |
| return; | |
| } | |
| this.appendMessage("right", prompt); | |
| this.uiChatInput.value = ""; | |
| this.uiChatInput.setAttribute("placeholder", "Generating..."); | |
| this.appendMessage("left", ""); | |
| const callbackUpdateResponse = (step, msg) => { | |
| if (msg.length === 0) return this.chat.interruptGenerate(); | |
| this.updateLastMessage("left", msg); | |
| }; | |
| try { | |
| if (this.selectedModel == "Local Server") { | |
| await this.localChat.generate(prompt, callbackUpdateResponse); | |
| this.uiChatInfoLabel.innerHTML = | |
| await this.localChat.runtimeStatsText(); | |
| } else { | |
| await this.chat.generate(prompt, callbackUpdateResponse); | |
| this.uiChatInfoLabel.innerHTML = await this.chat.runtimeStatsText(); | |
| } | |
| } catch (err) { | |
| this.appendMessage("error", "Generate error, " + err.toString()); | |
| console.log(err.stack); | |
| await this.unloadChat(); | |
| } | |
| this.uiChatInput.setAttribute("placeholder", "Enter your message..."); | |
| this.requestInProgress = false; | |
| } | |
| } | |
| const useWebWorker = appConfig.use_web_worker; | |
| let chat: ChatInterface; | |
| let localChat: ChatInterface; | |
| if (useWebWorker) { | |
| chat = new ChatWorkerClient( | |
| new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }) | |
| ); | |
| localChat = new ChatRestModule(); | |
| } else { | |
| chat = new ChatModule(); | |
| localChat = new ChatRestModule(); | |
| } | |
| new ChatUI(chat, localChat); | |