mihailik commited on
Commit
6fa125c
·
1 Parent(s): ed0a425

Updating to handle prompt queries better.

Browse files
package.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "name": "localm",
3
- "version": "1.0.8",
4
  "description": "",
5
  "main": "chat-full.js",
6
  "scripts": {
 
1
  {
2
  "name": "localm",
3
+ "version": "1.0.9",
4
  "description": "",
5
  "main": "chat-full.js",
6
  "scripts": {
src/app/worker-connection.js CHANGED
@@ -75,7 +75,7 @@ export function workerConnection() {
75
  async function loadModel(modelName) {
76
  await workerLoaded;
77
  const { send } = await workerLoaded;
78
- return send({ type: 'loadModel', model: modelName });
79
  }
80
 
81
  /**
@@ -85,6 +85,6 @@ export function workerConnection() {
85
  async function runPrompt(promptText, modelName) {
86
  await workerLoaded;
87
  const { send } = await workerLoaded;
88
- return send({ type: 'runPrompt', prompt: promptText, model: modelName });
89
  }
90
  }
 
75
  async function loadModel(modelName) {
76
  await workerLoaded;
77
  const { send } = await workerLoaded;
78
+ return send({ type: 'loadModel', modelName });
79
  }
80
 
81
  /**
 
85
  async function runPrompt(promptText, modelName) {
86
  await workerLoaded;
87
  const { send } = await workerLoaded;
88
+ return send({ type: 'runPrompt', prompt: promptText, modelName });
89
  }
90
  }
src/worker/boot-worker.js CHANGED
@@ -1,8 +1,9 @@
1
  // @ts-check
2
 
3
- import { pipeline } from '@huggingface/transformers';
4
 
5
  export function bootWorker() {
 
6
  // Report starting
7
  try {
8
  self.postMessage({ type: 'status', status: 'initializing' });
@@ -10,160 +11,70 @@ export function bootWorker() {
10
  // ignore if postMessage not available for some reason
11
  }
12
 
13
- (async () => {
14
- // named import `pipeline` is available from the bundled runtime
15
 
16
- // Detect available acceleration backends
17
- let backend = 'wasm';
18
- try {
19
- const hasWebGPU = typeof navigator !== 'undefined' && !!navigator.gpu;
20
- let hasWebGL2 = false;
21
- try {
22
- // In a worker environment prefer OffscreenCanvas to test webgl2
23
- if (typeof OffscreenCanvas !== 'undefined') {
24
- const c = new OffscreenCanvas(1, 1);
25
- const gl = c.getContext('webgl2') || c.getContext('webgl');
26
- hasWebGL2 = !!gl;
27
- } else if (typeof document !== 'undefined') {
28
- const canvas = document.createElement('canvas');
29
- const gl = canvas.getContext('webgl2') || canvas.getContext('webgl');
30
- hasWebGL2 = !!gl;
31
- }
32
- } catch (e) {
33
- hasWebGL2 = false;
34
- }
35
 
36
- if (hasWebGPU) backend = 'webgpu';
37
- else if (hasWebGL2) backend = 'webgl';
38
- } catch (e) {
39
- backend = 'wasm';
40
- }
41
 
42
- self.postMessage({ type: 'status', status: 'backend-detected', backend });
 
43
 
44
- // verify the named import is present
 
45
  try {
46
- if (!pipeline) throw new Error('transformers pipeline import not available');
47
- self.postMessage({ type: 'status', status: 'transformers-loaded', source: '@huggingface/transformers' });
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  } catch (err) {
49
- self.postMessage({ type: 'status', status: 'transformers-load-failed', error: String(err) });
50
  }
 
51
 
52
- // Model cache to avoid loading the same model multiple times.
53
- // value = { promise, pipeline }
54
- const modelCache = new Map();
55
-
56
- const availableModels = [
57
- 'Xenova/phi-3-mini-4k-instruct',
58
- 'Xenova/phi-1.5',
59
- 'Xenova/all-MiniLM-L6-v2'
60
- ];
61
-
62
- // signal ready to main thread (worker script loaded; model runtime may still be pending)
63
- self.postMessage({ type: 'ready' });
64
-
65
- // helper: create or return existing pipeline promise
66
- async function ensureModel(modelName, id) {
67
- if (modelCache.has(modelName)) {
68
- const entry = modelCache.get(modelName);
69
- // If pipeline already resolved, return it, otherwise await the promise
70
- if (entry.pipeline) return entry.pipeline;
71
- return entry.promise;
72
- }
73
-
74
- // create loader promise
75
- const loader = (async () => {
76
- if (!pipeline) {
77
- throw new Error('transformers runtime not available');
78
- }
79
-
80
- // Post progress and status
81
- if (id) self.postMessage({ id, type: 'status', status: 'model-loading', model: modelName });
82
-
83
- // Choose device hint as a literal union. Cast only at the call site if TypeScript
84
- // needs help narrowing.
85
- const deviceOption = backend === 'webgpu' ? 'webgpu' : (backend === 'webgl' ? 'gpu' : 'wasm');
86
-
87
- // Create a text-generation pipeline. Depending on the model this may
88
- // perform downloads of model weights; the library should report progress
89
- // via its own callbacks if available.
90
- const pipe = await pipeline('text-generation', modelName, /** @type {any} */ ({
91
- device: deviceOption,
92
- progress_callback: (progress) => {
93
- if (id) self.postMessage({ id, type: 'model-progress', progress, model: modelName });
94
- }
95
- }));
96
-
97
- // store pipeline for reuse
98
- const entry = modelCache.get(modelName) || {};
99
- entry.pipeline = pipe;
100
- modelCache.set(modelName, entry);
101
-
102
- if (id) self.postMessage({ id, type: 'status', status: 'model-loaded', model: modelName });
103
- return pipe;
104
- })();
105
-
106
- // temporarly store the in-progress promise so concurrent requests reuse it
107
- modelCache.set(modelName, { promise: loader });
108
- return loader;
109
  }
 
 
