File size: 2,575 Bytes
cb390da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import cv2
import numpy as np
from PIL import Image
import torch
from transformers import DPTForDepthEstimation, DPTImageProcessor

# Initialize Depth Estimator outside functions to avoid re-loading
# Use config.DEVICE and config.DTYPE for consistency
try:
    from config import DEVICE, DTYPE
except ImportError:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

print(f"Loading Depth Estimator on {DEVICE} with {DTYPE}...")
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
depth_estimator.to(DEVICE)
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
print("Depth Estimator loaded.")

def apply_canny(image: Image.Image) -> Image.Image:
    """
    Applies Canny edge detection to a PIL Image.
    """
    image_np = np.array(image)
    # Convert to grayscale if not already
    if len(image_np.shape) == 3 and image_np.shape[2] == 3:
        image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
    
    # Apply Canny
    image_edges = cv2.Canny(image_np, 100, 200) # You can adjust thresholds
    
    # Convert back to 3-channel for ControlNet
    image_edges = image_edges[:, :, None]
    image_edges = np.concatenate([image_edges, image_edges, image_edges], axis=2)
    return Image.fromarray(image_edges)

def apply_depth(image: Image.Image) -> Image.Image:
    """
    Estimates depth from a PIL Image and returns a depth map image.
    """
    original_size = image.size
    # Resize image for depth estimation speed if it's very large, maintain aspect ratio
    max_dim = max(original_size)
    if max_dim > 768:
        scale_factor = 768 / max_dim
        image = image.resize((int(original_size[0] * scale_factor), int(original_size[1] * scale_factor)), Image.BICUBIC)

    inputs = feature_extractor(images=image, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        outputs = depth_estimator(**inputs)
        predicted_depth = outputs.predicted_depth

    # Interpolate to original size and normalize
    prediction = torch.nn.functional.interpolate(
        predicted_depth.unsqueeze(1),
        size=original_size[::-1], # PIL size is (width, height), interpolate expects (height, width)
        mode="bicubic",
        align_corners=False,
    )
    output = prediction.squeeze().cpu().numpy()
    
    # Normalize to 0-255 and convert to uint8
    formatted_output = np.interp(output, (output.min(), output.max()), (0, 255)).astype(np.uint8)
    return Image.fromarray(formatted_output)