Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import mlflow | |
| import mlflow.pytorch | |
| from PIL import Image | |
| import numpy as np | |
| from skimage.color import rgb2lab, lab2rgb | |
| from torchvision import transforms | |
| import argparse | |
| from model import Generator | |
| EXPERIMENT_NAME = "Colorizer_Experiment" | |
| def setup_mlflow(): | |
| experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME) | |
| if experiment is None: | |
| experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) | |
| else: | |
| experiment_id = experiment.experiment_id | |
| return experiment_id | |
| def load_model(run_id, device): | |
| print(f"Loading model from run: {run_id}") | |
| model_uri = f"runs:/{run_id}/generator_model" | |
| model = mlflow.pytorch.load_model(model_uri, map_location=device) | |
| return model | |
| # Configuration variables | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| RUN_ID = "your_run_id_here" # Replace with the actual run ID | |
| IMAGE_PATH = "path/to/your/image.jpg" # Replace with the path to your input image | |
| SAVE_MODEL = False | |
| SERVE_MODEL = False | |
| SERVE_PORT = 5000 | |
| def preprocess_image(image_path): | |
| img = Image.open(image_path).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| img_tensor = transform(img) | |
| lab_img = rgb2lab(img_tensor.permute(1, 2, 0).numpy()) | |
| L = lab_img[:,:,0] | |
| L = (L - 50) / 50 | |
| L = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float() | |
| return L | |
| def postprocess_output(L, ab): | |
| L = L.squeeze().cpu().numpy() | |
| ab = ab.squeeze().cpu().numpy() | |
| L = (L + 1.) * 50. | |
| ab = ab * 128. | |
| Lab = np.concatenate([L[..., np.newaxis], ab], axis=2) | |
| rgb_img = lab2rgb(Lab) | |
| return (rgb_img * 255).astype(np.uint8) | |
| def colorize_image(model, image_path, device): | |
| L = preprocess_image(image_path).to(device) | |
| with torch.no_grad(): | |
| ab = model(L) | |
| colorized = postprocess_output(L, ab) | |
| return colorized | |
| def save_model(model, run_id): | |
| with mlflow.start_run(run_id=run_id): | |
| # Log the model | |
| mlflow.pytorch.log_model(model, "model") | |
| # Register the model | |
| model_uri = f"runs:/{run_id}/model" | |
| mlflow.register_model(model_uri, "colorizer_model") | |
| print(f"Model saved and registered with run_id: {run_id}") | |
| def serve_model(run_id, port=5000): | |
| model_uri = f"runs:/{run_id}/model" | |
| mlflow.pytorch.serve(model_uri, port=port) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Colorizer Inference") | |
| parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model") | |
| parser.add_argument("--image_path", type=str, required=True, help="Path to the input grayscale image") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", | |
| help="Device to use for inference (cuda/cpu)") | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| print(f"Using device: {device}") | |
| # If run_id is not provided, try to load it from the file | |
| if not args.run_id: | |
| try: | |
| with open("latest_run_id.txt", "r") as f: | |
| args.run_id = f.read().strip() | |
| except FileNotFoundError: | |
| print("No run ID provided and couldn't find latest_run_id.txt") | |
| exit(1) | |
| experiment_id = setup_mlflow() | |
| with mlflow.start_run(experiment_id=experiment_id, run_name="inference_run"): | |
| try: | |
| model = load_model(args.run_id, device) | |
| colorized = colorize_image(model, args.image_path, device) | |
| output_path = f"colorized_{os.path.basename(args.image_path)}" | |
| Image.fromarray(colorized).save(output_path) | |
| print(f"Colorized image saved as: {output_path}") | |
| mlflow.log_artifact(output_path) | |
| mlflow.log_param("input_image", args.image_path) | |
| mlflow.log_param("model_run_id", args.run_id) | |
| except Exception as e: | |
| print(f"Error during inference: {str(e)}") | |
| mlflow.log_param("error", str(e)) | |
| finally: | |
| mlflow.end_run() |