Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| #!/usr/bin/env python3 | |
| # import gradio as gr | |
| import json | |
| import logging | |
| import os | |
| import traceback | |
| from pathlib import Path | |
| from urllib.parse import urlparse | |
| from typing import Dict, Any, List, Set | |
| from git import Repo | |
| import io | |
| import torch | |
| import numpy as np | |
| import faiss | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer, util | |
| from huggingface_hub import snapshot_download | |
| import os | |
| from openai import AzureOpenAI | |
| import requests | |
| import re | |
| import matplotlib.pyplot as plt | |
| from sklearn.manifold import TSNE | |
| from sklearn.cluster import KMeans | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import random | |
| from sklearn.cluster import AgglomerativeClustering | |
| def load_env(): | |
| from dotenv import load_dotenv | |
| env_path = Path(__file__).parent.parent / '.env' | |
| load_dotenv(dotenv_path=env_path) | |
| load_env() | |
| # Centralized env parameters | |
| HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") | |
| AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") | |
| AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") | |
| MODEL_NAME = "gpt-4o-mini" | |
| DEPLOYMENT = "gpt-4o-mini" | |
| API_VERSION = "2024-12-01-preview" | |
| FILE_REGEX = re.compile(r"^diff --git a/(.+?) b/(.+)") | |
| LINE_HUNK = re.compile(r"@@ -(?P<old_start>\d+),(?P<old_len>\d+) \+(?P<new_start>\d+),(?P<new_len>\d+) @@") | |
| # Configure logging to capture all output | |
| log_stream = io.StringIO() | |
| log_handler = logging.StreamHandler(log_stream) | |
| log_handler.setLevel(logging.INFO) | |
| log_formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") | |
| log_handler.setFormatter(log_formatter) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| handlers=[log_handler, logging.StreamHandler()] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class InferenceContext: | |
| def __init__(self, repo_url: str): | |
| self.repo_url = repo_url | |
| owner, name = self._parse_owner_repo(repo_url) | |
| self.repo_id = f"{owner}/{name}" | |
| self.repo_dir = f"{owner}-{name}" | |
| self.hf_repo_id = "kotlarmilos/repository-learning" | |
| # Local paths for downloaded models | |
| self.base = Path("artifacts") / self.repo_dir | |
| self.model_dirs = { | |
| 'fine_tune': self.base / 'fine_tune', | |
| 'contrastive': self.base / 'contrastive', | |
| 'index': self.base / 'index' | |
| } | |
| self.code_dir = self.base / 'code' | |
| # Create directories | |
| for d in (*self.model_dirs.values(), self.code_dir): | |
| d.mkdir(parents=True, exist_ok=True) | |
| def _parse_owner_repo(url: str) -> tuple[str, str]: | |
| parts = urlparse(url).path.strip("/").split("/") | |
| if len(parts) < 2: | |
| raise ValueError(f"Invalid GitHub URL: {url}") | |
| return parts[-2], parts[-1] | |
| class InferencePipeline: | |
| def __init__(self, ctx: InferenceContext): | |
| self.ctx = ctx | |
| self.tokenizer = None | |
| self.llm = None | |
| self.embedder = None | |
| self.faiss_index = None | |
| self.faiss_metadata = None | |
| self.download_artifacts() | |
| self.load_models() | |
| def download_artifacts(self): | |
| """Download models and index from Hugging Face if they don't exist locally.""" | |
| self.repo_files = self._clone_or_pull() | |
| snapshot_download( | |
| repo_id=self.ctx.hf_repo_id, | |
| allow_patterns=f"{self.ctx.repo_dir}/**", | |
| local_dir=str(self.ctx.base.parent), | |
| local_dir_use_symlinks=False, | |
| token=HUGGINGFACE_HUB_TOKEN | |
| ) | |
| logger.info("All artifacts download complete.") | |
| def _clone_or_pull(self) -> bool: | |
| dest = self.ctx.code_dir | |
| git_dir = dest / ".git" | |
| if git_dir.exists(): | |
| Repo(dest).remotes.origin.pull() | |
| logger.info("Pulled latest code into %s", dest) | |
| else: | |
| Repo.clone_from(self.ctx.repo_url, dest) | |
| logger.info("Cloned repo %s into %s", self.ctx.repo_url, dest) | |
| return [str(f.relative_to(dest)) for f in dest.rglob("*") if f.is_file()] | |
| def load_models(self): | |
| """Load the fine-tuned LLM model.""" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.ctx.model_dirs['fine_tune']) | |
| self.local_llm = AutoModelForCausalLM.from_pretrained( | |
| self.ctx.model_dirs['fine_tune'], | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| self.enterprise_llm = AzureOpenAI( | |
| api_version=API_VERSION, | |
| azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
| api_key=AZURE_OPENAI_API_KEY, | |
| ) | |
| self.embedder = SentenceTransformer(str(self.ctx.model_dirs['contrastive'])) | |
| self.faiss_index = faiss.read_index(str(self.ctx.model_dirs['index'] / "index.faiss")) | |
| self.faiss_metadata = json.loads((self.ctx.model_dirs['index'] / "metadata.json").read_text()) | |
| logger.info("FAISS index loaded successfully") | |
| def _extract_pr_data(self, pr_url: str) -> dict: | |
| """ | |
| Collect PR data using GitHub API. | |
| """ | |
| match = re.search(r'/pull/(\d+)', pr_url) | |
| pr_number = int(match.group(1)) | |
| pr_url = f"https://api.github.com/repos/{self.ctx.repo_id}/pulls/{pr_number}" | |
| comments_url = f"https://api.github.com/repos/{self.ctx.repo_id}/pulls/{pr_number}/comments" | |
| headers = {} | |
| headers["Authorization"] = f"token {GITHUB_TOKEN}" | |
| headers["Accept"] = "application/vnd.github.v3+json" | |
| try: | |
| logger.info(f"Fetching PR #{pr_number} details...") | |
| pr_response = requests.get(pr_url, headers=headers) | |
| pr_response.raise_for_status() | |
| pr_data = pr_response.json() | |
| logger.info(f"Fetching PR #{pr_number} review comments...") | |
| comments_response = requests.get(comments_url, headers=headers) | |
| comments_response.raise_for_status() | |
| comments_data = comments_response.json() | |
| grouped = {} | |
| for comment in comments_data: | |
| hunk = comment.get("diff_hunk", "") | |
| grouped.setdefault(hunk, []).append(comment.get("body", "")) | |
| review_comments = [ | |
| {"diff_hunk": hunk, "comments": comments} | |
| for hunk, comments in grouped.items() | |
| ] | |
| logger.info(f"Fetching PR #{pr_number} diff...") | |
| diff_headers = headers.copy() | |
| diff_headers["Accept"] = "application/vnd.github.v3.diff" | |
| diff_response = requests.get(pr_url, headers=diff_headers) | |
| diff_response.raise_for_status() | |
| parsed_diff = self.parse_diff_with_lines(diff_response.text) | |
| result = { | |
| "title": pr_data.get("title", ""), | |
| "body": pr_data.get("body", ""), | |
| "review_comments": review_comments, | |
| "diff": diff_response.text, | |
| "changed_files": list(parsed_diff['changed_files']), | |
| "diff_hunks": parsed_diff['diff_hunks'] | |
| } | |
| logger.info(f"Successfully collected PR #{pr_number} data") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error processing PR #{pr_number} data: {e}") | |
| raise | |
| def parse_diff_with_lines(self, diff_text: str) -> Dict[str, Any]: | |
| lines = diff_text.splitlines() | |
| result = { | |
| 'changed_files': set(), | |
| 'diff_hunks': {} | |
| } | |
| current_file = None | |
| current_hunk_content = [] | |
| current_line_range = None | |
| file_header_lines = [] | |
| for line in lines: | |
| # Check if this is a new file header | |
| file_match = FILE_REGEX.match(line) | |
| if file_match: | |
| # Save previous file data if exists | |
| if current_file and current_hunk_content and current_line_range: | |
| if current_file not in result['diff_hunks']: | |
| result['diff_hunks'][current_file] = [] | |
| result['diff_hunks'][current_file].append({ | |
| 'line_range': current_line_range, | |
| 'content': '\n'.join(current_hunk_content) | |
| }) | |
| # Start new file | |
| current_file = file_match.group(2) # Use the 'b/' file path (new file) | |
| result['changed_files'].add(current_file) | |
| file_header_lines = [line] | |
| current_hunk_content = [] | |
| current_line_range = None | |
| elif current_file: # Only process if we're inside a file | |
| # Check for hunk headers to extract line ranges | |
| hunk_match = LINE_HUNK.match(line) | |
| if hunk_match: | |
| # Save previous hunk if exists | |
| if current_hunk_content and current_line_range: | |
| if current_file not in result['diff_hunks']: | |
| result['diff_hunks'][current_file] = [] | |
| result['diff_hunks'][current_file].append({ | |
| 'line_range': current_line_range, | |
| 'content': '\n'.join(current_hunk_content) | |
| }) | |
| # Start new hunk | |
| old_start = int(hunk_match.group('old_start')) | |
| old_len = int(hunk_match.group('old_len')) | |
| new_start = int(hunk_match.group('new_start')) | |
| new_len = int(hunk_match.group('new_len')) | |
| # Calculate the range of changed lines | |
| if new_len > 0: | |
| line_start = new_start | |
| line_end = new_start + new_len - 1 | |
| current_line_range = (line_start, line_end) | |
| else: | |
| current_line_range = (new_start, new_start) | |
| # Start fresh hunk content with file headers and current hunk header | |
| current_hunk_content = file_header_lines + [line] | |
| else: | |
| # Add content line to current hunk | |
| if current_hunk_content is not None: | |
| current_hunk_content.append(line) | |
| # Save the last hunk data | |
| if current_file and current_hunk_content and current_line_range: | |
| if current_file not in result['diff_hunks']: | |
| result['diff_hunks'][current_file] = [] | |
| result['diff_hunks'][current_file].append({ | |
| 'line_range': current_line_range, | |
| 'content': '\n'.join(current_hunk_content) | |
| }) | |
| return result | |
| def analyze_file_similarity(self, changed_files: List[str]) -> Dict[str, Any]: | |
| result = { | |
| 'similar_file_groups': [], | |
| 'anomalous_files': [], | |
| 'analysis_summary': { | |
| 'total_files': len(changed_files), | |
| 'num_groups': 0, | |
| 'num_anomalies': 0, | |
| 'avg_group_size': 0 | |
| } | |
| } | |
| # Handle edge cases | |
| if len(changed_files) == 0: | |
| logger.info("No changed files to analyze") | |
| return result | |
| if len(changed_files) == 1: | |
| logger.info(f"Only one file changed: {changed_files[0]} - no similarity analysis needed") | |
| result['analysis_summary']['num_anomalies'] = 1 | |
| result['anomalous_files'].append({ | |
| 'file': changed_files[0], | |
| 'reason': 'single_file', | |
| 'max_similarity_to_others': 0.0, | |
| 'most_similar_file': None, | |
| 'is_anomaly': False | |
| }) | |
| return result | |
| # Encode all changed files | |
| file_embeddings = self.embedder.encode(changed_files, convert_to_tensor=True) | |
| similarity_matrix = util.pytorch_cos_sim(file_embeddings, file_embeddings) | |
| # Convert similarity matrix to distance matrix for clustering | |
| distance_matrix = 1 - similarity_matrix.cpu().numpy() | |
| # Perform hierarchical clustering | |
| clustering = AgglomerativeClustering( | |
| n_clusters=None, | |
| distance_threshold=0.3, # 1 - 0.7 = 0.3 (similarity threshold of 0.7) | |
| metric='precomputed', | |
| linkage='average' | |
| ) | |
| cluster_labels = clustering.fit_predict(distance_matrix) | |
| # Group files by cluster | |
| clusters = {} | |
| for i, label in enumerate(cluster_labels): | |
| if label not in clusters: | |
| clusters[label] = [] | |
| clusters[label].append((changed_files[i], i)) # Store file and its index | |
| # Process clusters to identify groups and anomalies | |
| for cluster_id, files_with_indices in clusters.items(): | |
| files_in_cluster = [f[0] for f in files_with_indices] | |
| if len(files_in_cluster) > 1: | |
| # This is a group of similar files | |
| group_similarities = [] | |
| pairwise_similarities = [] | |
| for i in range(len(files_with_indices)): | |
| for j in range(i+1, len(files_with_indices)): | |
| file_i_idx = files_with_indices[i][1] | |
| file_j_idx = files_with_indices[j][1] | |
| similarity = float(similarity_matrix[file_i_idx][file_j_idx]) | |
| group_similarities.append(similarity) | |
| pairwise_similarities.append({ | |
| 'file1': files_with_indices[i][0], | |
| 'file2': files_with_indices[j][0], | |
| 'similarity': similarity | |
| }) | |
| avg_similarity = sum(group_similarities) / len(group_similarities) if group_similarities else 0 | |
| min_similarity = min(group_similarities) if group_similarities else 0 | |
| max_similarity = max(group_similarities) if group_similarities else 0 | |
| result['similar_file_groups'].append({ | |
| 'cluster_id': cluster_id, | |
| 'files': files_in_cluster, | |
| 'avg_similarity': avg_similarity, | |
| 'min_similarity': min_similarity, | |
| 'max_similarity': max_similarity, | |
| 'pairwise_similarities': pairwise_similarities, | |
| 'coherence': 'high' if min_similarity > 0.6 else 'medium' if min_similarity > 0.4 else 'low' | |
| }) | |
| else: | |
| # This is a singleton cluster - potentially anomalous | |
| file = files_in_cluster[0] | |
| file_idx = files_with_indices[0][1] | |
| # Calculate maximum similarity to any other file | |
| max_similarity = 0 | |
| most_similar_file = None | |
| similarities_to_others = [] | |
| for other_idx, other_file in enumerate(changed_files): | |
| if other_idx != file_idx: | |
| similarity = float(similarity_matrix[file_idx][other_idx]) | |
| similarities_to_others.append({ | |
| 'file': other_file, | |
| 'similarity': similarity | |
| }) | |
| if similarity > max_similarity: | |
| max_similarity = similarity | |
| most_similar_file = other_file | |
| result['anomalous_files'].append({ | |
| 'file': file, | |
| 'cluster_id': cluster_id, | |
| 'max_similarity_to_others': max_similarity, | |
| 'most_similar_file': most_similar_file, | |
| 'similarities_to_others': similarities_to_others, | |
| 'is_anomaly': max_similarity < 0.5, # Strong anomaly threshold | |
| 'anomaly_strength': 'strong' if max_similarity < 0.3 else 'medium' if max_similarity < 0.5 else 'weak', | |
| 'reason': 'isolated_cluster' | |
| }) | |
| # Additional anomaly detection: files that are far from the group average | |
| if len(changed_files) >= 3: | |
| # Calculate average embedding of all changed files | |
| avg_embedding = torch.mean(file_embeddings, dim=0) | |
| # Find files that are far from the average | |
| for i, file in enumerate(changed_files): | |
| file_embedding = file_embeddings[i] | |
| similarity_to_avg = float(util.pytorch_cos_sim(file_embedding.unsqueeze(0), avg_embedding.unsqueeze(0))[0][0]) | |
| # Check if this file is already in anomalous_files | |
| existing_anomaly = next((a for a in result['anomalous_files'] if a['file'] == file), None) | |
| if existing_anomaly: | |
| # Update existing anomaly record | |
| existing_anomaly['similarity_to_group_avg'] = similarity_to_avg | |
| existing_anomaly['is_strong_anomaly'] = ( | |
| similarity_to_avg < 0.4 and existing_anomaly['max_similarity_to_others'] < 0.5 | |
| ) | |
| if existing_anomaly['is_strong_anomaly']: | |
| existing_anomaly['anomaly_strength'] = 'very_strong' | |
| elif similarity_to_avg < 0.4: # Low similarity to group average | |
| # Calculate similarities to all other files | |
| similarities_to_others = [] | |
| max_sim = 0 | |
| most_sim_file = None | |
| for j, other_file in enumerate(changed_files): | |
| if i != j: | |
| sim = float(similarity_matrix[i][j]) | |
| similarities_to_others.append({ | |
| 'file': other_file, | |
| 'similarity': sim | |
| }) | |
| if sim > max_sim: | |
| max_sim = sim | |
| most_sim_file = other_file | |
| result['anomalous_files'].append({ | |
| 'file': file, | |
| 'cluster_id': None, | |
| 'max_similarity_to_others': max_sim, | |
| 'most_similar_file': most_sim_file, | |
| 'similarities_to_others': similarities_to_others, | |
| 'similarity_to_group_avg': similarity_to_avg, | |
| 'is_anomaly': True, | |
| 'is_strong_anomaly': max_sim < 0.5, | |
| 'anomaly_strength': 'very_strong' if max_sim < 0.3 else 'strong' if max_sim < 0.5 else 'medium', | |
| 'reason': 'distant_from_group_average' | |
| }) | |
| # Update analysis summary | |
| result['analysis_summary']['num_groups'] = len(result['similar_file_groups']) | |
| result['analysis_summary']['num_anomalies'] = len(result['anomalous_files']) | |
| if result['similar_file_groups']: | |
| total_files_in_groups = sum(len(group['files']) for group in result['similar_file_groups']) | |
| result['analysis_summary']['avg_group_size'] = total_files_in_groups / len(result['similar_file_groups']) | |
| # Log results | |
| logger.info(f"File similarity analysis complete:") | |
| logger.info(f" Total files: {result['analysis_summary']['total_files']}") | |
| logger.info(f" Similar groups: {result['analysis_summary']['num_groups']}") | |
| logger.info(f" Anomalous files: {result['analysis_summary']['num_anomalies']}") | |
| for i, group in enumerate(result['similar_file_groups']): | |
| logger.info(f" Group {i+1} ({group['coherence']} coherence): {group['files']} (avg: {group['avg_similarity']:.3f})") | |
| for anomaly in result['anomalous_files']: | |
| logger.info(f" {anomaly['anomaly_strength'].upper()} ANOMALY: {anomaly['file']} (reason: {anomaly['reason']}, max_sim: {anomaly['max_similarity_to_others']:.3f})") | |
| return result | |
| # TODO: Add local LLM reasoning | |
| # def generate_llm_response(self, prompt: str, max_new_tokens: int = 256) -> str: | |
| # """Generate response using the fine-tuned LLM.""" | |
| # if not self.tokenizer or not self.local_llm: | |
| # raise ValueError("LLM not loaded. Call load_llm() first.") | |
| # inputs = self.tokenizer(prompt, return_tensors="pt").to(self.local_llm.device) | |
| # outputs = self.local_llm.generate( | |
| # **inputs, | |
| # max_new_tokens=max_new_tokens, | |
| # pad_token_id=self.tokenizer.eos_token_id | |
| # ) | |
| # return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def search_code_snippets(self, diff_hunks) -> list: | |
| metadata_file = self.ctx.model_dirs["index"] / "metadata.json" | |
| with open(metadata_file, 'r', encoding='utf-8') as f: | |
| metadata = json.load(f) | |
| result = [] | |
| # Process each file's diff hunks | |
| for file_path, hunks in diff_hunks.items(): | |
| logger.info(f"Searching functions for file: {file_path}") | |
| for hunk in hunks: | |
| line_range = hunk.get('line_range') | |
| if not line_range: | |
| continue | |
| start_line, end_line = line_range | |
| logger.debug(f"Processing hunk at lines {start_line}-{end_line}") | |
| # Find functions that overlap with this line range | |
| overlapping_functions = [] | |
| for func_metadata in metadata: | |
| func_file = func_metadata.get('file', '') | |
| func_start = func_metadata.get('start_line') | |
| func_end = func_metadata.get('end_line') | |
| func_name = func_metadata.get('name', 'unknown') | |
| func_description = func_metadata.get('llm_description', '') | |
| # Check if this function is in the same file | |
| if func_file != file_path: | |
| continue | |
| # Check if function line range overlaps with diff hunk line range | |
| if func_start is not None and func_end is not None: | |
| # Check for overlap: function overlaps if it starts before diff ends | |
| # and ends after diff starts | |
| if func_start <= end_line and func_end >= start_line: | |
| overlap_start = max(func_start, start_line) | |
| overlap_end = min(func_end, end_line) | |
| overlapping_functions.append({ | |
| 'function_name': func_name, | |
| 'function_description': func_description, | |
| 'function_start_line': func_start, | |
| 'function_end_line': func_end, | |
| # 'overlap_start': overlap_start, | |
| # 'overlap_end': overlap_end, | |
| # 'overlap_lines': overlap_end - overlap_start + 1 | |
| }) | |
| # if len(overlapping_functions) > 0: | |
| hunk_result = { | |
| 'file_name': file_path, | |
| 'diff_hunk': hunk.get('content', ''), | |
| 'overlapping_functions': overlapping_functions | |
| } | |
| result.append(hunk_result) | |
| total_hunks = sum(len(hunks) for hunks in diff_hunks.values()) | |
| total_functions = sum(len(entry['overlapping_functions']) for entry in result) | |
| logger.info(f"Processed {total_hunks} diff hunks across {len(diff_hunks)} files, found {total_functions} overlapping functions") | |
| return result | |
| def _select_files_around_changed(self, changed_files: List[str] = None, max_files: int = 500) -> List[str]: | |
| """Select files to visualize, prioritizing changed files and semantically similar ones.""" | |
| logger.info(f"Selecting {max_files} files around {len(changed_files)} changed files...") | |
| # Start with changed files | |
| selected_files = set(changed_files) | |
| # Find files similar to changed files using embeddings | |
| try: | |
| # Encode changed files | |
| changed_embeddings = self.embedder.encode(changed_files, convert_to_tensor=False) | |
| # Calculate target number of similar files to find | |
| target_similar = min(max_files - len(changed_files), 200) # Leave room for random files | |
| # Get a sample of repo files to compare against (for performance) | |
| sample_size = min(2000, len(self.repo_files)) | |
| repo_sample = self.repo_files[:sample_size] | |
| # Remove already selected files from sample | |
| repo_sample = [f for f in repo_sample if f not in selected_files] | |
| if len(repo_sample) > 0: | |
| # Encode sample files | |
| sample_embeddings = self.embedder.encode(repo_sample, convert_to_tensor=False, show_progress_bar=False) | |
| # Calculate similarities | |
| similarities = [] | |
| for i, repo_file in enumerate(repo_sample): | |
| # Calculate max similarity to any changed file | |
| max_sim = 0 | |
| for changed_emb in changed_embeddings: | |
| sim = np.dot(changed_emb, sample_embeddings[i]) / ( | |
| np.linalg.norm(changed_emb) * np.linalg.norm(sample_embeddings[i]) | |
| ) | |
| max_sim = max(max_sim, sim) | |
| # Only add if not already selected (avoid duplicates) | |
| similarities.append((repo_file, max_sim)) | |
| # Sort by similarity and take top ones, avoiding duplicates | |
| added = 0 | |
| for file_path, sim in sorted(similarities, key=lambda x: x[1], reverse=True): | |
| if file_path not in selected_files: | |
| selected_files.add(file_path) | |
| added += 1 | |
| if len(selected_files) >= max_files or added >= target_similar: | |
| break | |
| logger.info(f"Added {len(similarities[:target_similar])} similar files to visualization") | |
| except Exception as e: | |
| logger.warning(f"Could not compute file similarities: {e}") | |
| # Fill remaining slots with random files | |
| remaining_slots = max_files - len(selected_files) | |
| if remaining_slots > 0: | |
| remaining_files = [f for f in self.repo_files if f not in selected_files] | |
| random.shuffle(remaining_files) | |
| for file_path in remaining_files[:remaining_slots]: | |
| selected_files.add(file_path) | |
| result = list(selected_files) | |
| logger.info(f"Selected {len(result)} files total: {len(changed_files)} changed, {len(result) - len(changed_files)} related/random") | |
| return result | |
| def create_repo_visualization(self, changed_files: List[str] = None, max_files: int = 500): | |
| files_to_plot = self._select_files_around_changed(changed_files, max_files * len(changed_files)) | |
| logger.info(f"Creating visualization for {len(files_to_plot)} files...") | |
| if len(files_to_plot) < 2: | |
| return self._create_dummy_plot(f"Only {len(files_to_plot)} files available") | |
| embeddings = self.embedder.encode(files_to_plot, convert_to_tensor=False, show_progress_bar=False) | |
| logger.info(f"Embeddings computed successfully: shape {getattr(embeddings, 'shape', None)}") | |
| n = len(files_to_plot) | |
| perplexity = min(30, max(1, n - 1)) | |
| tsne = TSNE(n_components=3, perplexity=perplexity, init='random', random_state=42) | |
| reduced = tsne.fit_transform(embeddings) | |
| fig = go.Figure() | |
| colors = [] | |
| sizes = [] | |
| hover_texts = [] | |
| for i, file_path in enumerate(files_to_plot): | |
| if changed_files and file_path in changed_files: | |
| colors.append('red') | |
| else: | |
| # Color by file type | |
| ext = os.path.splitext(file_path)[1].lower() | |
| if ext in ['.py', '.js', '.ts', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.rs']: | |
| colors.append('blue') | |
| elif ext in ['.md', '.txt', '.rst', '.doc']: | |
| colors.append('green') | |
| elif ext in ['.json', '.yaml', '.yml', '.xml', '.toml', '.ini']: | |
| colors.append('orange') | |
| elif ext in ['.html', '.css', '.scss', '.sass']: | |
| colors.append('purple') | |
| else: | |
| colors.append('gray') | |
| sizes.append(8) | |
| hover_texts.append(f"{os.path.basename(file_path)}") | |
| fig.add_trace(go.Scatter3d( | |
| x=reduced[:, 0].tolist(), | |
| y=reduced[:, 1].tolist(), | |
| z=reduced[:, 2].tolist(), | |
| mode='markers+text', | |
| marker=dict(size=sizes, color=colors), | |
| text=[os.path.basename(f) for f in files_to_plot], | |
| hovertext=hover_texts, | |
| textposition='middle center', | |
| name='Repository Files' | |
| )) | |
| title_text = 'Repository File Embeddings (3D t-SNE)' | |
| if changed_files: | |
| title_text += f' - {len(changed_files)} Changed Files Highlighted in Red' | |
| fig.update_layout( | |
| title=title_text, | |
| scene=dict( | |
| xaxis_title='t-SNE 1', | |
| yaxis_title='t-SNE 2', | |
| zaxis_title='t-SNE 3', | |
| camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) | |
| ), | |
| width=800, | |
| height=600, | |
| margin=dict(r=20, b=10, l=10, t=60) | |
| ) | |
| return fig | |
| def build_structured_prompt(self, data: dict, sim_analysis: dict, code_desc: list) -> str: | |
| # Group clusters | |
| clusters = sim_analysis['similar_file_groups'] | |
| anomalies = sim_analysis['anomalous_files'] | |
| # Header | |
| prompt = [] | |
| prompt.append("You are an expert reviewer. First give group summaries, then detailed line-by-line feedback.") | |
| prompt.append(f"Title: {data['title']}") | |
| prompt.append(f"Description: {data['body']}") | |
| # Clusters | |
| for c in clusters: | |
| prompt.append(f"## Group {c['cluster_id']} ({len(c['files'])} files, avg_sim={c['avg_similarity']:.2f}): {', '.join(c['files'])}") | |
| prompt.append("Files:") | |
| for f in c['files']: | |
| prompt.append(f"- {f}") | |
| prompt.append(f"Summary: Changes in these files share semantic pattern. Focus on shared logic.") | |
| # Anomalies | |
| if anomalies: | |
| prompt.append("## Isolated Files (low similarity with changed files)") | |
| for a in anomalies: | |
| prompt.append(f"- {a['file']} (reason: {a['reason']}, strength: {a.get('anomaly_strength')})") | |
| # Grounding diffs per cluster/files | |
| prompt.append("## Diff Hunks and Context:") | |
| for entry in code_desc: | |
| prompt.append(f"File: {entry['file_name']}\n{entry['diff_hunk']}") | |
| if entry['overlapping_functions']: | |
| prompt.append("Affected functions:") | |
| for f in entry['overlapping_functions']: | |
| prompt.append(f"- {f['function_name']}: {f['function_description']}") | |
| # Request | |
| prompt.append("Provide feedback on groups, then isolated files. After that provide line-by-line feedback in diff format.") | |
| return "\n".join(prompt) | |
| def get_current_logs(): | |
| return log_stream.getvalue() | |
| # Pipeline | |
| pipeline = InferencePipeline(InferenceContext("https://github.com/dotnet/xharness")) | |
| def analyze_pr_streaming(pr_url): | |
| log_stream.seek(0) | |
| log_stream.truncate() | |
| response = {} | |
| base_review = "" | |
| final_review = "" | |
| visualization = None | |
| data = pipeline._extract_pr_data(pr_url) | |
| yield base_review, final_review, get_current_logs(), visualization | |
| visualization = pipeline.create_repo_visualization(list(data["changed_files"]), max_files=20) | |
| yield "", "", get_current_logs(), visualization | |
| similarity_analysis = pipeline.analyze_file_similarity(list(data["changed_files"])) | |
| similar_file_groups = similarity_analysis['similar_file_groups'] | |
| anomalous_files = similarity_analysis['anomalous_files'] | |
| yield "", "", get_current_logs(), visualization | |
| code_description = pipeline.search_code_snippets(data["diff_hunks"]) | |
| comprehensive_prompt = pipeline.build_structured_prompt(data, similarity_analysis, code_description) | |
| # Base prompt | |
| base_prompt = f"""You are an expert reviewer. Provide detailed line-by-line feedback. | |
| Title: {data['title']} | |
| Description: {data['body']} | |
| Diff: {data['diff']} | |
| """ | |
| # similar_file_groups_formatted = [] | |
| # for i, group in enumerate(similar_file_groups): | |
| # files_str = ", ".join(group['files']) | |
| # similar_file_groups_formatted.append(f"group {i}: {files_str}") | |
| # anomalous_files_formatted = [] | |
| # for anomaly in anomalous_files: | |
| # anomalous_files_formatted.append(f"anomaly: {anomaly['file']} (reason: {anomaly['reason']}, strength: {anomaly['anomaly_strength']})") | |
| # grounding_formatted = "" | |
| # for entry in code_description: | |
| # file_name = entry['file_name'] | |
| # overlapping_functions = entry['overlapping_functions'] | |
| # diff_hunk = entry['diff_hunk'] | |
| # if len(overlapping_functions) > 0: | |
| # grounding_formatted += f"In file {file_name}, the following changes were made: {diff_hunk}\n" | |
| # grounding_formatted += f"These changes affected the following functions:\n" | |
| # for func in overlapping_functions: | |
| # grounding_formatted += f"{func['function_name']} - {func['function_description']}\n" | |
| # else: | |
| # grounding_formatted += f"In file {file_name}, the following changes were made: {diff_hunk}\n" | |
| # grounding_formatted += "\n" | |
| # # Create formatted strings for f-string | |
| # similar_groups_text = "\n".join(similar_file_groups_formatted) | |
| # anomalous_files_text = "\n".join(anomalous_files_formatted) | |
| # # TODO: Add local LLM reasoning | |
| # # TODO: Add relevant files from the directory not included | |
| # comprehensive_prompt = f"""{base_prompt} | |
| # FILES THAT ARE SEMANTICALLY CLOSE CHANGED IN THIS PR: | |
| # {similar_groups_text} | |
| # UNEXPECTED CHANGES IN FILES: | |
| # {anomalous_files_text} | |
| # GROUNDING DATA: The following provides specific information about which functions are affected by each diff hunk: | |
| # {grounding_formatted} | |
| # """ | |
| base_prompt += f""" | |
| DIFF: {data['diff']} | |
| """ | |
| logger.info(f"Base prompt word count: {len(base_prompt.split())}") | |
| logger.info(f"Base prompt: {base_prompt}") | |
| logger.info(f"Comprehensive prompt word count: {len(comprehensive_prompt.split())}") | |
| logger.info(f"Comprehensive prompt: {comprehensive_prompt}") | |
| logger.info("Calling Azure OpenAI...") | |
| yield "", "", get_current_logs(), visualization | |
| base_review_response = pipeline.enterprise_llm.chat.completions.create( | |
| model=DEPLOYMENT, | |
| messages=[ | |
| {"role": "system", "content": "You are an expert code reviewer. Provide thorough, constructive feedback."}, | |
| {"role": "user", "content": base_prompt} | |
| ], | |
| max_tokens=8192, | |
| temperature=0.3 | |
| ) | |
| base_review = base_review_response.choices[0].message.content | |
| logger.info("Base review completed") | |
| final_review_response = pipeline.enterprise_llm.chat.completions.create( | |
| model=DEPLOYMENT, | |
| messages=[ | |
| {"role": "system", "content": "You are an expert code reviewer. Provide thorough, constructive feedback."}, | |
| {"role": "user", "content": comprehensive_prompt} | |
| ], | |
| max_tokens=8192, | |
| temperature=0.3 | |
| ) | |
| final_review = final_review_response.choices[0].message.content | |
| logger.info("Final review completed") | |
| yield base_review, final_review, get_current_logs(), visualization | |
| with gr.Blocks(title="PR Code Review Assistant") as demo: | |
| gr.Markdown("# PR Code Review Assistant") | |
| gr.Markdown("Enter a GitHub PR URL to get comprehensive code review analysis with interactive repository visualization.") | |
| with gr.Row(): | |
| pr_url_input = gr.Textbox( | |
| label="GitHub PR URL", | |
| placeholder="https://github.com/owner/repo/pull/123", | |
| value="https://github.com/dotnet/xharness/pull/1416" | |
| ) | |
| analyze_btn = gr.Button("Analyze PR", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| base_review_output = gr.Textbox( | |
| label="Base Review", | |
| lines=15, | |
| max_lines=30, | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| final_review_output = gr.Textbox( | |
| label="Comprehensive Review", | |
| lines=15, | |
| max_lines=30, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| visualization_output = gr.Plot( | |
| label="Repository Files Visualization (3D)", | |
| value=None | |
| ) | |
| with gr.Column(scale=1): | |
| logs_output = gr.Textbox( | |
| label="Analysis Logs", | |
| lines=15, | |
| max_lines=25, | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| analyze_btn.click( | |
| fn=analyze_pr_streaming, | |
| inputs=[pr_url_input], | |
| outputs=[base_review_output, final_review_output, logs_output, visualization_output], | |
| show_progress=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |