Joseph Catrambone
				
			
		First import.  Add ControlNetSD21 Laion Face (full, pruned, and safetensors).  Add README and samples.  Add surrounding tools for example use.
		568dc2c
		
		| import json | |
| import os | |
| import sys | |
| from dataclasses import dataclass, field | |
| from glob import glob | |
| from typing import Mapping | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from laion_face_common import generate_annotation | |
| class RunProgress: | |
| pending: list = field(default_factory=list) | |
| success: list = field(default_factory=list) | |
| skipped_size: list = field(default_factory=list) | |
| skipped_nsfw: list = field(default_factory=list) | |
| skipped_noface: list = field(default_factory=list) | |
| skipped_smallface: list = field(default_factory=list) | |
| def main( | |
| status_filename: str, | |
| prompt_filename: str, | |
| input_glob: str, | |
| output_directory: str, | |
| annotated_output_directory: str = "", | |
| min_image_size: int = 384, | |
| max_image_size: int = 32766, | |
| min_face_size_pixels: int = 64, | |
| prompt_mapping: dict = None, # If present, maps a filename to a text prompt. | |
| ): | |
| status = RunProgress() | |
| if os.path.exists(status_filename): | |
| print("Continuing from checkpoint.") | |
| # Restore a saved state: | |
| status_temp = json.load(open(status_filename, 'rt')) | |
| for k in status.__dict__.keys(): | |
| status.__setattr__(k, status_temp[k]) | |
| # Output label file: | |
| pout = open(prompt_filename, 'at') | |
| else: | |
| print("Starting run.") | |
| status = RunProgress() | |
| status.pending = list(glob(input_glob)) | |
| # Output label file: | |
| pout = open(prompt_filename, 'wt') | |
| with open(status_filename, 'wt') as fout: | |
| json.dump(status.__dict__, fout) | |
| print(f"{len(status.pending)} images remaining") | |
| # If we don't have a preexisting set of labels (like for ImageNet/MSCOCO), just null-fill the mapping. | |
| # We will try on a per-image basis to see if there's a metadata .json. | |
| if prompt_mapping is None: | |
| prompt_mapping = dict() | |
| step = 0 | |
| with tqdm(total=len(status.pending)) as pbar: | |
| while len(status.pending) > 0: | |
| full_filename = status.pending.pop() | |
| pbar.update(1) | |
| step += 1 | |
| if step % 100 == 0: | |
| # Checkpoint save: | |
| with open(status_filename, 'wt') as fout: | |
| json.dump(status.__dict__, fout) | |
| _fpath, fname = os.path.split(full_filename) | |
| # Make our output filenames. | |
| # We used to do this here so we could check if a file existed before writing, then skip it, but since we | |
| # have a 'status' that we cache and update, we no longer have to do this check. | |
| annotation_filename = "" | |
| if annotated_output_directory: | |
| annotation_filename = os.path.join(annotated_output_directory, fname) | |
| output_filename = os.path.join(output_directory, fname) | |
| # The LAION dataset has accompanying .json files with each image. | |
| partial_filename, extension = os.path.splitext(full_filename) | |
| candidate_json_fullpath = partial_filename + ".json" | |
| image_metadata = {} | |
| if os.path.exists(candidate_json_fullpath): | |
| try: | |
| image_metadata = json.load(open(candidate_json_fullpath, 'rt')) | |
| except Exception as e: | |
| print(e) | |
| if "NSFW" in image_metadata: | |
| nsfw_marker = image_metadata.get("NSFW") # This can be "", None, or other weird things. | |
| if nsfw_marker is not None and nsfw_marker.lower() != "unlikely": | |
| # Skip NSFW images. | |
| status.skipped_nsfw.append(full_filename) | |
| continue | |
| # Try to get a prompt/caption from the metadata or the prompt mapping. | |
| image_prompt = image_metadata.get("caption", prompt_mapping.get(fname, "")) | |
| # Load image: | |
| img = Image.open(full_filename).convert("RGB") | |
| img_width = img.size[0] | |
| img_height = img.size[1] | |
| img_size = min(img.size[0], img.size[1]) | |
| if img_size < min_image_size or max(img_width, img_height) > max_image_size: | |
| status.skipped_size.append(full_filename) | |
| continue | |
| # We re-initialize the detector every time because it has a habit of triggering weird race conditions. | |
| empty, annotated, faces_before_filtering, faces_after_filtering = generate_annotation( | |
| img, | |
| max_faces=5, | |
| min_face_size_pixels=min_face_size_pixels, | |
| return_annotation_data=True | |
| ) | |
| if faces_before_filtering == 0: | |
| # Skip images with no faces. | |
| status.skipped_noface.append(full_filename) | |
| continue | |
| if faces_after_filtering == 0: | |
| # Skip images with no faces large enough | |
| status.skipped_smallface.append(full_filename) | |
| continue | |
| Image.fromarray(empty).save(output_filename) | |
| if annotation_filename: | |
| Image.fromarray(annotated).save(annotation_filename) | |
| # See https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md for the training file format. | |
| # prompt.json | |
| # a JSONL file with {"source": "source/0.jpg", "target": "target/0.jpg", "prompt": "..."}. | |
| # a source/xxxxx.jpg or source/xxxx.png file for each of the inputs. | |
| # a target/xxxxx.jpg for each of the outputs. | |
| pout.write(json.dumps({ | |
| "source": os.path.join(output_directory, fname), | |
| "target": full_filename, | |
| "prompt": image_prompt, | |
| }) + "\n") | |
| pout.flush() | |
| status.success.append(full_filename) | |
| # We do save every 100 iterations, but it's good to save on completion, too. | |
| with open(status_filename, 'wt') as fout: | |
| json.dump(status.__dict__, fout) | |
| pout.close() | |
| print("Done!") | |
| print(f"{len(status.success)} images added to dataset.") | |
| print(f"{len(status.skipped_size)} images rejected for size.") | |
| print(f"{len(status.skipped_smallface)} images rejected for having faces too small.") | |
| print(f"{len(status.skipped_noface)} images rejected for not having faces.") | |
| print(f"{len(status.skipped_nsfw)} images rejected for NSFW.") | |
| if __name__ == "__main__": | |
| if len(sys.argv) >= 3 and "-h" not in sys.argv: | |
| prompt_jsonl = sys.argv[1] | |
| in_glob = sys.argv[2] # Should probably be in a directory called "target/*.jpg". | |
| output_dir = sys.argv[3] # Should probably be a directory called "source". | |
| annotation_dir = "" | |
| if len(sys.argv) > 4: | |
| annotation_dir = sys.argv[4] | |
| main("generate_face_poses_checkpoint.json", prompt_jsonl, in_glob, output_dir, annotation_dir) | |
| else: | |
| print(f"""Usage: | |
| python {sys.argv[0]} prompt.jsonl target/*.jpg source/ [annotated/] | |
| source and target are slightly confusing in this context. We are writing the image names to prompt.jsonl, so | |
| the naming system has to be consistent with what ControlNet expects. In ControlNet, the source is the input and | |
| target is the output. We are generating source images from targets in this application, so the second argument | |
| should be a folder full of images. The third argument should be 'source', where the images should be places. | |
| Optionally, an 'annotated' directory can be provided. Augmented images will be placed here. | |
| A checkpoint file named 'generate_face_poses_checkpoint.json' will be created in the place where the script is | |
| run. If a run is cancelled, it can be resumed from this checkpoint. | |
| If invoking the script from bash, do not forget to enclose globs with quotes. Example usage: | |
| `python ./tool_generate_face_poses.py ./face_prompt.jsonl "/home/josephcatrambone/training_data/data-mscoco/images/train2017/*" /home/josephcatrambone/training_data/data-mscoco/images/source_2017/` | |
| """) | |