Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| from PIL import ImageDraw | |
| def encode_scene(obj_list, H=320, W=320, src_bbox_format='xywh', tgt_bbox_format='xyxy'): | |
| """Encode scene into text and bounding boxes | |
| Args: | |
| obj_list: list of dicts | |
| Each dict has keys: | |
| 'color': str | |
| 'material': str | |
| 'shape': str | |
| or | |
| 'caption': str | |
| and | |
| 'bbox': list of 4 floats (unnormalized) | |
| [x0, y0, x1, y1] or [x0, y0, w, h] | |
| """ | |
| box_captions = [] | |
| for obj in obj_list: | |
| if 'caption' in obj: | |
| box_caption = obj['caption'] | |
| else: | |
| box_caption = f"{obj['color']} {obj['material']} {obj['shape']}" | |
| box_captions += [box_caption] | |
| assert src_bbox_format in ['xywh', 'xyxy'], f"src_bbox_format must be 'xywh' or 'xyxy', not {src_bbox_format}" | |
| assert tgt_bbox_format in ['xywh', 'xyxy'], f"tgt_bbox_format must be 'xywh' or 'xyxy', not {tgt_bbox_format}" | |
| boxes_unnormalized = [] | |
| boxes_normalized = [] | |
| for obj in obj_list: | |
| if src_bbox_format == 'xywh': | |
| x0, y0, w, h = obj['bbox'] | |
| x1 = x0 + w | |
| y1 = y0 + h | |
| elif src_bbox_format == 'xyxy': | |
| x0, y0, x1, y1 = obj['bbox'] | |
| w = x1 - x0 | |
| h = y1 - y0 | |
| assert x1 > x0, f"x1={x1} <= x0={x0}" | |
| assert y1 > y0, f"y1={y1} <= y0={y0}" | |
| assert x1 <= W, f"x1={x1} > W={W}" | |
| assert y1 <= H, f"y1={y1} > H={H}" | |
| if tgt_bbox_format == 'xywh': | |
| bbox_unnormalized = [x0, y0, w, h] | |
| bbox_normalized = [x0 / W, y0 / H, w / W, h / H] | |
| elif tgt_bbox_format == 'xyxy': | |
| bbox_unnormalized = [x0, y0, x1, y1] | |
| bbox_normalized = [x0 / W, y0 / H, x1 / W, y1 / H] | |
| boxes_unnormalized += [bbox_unnormalized] | |
| boxes_normalized += [bbox_normalized] | |
| assert len(box_captions) == len(boxes_normalized), f"len(box_captions)={len(box_captions)} != len(boxes_normalized)={len(boxes_normalized)}" | |
| out = {} | |
| out['box_captions'] = box_captions | |
| out['boxes_normalized'] = boxes_normalized | |
| out['boxes_unnormalized'] = boxes_unnormalized | |
| return out | |
| def encode_from_custom_annotation(custom_annotations, size=512): | |
| # custom_annotations = [ | |
| # {'x': 83, 'y': 335, 'width': 70, 'height': 69, 'label': 'blue metal cube'}, | |
| # {'x': 162, 'y': 302, 'width': 110, 'height': 138, 'label': 'blue metal cube'}, | |
| # {'x': 274, 'y': 250, 'width': 191, 'height': 234, 'label': 'blue metal cube'}, | |
| # {'x': 14, 'y': 18, 'width': 155, 'height': 205, 'label': 'blue metal cube'}, | |
| # {'x': 175, 'y': 79, 'width': 106, 'height': 119, 'label': 'blue metal cube'}, | |
| # {'x': 288, 'y': 111, 'width': 69, 'height': 63, 'label': 'blue metal cube'} | |
| # ] | |
| H, W = size, size | |
| objects = [] | |
| for j in range(len(custom_annotations)): | |
| xyxy = [ | |
| custom_annotations[j]['x'], | |
| custom_annotations[j]['y'], | |
| custom_annotations[j]['x'] + custom_annotations[j]['width'], | |
| custom_annotations[j]['y'] + custom_annotations[j]['height']] | |
| objects.append({ | |
| 'caption': custom_annotations[j]['label'], | |
| 'bbox': xyxy, | |
| }) | |
| out = encode_scene(objects, H=H, W=W, | |
| src_bbox_format='xyxy', tgt_bbox_format='xyxy') | |
| return out | |
| #### Below are for HF diffusers | |
| def iterinpaint_sample_diffusers(pipe, datum, paste=True, verbose=False, guidance_scale=4.0, size=512, background_instruction='Add gray background'): | |
| d = datum | |
| d['unnormalized_boxes'] = d['boxes_unnormalized'] | |
| n_total_boxes = len(d['unnormalized_boxes']) | |
| context_imgs = [] | |
| mask_imgs = [] | |
| # masked_imgs = [] | |
| generated_images = [] | |
| prompts = [] | |
| context_img = Image.new('RGB', (size, size)) | |
| # context_draw = ImageDraw.Draw(context_img) | |
| if verbose: | |
| print('Initiailzed context image') | |
| background_mask_img = Image.new('L', (size, size)) | |
| background_mask_draw = ImageDraw.Draw(background_mask_img) | |
| background_mask_draw.rectangle([(0, 0), background_mask_img.size], fill=255) | |
| for i in range(n_total_boxes): | |
| if verbose: | |
| print('Iter: ', i+1, 'total: ', n_total_boxes) | |
| target_caption = d['box_captions'][i] | |
| if verbose: | |
| print('Drawing ', target_caption) | |
| mask_img = Image.new('L', context_img.size) | |
| mask_draw = ImageDraw.Draw(mask_img) | |
| mask_draw.rectangle([(0, 0), mask_img.size], fill=0) | |
| box = d['unnormalized_boxes'][i] | |
| if type(box) == list: | |
| box = torch.tensor(box) | |
| mask_draw.rectangle(box.long().tolist(), fill=255) | |
| background_mask_draw.rectangle(box.long().tolist(), fill=0) | |
| mask_imgs.append(mask_img.copy()) | |
| prompt = f"Add {d['box_captions'][i]}" | |
| if verbose: | |
| print('prompt:', prompt) | |
| prompts += [prompt] | |
| context_imgs.append(context_img.copy()) | |
| generated_image = pipe( | |
| prompt, | |
| context_img, | |
| mask_img, | |
| guidance_scale=guidance_scale).images[0] | |
| if paste: | |
| # context_img.paste(generated_image.crop(box.long().tolist()), box.long().tolist()) | |
| src_box = box.long().tolist() | |
| # x1 -> x1 + 1 | |
| # y1 -> y1 + 1 | |
| paste_box = box.long().tolist() | |
| paste_box[0] -= 1 | |
| paste_box[1] -= 1 | |
| paste_box[2] += 1 | |
| paste_box[3] += 1 | |
| box_w = paste_box[2] - paste_box[0] | |
| box_h = paste_box[3] - paste_box[1] | |
| context_img.paste(generated_image.crop(src_box).resize((box_w, box_h)), paste_box) | |
| generated_images.append(context_img.copy()) | |
| else: | |
| context_img = generated_image | |
| generated_images.append(context_img.copy()) | |
| if verbose: | |
| print('Fill background') | |
| mask_img = background_mask_img | |
| mask_imgs.append(mask_img) | |
| prompt = background_instruction | |
| if verbose: | |
| print('prompt:', prompt) | |
| prompts += [prompt] | |
| generated_image = pipe( | |
| prompt, | |
| context_img, | |
| mask_img, | |
| guidance_scale=guidance_scale).images[0] | |
| generated_images.append(generated_image) | |
| return { | |
| 'context_imgs': context_imgs, | |
| 'mask_imgs': mask_imgs, | |
| 'prompts': prompts, | |
| 'generated_images': generated_images, | |
| } |