Spaces:
Runtime error
Runtime error
| import datetime | |
| import glob | |
| import json | |
| import logging | |
| import os.path | |
| import random | |
| import re | |
| import shutil | |
| import zipfile | |
| from contextlib import contextmanager | |
| from textwrap import dedent | |
| from typing import Iterator | |
| import numpy as np | |
| import pandas as pd | |
| from hbutils.string import plural_word | |
| from hbutils.system import TemporaryDirectory | |
| from huggingface_hub import CommitOperationAdd, CommitOperationDelete | |
| from imgutils.data import load_image | |
| from imgutils.metrics import ccip_extract_feature, ccip_batch_differences, ccip_default_threshold | |
| from natsort import natsorted | |
| from sklearn.cluster import OPTICS | |
| from tqdm.auto import tqdm | |
| from waifuc.action import PaddingAlignAction, PersonSplitAction, FaceCountAction, MinSizeFilterAction, \ | |
| NoMonochromeAction, FilterSimilarAction, HeadCountAction, FileOrderAction, TaggingAction, RandomFilenameAction, \ | |
| BackgroundRemovalAction, ModeConvertAction, FileExtAction | |
| from waifuc.action.filter import MinAreaFilterAction | |
| from waifuc.export import SaveExporter, TextualInversionExporter | |
| from waifuc.model import ImageItem | |
| from waifuc.source import VideoSource, BaseDataSource, LocalSource, EmptySource | |
| from ...utils import number_to_tag, get_hf_client, get_hf_fs | |
| class ListFeatImageSource(BaseDataSource): | |
| def __init__(self, image_files, feats): | |
| self.image_files = image_files | |
| self.feats = feats | |
| def _iter(self) -> Iterator[ImageItem]: | |
| for file, feat in zip(self.image_files, self.feats): | |
| yield ImageItem(load_image(file), {'ccip_feature': feat, 'filename': os.path.basename(file)}) | |
| def cluster_from_directory(src_dir, dst_dir, merge_threshold: float = 0.85, clu_min_samples: int = 5, | |
| extract_from_noise: bool = True): | |
| image_files = np.array(natsorted(glob.glob(os.path.join(src_dir, '*.png')))) | |
| logging.info(f'Extracting feature of {plural_word(len(image_files), "images")} ...') | |
| images = [ccip_extract_feature(img) for img in tqdm(image_files, desc='Extract features')] | |
| batch_diff = ccip_batch_differences(images) | |
| batch_same = batch_diff <= ccip_default_threshold() | |
| # clustering | |
| def _metric(x, y): | |
| return batch_diff[int(x), int(y)].item() | |
| logging.info('Clustering ...') | |
| samples = np.arange(len(images)).reshape(-1, 1) | |
| # max_eps, _ = ccip_default_clustering_params(method='optics_best') | |
| clustering = OPTICS(min_samples=clu_min_samples, metric=_metric).fit(samples) | |
| labels = clustering.labels_ | |
| max_clu_id = labels.max().item() | |
| all_label_ids = np.array([-1, *range(0, max_clu_id + 1)]) | |
| logging.info(f'Cluster complete, with {plural_word(max_clu_id, "cluster")}.') | |
| label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0} | |
| logging.info(f'Current label count: {label_cnt}') | |
| if extract_from_noise: | |
| mask_labels = labels.copy() | |
| for nid in tqdm(np.where(labels == -1)[0], desc='Matching for noises'): | |
| avg_dists = np.array([ | |
| batch_diff[nid][labels == i].mean() | |
| for i in range(0, max_clu_id + 1) | |
| ]) | |
| r_sames = np.array([ | |
| batch_same[nid][labels == i].mean() | |
| for i in range(0, max_clu_id + 1) | |
| ]) | |
| best_id = np.argmin(avg_dists) | |
| if r_sames[best_id] >= 0.90: | |
| mask_labels[nid] = best_id | |
| labels = mask_labels | |
| logging.info('Noise extracting complete.') | |
| label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0} | |
| logging.info(f'Current label count: {label_cnt}') | |
| # trying to merge clusters | |
| _exist_ids = set(range(0, max_clu_id + 1)) | |
| while True: | |
| _round_merged = False | |
| for xi in range(0, max_clu_id + 1): | |
| if xi not in _exist_ids: | |
| continue | |
| for yi in range(xi + 1, max_clu_id + 1): | |
| if yi not in _exist_ids: | |
| continue | |
| score = (batch_same[labels == xi][:, labels == yi]).mean() | |
| logging.info(f'Label {xi} and {yi}\'s similarity score: {score}') | |
| if score >= merge_threshold: | |
| labels[labels == yi] = xi | |
| logging.info(f'Merging label {yi} into {xi} ...') | |
| _exist_ids.remove(yi) | |
| _round_merged = True | |
| if not _round_merged: | |
| break | |
| logging.info(f'Merge complete, remained cluster ids: {sorted(_exist_ids)}.') | |
| label_cnt = {i: (labels == i).sum() for i in all_label_ids if (labels == i).sum() > 0} | |
| logging.info(f'Current label count: {label_cnt}') | |
| ids = [] | |
| for i, clu_id in enumerate(tqdm(sorted(_exist_ids))): | |
| total = (labels == clu_id).sum() | |
| logging.info(f'Cluster {clu_id} will be renamed as #{i}, {plural_word(total, "image")} in total.') | |
| os.makedirs(os.path.join(dst_dir, str(i)), exist_ok=True) | |
| for imgfile in image_files[labels == clu_id]: | |
| shutil.copyfile(imgfile, os.path.join(dst_dir, str(i), os.path.basename(imgfile))) | |
| ids.append(i) | |
| n_total = (labels == -1).sum() | |
| if n_total > 0: | |
| logging.info(f'Save noise images, {plural_word(n_total, "image")} in total.') | |
| os.makedirs(os.path.join(dst_dir, str(-1)), exist_ok=True) | |
| for imgfile in image_files[labels == -1]: | |
| shutil.copyfile(imgfile, os.path.join(dst_dir, str(-1), os.path.basename(imgfile))) | |
| ids.append(-1) | |
| return ids | |
| def create_project_by_result(bangumi_name: str, ids, clu_dir, dst_dir, preview_count: int = 8, regsize: int = 1000): | |
| total_image_cnt = 0 | |
| columns = ['#', 'Images', 'Download', *(f'Preview {i}' for i in range(1, preview_count + 1))] | |
| rows = [] | |
| reg_source = EmptySource() | |
| for id_ in ids: | |
| logging.info(f'Packing for #{id_} ...') | |
| person_dir = os.path.join(dst_dir, str(id_)) | |
| new_reg_source = LocalSource(os.path.join(clu_dir, str(id_)), shuffle=True).attach( | |
| MinAreaFilterAction(400) | |
| ) | |
| reg_source = reg_source | new_reg_source | |
| os.makedirs(person_dir, exist_ok=True) | |
| with zipfile.ZipFile(os.path.join(person_dir, 'dataset.zip'), 'w') as zf: | |
| all_person_images = glob.glob(os.path.join(clu_dir, str(id_), '*.png')) | |
| total_image_cnt += len(all_person_images) | |
| for file in all_person_images: | |
| zf.write(file, os.path.basename(file)) | |
| for i, file in enumerate(random.sample(all_person_images, k=min(len(all_person_images), preview_count)), | |
| start=1): | |
| PaddingAlignAction((512, 704))(ImageItem(load_image(file))) \ | |
| .image.save(os.path.join(person_dir, f'preview_{i}.png')) | |
| rel_zip_path = os.path.relpath(os.path.join(person_dir, 'dataset.zip'), dst_dir) | |
| row = [id_ if id_ != -1 else 'noise', len(all_person_images), f'[Download]({rel_zip_path})'] | |
| for i in range(1, preview_count + 1): | |
| if os.path.exists(os.path.join(person_dir, f'preview_{i}.png')): | |
| relpath = os.path.relpath(os.path.join(person_dir, f'preview_{i}.png'), dst_dir) | |
| row.append(f'') | |
| else: | |
| row.append('N/A') | |
| rows.append(row) | |
| with TemporaryDirectory() as td: | |
| logging.info('Creating regular normal dataset ...') | |
| reg_source.attach( | |
| TaggingAction(force=False, character_threshold=1.01), | |
| RandomFilenameAction(), | |
| )[:regsize].export(TextualInversionExporter(td)) | |
| logging.info('Packing regular normal dataset ...') | |
| reg_zip = os.path.join(dst_dir, 'regular', 'normal.zip') | |
| os.makedirs(os.path.dirname(reg_zip), exist_ok=True) | |
| with zipfile.ZipFile(reg_zip, 'w') as zf: | |
| for file in glob.glob(os.path.join(td, '*')): | |
| zf.write(file, os.path.relpath(file, td)) | |
| with TemporaryDirectory() as td_nobg: | |
| logging.info('Creating regular no-background dataset ...') | |
| LocalSource(td).attach( | |
| BackgroundRemovalAction(), | |
| ModeConvertAction('RGB', 'white'), | |
| TaggingAction(force=True, character_threshold=1.01), | |
| FileExtAction('.png'), | |
| ).export(TextualInversionExporter(td_nobg)) | |
| logging.info('Packing regular no-background dataset ...') | |
| reg_nobg_zip = os.path.join(dst_dir, 'regular', 'nobg.zip') | |
| os.makedirs(os.path.dirname(reg_nobg_zip), exist_ok=True) | |
| with zipfile.ZipFile(reg_nobg_zip, 'w') as zf: | |
| for file in glob.glob(os.path.join(td_nobg, '*')): | |
| zf.write(file, os.path.relpath(file, td_nobg)) | |
| logging.info('Packing all images ...') | |
| all_zip = os.path.join(dst_dir, 'all.zip') | |
| with zipfile.ZipFile(all_zip, 'w') as zf: | |
| for file in glob.glob(os.path.join(clu_dir, '*', '*.png')): | |
| zf.write(file, os.path.relpath(file, clu_dir)) | |
| logging.info('Packing raw package ...') | |
| raw_zip = os.path.join(dst_dir, 'raw.zip') | |
| with zipfile.ZipFile(raw_zip, 'w') as zf: | |
| for file in glob.glob(os.path.join(clu_dir, '*', '*.png')): | |
| zf.write(file, os.path.basename(file)) | |
| with open(os.path.join(dst_dir, 'meta.json'), 'w', encoding='utf-8') as f: | |
| json.dump({ | |
| 'name': bangumi_name, | |
| 'ids': ids, | |
| 'total': total_image_cnt, | |
| }, f, indent=4, sort_keys=True, ensure_ascii=False) | |
| with open(os.path.join(dst_dir, 'README.md'), 'w', encoding='utf-8') as f: | |
| print(dedent(f""" | |
| --- | |
| license: mit | |
| tags: | |
| - art | |
| size_categories: | |
| - {number_to_tag(total_image_cnt)} | |
| --- | |
| """).strip(), file=f) | |
| print('', file=f) | |
| c_name = ' '.join(map(str.capitalize, re.split(r'\s+', bangumi_name))) | |
| print(f'# Bangumi Image Base of {c_name}', file=f) | |
| print('', file=f) | |
| print(f'This is the image base of bangumi {bangumi_name}, ' | |
| f'we detected {plural_word(len(ids), "character")}, ' | |
| f'{plural_word(total_image_cnt, "images")} in total. ' | |
| f'The full dataset is [here]({os.path.relpath(all_zip, dst_dir)}).', file=f) | |
| print('', file=f) | |
| print(f'**Please note that these image bases are not guaranteed to be 100% cleaned, ' | |
| f'they may be noisy actual.** If you intend to manually train models using this dataset, ' | |
| f'we recommend performing necessary preprocessing on the downloaded dataset to eliminate ' | |
| f'potential noisy samples (approximately 1% probability).', file=f) | |
| print('', file=f) | |
| print(f'Here is the characters\' preview:', file=f) | |
| print('', file=f) | |
| df = pd.DataFrame(columns=columns, data=rows) | |
| print(df.to_markdown(index=False), file=f) | |
| print('', file=f) | |
| def extract_from_videos(video_or_directory: str, bangumi_name: str, no_extract: bool = False, | |
| min_size: int = 320, merge_threshold: float = 0.85, preview_count: int = 8): | |
| if no_extract: | |
| source = LocalSource(video_or_directory) | |
| else: | |
| if os.path.isfile(video_or_directory): | |
| source = VideoSource(video_or_directory) | |
| elif os.path.isdir(video_or_directory): | |
| source = VideoSource.from_directory(video_or_directory) | |
| else: | |
| raise TypeError(f'Unknown video - {video_or_directory!r}.') | |
| source = source.attach( | |
| NoMonochromeAction(), | |
| PersonSplitAction(keep_original=False, level='n'), | |
| FaceCountAction(1, level='n'), | |
| HeadCountAction(1, level='n'), | |
| MinSizeFilterAction(min_size), | |
| FilterSimilarAction('all'), | |
| FileOrderAction(ext='.png'), | |
| ) | |
| with TemporaryDirectory() as src_dir: | |
| logging.info('Extract figures from videos ...') | |
| source.export(SaveExporter(src_dir, no_meta=True)) | |
| with TemporaryDirectory() as clu_dir: | |
| logging.info(f'Clustering from {src_dir!r} to {clu_dir!r} ...') | |
| ids = cluster_from_directory(src_dir, clu_dir, merge_threshold) | |
| with TemporaryDirectory() as dst_dir: | |
| create_project_by_result(bangumi_name, ids, clu_dir, dst_dir, preview_count) | |
| yield dst_dir | |
| def extract_to_huggingface(video_or_directory: str, bangumi_name: str, | |
| repository: str, revision: str = 'main', no_extract: bool = False, | |
| min_size: int = 320, merge_threshold: float = 0.85, preview_count: int = 8): | |
| logging.info(f'Initializing repository {repository!r} ...') | |
| hf_client = get_hf_client() | |
| hf_fs = get_hf_fs() | |
| if not hf_fs.exists(f'datasets/{repository}/.gitattributes'): | |
| hf_client.create_repo(repo_id=repository, repo_type='dataset', exist_ok=True) | |
| _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')] | |
| _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1]) | |
| pre_exist_files = set() | |
| for i, (file, segments) in enumerate(_exist_ps): | |
| if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]: | |
| continue | |
| if file != '.': | |
| pre_exist_files.add(file) | |
| with extract_from_videos(video_or_directory, bangumi_name, no_extract, | |
| min_size, merge_threshold, preview_count) as dst_dir: | |
| operations = [] | |
| for directory, _, files in os.walk(dst_dir): | |
| for file in files: | |
| filename = os.path.abspath(os.path.join(dst_dir, directory, file)) | |
| file_in_repo = os.path.relpath(filename, dst_dir) | |
| operations.append(CommitOperationAdd( | |
| path_in_repo=file_in_repo, | |
| path_or_fileobj=filename, | |
| )) | |
| if file_in_repo in pre_exist_files: | |
| pre_exist_files.remove(file_in_repo) | |
| logging.info(f'Useless files: {sorted(pre_exist_files)} ...') | |
| for file in sorted(pre_exist_files): | |
| operations.append(CommitOperationDelete(path_in_repo=file)) | |
| current_time = datetime.datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S %Z') | |
| commit_message = f'Publish {bangumi_name}\'s data, on {current_time}' | |
| logging.info(f'Publishing {bangumi_name}\'s data to repository {repository!r} ...') | |
| hf_client.create_commit( | |
| repository, | |
| operations, | |
| commit_message=commit_message, | |
| repo_type='dataset', | |
| revision=revision, | |
| ) | |