hash-map's picture
files
b466b8b verified
import os
import random
import numpy as np
import tensorflow as tf
from pathlib import Path
from tqdm import tqdm
import cv2 # for reading images from disk
# -------------------------------------------------
# CONFIGURATION
# -------------------------------------------------
GENERATOR_PATH = "generator_final.h5" # <-- your model
VIS_FOLDER = "data/train/visible" # <-- folder with visible images
IR_FOLDER = "data/train/infrared" # <-- (optional) ground-truth IR
SAVE_DIR = "output/train_results" # where to save side-by-side
NUM_SAMPLES = 10
IMG_SIZE = (256, 256) # adjust to your model's input size
SEED = 42
# -------------------------------------------------
# 1. Load the generator
# -------------------------------------------------
print(f"Loading generator from {GENERATOR_PATH} ...")
generator = tf.keras.models.load_model(GENERATOR_PATH, compile=False)
print("Generator loaded successfully.")
# -------------------------------------------------
# 2. Helper: preprocess image (resize + normalize to [-1, 1])
# -------------------------------------------------
def load_and_preprocess_image(img_path, target_size=IMG_SIZE):
img = cv2.imread(img_path)
if img is None:
raise FileNotFoundError(f"Image not found: {img_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, target_size)
img = img.astype(np.float32) / 127.5 - 1.0 # β†’ [-1, 1]
img = np.expand_dims(img, axis=0) # (1, H, W, 3)
return img
# -------------------------------------------------
# 3. Helper: convert [-1,1] tensor β†’ uint8 image
# -------------------------------------------------
def to_uint8(tensor):
tensor = np.clip(tensor, -1.0, 1.0)
tensor = (tensor + 1.0) * 127.5
return np.clip(tensor, 0, 255).astype(np.uint8)
# -------------------------------------------------
# 4. Main function
# -------------------------------------------------
def generate_from_folder(
vis_folder, ir_folder=None, save_dir=SAVE_DIR, num_samples=NUM_SAMPLES
):
os.makedirs(save_dir, exist_ok=True)
random.seed(SEED)
vis_paths = sorted(Path(vis_folder).glob("*.*"))
vis_paths = [p for p in vis_paths if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp"}]
if len(vis_paths) == 0:
raise ValueError(f"No images found in {vis_folder}")
# Sample random images
sample_paths = random.sample(vis_paths, min(num_samples, len(vis_paths)))
print(f"Generating {len(sample_paths)} random side-by-side images...")
for idx, vis_path in enumerate(tqdm(sample_paths)):
# Load visible image
vis_tensor = load_and_preprocess_image(str(vis_path))
# Generate IR
pred_tensor = generator(vis_tensor, training=False) # (1, H, W, C)
pred_img = to_uint8(pred_tensor[0].numpy()) # (H, W, 3)
# Optional: load ground-truth IR (same filename)
ir_img = None
if ir_folder:
ir_path = Path(ir_folder) / vis_path.name
if ir_path.exists():
ir_tensor = load_and_preprocess_image(str(ir_path))
ir_tensor = generator.predict(ir_tensor) # not needed, just load raw
# Actually just read and convert
ir_raw = cv2.imread(str(ir_path))
ir_raw = cv2.cvtColor(ir_raw, cv2.COLOR_BGR2RGB)
ir_raw = cv2.resize(ir_raw, IMG_SIZE)
ir_img = ir_raw
else:
print(f"Warning: IR not found for {vis_path.name}, using black placeholder.")
ir_img = np.zeros((IMG_SIZE[1], IMG_SIZE[0], 3), dtype=np.uint8)
# If no IR folder, show only generated
if ir_img is None:
# Show: [Black | Generated]
left = np.zeros_like(pred_img)
row = np.concatenate([left, pred_img], axis=1)
title = "Generated Only"
else:
# Show: [Ground Truth IR | Generated]
row = np.concatenate([ir_img, pred_img], axis=1)
title = "GT | Generated"
# Save
save_path = os.path.join(save_dir, f"sample_{idx:02d}_{vis_path.stem}.png")
cv2.imwrite(save_path, cv2.cvtColor(row, cv2.COLOR_RGB2BGR))
print(f"All {len(sample_paths)} images saved to {save_dir}")
# -------------------------------------------------
# 5. RUN
# -------------------------------------------------
if __name__ == "__main__":
# Case 1: You have ground-truth IR images (same filename)
generate_from_folder(
vis_folder=VIS_FOLDER,
ir_folder=IR_FOLDER, # set to None if you don't have GT
num_samples=NUM_SAMPLES
)
# Case 2: Only visible images β†’ show generated only
# generate_from_folder(vis_folder=VIS_FOLDER, ir_folder=None)