Spaces:
Paused
Paused
| import gradio as gr | |
| import spaces | |
| import os | |
| import json | |
| import torch | |
| import gc | |
| from datetime import datetime | |
| from pathlib import Path | |
| # Disable torch dynamo globally to avoid ConstantVariable errors | |
| torch._dynamo.config.suppress_errors = True | |
| # Initialize directories | |
| DATA_DIR = Path("/data") if os.path.exists("/data") else Path("./data") | |
| DATA_DIR.mkdir(exist_ok=True) | |
| (DATA_DIR / "users").mkdir(exist_ok=True) | |
| (DATA_DIR / "monsters").mkdir(exist_ok=True) | |
| (DATA_DIR / "models").mkdir(exist_ok=True) | |
| (DATA_DIR / "cache").mkdir(exist_ok=True) | |
| # Ensure Gradio cache directory exists | |
| import tempfile | |
| gradio_cache_dir = Path("/tmp/gradio") | |
| gradio_cache_dir.mkdir(parents=True, exist_ok=True) | |
| # Set environment variable for Gradio cache | |
| os.environ.setdefault("GRADIO_TEMP_DIR", str(gradio_cache_dir)) | |
| # Import modules (to be created) | |
| from core.ai_pipeline import MonsterGenerationPipeline | |
| from core.game_mechanics import GameMechanics | |
| from core.state_manager import StateManager | |
| from core.auth_manager import AuthManager | |
| from ui.themes import get_cyberpunk_theme, CYBERPUNK_CSS | |
| from ui.interfaces import create_voice_interface, create_visual_interface | |
| # Initialize with GPU optimization | |
| def initialize_systems(): | |
| """Initialize all core systems with GPU""" | |
| pipeline = MonsterGenerationPipeline() | |
| return pipeline | |
| # Initialize core systems (defer GPU initialization) | |
| pipeline = None | |
| def get_pipeline(): | |
| """Get or initialize the pipeline with GPU support""" | |
| global pipeline | |
| if pipeline is None: | |
| try: | |
| pipeline = initialize_systems() | |
| except Exception as e: | |
| print(f"GPU initialization failed, falling back to CPU: {e}") | |
| pipeline = MonsterGenerationPipeline(device="cpu") | |
| return pipeline | |
| game_mechanics = GameMechanics() | |
| state_manager = StateManager(DATA_DIR) | |
| auth_manager = AuthManager() | |
| # Main generation function | |
| def generate_monster(oauth_profile: gr.OAuthProfile | None, audio_input=None, text_input=None, reference_images=None, | |
| training_focus="balanced", care_level="normal"): | |
| """Generate a new monster with AI pipeline""" | |
| if oauth_profile is None: | |
| return { | |
| "message": "๐ Please log in to create monsters!", | |
| "image": None, | |
| "model_3d": None, | |
| "stats": None, | |
| "dialogue": None | |
| } | |
| user_id = oauth_profile.username | |
| try: | |
| # Generate monster using AI pipeline | |
| current_pipeline = get_pipeline() | |
| result = current_pipeline.generate_monster( | |
| audio_input=audio_input, | |
| text_input=text_input, | |
| reference_images=reference_images, | |
| user_id=user_id | |
| ) | |
| # Create game monster from AI result | |
| monster = game_mechanics.create_monster(result, { | |
| "training_focus": training_focus, | |
| "care_level": care_level | |
| }, user_id) | |
| # Save to persistent storage | |
| state_manager.save_monster(user_id, monster) | |
| # Prepare response | |
| response_dict = { | |
| "message": f"โจ {monster.name} has been created!", | |
| "image": result.get('image'), | |
| "model_3d": result.get('model_3d'), | |
| "stats": monster.get_stats_display(), | |
| "dialogue": result.get('dialogue', "๐ค๐1๏ธโฃ0๏ธโฃ0๏ธโฃ") | |
| } | |
| return ( | |
| response_dict["message"], | |
| response_dict["image"], | |
| response_dict["model_3d"], | |
| response_dict["stats"], | |
| response_dict["dialogue"] | |
| ) | |
| except Exception as e: | |
| print(f"Error generating monster: {str(e)}") | |
| # Use fallback generation | |
| current_pipeline = get_pipeline() | |
| fallback_result = current_pipeline.fallback_generation(text_input or "friendly digital creature") | |
| fallback_dict = { | |
| "message": "โก Created using quick generation mode", | |
| "image": fallback_result.get('image'), | |
| "model_3d": None, | |
| "stats": fallback_result.get('stats'), | |
| "dialogue": "๐คโ9๏ธโฃ9๏ธโฃ" | |
| } | |
| return ( | |
| fallback_dict["message"], | |
| fallback_dict["image"], | |
| fallback_dict["model_3d"], | |
| fallback_dict["stats"], | |
| fallback_dict["dialogue"] | |
| ) | |
| # Training function | |
| def train_monster(oauth_profile: gr.OAuthProfile | None, training_type, intensity): | |
| """Train the active monster""" | |
| if oauth_profile is None: | |
| return "๐ Please log in to train monsters!", None, None | |
| user_id = oauth_profile.username | |
| current_monster = state_manager.get_current_monster(user_id) | |
| if not current_monster: | |
| return "โ No active monster to train!", None, None | |
| # Apply training | |
| result = game_mechanics.train_monster(current_monster, training_type, intensity) | |
| if result['success']: | |
| state_manager.update_monster(user_id, current_monster) | |
| return ( | |
| result['message'], | |
| current_monster.get_stats_display(), | |
| result.get('evolution_check') | |
| ) | |
| else: | |
| return result['message'], None, None | |
| # Care functions | |
| def feed_monster(oauth_profile: gr.OAuthProfile | None, food_type): | |
| """Feed the active monster""" | |
| if oauth_profile is None: | |
| return "๐ Please log in to care for monsters!" | |
| user_id = oauth_profile.username | |
| current_monster = state_manager.get_current_monster(user_id) | |
| if not current_monster: | |
| return "โ No active monster to feed!" | |
| result = game_mechanics.feed_monster(current_monster, food_type) | |
| state_manager.update_monster(user_id, current_monster) | |
| return result['message'] | |
| # Build the Gradio interface | |
| with gr.Blocks( | |
| theme=get_cyberpunk_theme(), | |
| css=CYBERPUNK_CSS, | |
| title="DigiPal - Digital Monster Companion" | |
| ) as demo: | |
| # Header with cyberpunk styling | |
| gr.HTML(""" | |
| <div class="cyber-header"> | |
| <h1 class="glitch-text">๐ค DigiPal ๐ค</h1> | |
| <p class="cyber-subtitle">Your AI-Powered Digital Monster Companion</p> | |
| <div class="pulse-line"></div> | |
| </div> | |
| """) | |
| # Authentication | |
| with gr.Row(): | |
| login_btn = gr.LoginButton("๐ Connect to Digital World", size="lg") | |
| user_display = gr.Markdown("", elem_classes=["user-status"]) | |
| # Main interface tabs | |
| with gr.Tabs(elem_classes=["cyber-tabs"]): | |
| # Monster Creation Tab | |
| with gr.TabItem("๐งฌ Create Monster", elem_classes=["cyber-tab-content"]): | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ๐๏ธ Voice Input") | |
| audio_input = gr.Audio( | |
| label="Describe your monster", | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| elem_classes=["cyber-input"] | |
| ) | |
| gr.Markdown("### ๐ฌ Text Input") | |
| text_input = gr.Textbox( | |
| label="Or type a description", | |
| placeholder="Describe your ideal digital monster...", | |
| lines=3, | |
| elem_classes=["cyber-input"] | |
| ) | |
| gr.Markdown("### ๐ผ๏ธ Reference Images") | |
| reference_images = gr.File( | |
| label="Upload reference images (optional)", | |
| file_count="multiple", | |
| file_types=["image"], | |
| elem_classes=["cyber-input"] | |
| ) | |
| with gr.Row(): | |
| training_focus = gr.Radio( | |
| choices=["balanced", "strength", "defense", "speed", "intelligence"], | |
| label="Training Focus", | |
| value="balanced", | |
| elem_classes=["cyber-radio"] | |
| ) | |
| generate_btn = gr.Button( | |
| "โก Generate Monster", | |
| variant="primary", | |
| size="lg", | |
| elem_classes=["cyber-button", "generate-button"] | |
| ) | |
| # Output Column | |
| with gr.Column(scale=1): | |
| generation_message = gr.Markdown("", elem_classes=["cyber-message"]) | |
| monster_image = gr.Image( | |
| label="Monster Appearance", | |
| type="pil", | |
| elem_classes=["monster-display"] | |
| ) | |
| monster_model = gr.Model3D( | |
| label="3D Model", | |
| height=400, | |
| elem_classes=["monster-display"] | |
| ) | |
| monster_dialogue = gr.Textbox( | |
| label="Monster Says", | |
| interactive=False, | |
| elem_classes=["cyber-dialogue"] | |
| ) | |
| monster_stats = gr.JSON( | |
| label="Stats", | |
| elem_classes=["cyber-stats"] | |
| ) | |
| # Monster Status Tab | |
| with gr.TabItem("๐ Monster Status", elem_classes=["cyber-tab-content"]): | |
| with gr.Row(): | |
| with gr.Column(): | |
| current_monster_display = gr.Model3D( | |
| label="Your Digital Monster", | |
| height=400, | |
| elem_classes=["monster-display"] | |
| ) | |
| monster_communication = gr.Textbox( | |
| label="Monster Communication", | |
| placeholder="Your monster speaks in emojis and numbers...", | |
| interactive=False, | |
| elem_classes=["cyber-dialogue"] | |
| ) | |
| with gr.Column(): | |
| stats_display = gr.JSON( | |
| label="Current Stats", | |
| elem_classes=["cyber-stats"] | |
| ) | |
| care_metrics = gr.JSON( | |
| label="Care Status", | |
| elem_classes=["cyber-stats"] | |
| ) | |
| evolution_progress = gr.HTML( | |
| elem_classes=["evolution-display"] | |
| ) | |
| refresh_btn = gr.Button( | |
| "๐ Refresh Status", | |
| elem_classes=["cyber-button"] | |
| ) | |
| # Training Tab | |
| with gr.TabItem("๐ช Training", elem_classes=["cyber-tab-content"]): | |
| with gr.Row(): | |
| with gr.Column(): | |
| training_type = gr.Radio( | |
| choices=["Strength", "Defense", "Speed", "Intelligence", "Special"], | |
| label="Training Type", | |
| value="Strength", | |
| elem_classes=["cyber-radio"] | |
| ) | |
| training_intensity = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Training Intensity", | |
| elem_classes=["cyber-slider"] | |
| ) | |
| train_btn = gr.Button( | |
| "๐๏ธ Start Training", | |
| variant="primary", | |
| elem_classes=["cyber-button"] | |
| ) | |
| with gr.Column(): | |
| training_result = gr.Textbox( | |
| label="Training Result", | |
| interactive=False, | |
| elem_classes=["cyber-output"] | |
| ) | |
| updated_stats = gr.JSON( | |
| label="Updated Stats", | |
| elem_classes=["cyber-stats"] | |
| ) | |
| evolution_check = gr.HTML( | |
| elem_classes=["evolution-display"] | |
| ) | |
| # Care Tab | |
| with gr.TabItem("โค๏ธ Care", elem_classes=["cyber-tab-content"]): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### ๐ Feeding") | |
| food_type = gr.Radio( | |
| choices=["Meat", "Fish", "Vegetable", "Treat", "Medicine"], | |
| label="Select Food", | |
| value="Meat", | |
| elem_classes=["cyber-radio"] | |
| ) | |
| feed_btn = gr.Button( | |
| "๐ฝ๏ธ Feed Monster", | |
| elem_classes=["cyber-button"] | |
| ) | |
| feeding_result = gr.Textbox( | |
| label="Feeding Result", | |
| interactive=False, | |
| elem_classes=["cyber-output"] | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### ๐ฎ Interaction") | |
| play_btn = gr.Button( | |
| "๐พ Play", | |
| elem_classes=["cyber-button"] | |
| ) | |
| praise_btn = gr.Button( | |
| "๐ Praise", | |
| elem_classes=["cyber-button"] | |
| ) | |
| scold_btn = gr.Button( | |
| "๐ Scold", | |
| elem_classes=["cyber-button"] | |
| ) | |
| interaction_result = gr.Textbox( | |
| label="Monster Response", | |
| interactive=False, | |
| elem_classes=["cyber-output"] | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_monster, | |
| inputs=[ | |
| audio_input, | |
| text_input, | |
| reference_images, | |
| training_focus, | |
| gr.State("normal") # care_level | |
| ], | |
| outputs=[ | |
| generation_message, | |
| monster_image, | |
| monster_model, | |
| monster_stats, | |
| monster_dialogue | |
| ] | |
| ) | |
| train_btn.click( | |
| fn=train_monster, | |
| inputs=[ | |
| training_type, | |
| training_intensity | |
| ], | |
| outputs=[ | |
| training_result, | |
| updated_stats, | |
| evolution_check | |
| ] | |
| ) | |
| feed_btn.click( | |
| fn=feed_monster, | |
| inputs=[ | |
| food_type | |
| ], | |
| outputs=[feeding_result] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| # Suppress MCP warnings if needed | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="gradio.mcp") | |
| demo.queue( | |
| default_concurrency_limit=10, | |
| max_size=100 | |
| ).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=False, | |
| show_error=True | |
| ) |