110
 
111
- // helper to extract generated text from various runtime outputs
112
- function extractText(output) {
113
- // typical shapes: [{ generated_text: '...' }] or [{ text: '...' }] or string
114
- try {
115
- if (!output) return '';
116
- if (typeof output === 'string') return output;
117
- if (Array.isArray(output) && output.length > 0) {
118
- const el = output[0];
119
- if (el.generated_text) return el.generated_text;
120
- if (el.text) return el.text;
121
- // Some runtimes return an array of strings
122
- if (typeof el === 'string') return el;
123
- }
124
- // Fallback: try JSON stringify
125
- return String(output);
126
- } catch (e) {
127
- return '';
128
- }
129
  }
130
-
131
- // handle incoming requests from the UI thread
132
- self.addEventListener('message', async (ev) => {
133
- const msg = ev.data || {};
134
- const id = msg.id;
135
- try {
136
- if (msg.type === 'listModels') {
137
- self.postMessage({ id, type: 'response', result: availableModels });
138
- } else if (msg.type === 'loadModel') {
139
- const modelName = msg.model;
140
- try {
141
- await ensureModel(modelName, id);
142
- self.postMessage({ id, type: 'response', result: { model: modelName, status: 'loaded' } });
143
- } catch (err) {
144
- self.postMessage({ id, type: 'error', error: String(err) });
145
- }
146
- } else if (msg.type === 'runPrompt') {
147
- const prompt = msg.prompt || '';
148
- const modelName = msg.model;
149
- try {
150
- const pipe = await ensureModel(modelName, id);
151
- // run the pipeline
152
- if (!pipe) throw new Error('pipeline not available');
153
- self.postMessage({ id, type: 'status', status: 'inference-start', model: modelName });
154
- const out = await pipe(prompt, msg.options || {});
155
- const text = extractText(out);
156
- self.postMessage({ id, type: 'status', status: 'inference-done', model: modelName });
157
- self.postMessage({ id, type: 'response', result: text });
158
- } catch (err) {
159
- self.postMessage({ id, type: 'error', error: String(err) });
160
- }
161
- } else {
162
- if (id) self.postMessage({ id, type: 'error', error: 'unknown-message-type' });
163
- }
164
- } catch (err) {
165
- if (id) self.postMessage({ id, type: 'error', error: String(err) });
166
- }
167
- });
168
- })();
169
  }
 
