import gradio as gr from PIL import Image import time import threading import tqdm # -------------------------------model-------------------------- import subprocess import sys def run_model(txt='wqq'): print("start running model") ckpt_path = '/home/zhutiantian/code/One-DM/One-DM-ckpt.pt' dir_path = '/home/zhutiantian/code/One-DM/Generated/English' code_path = '/home/zhutiantian/code/One-DM/test.py' # 构造 Conda 命令 conda_path = "/home/zhutiantian/anaconda3/condabin/conda" command = f"export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH && cd /home/zhutiantian/code/One-DM && {conda_path} run -n torch13 CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 {code_path} --one_dm {ckpt_path} --generate_type oov_u --dir {dir_path} --input_text {txt}" subprocess.run(command, shell=True) # -------------------------------gradio------------------------- # 所有button响应的函数 # 任务处理函数 def process_task(txt, p1, p2, p3, p4, p5): # 这里需要添加对txt的检查! # txt为空、非英文字母时,返回错误,要求重新填写 # p3、p5 非空 # p1: 时间,p2:图片数量,p4:步数 # p3:等于 "DDIM" 或 "DDPM",生成方式 # p5:等于 "iv_s","iv_u","oov_s","oov_u",生成类型 # 启动一个 15 秒的命令模拟 def execute_command(): # 这里用time.sleep(50)来模拟命令执行 # get_wsl.open_wsl_and_run_model() # 单独修改UI的时候把这里换成time.sleep(15) # run_model(txt) time.sleep(5) print("命令执行完毕!") # 启动执行命令的线程 command_thread = threading.Thread(target=execute_command) command_thread.start() time.sleep(5) command_thread.join() # 等待命令执行完成 # img_path = f'/home/zhutiantian/code/One-DM/Generated/English/oov_u/168/{txt}.png' img_path = ["img/bg.png"] return img_path, gr.update(visible=True) def func(txt, p1, p2, p3, p4, progress=gr.Progress()): progress(0, desc="Starting") time.sleep(1) progress(0.3, desc="Progressing") time.sleep(p1) progress(1, desc="Completed") img = Image.open('img/bg.png') time.sleep(2) # 返回显示评价区域 return img, gr.update(visible=True) def funcgal(txt, p1, p2, p3, p4, progress=gr.Progress()): # imggallery = ["img/1.jpg", "img/2.jpg", "img/3.jpg", "img/4.jpg", "img/5.jpg", "img/6.jpg", "img/7.jpg", "img/8.jpg", "img/9.jpg", "img/10.jpg", "img/11.jpg", "img/12.jpg", "img/13.jpg", "img/14.jpg", "img/15.jpg", "img/16.jpg", "img/17.jpg", "img/18.jpg", "img/bg.png", "img/sample.png"] imggallery = ["img/1.jpg"] return imggallery, gr.update(visible=True) # 主函数,调用demo def main(): with gr.Blocks(theme='NoCrypt/miku') as demo: feedback_visible = gr.State(False) def submit_feedback(feedback): # 隐藏评价区域,并返回提示信息 gr.Info("感谢您的评价") return gr.update(visible=False) with gr.Column(): introtext1 = gr.Markdown( # 在此输入描述,使用 Markdown """ ## “一眼临摹”*手写风格迁移* 该项目旨在以手写文本图像为基础,学习其手写字迹风格并生成对应风格的特定文本内容。 该项目以论文 [One-Shot Diffusion Mimicker for Handwritten Text Generation](https://arxiv.org/abs/2409.04004) 为基础。该论文详细介绍了如何通过提取单一参考样本的高频信息来改进样式提取,并在此基础上生成对应的手写文本图像。部分代码修改自该论文的 [源代码](https://github.com/dailenson/One-DM)。 """ ) with gr.Group(visible=False) as feedback_group: feedback = gr.Textbox(label="请输入您的使用评价") submit_btn = gr.Button("提交") with gr.Row(): with gr.Column(): textinfo = gr.Textbox(label="此处输入你要生成的文字") with gr.Row(): # inputimage = gr.Image(label="此处上传你要生成的字迹风格图像") with gr.Column(): para1 = gr.Slider(label="运行 GPU 卡数", minimum = 1, maximum = 4, step = 1) para2 = gr.Slider(label="生成风格数量", minimum = 10, maximum = 150) para3 = gr.Radio(["DDIM", "DDPM"], label = "生成方式", info = "解释") para4 = gr.Slider(label="生成 sample 步数", minimum = 50, maximum = 1000) para5 = gr.Radio(["iv_s", "iv_u", "oov_s", "oov_u"], label="生成类型", info = "解释") with gr.Row(): # genebutton = gr.Button("生成用户字体风格") genegallery = gr.Button("生成示例字体风格") with gr.Column(): # outputimage = gr.Image(label="用户字体风格图片") outputgallery = gr.Gallery(label="示例字体风格图片", columns=5) with gr.Row(): introtext2 = gr.Image(label="演示图片,来自 One-DM", value="img/intro1.jpg") # 在此修改描述图片路径 introtext3 = gr.Image(label="演示图片,来自 One-DM", value="img/intro2.png") with gr.Row(): gr.Markdown(""" 这里写第一段。 """) gr.Markdown(""" 这里写第二段。 """) submit_btn.click( fn=submit_feedback, inputs=[feedback], outputs=[feedback_group] ) # genebutton.click( # fn=process_task, # inputs=[textinfo, para1, para2, para3, para4], # outputs=[outputimage, feedback_group] # ) genegallery.click( fn=process_task, inputs=[textinfo, para1, para2, para3, para4, para5], outputs=[outputgallery, feedback_group] ) # demo.launch(server_name="172.30.180.28", server_port=45632, debug=True, show_error=True) demo.launch() if __name__ == "__main__": main()