Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from pathlib import Path | |
| from loguru import logger | |
| from dialoggen.dialoggen_demo import DialogGen | |
| from hydit.config import get_args | |
| from hydit.inference import End2End | |
| def inferencer(): | |
| args = get_args() | |
| models_root_path = Path(args.model_root) | |
| if not models_root_path.exists(): | |
| raise ValueError(f"`models_root` not exists: {models_root_path}") | |
| # Load models | |
| gen = End2End(args, models_root_path) | |
| # Try to enhance prompt | |
| if args.enhance: | |
| logger.info("Loading DialogGen model (for prompt enhancement)...") | |
| enhancer = DialogGen(str(models_root_path / "dialoggen")) | |
| logger.info("DialogGen model loaded.") | |
| else: | |
| enhancer = None | |
| return args, gen, enhancer | |
| if __name__ == "__main__": | |
| args, gen, enhancer = inferencer() | |
| if enhancer: | |
| logger.info("Prompt Enhancement...") | |
| success, enhanced_prompt = enhancer(args.prompt) | |
| if not success: | |
| logger.info("Sorry, the prompt is not compliant, refuse to draw.") | |
| exit() | |
| logger.info(f"Enhanced prompt: {enhanced_prompt}") | |
| else: | |
| enhanced_prompt = None | |
| # Run inference | |
| logger.info("Generating images...") | |
| height, width = args.image_size | |
| results = gen.predict(args.prompt, | |
| height=height, | |
| width=width, | |
| seed=args.seed, | |
| enhanced_prompt=enhanced_prompt, | |
| negative_prompt=args.negative, | |
| infer_steps=args.infer_steps, | |
| guidance_scale=args.cfg_scale, | |
| batch_size=args.batch_size, | |
| src_size_cond=args.size_cond, | |
| ) | |
| images = results['images'] | |
| # Save images | |
| save_dir = Path('results') | |
| save_dir.mkdir(exist_ok=True) | |
| # Find the first available index | |
| all_files = list(save_dir.glob('*.png')) | |
| if all_files: | |
| start = max([int(f.stem) for f in all_files]) + 1 | |
| else: | |
| start = 0 | |
| for idx, pil_img in enumerate(images): | |
| save_path = save_dir / f"{idx + start}.png" | |
| pil_img.save(save_path) | |
| logger.info(f"Save to {save_path}") | |

