Spaces:
Runtime error
Runtime error
| import datetime | |
| import os | |
| import pathlib | |
| import pytz | |
| from typing import Optional | |
| from ditk import logging | |
| from hbutils.system import TemporaryDirectory | |
| from huggingface_hub import CommitOperationAdd, CommitOperationDelete | |
| from huggingface_hub.utils import RepositoryNotFoundError | |
| from .export import export_workdir, _GITLFS | |
| from .steps import find_steps_in_workdir | |
| from ..infer.draw import _DEFAULT_INFER_MODEL | |
| from ..utils import get_hf_client, get_hf_fs | |
| def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3, | |
| pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, | |
| image_width: int = 512, image_height: int = 768, infer_steps: int = 30, | |
| lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', | |
| model_hash: Optional[str] = None, ds_dir: str = None): | |
| name, _ = find_steps_in_workdir(workdir) | |
| repository = repository or f'AppleHarem/{name}' | |
| logging.info(f'Initializing repository {repository!r} ...') | |
| hf_client = get_hf_client() | |
| hf_fs = get_hf_fs() | |
| if not hf_fs.exists(f'{repository}/.gitattributes'): | |
| hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True) | |
| if not hf_fs.exists(f'{repository}/.gitattributes') or \ | |
| '*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'): | |
| logging.info(f'Preparing for lfs attributes of repository {repository!r}.') | |
| with TemporaryDirectory() as td: | |
| _git_attr_file = os.path.join(td, '.gitattributes') | |
| with open(_git_attr_file, 'w', encoding='utf-8') as f: | |
| print(_GITLFS, file=f) | |
| operations = [ | |
| CommitOperationAdd( | |
| path_in_repo='.gitattributes', | |
| path_or_fileobj=_git_attr_file, | |
| ) | |
| ] | |
| tokyo_tz = pytz.timezone('Asia/Tokyo') | |
| current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z') | |
| commit_message = f'Update {name}\'s .gitattributes, on {current_time}' | |
| logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...') | |
| hf_client.create_commit( | |
| repository, | |
| operations, | |
| commit_message=commit_message, | |
| repo_type='model', | |
| revision=revision, | |
| ) | |
| with TemporaryDirectory() as td: | |
| export_workdir( | |
| workdir, td, n_repeats, pretrained_model, | |
| clip_skip, image_width, image_height, infer_steps, | |
| lora_alpha, sample_method, model_hash, ds_repo=ds_dir, # ds_repo: 本地数据集或远端数据集 | |
| ) | |
| try: | |
| hf_client.repo_info(repo_id=repository, repo_type='dataset') | |
| except RepositoryNotFoundError: | |
| has_dataset_repo = False | |
| else: | |
| has_dataset_repo = True | |
| readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8') | |
| with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f: | |
| print('---', file=f) | |
| print('license: mit', file=f) | |
| if has_dataset_repo: | |
| print('datasets:', file=f) | |
| print(f'- {repository}', file=f) | |
| print('pipeline_tag: text-to-image', file=f) | |
| print('tags:', file=f) | |
| print('- art', file=f) | |
| print('---', file=f) | |
| print('', file=f) | |
| print(readme_text, file=f) | |
| _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')] | |
| _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1]) | |
| pre_exist_files = set() | |
| for i, (file, segments) in enumerate(_exist_ps): | |
| if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]: | |
| continue | |
| if file != '.': | |
| pre_exist_files.add(file) | |
| operations = [] | |
| for directory, _, files in os.walk(td): | |
| for file in files: | |
| filename = os.path.abspath(os.path.join(td, directory, file)) | |
| file_in_repo = os.path.relpath(filename, td) | |
| operations.append(CommitOperationAdd( | |
| path_in_repo=file_in_repo, | |
| path_or_fileobj=filename, | |
| )) | |
| if file_in_repo in pre_exist_files: | |
| pre_exist_files.remove(file_in_repo) | |
| logging.info(f'Useless files: {sorted(pre_exist_files)} ...') | |
| for file in sorted(pre_exist_files): | |
| operations.append(CommitOperationDelete(path_in_repo=file)) | |
| tokyo_tz = pytz.timezone('Asia/Tokyo') | |
| current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z') | |
| commit_message = f'Publish {name}\'s lora, on {current_time}' | |
| logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...') | |
| hf_client.create_commit( | |
| repository, | |
| operations, | |
| commit_message=commit_message, | |
| repo_type='model', | |
| revision=revision, | |
| ) | |