|
|
import { useState, useEffect, useRef, useCallback } from "react"; |
|
|
import { |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
TextStreamer, |
|
|
} from "@huggingface/transformers"; |
|
|
import { MODEL_OPTIONS } from "../constants/models"; |
|
|
|
|
|
interface LLMState { |
|
|
isLoading: boolean; |
|
|
isReady: boolean; |
|
|
error: string | null; |
|
|
progress: number; |
|
|
} |
|
|
|
|
|
interface LLMInstance { |
|
|
model: any; |
|
|
tokenizer: any; |
|
|
} |
|
|
|
|
|
let moduleCache: { |
|
|
[modelId: string]: { |
|
|
instance: LLMInstance | null; |
|
|
loadingPromise: Promise<LLMInstance> | null; |
|
|
}; |
|
|
} = {}; |
|
|
|
|
|
export const useLLM = (modelName?: string) => { |
|
|
const [state, setState] = useState<LLMState>({ |
|
|
isLoading: false, |
|
|
isReady: false, |
|
|
error: null, |
|
|
progress: 0, |
|
|
}); |
|
|
|
|
|
const instanceRef = useRef<LLMInstance | null>(null); |
|
|
const loadingPromiseRef = useRef<Promise<LLMInstance> | null>(null); |
|
|
|
|
|
const abortControllerRef = useRef<AbortController | null>(null); |
|
|
const pastKeyValuesRef = useRef<any>(null); |
|
|
|
|
|
const { modelId, dtype } = MODEL_OPTIONS.find((opt) => opt.id === modelName)!; |
|
|
const loadModel = useCallback(async () => { |
|
|
if (!modelId) { |
|
|
throw new Error("Model ID is required"); |
|
|
} |
|
|
|
|
|
if (!moduleCache[modelId]) { |
|
|
moduleCache[modelId] = { |
|
|
instance: null, |
|
|
loadingPromise: null, |
|
|
}; |
|
|
} |
|
|
|
|
|
const cache = moduleCache[modelId]; |
|
|
|
|
|
const existingInstance = instanceRef.current || cache.instance; |
|
|
if (existingInstance) { |
|
|
instanceRef.current = existingInstance; |
|
|
cache.instance = existingInstance; |
|
|
setState((prev) => ({ ...prev, isReady: true, isLoading: false })); |
|
|
return existingInstance; |
|
|
} |
|
|
|
|
|
const existingPromise = loadingPromiseRef.current || cache.loadingPromise; |
|
|
if (existingPromise) { |
|
|
try { |
|
|
const instance = await existingPromise; |
|
|
instanceRef.current = instance; |
|
|
cache.instance = instance; |
|
|
setState((prev) => ({ ...prev, isReady: true, isLoading: false })); |
|
|
return instance; |
|
|
} catch (error) { |
|
|
setState((prev) => ({ |
|
|
...prev, |
|
|
isLoading: false, |
|
|
error: |
|
|
error instanceof Error ? error.message : "Failed to load model", |
|
|
})); |
|
|
throw error; |
|
|
} |
|
|
} |
|
|
|
|
|
setState((prev) => ({ |
|
|
...prev, |
|
|
isLoading: true, |
|
|
error: null, |
|
|
progress: 0, |
|
|
})); |
|
|
|
|
|
abortControllerRef.current = new AbortController(); |
|
|
|
|
|
const loadingPromise = (async () => { |
|
|
try { |
|
|
const progress_callback = (progress: any) => { |
|
|
|
|
|
if ( |
|
|
progress.status === "progress" && |
|
|
progress.file.endsWith(".onnx_data") |
|
|
) { |
|
|
const percentage = Math.round( |
|
|
(progress.loaded / progress.total) * 100, |
|
|
); |
|
|
setState((prev) => ({ ...prev, progress: percentage })); |
|
|
} |
|
|
}; |
|
|
|
|
|
const tokenizer = await AutoTokenizer.from_pretrained(modelId, { |
|
|
progress_callback, |
|
|
}); |
|
|
|
|
|
const model = await AutoModelForCausalLM.from_pretrained(modelId, { |
|
|
dtype, |
|
|
device: "webgpu", |
|
|
progress_callback, |
|
|
}); |
|
|
|
|
|
const instance = { model, tokenizer }; |
|
|
instanceRef.current = instance; |
|
|
cache.instance = instance; |
|
|
loadingPromiseRef.current = null; |
|
|
cache.loadingPromise = null; |
|
|
|
|
|
setState((prev) => ({ |
|
|
...prev, |
|
|
isLoading: false, |
|
|
isReady: true, |
|
|
progress: 100, |
|
|
})); |
|
|
return instance; |
|
|
} catch (error) { |
|
|
loadingPromiseRef.current = null; |
|
|
cache.loadingPromise = null; |
|
|
setState((prev) => ({ |
|
|
...prev, |
|
|
isLoading: false, |
|
|
error: |
|
|
error instanceof Error ? error.message : "Failed to load model", |
|
|
})); |
|
|
throw error; |
|
|
} |
|
|
})(); |
|
|
|
|
|
loadingPromiseRef.current = loadingPromise; |
|
|
cache.loadingPromise = loadingPromise; |
|
|
return loadingPromise; |
|
|
}, [modelId]); |
|
|
|
|
|
const generateResponse = useCallback( |
|
|
async ( |
|
|
messages: Array<{ role: string; content: string }>, |
|
|
tools: Array<any>, |
|
|
onToken?: (token: string) => void, |
|
|
): Promise<string> => { |
|
|
const instance = instanceRef.current; |
|
|
if (!instance) { |
|
|
throw new Error("Model not loaded. Call loadModel() first."); |
|
|
} |
|
|
|
|
|
const { model, tokenizer } = instance; |
|
|
|
|
|
|
|
|
const input = tokenizer.apply_chat_template(messages, { |
|
|
tools, |
|
|
add_generation_prompt: true, |
|
|
return_dict: true, |
|
|
}); |
|
|
|
|
|
const streamer = onToken |
|
|
? new TextStreamer(tokenizer, { |
|
|
skip_prompt: true, |
|
|
skip_special_tokens: false, |
|
|
callback_function: (token: string) => { |
|
|
onToken(token); |
|
|
}, |
|
|
}) |
|
|
: undefined; |
|
|
|
|
|
|
|
|
const { sequences, past_key_values } = await model.generate({ |
|
|
...input, |
|
|
past_key_values: pastKeyValuesRef.current, |
|
|
max_new_tokens: 512, |
|
|
do_sample: false, |
|
|
streamer, |
|
|
return_dict_in_generate: true, |
|
|
}); |
|
|
pastKeyValuesRef.current = past_key_values; |
|
|
|
|
|
|
|
|
const response = tokenizer |
|
|
.batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), { |
|
|
skip_special_tokens: false, |
|
|
})[0] |
|
|
.replace(/<\|end_of_text\|>$/, ""); |
|
|
|
|
|
return response; |
|
|
}, |
|
|
[], |
|
|
); |
|
|
|
|
|
const clearPastKeyValues = useCallback(() => { |
|
|
pastKeyValuesRef.current = null; |
|
|
}, []); |
|
|
|
|
|
const cleanup = useCallback(() => { |
|
|
if (abortControllerRef.current) { |
|
|
abortControllerRef.current.abort(); |
|
|
} |
|
|
}, []); |
|
|
|
|
|
useEffect(() => { |
|
|
return cleanup; |
|
|
}, [cleanup]); |
|
|
|
|
|
useEffect(() => { |
|
|
if (modelId && moduleCache[modelId]) { |
|
|
const existingInstance = |
|
|
instanceRef.current || moduleCache[modelId].instance; |
|
|
if (existingInstance) { |
|
|
instanceRef.current = existingInstance; |
|
|
setState((prev) => ({ ...prev, isReady: true })); |
|
|
} |
|
|
} |
|
|
}, [modelId]); |
|
|
|
|
|
return { |
|
|
...state, |
|
|
loadModel, |
|
|
generateResponse, |
|
|
clearPastKeyValues, |
|
|
cleanup, |
|
|
}; |
|
|
}; |
|
|
|