Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import logging | |
| import json | |
| import re | |
| from typing import List, Dict | |
| from datetime import datetime | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Model configuration | |
| MODEL_NAME = "wizcodes12/snaxfix-model" | |
| FALLBACK_MODEL = "google/flan-t5-small" # Fallback model if main model fails | |
| SUPPORTED_LANGUAGES = [ | |
| "python", "javascript", "java", "c", "cpp", "csharp", "rust", | |
| "php", "html", "css", "sql" | |
| ] | |
| MAX_LENGTH = 512 | |
| # Example code snippets with errors for testing | |
| EXAMPLE_SNIPPETS = { | |
| "python": { | |
| "broken": 'def add(a b):\n return a + b', | |
| "description": "Missing comma in function parameters" | |
| }, | |
| "javascript": { | |
| "broken": 'function greet() {\n console.log("Hello"\n}', | |
| "description": "Missing closing parenthesis and brace" | |
| }, | |
| "java": { | |
| "broken": 'public class Hello {\n public static void main(String[] args) {\n System.out.println("Hello World")\n }\n}', | |
| "description": "Missing semicolon" | |
| }, | |
| "c": { | |
| "broken": '#include <stdio.h>\n\nint main() {\n printf("Hello World")\n return 0;\n}', | |
| "description": "Missing semicolon" | |
| }, | |
| "cpp": { | |
| "broken": '#include <iostream>\n\nint main() {\n std::cout << "Hello World" << std::endl\n return 0;\n}', | |
| "description": "Missing semicolon" | |
| }, | |
| "csharp": { | |
| "broken": 'class Program {\n static void Main(string[] args) {\n Console.WriteLine("Hello World")\n }\n}', | |
| "description": "Missing semicolon" | |
| }, | |
| "rust": { | |
| "broken": 'fn main() {\n println!("Hello World")\n}', | |
| "description": "Missing semicolon" | |
| }, | |
| "php": { | |
| "broken": '<?php\n echo "Hello World"\n?>', | |
| "description": "Missing semicolon" | |
| }, | |
| "html": { | |
| "broken": '<div>\n <p>Hello World</div>\n</p>', | |
| "description": "Incorrect tag nesting" | |
| }, | |
| "css": { | |
| "broken": 'body {\n background-color: #ffffff\n}', | |
| "description": "Missing semicolon" | |
| }, | |
| "sql": { | |
| "broken": 'SELECT name age FROM users WHERE age > 18', | |
| "description": "Missing comma in SELECT clause" | |
| } | |
| } | |
| class SyntaxFixerApp: | |
| def __init__(self): | |
| logger.info("Initializing SyntaxFixerApp...") | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Load model and tokenizer with fallback | |
| self.model_name_used = None | |
| try: | |
| logger.info(f"Attempting to load primary model: {MODEL_NAME}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| self.model_name_used = MODEL_NAME | |
| logger.info("Primary model and tokenizer loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Failed to load primary model: {e}") | |
| logger.info(f"Attempting to load fallback model: {FALLBACK_MODEL}") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(FALLBACK_MODEL) | |
| self.model_name_used = FALLBACK_MODEL | |
| logger.info("Fallback model and tokenizer loaded successfully") | |
| except Exception as fallback_error: | |
| logger.error(f"Failed to load fallback model: {fallback_error}") | |
| raise Exception(f"Failed to load both primary and fallback models. Primary: {e}, Fallback: {fallback_error}") | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info(f"Using model: {self.model_name_used}") | |
| # Initialize history | |
| self.history = [] | |
| def fix_syntax(self, broken_code: str, language: str) -> str: | |
| """Fix syntax errors in the provided code.""" | |
| if not broken_code.strip(): | |
| return "Error: Please enter code to fix." | |
| if language not in SUPPORTED_LANGUAGES: | |
| return f"Error: Language '{language}' is not supported. Choose from: {', '.join(SUPPORTED_LANGUAGES)}" | |
| try: | |
| # Prepare input - adjust prompt based on model being used | |
| if self.model_name_used == FALLBACK_MODEL: | |
| # Simplified prompt for fallback model | |
| input_text = f"Fix the syntax errors in this {language} code: {broken_code}" | |
| else: | |
| # Original prompt for specialized model | |
| input_text = f"<{language.upper()}> Fix the syntax errors in this {language} code: {broken_code}" | |
| inputs = self.tokenizer( | |
| input_text, | |
| max_length=MAX_LENGTH, | |
| truncation=True, | |
| padding=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Generate fixed code | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=MAX_LENGTH, | |
| num_beams=4, | |
| early_stopping=True, | |
| do_sample=False, | |
| temperature=0.7, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| use_cache=True | |
| ) | |
| fixed_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Save to history | |
| self.history.append({ | |
| "timestamp": datetime.now().isoformat(), | |
| "language": language, | |
| "broken_code": broken_code, | |
| "fixed_code": fixed_code | |
| }) | |
| return fixed_code | |
| except Exception as e: | |
| logger.error(f"Error fixing code: {e}") | |
| return f"Error: Failed to fix code - {str(e)}" | |
| def load_example(self, language: str) -> str: | |
| """Load example broken code for the selected language.""" | |
| return EXAMPLE_SNIPPETS.get(language, {}).get("broken", "No example available for this language.") | |
| def get_history(self) -> str: | |
| """Return formatted history of fixes.""" | |
| if not self.history: | |
| return "No history available." | |
| history_text = "=== Fix History ===\n" | |
| for entry in self.history[-5:]: # Show only last 5 entries to avoid too much text | |
| history_text += f"Timestamp: {entry['timestamp']}\n" | |
| history_text += f"Language: {entry['language']}\n" | |
| history_text += f"Broken Code:\n{entry['broken_code']}\n" | |
| history_text += f"Fixed Code:\n{entry['fixed_code']}\n" | |
| history_text += "-" * 50 + "\n" | |
| return history_text | |
| def clear_history(self) -> str: | |
| """Clear the history of fixes.""" | |
| self.history = [] | |
| return "History cleared." | |
| def create_gradio_interface(): | |
| """Create and return the Gradio interface.""" | |
| app = SyntaxFixerApp() | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# SnaxFix: Advanced Syntax Error Fixer") | |
| gr.Markdown("Fix syntax errors in code across multiple programming languages using AI models.") | |
| # Show which model is being used | |
| gr.Markdown(f"**Currently using model:** `{app.model_name_used}`") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| language_dropdown = gr.Dropdown( | |
| choices=SUPPORTED_LANGUAGES, | |
| label="Select Programming Language", | |
| value="python" | |
| ) | |
| code_input = gr.Code( | |
| label="Enter Code with Syntax Errors", | |
| lines=10, | |
| language="python" | |
| ) | |
| with gr.Row(): | |
| fix_button = gr.Button("Fix Syntax", variant="primary") | |
| example_button = gr.Button("Load Example", variant="secondary") | |
| clear_button = gr.Button("Clear Input", variant="secondary") | |
| with gr.Column(scale=2): | |
| code_output = gr.Code( | |
| label="Fixed Code", | |
| lines=10, | |
| language="python" | |
| ) | |
| with gr.Accordion("History of Fixes", open=False): | |
| history_output = gr.Textbox(label="Fix History", lines=10) | |
| with gr.Row(): | |
| refresh_history_button = gr.Button("Refresh History") | |
| clear_history_button = gr.Button("Clear History") | |
| with gr.Accordion("About & License", open=False): | |
| gr.Markdown(""" | |
| **About SnaxFix** | |
| SnaxFix is an AI-powered tool for fixing syntax errors in multiple programming languages, built with `google/flan-t5-base` and fine-tuned by wizcodes12. | |
| **MIT License** | |
| This project is licensed under the MIT License. See the [LICENSE](https://github.com/wizcodes12/snaxfix-model/blob/main/LICENSE) file for details. | |
| """) | |
| # Event handlers | |
| def update_code_language(language): | |
| return gr.update(language=language) | |
| def fix_and_update_history(code, language): | |
| """Fix code and return both fixed code and updated history.""" | |
| fixed = app.fix_syntax(code, language) | |
| history = app.get_history() | |
| return fixed, history | |
| # Main fix button - fixes code and updates history | |
| fix_button.click( | |
| fn=fix_and_update_history, | |
| inputs=[code_input, language_dropdown], | |
| outputs=[code_output, history_output] | |
| ) | |
| # Load example button | |
| example_button.click( | |
| fn=app.load_example, | |
| inputs=language_dropdown, | |
| outputs=code_input | |
| ) | |
| # Clear input button | |
| clear_button.click( | |
| fn=lambda: "", | |
| inputs=None, | |
| outputs=code_input | |
| ) | |
| # Language dropdown change - updates code editor language | |
| language_dropdown.change( | |
| fn=update_code_language, | |
| inputs=language_dropdown, | |
| outputs=code_input | |
| ) | |
| # Refresh history button | |
| refresh_history_button.click( | |
| fn=app.get_history, | |
| inputs=None, | |
| outputs=history_output | |
| ) | |
| # Clear history button | |
| clear_history_button.click( | |
| fn=app.clear_history, | |
| inputs=None, | |
| outputs=history_output | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| logger.info("Starting Gradio application...") | |
| demo = create_gradio_interface() | |
| try: | |
| demo.launch() | |
| except Exception as e: | |
| logger.error(f"Failed to launch Gradio app: {e}") | |
| raise |