File size: 1,897 Bytes
b91943a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# inference.py
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import textwrap
from viz_generator import code_to_mermaid

DEFAULT_MODEL = "Salesforce/codet5-small"

class CodeExplainViz:
    def __init__(self, model_name_or_path=DEFAULT_MODEL):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

    def explain(self, code: str, max_length: int = 256) -> dict:
        prompt = "explain: " + code
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        outputs = self.model.generate(**inputs, max_length=max_length, num_beams=4, early_stopping=True)
        text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        lines = [l.strip() for l in text.splitlines() if l.strip()]
        short = lines[0] if lines else textwrap.shorten(text, width=120)
        detailed = "\n".join(lines[1:]) if len(lines) > 1 else text
        mermaid = code_to_mermaid(code)
        unit_tests = self._make_unit_test_template(code)
        return {"short": short, "detailed": detailed, "mermaid": mermaid, "unit_tests": unit_tests}

    def _make_unit_test_template(self, code: str) -> str:
        import re
        m = re.search(r"def\s+([A-Za-z0-9_]+)\s*\((.*?)\):", code)
        fn = m.group(1) if m else "function_under_test"
        params = m.group(2) if m else ""
        param_count = len([p for p in params.split(',') if p.strip()]) if params.strip() else 0
        args = ", ".join(["0"] * param_count)
        template = f"""import pytest

from your_module import {fn}

def test_{fn}_basic():
    # TODO: replace with real inputs and expected outputs
    assert {fn}({args}) == ...

def test_{fn}_edge_cases():
    # Example edge-case tests
    with pytest.raises(Exception):
        {fn}(...)"""
        return template