|
|
|
|
|
import { useMemo, useRef, useState } from 'react'; |
|
|
import { fetchHFModel, fetchHFDataset } from '../services/hf'; |
|
|
|
|
|
const debounce = (fn: (...args: any[]) => void, ms = 350) => { |
|
|
let t: any; |
|
|
return (...args: any[]) => { |
|
|
clearTimeout(t); |
|
|
t = setTimeout(() => fn(...args), ms); |
|
|
}; |
|
|
}; |
|
|
|
|
|
export function useHFModelValidator() { |
|
|
const cache = useRef<Map<string, any>>(new Map()); |
|
|
const [loading, setLoading] = useState(false); |
|
|
const [result, setResult] = useState<any>(null); |
|
|
|
|
|
const validate = useMemo( |
|
|
() => |
|
|
debounce(async (modelId: string, expected: 'language' | 'scorer') => { |
|
|
if (!modelId?.includes('/')) { |
|
|
setResult(null); |
|
|
return; |
|
|
} |
|
|
if (cache.current.has(modelId)) { |
|
|
setResult(cache.current.get(modelId)); |
|
|
return; |
|
|
} |
|
|
setLoading(true); |
|
|
try { |
|
|
const info = await fetchHFModel(modelId); |
|
|
const id: string = info.id?.toLowerCase() ?? ''; |
|
|
const tags: string[] = info.tags ?? []; |
|
|
let actual: string | undefined = info.pipeline_tag; |
|
|
|
|
|
if (!actual) { |
|
|
if (id.includes('t5') || tags.includes('text2text-generation')) actual = 'text2text-generation'; |
|
|
else if (tags.includes('text-generation')) actual = 'text-generation'; |
|
|
else if (tags.includes('text-classification')) actual = 'text-classification'; |
|
|
} |
|
|
|
|
|
const ok = |
|
|
expected === 'language' |
|
|
? ['text-generation', 'text2text-generation'].includes(actual || '') |
|
|
: actual === 'text-classification'; |
|
|
|
|
|
const payload = ok |
|
|
? { |
|
|
isValid: true, |
|
|
modelInfo: { |
|
|
id: info.id, |
|
|
downloads: info.downloads ?? 0, |
|
|
pipeline_tag: actual, |
|
|
tags, |
|
|
author: info.id?.split('/')?.[0] ?? 'unknown', |
|
|
modelName: info.id?.split('/')?.[1] ?? info.id, |
|
|
}, |
|
|
} |
|
|
: { |
|
|
isValid: false, |
|
|
error: |
|
|
expected === 'language' |
|
|
? `Model task should be text-generation or text2text-generation, but is ${actual || 'Unknown'}」` |
|
|
: `Model task should be text-classification, but is${actual || 'Unknown'}」`, |
|
|
}; |
|
|
|
|
|
cache.current.set(modelId, payload); |
|
|
setResult(payload); |
|
|
} catch (e: any) { |
|
|
setResult({ isValid: false, error: e?.message || 'Error when valiating model' }); |
|
|
} finally { |
|
|
setLoading(false); |
|
|
} |
|
|
}), |
|
|
[] |
|
|
); |
|
|
|
|
|
return { loading, result, validate }; |
|
|
} |
|
|
|
|
|
export function useHFDatasetValidator() { |
|
|
const cache = useRef<Map<string, any>>(new Map()); |
|
|
const [loading, setLoading] = useState(false); |
|
|
const [result, setResult] = useState<any>(null); |
|
|
|
|
|
const validate = useMemo( |
|
|
() => |
|
|
debounce(async (datasetId: string) => { |
|
|
if (!datasetId?.includes('/')) { |
|
|
setResult(null); |
|
|
return; |
|
|
} |
|
|
if (cache.current.has(datasetId)) { |
|
|
setResult(cache.current.get(datasetId)); |
|
|
return; |
|
|
} |
|
|
const valid = /^[a-zA-Z0-9._-]+\/[a-zA-Z0-9._-]+$/.test(datasetId); |
|
|
if (!valid) { |
|
|
setResult({ isValid: false, error: 'Incorrect Dataset ID Format' }); |
|
|
return; |
|
|
} |
|
|
setLoading(true); |
|
|
try { |
|
|
const info = await fetchHFDataset(datasetId); |
|
|
const payload = { |
|
|
isValid: true, |
|
|
datasetInfo: { |
|
|
id: info.id, |
|
|
author: info.id?.split('/')?.[0] ?? 'unknown', |
|
|
datasetName: info.id?.split('/')?.[1] ?? info.id, |
|
|
downloads: info.downloads ?? 0, |
|
|
tags: info.tags ?? [], |
|
|
description: info.description ?? 'No Description', |
|
|
task_categories: info.task_categories ?? [], |
|
|
}, |
|
|
}; |
|
|
cache.current.set(datasetId, payload); |
|
|
setResult(payload); |
|
|
} catch (e: any) { |
|
|
setResult({ isValid: false, error: e?.message || 'An error occurred while validating the dataset' }); |
|
|
} finally { |
|
|
setLoading(false); |
|
|
} |
|
|
}), |
|
|
[] |
|
|
); |
|
|
|
|
|
return { loading, result, validate }; |
|
|
} |
|
|
|