import os import sys import cv2 import json import random import time import datetime import requests import func_timeout import numpy as np import gradio as gr import boto3 import tempfile import io import uuid from botocore.client import Config from PIL import Image # TOKEN = os.environ['TOKEN'] # APIKEY = os.environ['APIKEY'] # UKAPIURL = os.environ['UKAPIURL'] OneKey = os.environ['OneKey'].strip() OneKey = OneKey.split("#") TOKEN = OneKey[0] APIKEY = OneKey[1] UKAPIURL = OneKey[2] LLMKEY = OneKey[3] R2_ACCESS_KEY = OneKey[4] R2_SECRET_KEY = OneKey[5] R2_ENDPOINT = OneKey[6] samples = [ { "image": ["datas/data01/hei_cat01.webp","datas/data01/tom.webp"], "prompt": "Let all cats dance together in the park", "aspect_ratio": "16:9" }, { "image": ["datas/data02/girl.jpg","datas/data02/cloth.jpeg"], "prompt": "Let the girl in first image, wear the dress in second image", "aspect_ratio": "9:16" }, { "image": ["datas/data03/girl.webp","datas/data03/cloth.jpg"], "prompt": "Let the girl in first image, wear the bikini in second image, lying on the beach", "aspect_ratio": "1:1" } ] class R2Api: def __init__(self, session=None): super().__init__() self.R2_BUCKET = "omni-creator" self.domain = "https://www.omnicreator.net/" self.R2_ACCESS_KEY = R2_ACCESS_KEY self.R2_SECRET_KEY = R2_SECRET_KEY self.R2_ENDPOINT = R2_ENDPOINT self.client = boto3.client( "s3", endpoint_url=self.R2_ENDPOINT, aws_access_key_id=self.R2_ACCESS_KEY, aws_secret_access_key=self.R2_SECRET_KEY, config=Config(signature_version="s3v4") ) self.session = requests.Session() if session is None else session def upload_from_memory(self, image_data, filename, content_type='image/jpeg'): """ Upload image data directly from memory to R2 Args: image_data (bytes): Image data in bytes filename (str): Filename for the uploaded file content_type (str): MIME type of the image Returns: str: URL of the uploaded file """ t1 = time.time() headers = {"Content-Type": content_type} cloud_path = f"ImageEdit/Uploads/{str(datetime.date.today())}/{filename}" url = self.client.generate_presigned_url( "put_object", Params={"Bucket": self.R2_BUCKET, "Key": cloud_path, "ContentType": content_type}, ExpiresIn=604800 ) retry_count = 0 while retry_count < 3: try: response = self.session.put(url, data=image_data, headers=headers, timeout=15) if response.status_code == 200: break else: print(f"⚠️ Upload failed with status code: {response.status_code}") retry_count += 1 except (requests.exceptions.Timeout, requests.exceptions.RequestException) as e: print(f"⚠️ Upload retry {retry_count + 1}/3 failed: {e}") retry_count += 1 if retry_count == 3: raise Exception(f'Failed to upload file to R2 after 3 retries! Last error: {str(e)}') time.sleep(1) # 等待1秒后重试 continue print("upload_from_memory time is ====>", time.time() - t1) return f"{self.domain}{cloud_path}" def upload_file(self, local_path, cloud_path): t1 = time.time() head_dict = { 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'png': 'image/png', 'gif': 'image/gif', 'bmp': 'image/bmp', 'webp': 'image/webp', 'ico': 'image/x-icon' } ftype = os.path.basename(local_path).split(".")[-1].lower() ctype = head_dict.get(ftype, 'application/octet-stream') headers = {"Content-Type": ctype} cloud_path = f"ImageEdit/Uploads/{str(datetime.date.today())}/{os.path.basename(local_path)}" url = self.client.generate_presigned_url( "put_object", Params={"Bucket": self.R2_BUCKET, "Key": cloud_path, "ContentType": ctype}, ExpiresIn=604800 ) retry_count = 0 while retry_count < 3: try: with open(local_path, 'rb') as f: self.session.put(url, data=f.read(), headers=headers, timeout=8) break except (requests.exceptions.Timeout, requests.exceptions.RequestException): retry_count += 1 if retry_count == 3: raise Exception('Failed to upload file to R2 after 3 retries!') continue print("upload_file time is ====>", time.time() - t1) return f"{self.domain}{cloud_path}" def upload_user_img_r2(clientIp, timeId, pil_image): """ Upload PIL Image directly to R2 without saving to local file Args: clientIp (str): Client IP address timeId (int): Timestamp pil_image (PIL.Image): PIL Image object Returns: str: Uploaded URL """ # Generate unique filename using UUID to prevent file conflicts in concurrent environment unique_id = str(uuid.uuid4()) fileName = f"user_img_{unique_id}_{timeId}.jpg" # Convert PIL Image to bytes img_buffer = io.BytesIO() if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') pil_image.save(img_buffer, format='JPEG', quality=95) img_data = img_buffer.getvalue() # Upload directly from memory res = R2Api().upload_from_memory(img_data, fileName, 'image/jpeg') return res def create_mask_from_layers(base_image, layers): """ Create mask image from ImageEditor layers Args: base_image (PIL.Image): Original image layers (list): ImageEditor layer data Returns: PIL.Image: Black and white mask image """ from PIL import Image, ImageDraw import numpy as np # Create blank mask with same size as original image mask = Image.new('L', base_image.size, 0) # 'L' mode is grayscale, 0 is black if not layers: return mask # Iterate through all layers, set drawn areas to white for layer in layers: if layer is not None: # Convert layer to numpy array layer_array = np.array(layer) # Check layer format if len(layer_array.shape) == 3: # RGB/RGBA format # If RGBA, check alpha channel if layer_array.shape[2] == 4: # Use alpha channel as mask alpha_channel = layer_array[:, :, 3] # Set non-transparent areas (alpha > 0) to white mask_array = np.where(alpha_channel > 0, 255, 0).astype(np.uint8) else: # RGB format, check if not pure black (0,0,0) # Assume drawn areas are non-black non_black = np.any(layer_array > 0, axis=2) mask_array = np.where(non_black, 255, 0).astype(np.uint8) elif len(layer_array.shape) == 2: # Grayscale # Use grayscale values directly, set non-zero areas to white mask_array = np.where(layer_array > 0, 255, 0).astype(np.uint8) else: continue # Convert mask_array to PIL image and merge into total mask layer_mask = Image.fromarray(mask_array, mode='L') # Resize to match original image if layer_mask.size != base_image.size: layer_mask = layer_mask.resize(base_image.size, Image.LANCZOS) # Merge masks (use maximum value to ensure all drawn areas are included) mask_array_current = np.array(mask) layer_mask_array = np.array(layer_mask) combined_mask_array = np.maximum(mask_array_current, layer_mask_array) mask = Image.fromarray(combined_mask_array, mode='L') return mask def upload_mask_image_r2(client_ip, time_id, mask_image): """ Upload mask image to R2 directly from memory Args: client_ip (str): Client IP time_id (int): Timestamp mask_image (PIL.Image): Mask image Returns: str: Uploaded URL """ # Generate unique filename using UUID to prevent file conflicts in concurrent environment unique_id = str(uuid.uuid4()) file_name = f"mask_img_{unique_id}_{time_id}.png" try: # Convert mask image to bytes img_buffer = io.BytesIO() mask_image.save(img_buffer, format='PNG') img_data = img_buffer.getvalue() # Upload directly from memory res = R2Api().upload_from_memory(img_data, file_name, 'image/png') return res except Exception as e: print(f"Failed to upload mask image: {e}") return None def submit_image_edit_task(user_image_url, prompt, task_type="80", mask_image_url="", user_image2_url="", user_image3_url="", width=0, height=0): """ Submit image editing task with improved error handling using API v2 Supports multi-image editing with user_image2, user_image3, width, height parameters """ headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {APIKEY}' } data = { "user_image": user_image_url, "user_image2": user_image2_url, "user_image3": user_image3_url, "user_mask": mask_image_url, "type": task_type, "width": width, "height": height, "text": prompt, "user_uuid": APIKEY, "priority": 0, "secret_key": "219ngu" } retry_count = 0 max_retries = 3 while retry_count < max_retries: try: response = requests.post( f'{UKAPIURL}/public_image_edit_v2', headers=headers, json=data, timeout=30 # 增加超时时间 ) if response.status_code == 200: result = response.json() if result.get('code') == 0: return result['data']['task_id'], None else: return None, f"API Error: {result.get('message', 'Unknown error')}" elif response.status_code in [502, 503, 504]: # 服务器错误,可以重试 retry_count += 1 if retry_count < max_retries: print(f"⚠️ Server error {response.status_code}, retrying {retry_count}/{max_retries}") time.sleep(2) # 等待2秒后重试 continue else: return None, f"HTTP Error after {max_retries} retries: {response.status_code}" else: return None, f"HTTP Error: {response.status_code}" except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: retry_count += 1 if retry_count < max_retries: print(f"⚠️ Network error, retrying {retry_count}/{max_retries}: {e}") time.sleep(2) continue else: return None, f"Network error after {max_retries} retries: {str(e)}" except Exception as e: return None, f"Request Exception: {str(e)}" return None, f"Failed after {max_retries} retries" def check_task_status(task_id): """ Query task status with improved error handling using API v2 """ headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {APIKEY}' } data = { "task_id": task_id } retry_count = 0 max_retries = 2 # 状态查询重试次数少一些 while retry_count < max_retries: try: response = requests.post( f'{UKAPIURL}/status_image_edit_v2', headers=headers, json=data, timeout=15 # 状态查询超时时间短一些 ) if response.status_code == 200: result = response.json() if result.get('code') == 0: task_data = result['data'] status = task_data['status'] image_url = task_data.get('image_url') # Extract and log queue information queue_info = task_data.get('queue_info', {}) if queue_info: tasks_ahead = queue_info.get('tasks_ahead', 0) current_priority = queue_info.get('current_priority', 0) print(f"📊 Queue Info - Tasks ahead: {tasks_ahead}, Priority: {current_priority}") return status, image_url, task_data else: return 'error', None, result.get('message', 'Unknown error') elif response.status_code in [502, 503, 504]: # 服务器错误,可以重试 retry_count += 1 if retry_count < max_retries: print(f"⚠️ Status check server error {response.status_code}, retrying {retry_count}/{max_retries}") time.sleep(1) # 状态查询重试间隔短一些 continue else: return 'error', None, f"HTTP Error after {max_retries} retries: {response.status_code}" else: return 'error', None, f"HTTP Error: {response.status_code}" except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: retry_count += 1 if retry_count < max_retries: print(f"⚠️ Status check network error, retrying {retry_count}/{max_retries}: {e}") time.sleep(1) continue else: return 'error', None, f"Network error after {max_retries} retries: {str(e)}" except Exception as e: return 'error', None, f"Request Exception: {str(e)}" return 'error', None, f"Failed after {max_retries} retries" def process_multi_image_edit(img_inputs, prompt, width=0, height=0, progress_callback=None): """ Complete process for multi-image editing Args: img_inputs: List of image inputs (PIL Image objects or file paths), 2-3 images prompt: Editing instructions width: Output width (0 for auto) height: Output height (0 for auto) progress_callback: Progress callback function """ try: # Generate client IP and timestamp client_ip = "127.0.0.1" # Default IP time_id = int(time.time()) # Validate input images if not img_inputs or len(img_inputs) < 2: return None, "Please upload at least 2 images", None if len(img_inputs) > 3: return None, "Maximum 3 images allowed", None # Process input images - supports PIL Image and file path uploaded_urls = [] if progress_callback: progress_callback("uploading images...") for i, img_input in enumerate(img_inputs): if img_input is None: continue if hasattr(img_input, 'save'): # PIL Image object pil_image = img_input print(f"💾 Using PIL Image {i+1} directly from memory") else: # Load from file path pil_image = Image.open(img_input) print(f"📁 Loaded image {i+1} from file: {img_input}") # Upload user image directly from memory uploaded_url = upload_user_img_r2(client_ip, time_id + i, pil_image) if not uploaded_url: return None, f"Image {i+1} upload failed", None # Extract actual image URL from upload URL if "?" in uploaded_url: uploaded_url = uploaded_url.split("?")[0] uploaded_urls.append(uploaded_url) # Ensure we have the required URLs user_image_url = uploaded_urls[0] if len(uploaded_urls) > 0 else "" user_image2_url = uploaded_urls[1] if len(uploaded_urls) > 1 else "" user_image3_url = uploaded_urls[2] if len(uploaded_urls) > 2 else "" if progress_callback: progress_callback("submitting multi-image edit task...") # Submit multi-image editing task task_id, error = submit_image_edit_task( user_image_url, prompt, task_type="80", mask_image_url="", user_image2_url=user_image2_url, user_image3_url=user_image3_url, width=width, height=height ) if error: return None, error, None if progress_callback: progress_callback(f"task submitted, ID: {task_id}, processing...") # Wait for task completion max_attempts = 60 # Wait up to 10 minutes task_uuid = None for attempt in range(max_attempts): status, output_url, task_data = check_task_status(task_id) # Extract task_uuid from task_data if task_data and isinstance(task_data, dict): task_uuid = task_data.get('uuid', None) if status == 'completed': if output_url: return output_url, "multi-image edit completed", task_uuid else: return None, "Task completed but no result image returned", task_uuid elif status == 'error' or status == 'failed': return None, f"task processing failed: {task_data}", task_uuid elif status in ['queued', 'processing', 'running', 'created', 'working']: # Enhanced progress message with queue info if progress_callback and task_data and isinstance(task_data, dict): queue_info = task_data.get('queue_info', {}) if queue_info and status in ['queued', 'created']: tasks_ahead = queue_info.get('tasks_ahead', 0) if tasks_ahead > 0: progress_callback(f"⏳ In queue: {tasks_ahead} tasks ahead. Visit https://omnicreator.net/multi-image-edit#generator for instant processing!") else: progress_callback(f"🚀 Processing your multi-image request...") else: progress_callback(f"multi-image processing... (status: {status})") else: if progress_callback: progress_callback(f"multi-image processing... (status: {status})") time.sleep(1) else: if progress_callback: progress_callback(f"unknown status: {status}") time.sleep(1) return None, "task processing timeout", task_uuid except Exception as e: return None, f"error occurred during processing: {str(e)}", None def process_image_edit(img_input, prompt, progress_callback=None): """ Complete process for single image editing (backward compatibility) Args: img_input: Can be file path (str) or PIL Image object prompt: Editing instructions progress_callback: Progress callback function """ return process_multi_image_edit([img_input], prompt, 0, 0, progress_callback) def process_local_image_edit(base_image, layers, prompt, progress_callback=None): """ 处理局部图片编辑的完整流程 Args: base_image (PIL.Image): 原始图片 layers (list): ImageEditor的层数据 prompt (str): 编辑指令 progress_callback: 进度回调函数 """ try: # Generate client IP and timestamp client_ip = "127.0.0.1" # Default IP time_id = int(time.time()) if progress_callback: progress_callback("creating mask image...") # 从layers创建mask图片 mask_image = create_mask_from_layers(base_image, layers) # 检查mask是否有内容 mask_array = np.array(mask_image) if np.max(mask_array) == 0: return None, "please draw mask", None print(f"📝 创建mask图片成功,绘制区域像素数: {np.sum(mask_array > 0)}") if progress_callback: progress_callback("uploading original image...") # 直接从内存上传原始图片 uploaded_url = upload_user_img_r2(client_ip, time_id, base_image) if not uploaded_url: return None, "original image upload failed", None # 从上传 URL 中提取实际的图片 URL if "?" in uploaded_url: uploaded_url = uploaded_url.split("?")[0] if progress_callback: progress_callback("uploading mask image...") # 直接从内存上传mask图片 mask_url = upload_mask_image_r2(client_ip, time_id, mask_image) if not mask_url: return None, "mask image upload failed", None # 从上传 URL 中提取实际的图片 URL if "?" in mask_url: mask_url = mask_url.split("?")[0] print(f"📤 图片上传成功:") print(f" 原始图片: {uploaded_url}") print(f" Mask图片: {mask_url}") if progress_callback: progress_callback("submitting local edit task...") # 提交局部图片编辑任务 (task_type=81) task_id, error = submit_image_edit_task(uploaded_url, prompt, task_type="81", mask_image_url=mask_url) if error: return None, error, None if progress_callback: progress_callback(f"task submitted, ID: {task_id}, processing...") print(f"🚀 局部编辑任务已提交,任务ID: {task_id}") # Wait for task completion max_attempts = 60 # Wait up to 10 minutes task_uuid = None for attempt in range(max_attempts): status, output_url, task_data = check_task_status(task_id) # Extract task_uuid from task_data if task_data and isinstance(task_data, dict): task_uuid = task_data.get('uuid', None) if status == 'completed': if output_url: print(f"✅ 局部编辑任务完成,结果: {output_url}") return output_url, "local image edit completed", task_uuid else: return None, "task completed but no result image returned", task_uuid elif status == 'error' or status == 'failed': return None, f"task processing failed: {task_data}", task_uuid elif status in ['queued', 'processing', 'running', 'created', 'working']: # Enhanced progress message with queue info if progress_callback and task_data and isinstance(task_data, dict): queue_info = task_data.get('queue_info', {}) if queue_info and status in ['queued', 'created']: tasks_ahead = queue_info.get('tasks_ahead', 0) if tasks_ahead > 0: progress_callback(f"⏳ In queue: {tasks_ahead} tasks ahead. Visit https://omnicreator.net/multi-image-edit#generator for instant processing!") else: progress_callback(f"🚀 Processing your local edit request...") else: progress_callback(f"processing... (status: {status})") else: if progress_callback: progress_callback(f"processing... (status: {status})") time.sleep(1) # Wait 1 second before retry else: if progress_callback: progress_callback(f"unknown status: {status}") time.sleep(1) return None, "task processing timeout", task_uuid except Exception as e: print(f"❌ 局部编辑处理异常: {str(e)}") return None, f"error occurred during processing: {str(e)}", None def download_and_check_result_nsfw(image_url, nsfw_detector=None): """ 下载结果图片并进行NSFW检测 Args: image_url (str): 结果图片URL nsfw_detector: NSFW检测器实例 Returns: tuple: (is_nsfw, error_message) """ if nsfw_detector is None: return False, None try: # 下载图片 response = requests.get(image_url, timeout=30) if response.status_code != 200: return False, f"Failed to download result image: HTTP {response.status_code}" # 将图片数据转换为PIL Image image_data = io.BytesIO(response.content) result_image = Image.open(image_data) # 进行NSFW检测 nsfw_result = nsfw_detector.predict_pil_label_only(result_image) is_nsfw = nsfw_result.lower() == "nsfw" print(f"🔍 结果图片NSFW检测: {'❌❌❌ ' + nsfw_result if is_nsfw else '✅✅✅ ' + nsfw_result}") return is_nsfw, None except Exception as e: print(f"⚠️ 结果图片NSFW检测失败: {e}") return False, f"Failed to check result image: {str(e)}" if __name__ == "__main__": pass