Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| from PIL import Image | |
| def tensor_to_pil(tensor, mask=None, normalize: bool = True): | |
| """ | |
| Convert tensor to PIL Image. | |
| :param tensor: torch.Tensor or str (file path to tensor), shape can be (Nv, H, W, C), (Nv, C, H, W), (H, W, C), (C, H, W) | |
| :param mask: torch.Tensor or str (file path to tensor), shape same as tensor, effective when C=3 | |
| :return: PIL.Image | |
| """ | |
| # If input is a file path, load the tensor | |
| if isinstance(tensor, str): | |
| from utils.file_utils import load_tensor_from_file | |
| tensor = load_tensor_from_file(tensor, map_location="cpu") | |
| if mask is not None and isinstance(mask, str): | |
| from utils.file_utils import load_tensor_from_file | |
| mask = load_tensor_from_file(mask, map_location="cpu") | |
| # Move to cpu | |
| tensor = tensor.detach() | |
| if tensor.is_cuda: | |
| tensor = tensor.cpu() | |
| if mask is not None and mask.is_cuda: | |
| mask = mask.cpu() | |
| # Convert to float32 | |
| tensor = tensor.float() | |
| if mask is not None: | |
| mask = mask.float() | |
| if normalize: | |
| tensor = (tensor + 1.0) / 2.0 | |
| tensor = torch.clamp(tensor, 0.0, 1.0) | |
| if mask is not None: | |
| if mask.shape[-1] not in [1, 3]: | |
| mask = mask.unsqueeze(-1) | |
| tensor = torch.cat([tensor, mask], dim=-1) | |
| shape = tensor.shape | |
| # 4D: (Nv, H, W, C) or (Nv, C, H, W) | |
| if len(shape) == 4: | |
| Nv = shape[0] | |
| if shape[-1] in [3, 4]: # (Nv, H, W, C) | |
| tensor = rearrange(tensor, 'nv h w c -> h (nv w) c') | |
| else: # (Nv, C, H, W) | |
| tensor = rearrange(tensor, 'nv c h w -> h (nv w) c') | |
| # 3D: (H, W, C) or (C, H, W) | |
| elif len(shape) == 3: | |
| if shape[-1] in [3, 4]: # (H, W, C) | |
| tensor = rearrange(tensor, 'h w c -> h w c') | |
| else: # (C, H, W) | |
| tensor = rearrange(tensor, 'c h w -> h w c') | |
| else: | |
| raise ValueError(f"Unsupported tensor shape: {shape}") | |
| # Convert to numpy | |
| np_img = (tensor.numpy() * 255).round().astype(np.uint8) | |
| # Create PIL Image | |
| if np_img.shape[2] == 3: | |
| return Image.fromarray(np_img, mode="RGB") | |
| elif np_img.shape[2] == 4: | |
| return Image.fromarray(np_img, mode="RGBA") | |
| else: | |
| raise ValueError("Only support 3 or 4 channel images.") |