|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import textwrap |
|
|
from viz_generator import code_to_mermaid |
|
|
|
|
|
DEFAULT_MODEL = "Salesforce/codet5-small" |
|
|
|
|
|
class CodeExplainViz: |
|
|
def __init__(self, model_name_or_path=DEFAULT_MODEL): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) |
|
|
|
|
|
def explain(self, code: str, max_length: int = 256) -> dict: |
|
|
prompt = "explain: " + code |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
outputs = self.model.generate(**inputs, max_length=max_length, num_beams=4, early_stopping=True) |
|
|
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
lines = [l.strip() for l in text.splitlines() if l.strip()] |
|
|
short = lines[0] if lines else textwrap.shorten(text, width=120) |
|
|
detailed = "\n".join(lines[1:]) if len(lines) > 1 else text |
|
|
mermaid = code_to_mermaid(code) |
|
|
unit_tests = self._make_unit_test_template(code) |
|
|
return {"short": short, "detailed": detailed, "mermaid": mermaid, "unit_tests": unit_tests} |
|
|
|
|
|
def _make_unit_test_template(self, code: str) -> str: |
|
|
import re |
|
|
m = re.search(r"def\s+([A-Za-z0-9_]+)\s*\((.*?)\):", code) |
|
|
fn = m.group(1) if m else "function_under_test" |
|
|
params = m.group(2) if m else "" |
|
|
param_count = len([p for p in params.split(',') if p.strip()]) if params.strip() else 0 |
|
|
args = ", ".join(["0"] * param_count) |
|
|
template = f"""import pytest |
|
|
|
|
|
from your_module import {fn} |
|
|
|
|
|
def test_{fn}_basic(): |
|
|
# TODO: replace with real inputs and expected outputs |
|
|
assert {fn}({args}) == ... |
|
|
|
|
|
def test_{fn}_edge_cases(): |
|
|
# Example edge-case tests |
|
|
with pytest.raises(Exception): |
|
|
{fn}(...)""" |
|
|
return template |
|
|
|