Spaces:
Runtime error
Runtime error
| # OminiControl Training 🛠️ | |
| ## Preparation | |
| ### Setup | |
| 1. **Environment** | |
| ```bash | |
| conda create -n omini python=3.10 | |
| conda activate omini | |
| ``` | |
| 2. **Requirements** | |
| ```bash | |
| pip install -r train/requirements.txt | |
| ``` | |
| ### Dataset | |
| 1. Download dataset [Subject200K](https://huggingface.co/datasets/Yuanshi/Subjects200K). (**subject-driven generation**) | |
| ``` | |
| bash train/script/data_download/data_download1.sh | |
| ``` | |
| 2. Download dataset [text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M). (**spatial control task**) | |
| ``` | |
| bash train/script/data_download/data_download2.sh | |
| ``` | |
| **Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly. | |
| ## Training | |
| ### Start training training | |
| **Config file path**: `./train/config` | |
| **Scripts path**: `./train/script` | |
| 1. Subject-driven generation | |
| ```bash | |
| bash train/script/train_subject.sh | |
| ``` | |
| 2. Spatial control task | |
| ```bash | |
| bash train/script/train_canny.sh | |
| ``` | |
| **Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files. | |
| ### Other spatial control tasks | |
| This repository supports 5 spatial control tasks: | |
| 1. Canny edge to image (`canny`) | |
| 2. Image colorization (`coloring`) | |
| 3. Image deblurring (`deblurring`) | |
| 4. Depth map to image (`depth`) | |
| 5. Image to depth map (`depth_pred`) | |
| 6. Image inpainting (`fill`) | |
| 7. Super resolution (`sr`) | |
| You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks. | |
| ### Customize your own task | |
| You can customize your own task by constructing a new dataset and modifying the training code. | |
| <details> | |
| <summary>Instructions</summary> | |
| 1. **Dataset** : | |
| Construct a new dataset with the following format: (`src/train/data.py`) | |
| ```python | |
| class MyDataset(Dataset): | |
| def __init__(self, ...): | |
| ... | |
| def __len__(self): | |
| ... | |
| def __getitem__(self, idx): | |
| ... | |
| return { | |
| "image": image, | |
| "condition": condition_img, | |
| "condition_type": "your_condition_type", | |
| "description": description, | |
| "position_delta": position_delta | |
| } | |
| ``` | |
| **Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`. | |
| 2. **Condition**: | |
| Add a new condition type in the `Condition` class. (`src/flux/condition.py`) | |
| ```python | |
| condition_dict = { | |
| ... | |
| "your_condition_type": your_condition_id_number, # Add your condition type here | |
| } | |
| ... | |
| if condition_type in [ | |
| ... | |
| "your_condition_type", # Add your condition type here | |
| ]: | |
| ... | |
| ``` | |
| 3. **Test**: | |
| Add a new test function for your task. (`src/train/callbacks.py`) | |
| ```python | |
| if self.condition_type == "your_condition_type": | |
| condition_img = ( | |
| Image.open("images/vase.jpg") | |
| .resize((condition_size, condition_size)) | |
| .convert("RGB") | |
| ) | |
| ... | |
| test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) | |
| ``` | |
| 4. **Import relevant dataset in the training script** | |
| Update the file in the following section. (`src/train/train.py`) | |
| ```python | |
| from .data import ( | |
| ImageConditionDataset, | |
| Subject200KDateset, | |
| MyDataset | |
| ) | |
| ... | |
| # Initialize dataset and dataloader | |
| if training_config["dataset"]["type"] == "your_condition_type": | |
| ... | |
| ``` | |
| </details> | |
| ## Hardware requirement | |
| **Note**: Memory optimization (like dynamic T5 model loading) is pending implementation. | |
| **Recommanded** | |
| - Hardware: 2x NVIDIA H100 GPUs | |
| - Memory: ~80GB GPU memory | |
| **Minimal** | |
| - Hardware: 1x NVIDIA L20 GPU | |
| - Memory: ~48GB GPU memory |