Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as transforms | |
| import options | |
| import test | |
| import importlib | |
| from scipy.interpolate import interp1d, splev, splprep | |
| import cv2 | |
| import subprocess | |
| subprocess.run(["bash", "install_imaginaire.sh"]) | |
| def get_single(sat_img, style_img, x_offset, y_offset): | |
| name = '' | |
| for i in [name for name in os.listdir('demo_img') if 'case' in name]: | |
| style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') | |
| style =np.array(style) | |
| if (style == style_img).all(): | |
| name = i | |
| break | |
| input_dict = {} | |
| trans = transforms.ToTensor() | |
| input_dict['sat'] = trans(sat_img) | |
| input_dict['pano'] = trans(style_img) | |
| input_dict['paths'] = "demo.png" | |
| sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) | |
| input_a = input_dict['pano']*sky | |
| sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) | |
| input_dict['sky_histc'] = sky_histc | |
| input_dict['sky_mask'] = sky | |
| for key in input_dict.keys(): | |
| if isinstance(input_dict[key], torch.Tensor): | |
| input_dict[key] = input_dict[key].unsqueeze(0) | |
| args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", | |
| "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] | |
| opt_cmd = options.parse_arguments(args=args) | |
| opt = options.set(opt_cmd=opt_cmd) | |
| opt.isTrain = False | |
| opt.name = opt.yaml if opt.name is None else opt.name | |
| opt.batch_size = 1 | |
| m = importlib.import_module("model.{}".format(opt.model)) | |
| model = m.Model(opt) | |
| # m.load_dataset(opt) | |
| model.build_networks(opt) | |
| ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') | |
| model.netG.load_state_dict(ckpt['netG']) | |
| model.netG.eval() | |
| model.set_input(input_dict) | |
| model.style_temp = model.sky_histc | |
| opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future | |
| model.forward(opt) | |
| rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) | |
| rgb = np.array(rgb*255, dtype=np.uint8) | |
| return rgb | |
| def get_video(sat_img, style_img, positions): | |
| name = '' | |
| for i in [name for name in os.listdir('demo_img') if 'case' in name]: | |
| style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') | |
| style =np.array(style) | |
| if (style == style_img).all(): | |
| name = i | |
| break | |
| input_dict = {} | |
| trans = transforms.ToTensor() | |
| input_dict['sat'] = trans(sat_img) | |
| input_dict['pano'] = trans(style_img) | |
| input_dict['paths'] = "demo.png" | |
| sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) | |
| input_a = input_dict['pano']*sky | |
| sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) | |
| input_dict['sky_histc'] = sky_histc | |
| input_dict['sky_mask'] = sky | |
| for key in input_dict.keys(): | |
| if isinstance(input_dict[key], torch.Tensor): | |
| input_dict[key] = input_dict[key].unsqueeze(0) | |
| args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", | |
| "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] | |
| opt_cmd = options.parse_arguments(args=args) | |
| opt = options.set(opt_cmd=opt_cmd) | |
| opt.isTrain = False | |
| opt.name = opt.yaml if opt.name is None else opt.name | |
| opt.batch_size = 1 | |
| m = importlib.import_module("model.{}".format(opt.model)) | |
| model = m.Model(opt) | |
| # m.load_dataset(opt) | |
| model.build_networks(opt) | |
| ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') | |
| model.netG.load_state_dict(ckpt['netG']) | |
| model.netG.eval() | |
| model.set_input(input_dict) | |
| model.style_temp = model.sky_histc | |
| unique_lst = list(dict.fromkeys(positions)) | |
| pixels = [] | |
| for x in positions: | |
| if x in unique_lst: | |
| if x not in pixels: | |
| pixels.append(x) | |
| pixels = np.array(pixels) | |
| tck, u = splprep(pixels.T, s=25, per=0) | |
| u_new = np.linspace(u.min(), u.max(), 80) | |
| x_new, y_new = splev(u_new, tck) | |
| smooth_path = np.array([x_new,y_new]).T | |
| rendered_image_list = [] | |
| rendered_depth_list = [] | |
| for i, (x,y) in enumerate(smooth_path): | |
| opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future | |
| print('Rendering at ({}, {})'.format(x,y)) | |
| model.forward(opt) | |
| rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) | |
| rgb = np.array(rgb*255, dtype=np.uint8) | |
| rendered_image_list.append(rgb) | |
| rendered_depth_list.append( | |
| model.out_put.depth[0,0].cpu().detach().numpy() | |
| ) | |
| output_video_path = 'output_video.mp4' | |
| frame_rate = 15 | |
| frame_width = 512 | |
| frame_height = 128 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height)) | |
| for image_np in rendered_image_list: | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
| out.write(image_np) | |
| out.release() | |
| return "output_video.mp4" | |
| def copy_image(image): | |
| return image | |
| def show_image_and_point(image, x, y): | |
| x = int(x*image.shape[1]) | |
| y = image.shape[0]-int(y*image.shape[0]) | |
| mask = np.zeros(image.shape[:2]) | |
| radius = min(image.shape[0], image.shape[1])//60 | |
| for i in range(x-radius-2, x+radius+2): | |
| for j in range(y-radius-2, y+radius+2): | |
| if (i-x)**2+(j-y)**2<=radius**2: | |
| mask[j, i] = 1 | |
| return (image, [(mask, 'render point')]) | |
| def add_select_point(image, evt: gr.SelectData, state1): | |
| if state1 == None: | |
| state1 = [] | |
| x, y = evt.index | |
| state1.append((x, y)) | |
| print(state1) | |
| radius = min(image.shape[0], image.shape[1])//60 | |
| for i in range(x-radius-2, x+radius+2): | |
| for j in range(y-radius-2, y+radius+2): | |
| if (i-x)**2+(j-y)**2<=radius**2: | |
| image[j, i, :] = 0 | |
| return image, state1 | |
| def reset_select_points(image): | |
| return image, [] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Sat2Density Demos") | |
| gr.Markdown("### select/upload the satllite image and select the style image") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True) | |
| img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], | |
| inputs=sat_img, outputs=None, examples_per_page=20) | |
| with gr.Column(): | |
| style_img = gr.Image() | |
| style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], | |
| inputs=style_img, outputs=None, examples_per_page=20) | |
| gr.Markdown("### select a certain point to generate single groundview image") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position") | |
| slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position") | |
| btn_single = gr.Button(label="demo1") | |
| annotation_image = gr.AnnotatedImage() | |
| out_single = gr.Image() | |
| gr.Markdown("### draw a trajectory on the map to generate video") | |
| state_select_points = gr.State() | |
| with gr.Row(): | |
| with gr.Column(): | |
| draw_img = gr.Image(shape=[256, 256], interactive=True) | |
| with gr.Column(): | |
| out_video = gr.Video() | |
| reset_btn =gr.Button(value="Reset") | |
| btn_video = gr.Button(label="demo1") | |
| sat_img.change(copy_image, inputs = sat_img, outputs=draw_img) | |
| draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points]) | |
| sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image) | |
| slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') | |
| slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') | |
| btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single) | |
| reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points]) | |
| btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发 | |
| demo.launch() |