1
  // @ts-check
2
 
3
+ import { ModelCache } from './model-cache';
4
 
5
  export function bootWorker() {
6
+ const modelCache = new ModelCache();
7
  // Report starting
8
  try {
9
  self.postMessage({ type: 'status', status: 'initializing' });
 
11
  // ignore if postMessage not available for some reason
12
  }
13
 
14
+ self.postMessage({ type: 'status', status: 'backend-detected', backend: modelCache.backend });
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ // signal ready to main thread (worker script loaded; model runtime may still be pending)
18
+ self.postMessage({ type: 'ready' });
 
 
 
19
 
20
+ // handle incoming requests from the UI thread
21
+ self.addEventListener('message', handleMessage);
22
 
23
+ async function handleMessage({ data }) {
24
+ const { id } = data;
25
  try {
26
+ if (data.type === 'listModels') {
27
+ self.postMessage({ id, type: 'response', result: modelCache.knownModels });
28
+ } else if (data.type === 'loadModel') {
29
+ const { modelName = modelCache.knownModels[0] } = data;
30
+ try {
31
+ const pipe = await modelCache.getModel({ modelName });
32
+ self.postMessage({ id, type: 'response', result: { model: modelName, status: 'loaded' } });
33
+ } catch (err) {
34
+ self.postMessage({ id, type: 'error', error: String(err) });
35
+ }
36
+ } else if (data.type === 'runPrompt') {
37
+ handleRunPrompt(data);
38
+ } else {
39
+ if (id) self.postMessage({ id, type: 'error', error: 'unknown-message-type' });
40
+ }
41
  } catch (err) {
42
+ if (id) self.postMessage({ id, type: 'error', error: String(err) });
43
  }
44
+ }
45
 
46
+ async function handleRunPrompt({ prompt, modelName = modelCache.knownModels[0], id, options }) {
47
+ try {
48
+ const pipe = await modelCache.getModel({ modelName });
49
+ // run the pipeline
50
+ if (!pipe) throw new Error('pipeline not available');
51
+ self.postMessage({ id, type: 'status', status: 'inference-start', model: modelName });
52
+ const out = await pipe(prompt, options || {});
53
+ const text = extractText(out);
54
+ self.postMessage({ id, type: 'status', status: 'inference-done', model: modelName });
55
+ self.postMessage({ id, type: 'response', result: text });
56
+ } catch (err) {
57
+ self.postMessage({ id, type: 'error', error: String(err) });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  }
59
+ }
60
+ }
61
 
