| import React, { createContext, useContext, useMemo, useState } from 'react'; | |
| import type { JobConfig, JobResult } from '../types'; | |
| import { MLBiasAPI } from '../services/api'; | |
| type PipelinePlots = { | |
| original_sentiment: string; | |
| counterfactual_sentiment: string; | |
| }; | |
| type PipelineResultsDTO = { | |
| generation_file: string; | |
| sentiment_subset_file: string; | |
| cf_sentiment_subset_file: string; | |
| metrics: { | |
| finalMeanDiff: number; | |
| cfFinalMeanDiff: number; | |
| reductionPct?: number; | |
| stableCoverage?: number; | |
| }; | |
| plots: PipelinePlots; | |
| finetuned_model_zip?: string; | |
| finetuned_model_dir?: string; | |
| run_config_files?: { | |
| json?: string; | |
| markdown?: string; | |
| }; | |
| }; | |
| type PipelineResponseDTO = { | |
| status: 'success' | 'error'; | |
| message: string; | |
| timestamp: string; | |
| results: PipelineResultsDTO; | |
| }; | |
| type Extras = { | |
| datasetLimit: number | |
| }; | |
| type Ctx = { | |
| result: JobResult | null; | |
| resp?: PipelineResponseDTO; | |
| loading: boolean; | |
| error?: string; | |
| start: (cfg: JobConfig, extras: Extras) => Promise<void>; | |
| url: (p?: string) => string; | |
| }; | |
| const JobRunnerContext = createContext<Ctx | undefined>(undefined); | |
| export function JobRunnerProvider({ children }: { children: React.ReactNode }) { | |
| const [result, setResult] = useState<JobResult | null>(null); | |
| const [resp, setResp] = useState<PipelineResponseDTO | undefined>(); | |
| const [loading, setLoading] = useState(false); | |
| const [error, setErr] = useState<string | undefined>(); | |
| const start: Ctx['start'] = async (cfg, extras) => { | |
| setLoading(true); | |
| setErr(undefined); | |
| setResp(undefined); | |
| const now = new Date().toISOString(); | |
| setResult({ | |
| id: crypto.randomUUID(), | |
| status: 'running', | |
| progress: 0, | |
| config: cfg, | |
| createdAt: now, | |
| updatedAt: now, | |
| }); | |
| try { | |
| const cfgToSend = { | |
| ...cfg, | |
| datasetLimit: extras.datasetLimit | |
| } as unknown as JobConfig; | |
| const r = await MLBiasAPI.runPipeline(cfgToSend as any); | |
| setResp(r); | |
| const done = new Date().toISOString(); | |
| setResult((prev) => ({ | |
| ...(prev as JobResult), | |
| status: 'completed', | |
| progress: 100, | |
| updatedAt: done, | |
| completedAt: done, | |
| metrics: { | |
| finalMeanDiff: r.results.metrics.finalMeanDiff, | |
| reductionPct: r.results.metrics.reductionPct ?? 0, | |
| stableCoverage: r.results.metrics.stableCoverage ?? 100, | |
| }, | |
| })); | |
| } catch (e: any) { | |
| setErr(e.message || String(e)); | |
| setResult((prev) => | |
| prev | |
| ? { ...prev, status: 'failed', progress: 100, updatedAt: new Date().toISOString() } | |
| : prev | |
| ); | |
| } finally { | |
| setLoading(false); | |
| } | |
| }; | |
| const url = MLBiasAPI.resolvePath; | |
| const value = useMemo<Ctx>( | |
| () => ({ result, resp, loading, error, start, url }), | |
| [result, resp, loading, error] | |
| ); | |
| return <JobRunnerContext.Provider value={value}>{children}</JobRunnerContext.Provider>; | |
| } | |
| export function useJobRunner() { | |
| const ctx = useContext(JobRunnerContext); | |
| if (!ctx) throw new Error('useJobRunner must be used within JobRunnerProvider'); | |
| return ctx; | |
| } | |