Spaces:
Runtime error
Runtime error
| 'use client'; | |
| import { useMemo } from 'react'; | |
| import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options'; | |
| import { defaultDatasetConfig } from './jobConfig'; | |
| import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; | |
| import { objectCopy } from '@/utils/basic'; | |
| import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs'; | |
| import Card from '@/components/Card'; | |
| import { X } from 'lucide-react'; | |
| import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; | |
| import {FlipHorizontal2, FlipVertical2} from "lucide-react" | |
| type Props = { | |
| jobConfig: JobConfig; | |
| setJobConfig: (value: any, key: string) => void; | |
| status: 'idle' | 'saving' | 'success' | 'error'; | |
| handleSubmit: (event: React.FormEvent<HTMLFormElement>) => void; | |
| runId: string | null; | |
| gpuIDs: string | null; | |
| setGpuIDs: (value: string | null) => void; | |
| gpuList: any; | |
| datasetOptions: any; | |
| }; | |
| const isDev = process.env.NODE_ENV === 'development'; | |
| export default function SimpleJob({ | |
| jobConfig, | |
| setJobConfig, | |
| handleSubmit, | |
| status, | |
| runId, | |
| gpuIDs, | |
| setGpuIDs, | |
| gpuList, | |
| datasetOptions, | |
| }: Props) { | |
| const modelArch = useMemo(() => { | |
| return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; | |
| }, [jobConfig.config.process[0].model.arch]); | |
| const isVideoModel = !!(modelArch?.group === 'video'); | |
| const numTopCards = useMemo(() => { | |
| let count = 4; // job settings, model config, target config, save config | |
| if (modelArch?.additionalSections?.includes('model.multistage')) { | |
| count += 1; // add multistage card | |
| } | |
| if (!modelArch?.disableSections?.includes('model.quantize')) { | |
| count += 1; // add quantization card | |
| } | |
| return count; | |
| }, [modelArch]); | |
| let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; | |
| if (numTopCards == 5) { | |
| topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6'; | |
| } | |
| if (numTopCards == 6) { | |
| topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; | |
| } | |
| const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { | |
| const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0; | |
| if (!hasARA) { | |
| return quantizationOptions; | |
| } | |
| let newQuantizationOptions = [ | |
| { | |
| label: 'Standard', | |
| options: [quantizationOptions[0], quantizationOptions[1]], | |
| }, | |
| ]; | |
| // add ARAs if they exist for the model | |
| let ARAs: SelectOption[] = []; | |
| if (modelArch.accuracyRecoveryAdapters) { | |
| for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) { | |
| ARAs.push({ value, label }); | |
| } | |
| } | |
| if (ARAs.length > 0) { | |
| newQuantizationOptions.push({ | |
| label: 'Accuracy Recovery Adapters', | |
| options: ARAs, | |
| }); | |
| } | |
| let additionalQuantizationOptions: SelectOption[] = []; | |
| // add the quantization options if they are not already included | |
| for (let i = 2; i < quantizationOptions.length; i++) { | |
| const option = quantizationOptions[i]; | |
| additionalQuantizationOptions.push(option); | |
| } | |
| if (additionalQuantizationOptions.length > 0) { | |
| newQuantizationOptions.push({ | |
| label: 'Additional Quantization Options', | |
| options: additionalQuantizationOptions, | |
| }); | |
| } | |
| return newQuantizationOptions; | |
| }, [modelArch]); | |
| return ( | |
| <> | |
| <form onSubmit={handleSubmit} className="space-y-8"> | |
| <div className={topBarClass}> | |
| <Card title="Job"> | |
| <TextInput | |
| label="Training Name" | |
| value={jobConfig.config.name} | |
| docKey="config.name" | |
| onChange={value => setJobConfig(value, 'config.name')} | |
| placeholder="Enter training name" | |
| disabled={runId !== null} | |
| required | |
| /> | |
| <SelectInput | |
| label="GPU ID" | |
| value={`${gpuIDs}`} | |
| docKey="gpuids" | |
| onChange={value => setGpuIDs(value)} | |
| options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} | |
| /> | |
| <TextInput | |
| label="Trigger Word" | |
| value={jobConfig.config.process[0].trigger_word || ''} | |
| docKey="config.process[0].trigger_word" | |
| onChange={(value: string | null) => { | |
| if (value?.trim() === '') { | |
| value = null; | |
| } | |
| setJobConfig(value, 'config.process[0].trigger_word'); | |
| }} | |
| placeholder="" | |
| required | |
| /> | |
| </Card> | |
| {/* Model Configuration Section */} | |
| <Card title="Model"> | |
| <SelectInput | |
| label="Model Architecture" | |
| value={jobConfig.config.process[0].model.arch} | |
| onChange={value => { | |
| const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch); | |
| if (!currentArch || currentArch.name === value) { | |
| return; | |
| } | |
| // update the defaults when a model is selected | |
| const newArch = modelArchs.find(model => model.name === value); | |
| // update vram setting | |
| if (!newArch?.additionalSections?.includes('model.low_vram')) { | |
| setJobConfig(false, 'config.process[0].model.low_vram'); | |
| } | |
| // revert defaults from previous model | |
| for (const key in currentArch.defaults) { | |
| setJobConfig(currentArch.defaults[key][1], key); | |
| } | |
| if (newArch?.defaults) { | |
| for (const key in newArch.defaults) { | |
| setJobConfig(newArch.defaults[key][0], key); | |
| } | |
| } | |
| // set new model | |
| setJobConfig(value, 'config.process[0].model.arch'); | |
| // update datasets | |
| const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; | |
| const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false; | |
| const controls = newArch?.controls ?? []; | |
| const datasets = jobConfig.config.process[0].datasets.map(dataset => { | |
| const newDataset = objectCopy(dataset); | |
| newDataset.controls = controls; | |
| if (!hasControlPath) { | |
| newDataset.control_path = null; // reset control path if not applicable | |
| } | |
| if (!hasNumFrames) { | |
| newDataset.num_frames = 1; // reset num_frames if not applicable | |
| } | |
| return newDataset; | |
| }); | |
| setJobConfig(datasets, 'config.process[0].datasets'); | |
| // update samples | |
| const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false; | |
| const samples = jobConfig.config.process[0].sample.samples.map(sample => { | |
| const newSample = objectCopy(sample); | |
| if (!hasSampleCtrlImg) { | |
| delete newSample.ctrl_img; // remove ctrl_img if not applicable | |
| } | |
| return newSample; | |
| }); | |
| setJobConfig(samples, 'config.process[0].sample.samples'); | |
| }} | |
| options={groupedModelOptions} | |
| /> | |
| <TextInput | |
| label="Name or Path" | |
| value={jobConfig.config.process[0].model.name_or_path} | |
| docKey="config.process[0].model.name_or_path" | |
| onChange={(value: string | null) => { | |
| if (value?.trim() === '') { | |
| value = null; | |
| } | |
| setJobConfig(value, 'config.process[0].model.name_or_path'); | |
| }} | |
| placeholder="" | |
| required | |
| /> | |
| {modelArch?.additionalSections?.includes('model.low_vram') && ( | |
| <FormGroup label="Options"> | |
| <Checkbox | |
| label="Low VRAM" | |
| checked={jobConfig.config.process[0].model.low_vram} | |
| onChange={value => setJobConfig(value, 'config.process[0].model.low_vram')} | |
| /> | |
| </FormGroup> | |
| )} | |
| </Card> | |
| {modelArch?.disableSections?.includes('model.quantize') ? null : ( | |
| <Card title="Quantization"> | |
| <SelectInput | |
| label="Transformer" | |
| value={jobConfig.config.process[0].model.quantize ? jobConfig.config.process[0].model.qtype : ''} | |
| onChange={value => { | |
| if (value === '') { | |
| setJobConfig(false, 'config.process[0].model.quantize'); | |
| value = defaultQtype; | |
| } else { | |
| setJobConfig(true, 'config.process[0].model.quantize'); | |
| } | |
| setJobConfig(value, 'config.process[0].model.qtype'); | |
| }} | |
| options={transformerQuantizationOptions} | |
| /> | |
| <SelectInput | |
| label="Text Encoder" | |
| value={jobConfig.config.process[0].model.quantize_te ? jobConfig.config.process[0].model.qtype_te : ''} | |
| onChange={value => { | |
| if (value === '') { | |
| setJobConfig(false, 'config.process[0].model.quantize_te'); | |
| value = defaultQtype; | |
| } else { | |
| setJobConfig(true, 'config.process[0].model.quantize_te'); | |
| } | |
| setJobConfig(value, 'config.process[0].model.qtype_te'); | |
| }} | |
| options={quantizationOptions} | |
| /> | |
| </Card> | |
| )} | |
| {modelArch?.additionalSections?.includes('model.multistage') && ( | |
| <Card title="Multistage"> | |
| <FormGroup label="Stages to Train" docKey={'model.multistage'}> | |
| <Checkbox | |
| label="High Noise" | |
| checked={jobConfig.config.process[0].model.model_kwargs?.train_high_noise || false} | |
| onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')} | |
| /> | |
| <Checkbox | |
| label="Low Noise" | |
| checked={jobConfig.config.process[0].model.model_kwargs?.train_low_noise || false} | |
| onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')} | |
| /> | |
| </FormGroup> | |
| <NumberInput | |
| label="Switch Every" | |
| value={jobConfig.config.process[0].train.switch_boundary_every} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')} | |
| placeholder="eg. 1" | |
| docKey={'train.switch_boundary_every'} | |
| min={1} | |
| required | |
| /> | |
| </Card> | |
| )} | |
| <Card title="Target"> | |
| <SelectInput | |
| label="Target Type" | |
| value={jobConfig.config.process[0].network?.type ?? 'lora'} | |
| onChange={value => setJobConfig(value, 'config.process[0].network.type')} | |
| options={[ | |
| { value: 'lora', label: 'LoRA' }, | |
| { value: 'lokr', label: 'LoKr' }, | |
| ]} | |
| /> | |
| {jobConfig.config.process[0].network?.type == 'lokr' && ( | |
| <SelectInput | |
| label="LoKr Factor" | |
| value={`${jobConfig.config.process[0].network?.lokr_factor ?? -1}`} | |
| onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')} | |
| options={[ | |
| { value: '-1', label: 'Auto' }, | |
| { value: '4', label: '4' }, | |
| { value: '8', label: '8' }, | |
| { value: '16', label: '16' }, | |
| { value: '32', label: '32' }, | |
| ]} | |
| /> | |
| )} | |
| {jobConfig.config.process[0].network?.type == 'lora' && ( | |
| <> | |
| <NumberInput | |
| label="Linear Rank" | |
| value={jobConfig.config.process[0].network.linear} | |
| onChange={value => { | |
| console.log('onChange', value); | |
| setJobConfig(value, 'config.process[0].network.linear'); | |
| setJobConfig(value, 'config.process[0].network.linear_alpha'); | |
| }} | |
| placeholder="eg. 16" | |
| min={0} | |
| max={1024} | |
| required | |
| /> | |
| {modelArch?.disableSections?.includes('network.conv') ? null : ( | |
| <NumberInput | |
| label="Conv Rank" | |
| value={jobConfig.config.process[0].network.conv} | |
| onChange={value => { | |
| console.log('onChange', value); | |
| setJobConfig(value, 'config.process[0].network.conv'); | |
| setJobConfig(value, 'config.process[0].network.conv_alpha'); | |
| }} | |
| placeholder="eg. 16" | |
| min={0} | |
| max={1024} | |
| /> | |
| )} | |
| </> | |
| )} | |
| </Card> | |
| <Card title="Save"> | |
| <SelectInput | |
| label="Data Type" | |
| value={jobConfig.config.process[0].save.dtype} | |
| onChange={value => setJobConfig(value, 'config.process[0].save.dtype')} | |
| options={[ | |
| { value: 'bf16', label: 'BF16' }, | |
| { value: 'fp16', label: 'FP16' }, | |
| { value: 'fp32', label: 'FP32' }, | |
| ]} | |
| /> | |
| <NumberInput | |
| label="Save Every" | |
| value={jobConfig.config.process[0].save.save_every} | |
| onChange={value => setJobConfig(value, 'config.process[0].save.save_every')} | |
| placeholder="eg. 250" | |
| min={1} | |
| required | |
| /> | |
| <NumberInput | |
| label="Max Step Saves to Keep" | |
| value={jobConfig.config.process[0].save.max_step_saves_to_keep} | |
| onChange={value => setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')} | |
| placeholder="eg. 4" | |
| min={1} | |
| required | |
| /> | |
| </Card> | |
| </div> | |
| <div> | |
| <Card title="Training"> | |
| <div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6"> | |
| <div> | |
| <NumberInput | |
| label="Batch Size" | |
| value={jobConfig.config.process[0].train.batch_size} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')} | |
| placeholder="eg. 4" | |
| min={1} | |
| required | |
| /> | |
| <NumberInput | |
| label="Gradient Accumulation" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.gradient_accumulation} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.gradient_accumulation')} | |
| placeholder="eg. 1" | |
| min={1} | |
| required | |
| /> | |
| <NumberInput | |
| label="Steps" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.steps} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.steps')} | |
| placeholder="eg. 2000" | |
| min={1} | |
| required | |
| /> | |
| </div> | |
| <div> | |
| <SelectInput | |
| label="Optimizer" | |
| value={jobConfig.config.process[0].train.optimizer} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')} | |
| options={[ | |
| { value: 'adamw8bit', label: 'AdamW8Bit' }, | |
| { value: 'adafactor', label: 'Adafactor' }, | |
| ]} | |
| /> | |
| <NumberInput | |
| label="Learning Rate" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.lr} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.lr')} | |
| placeholder="eg. 0.0001" | |
| min={0} | |
| required | |
| /> | |
| <NumberInput | |
| label="Weight Decay" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.optimizer_params.weight_decay} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')} | |
| placeholder="eg. 0.0001" | |
| min={0} | |
| required | |
| /> | |
| </div> | |
| <div> | |
| {modelArch?.disableSections?.includes('train.timestep_type') ? null : ( | |
| <SelectInput | |
| label="Timestep Type" | |
| value={jobConfig.config.process[0].train.timestep_type} | |
| disabled={modelArch?.disableSections?.includes('train.timestep_type') || false} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')} | |
| options={[ | |
| { value: 'sigmoid', label: 'Sigmoid' }, | |
| { value: 'linear', label: 'Linear' }, | |
| { value: 'shift', label: 'Shift' }, | |
| { value: 'weighted', label: 'Weighted' }, | |
| ]} | |
| /> | |
| )} | |
| <SelectInput | |
| label="Timestep Bias" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.content_or_style} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.content_or_style')} | |
| options={[ | |
| { value: 'balanced', label: 'Balanced' }, | |
| { value: 'content', label: 'High Noise' }, | |
| { value: 'style', label: 'Low Noise' }, | |
| ]} | |
| /> | |
| <SelectInput | |
| label="Noise Scheduler" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.noise_scheduler} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')} | |
| options={[ | |
| { value: 'flowmatch', label: 'FlowMatch' }, | |
| { value: 'ddpm', label: 'DDPM' }, | |
| ]} | |
| /> | |
| </div> | |
| <div> | |
| <FormGroup label="EMA (Exponential Moving Average)"> | |
| <Checkbox | |
| label="Use EMA" | |
| className="pt-1" | |
| checked={jobConfig.config.process[0].train.ema_config?.use_ema || false} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')} | |
| /> | |
| </FormGroup> | |
| {jobConfig.config.process[0].train.ema_config?.use_ema && ( | |
| <NumberInput | |
| label="EMA Decay" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.ema_config?.ema_decay as number} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')} | |
| placeholder="eg. 0.99" | |
| min={0} | |
| /> | |
| )} | |
| <FormGroup label="Text Encoder Optimizations" className="pt-2"> | |
| <Checkbox | |
| label="Unload TE" | |
| checked={jobConfig.config.process[0].train.unload_text_encoder || false} | |
| docKey={'train.unload_text_encoder'} | |
| onChange={value => { | |
| setJobConfig(value, 'config.process[0].train.unload_text_encoder'); | |
| if (value) { | |
| setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); | |
| } | |
| }} | |
| /> | |
| <Checkbox | |
| label="Cache Text Embeddings" | |
| checked={jobConfig.config.process[0].train.cache_text_embeddings || false} | |
| docKey={'train.cache_text_embeddings'} | |
| onChange={value => { | |
| setJobConfig(value, 'config.process[0].train.cache_text_embeddings'); | |
| if (value) { | |
| setJobConfig(false, 'config.process[0].train.unload_text_encoder'); | |
| } | |
| }} | |
| /> | |
| </FormGroup> | |
| </div> | |
| <div> | |
| <FormGroup label="Regularization"> | |
| <Checkbox | |
| label="Differtial Output Preservation" | |
| className="pt-1" | |
| checked={jobConfig.config.process[0].train.diff_output_preservation || false} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')} | |
| /> | |
| </FormGroup> | |
| {jobConfig.config.process[0].train.diff_output_preservation && ( | |
| <> | |
| <NumberInput | |
| label="DOP Loss Multiplier" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number} | |
| onChange={value => | |
| setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') | |
| } | |
| placeholder="eg. 1.0" | |
| min={0} | |
| /> | |
| <TextInput | |
| label="DOP Preservation Class" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].train.diff_output_preservation_class as string} | |
| onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} | |
| placeholder="eg. woman" | |
| /> | |
| </> | |
| )} | |
| </div> | |
| </div> | |
| </Card> | |
| </div> | |
| <div> | |
| <Card title="Datasets"> | |
| <> | |
| {jobConfig.config.process[0].datasets.map((dataset, i) => ( | |
| <div key={i} className="p-4 rounded-lg bg-gray-800 relative"> | |
| <button | |
| type="button" | |
| onClick={() => | |
| setJobConfig( | |
| jobConfig.config.process[0].datasets.filter((_, index) => index !== i), | |
| 'config.process[0].datasets', | |
| ) | |
| } | |
| className="absolute top-2 right-2 bg-red-800 hover:bg-red-700 rounded-full p-1 text-sm transition-colors" | |
| > | |
| <X /> | |
| </button> | |
| <h2 className="text-lg font-bold mb-4">Dataset {i + 1}</h2> | |
| <div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6"> | |
| <div> | |
| <SelectInput | |
| label="Dataset" | |
| value={dataset.folder_path} | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} | |
| options={datasetOptions} | |
| /> | |
| {modelArch?.additionalSections?.includes('datasets.control_path') && ( | |
| <SelectInput | |
| label="Control Dataset" | |
| docKey="datasets.control_path" | |
| value={dataset.control_path ?? ''} | |
| className="pt-2" | |
| onChange={value => | |
| setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`) | |
| } | |
| options={[{ value: '', label: <> </> }, ...datasetOptions]} | |
| /> | |
| )} | |
| <NumberInput | |
| label="LoRA Weight" | |
| value={dataset.network_weight} | |
| className="pt-2" | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)} | |
| placeholder="eg. 1.0" | |
| /> | |
| </div> | |
| <div> | |
| <TextInput | |
| label="Default Caption" | |
| value={dataset.default_caption} | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)} | |
| placeholder="eg. A photo of a cat" | |
| /> | |
| <NumberInput | |
| label="Caption Dropout Rate" | |
| className="pt-2" | |
| value={dataset.caption_dropout_rate} | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)} | |
| placeholder="eg. 0.05" | |
| min={0} | |
| required | |
| /> | |
| {modelArch?.additionalSections?.includes('datasets.num_frames') && ( | |
| <NumberInput | |
| label="Num Frames" | |
| className="pt-2" | |
| docKey="datasets.num_frames" | |
| value={dataset.num_frames} | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)} | |
| placeholder="eg. 41" | |
| min={1} | |
| required | |
| /> | |
| )} | |
| </div> | |
| <div> | |
| <FormGroup label="Settings" className=""> | |
| <Checkbox | |
| label="Cache Latents" | |
| checked={dataset.cache_latents_to_disk || false} | |
| onChange={value => | |
| setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`) | |
| } | |
| /> | |
| <Checkbox | |
| label="Is Regularization" | |
| checked={dataset.is_reg || false} | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} | |
| /> | |
| {modelArch?.additionalSections?.includes('datasets.do_i2v') && ( | |
| <Checkbox | |
| label="Do I2V" | |
| checked={dataset.do_i2v || false} | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)} | |
| docKey="datasets.do_i2v" | |
| /> | |
| )} | |
| </FormGroup> | |
| <FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2"> | |
| <Checkbox | |
| label={<>Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" /></>} | |
| checked={dataset.flip_x || false} | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)} | |
| /> | |
| <Checkbox | |
| label={<>Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" /></>} | |
| checked={dataset.flip_y || false} | |
| onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)} | |
| /> | |
| </FormGroup> | |
| </div> | |
| <div> | |
| <FormGroup label="Resolutions" className="pt-2"> | |
| <div className="grid grid-cols-2 gap-2"> | |
| {[ | |
| [256, 512, 768], | |
| [1024, 1280, 1536], | |
| ].map(resGroup => ( | |
| <div key={resGroup[0]} className="space-y-2"> | |
| {resGroup.map(res => ( | |
| <Checkbox | |
| key={res} | |
| label={res.toString()} | |
| checked={dataset.resolution.includes(res)} | |
| onChange={value => { | |
| const resolutions = dataset.resolution.includes(res) | |
| ? dataset.resolution.filter(r => r !== res) | |
| : [...dataset.resolution, res]; | |
| setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`); | |
| }} | |
| /> | |
| ))} | |
| </div> | |
| ))} | |
| </div> | |
| </FormGroup> | |
| </div> | |
| </div> | |
| </div> | |
| ))} | |
| <button | |
| type="button" | |
| onClick={() => { | |
| const newDataset = objectCopy(defaultDatasetConfig); | |
| // automaticallt add the controls for a new dataset | |
| const controls = modelArch?.controls ?? []; | |
| newDataset.controls = controls; | |
| setJobConfig([...jobConfig.config.process[0].datasets, newDataset], 'config.process[0].datasets'); | |
| }} | |
| className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors" | |
| > | |
| Add Dataset | |
| </button> | |
| </> | |
| </Card> | |
| </div> | |
| <div> | |
| <Card title="Sample"> | |
| <div | |
| className={ | |
| isVideoModel | |
| ? 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6' | |
| : 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6' | |
| } | |
| > | |
| <div> | |
| <NumberInput | |
| label="Sample Every" | |
| value={jobConfig.config.process[0].sample.sample_every} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.sample_every')} | |
| placeholder="eg. 250" | |
| min={1} | |
| required | |
| /> | |
| <SelectInput | |
| label="Sampler" | |
| className="pt-2" | |
| value={jobConfig.config.process[0].sample.sampler} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')} | |
| options={[ | |
| { value: 'flowmatch', label: 'FlowMatch' }, | |
| { value: 'ddpm', label: 'DDPM' }, | |
| ]} | |
| /> | |
| <NumberInput | |
| label="Guidance Scale" | |
| value={jobConfig.config.process[0].sample.guidance_scale} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')} | |
| placeholder="eg. 1.0" | |
| className="pt-2" | |
| min={0} | |
| required | |
| /> | |
| <NumberInput | |
| label="Sample Steps" | |
| value={jobConfig.config.process[0].sample.sample_steps} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.sample_steps')} | |
| placeholder="eg. 1" | |
| className="pt-2" | |
| min={1} | |
| required | |
| /> | |
| </div> | |
| <div> | |
| <NumberInput | |
| label="Width" | |
| value={jobConfig.config.process[0].sample.width} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.width')} | |
| placeholder="eg. 1024" | |
| min={0} | |
| required | |
| /> | |
| <NumberInput | |
| label="Height" | |
| value={jobConfig.config.process[0].sample.height} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.height')} | |
| placeholder="eg. 1024" | |
| className="pt-2" | |
| min={0} | |
| required | |
| /> | |
| {isVideoModel && ( | |
| <div> | |
| <NumberInput | |
| label="Num Frames" | |
| value={jobConfig.config.process[0].sample.num_frames} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')} | |
| placeholder="eg. 0" | |
| className="pt-2" | |
| min={0} | |
| required | |
| /> | |
| <NumberInput | |
| label="FPS" | |
| value={jobConfig.config.process[0].sample.fps} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.fps')} | |
| placeholder="eg. 0" | |
| className="pt-2" | |
| min={0} | |
| required | |
| /> | |
| </div> | |
| )} | |
| </div> | |
| <div> | |
| <NumberInput | |
| label="Seed" | |
| value={jobConfig.config.process[0].sample.seed} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.seed')} | |
| placeholder="eg. 0" | |
| min={0} | |
| required | |
| /> | |
| <Checkbox | |
| label="Walk Seed" | |
| className="pt-4 pl-2" | |
| checked={jobConfig.config.process[0].sample.walk_seed} | |
| onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')} | |
| /> | |
| </div> | |
| <div> | |
| <FormGroup label="Advanced Sampling" className="pt-2"> | |
| <div> | |
| <Checkbox | |
| label="Skip First Sample" | |
| className="pt-4" | |
| checked={jobConfig.config.process[0].train.skip_first_sample || false} | |
| onChange={value => { | |
| setJobConfig(value, 'config.process[0].train.skip_first_sample'); | |
| // cannot do both, so disable the other | |
| if (value){ | |
| setJobConfig(false, 'config.process[0].train.force_first_sample'); | |
| } | |
| }} | |
| /> | |
| </div> | |
| <div> | |
| <Checkbox | |
| label="Force First Sample" | |
| className="pt-1" | |
| checked={jobConfig.config.process[0].train.force_first_sample || false} | |
| docKey={'train.force_first_sample'} | |
| onChange={value => { | |
| setJobConfig(value, 'config.process[0].train.force_first_sample'); | |
| // cannot do both, so disable the other | |
| if (value){ | |
| setJobConfig(false, 'config.process[0].train.skip_first_sample'); | |
| } | |
| }} | |
| /> | |
| </div> | |
| <div> | |
| <Checkbox | |
| label="Disable Sampling" | |
| className="pt-1" | |
| checked={jobConfig.config.process[0].train.disable_sampling || false} | |
| onChange={value => { | |
| setJobConfig(value, 'config.process[0].train.disable_sampling'); | |
| // cannot do both, so disable the other | |
| if (value){ | |
| setJobConfig(false, 'config.process[0].train.force_first_sample'); | |
| } | |
| }} | |
| /> | |
| </div> | |
| </FormGroup> | |
| </div> | |
| </div> | |
| <FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.samples.length})`} className="pt-2"> | |
| <div></div> | |
| </FormGroup> | |
| {jobConfig.config.process[0].sample.samples.map((sample, i) => ( | |
| <div key={i} className="rounded-lg pl-4 pr-1 mb-4 bg-gray-950"> | |
| <div className="flex items-center space-x-2"> | |
| <div className="flex-1"> | |
| <div className="flex"> | |
| <div className="flex-1"> | |
| <TextInput | |
| label={`Prompt`} | |
| value={sample.prompt} | |
| onChange={value => setJobConfig(value, `config.process[0].sample.samples[${i}].prompt`)} | |
| placeholder="Enter prompt" | |
| required | |
| /> | |
| </div> | |
| {modelArch?.additionalSections?.includes('sample.ctrl_img') && ( | |
| <div | |
| className="h-14 w-14 mt-2 ml-4 border border-gray-500 flex items-center justify-center rounded cursor-pointer hover:bg-gray-700 transition-colors" | |
| style={{ | |
| backgroundImage: sample.ctrl_img | |
| ? `url(${`/api/img/${encodeURIComponent(sample.ctrl_img)}`})` | |
| : 'none', | |
| backgroundSize: 'cover', | |
| backgroundPosition: 'center', | |
| marginBottom: '-1rem', | |
| }} | |
| onClick={() => { | |
| openAddImageModal(imagePath => { | |
| console.log('Selected image path:', imagePath); | |
| if (!imagePath) return; | |
| setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`); | |
| }); | |
| }} | |
| > | |
| {!sample.ctrl_img && ( | |
| <div className="text-gray-400 text-xs text-center font-bold">Add Control Image</div> | |
| )} | |
| </div> | |
| )} | |
| </div> | |
| <div className="pb-4"></div> | |
| </div> | |
| <div> | |
| <button | |
| type="button" | |
| onClick={() => | |
| setJobConfig( | |
| jobConfig.config.process[0].sample.samples.filter((_, index) => index !== i), | |
| 'config.process[0].sample.samples', | |
| ) | |
| } | |
| className="rounded-full p-1 text-sm" | |
| > | |
| <X /> | |
| </button> | |
| </div> | |
| </div> | |
| </div> | |
| ))} | |
| <button | |
| type="button" | |
| onClick={() => | |
| setJobConfig( | |
| [...jobConfig.config.process[0].sample.samples, { prompt: '' }], | |
| 'config.process[0].sample.samples', | |
| ) | |
| } | |
| className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors" | |
| > | |
| Add Prompt | |
| </button> | |
| </Card> | |
| </div> | |
| {status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>} | |
| {status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>} | |
| </form> | |
| <AddSingleImageModal /> | |
| </> | |
| ); | |
| } | |