62
+ // helper to extract generated text from various runtime outputs
63
+ function extractText(output) {
64
+ // typical shapes: [{ generated_text: '...' }] or [{ text: '...' }] or string
65
+ try {
66
+ if (!output) return '';
67
+ if (typeof output === 'string') return output;
68
+ if (Array.isArray(output) && output.length > 0) {
69
+ const el = output[0];
70
+ if (el.generated_text) return el.generated_text;
71
+ if (el.text) return el.text;
72
+ // Some runtimes return an array of strings
73
+ if (typeof el === 'string') return el;
 
 
 
 
 
 
74
  }
75
+ // Fallback: try JSON stringify
76
+ return String(output);
77
+ } catch (e) {
78
+ return '';
79
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  }
src/worker/load-model-core.js ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // @ts-check
2
+
3
+ import { pipeline } from '@huggingface/transformers';
4
+
5
+ /**
6
+ * @param {{
7
+ * modelName: string,
8
+ * device: import('@huggingface/transformers').DeviceType,
9
+ * onProgress?: import('@huggingface/transformers').ProgressCallback
10
+ * }} _
11
+ */
12
+ export async function loadModelCore({
13
+ modelName,
14
+ device,
15
+ onProgress
16
+ }) {
17
+ // Create a text-generation pipeline. Depending on the model this may
18
+ // perform downloads of model weights; the library should report progress
19
+ // via its own callbacks if available.
20
+ const pipe = await pipeline(
21
+ 'text-generation',
22
+ modelName,{
23
+ device,
24
+ progress_callback: (progress) => {
25
+ if (onProgress) onProgress(progress);
26
+ }
27
+ });
28
+
29
+ return pipe;
30
+ }
src/worker/model-cache.js ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // @ts-check
2
+
3
+ import { pipeline } from '@huggingface/transformers';
4
+ import { loadModelCore } from './load-model-core';
5
+
6
+ export class ModelCache {
7
+ cache = new Map();
8
+ /** @type {import('@huggingface/transformers').DeviceType | undefined} */
9
+ backend = undefined;
10
+
11
+ knownModels = [
12
+ 'Xenova/phi-3-mini-4k-instruct',
13
+ 'Xenova/phi-1.5',
14
+ 'Xenova/all-MiniLM-L6-v2'
15
+ ];
16
+
17
+ /**
18
+ * @param {{
19
+ * modelName: string
20
+ * }} _
21
+ */
22
+ getModel({ modelName }) {
23
+ return this.cache.get(modelName) || this._loadModelAndStore({ modelName });
24
+ }
25
+
26
+ /**
27
+ * @param {{
28
+ * modelName: string
29
+ * }} _
30
+ */
31
+ _loadModelAndStore({ modelName }) {
32
+ if (!this.backend) this.backend = detectTransformersBackend();
33
+ const modelPromise = loadModelCore({
34
+ modelName,
35
+ device: this.backend
36
+ });
37
+ this.cache.set(modelName, modelPromise);
38
+ modelPromise.then(
39
+ model => {
40
+ this.cache.set(modelName, model);
41
+ },
42
+ () => {
43
+ this.cache.delete(modelName);
44
+ });
45
+
46
+ return modelPromise;
47
+ }
48
+
49
+ }
50
+
51
+ export function detectTransformersBackend() {
52
+ /**
53
+ * Detect available acceleration backends
54
+ * @type {import('@huggingface/transformers').DeviceType}
55
+ */
56
+ let backend = 'wasm';
57
+ try {
58
+ const hasWebGPU = typeof navigator !== 'undefined' && !!/** @type {*} */(navigator).gpu;
59
+ let hasWebGL2 = false;
60
+ try {
61
+ // In a worker environment prefer OffscreenCanvas to test webgl2
62
+ if (typeof OffscreenCanvas !== 'undefined') {
63
+ const c = new OffscreenCanvas(1, 1);
64
+ const gl = c.getContext('webgl2') || c.getContext('webgl');
65
+ hasWebGL2 = !!gl;
66
+ } else if (typeof document !== 'undefined') {
67
+ const canvas = document.createElement('canvas');
68
+ const gl = canvas.getContext('webgl2') || canvas.getContext('webgl');
69
+ hasWebGL2 = !!gl;
70
+ }
71
+ } catch (e) {
72
+ hasWebGL2 = false;
73
+ }
74
+
75
+ if (hasWebGPU) backend = 'webgpu';
76
+ else if (hasWebGL2) backend = 'gpu';
77
+ } catch (e) {
78
+ backend = 'wasm';
79
+ }
80
+
81
+ return backend;
82
+ }