code-explain-viz / inference.py
hmnshudhmn24's picture
Upload 12 files
b91943a verified
# inference.py
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import textwrap
from viz_generator import code_to_mermaid
DEFAULT_MODEL = "Salesforce/codet5-small"
class CodeExplainViz:
def __init__(self, model_name_or_path=DEFAULT_MODEL):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
def explain(self, code: str, max_length: int = 256) -> dict:
prompt = "explain: " + code
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
outputs = self.model.generate(**inputs, max_length=max_length, num_beams=4, early_stopping=True)
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
lines = [l.strip() for l in text.splitlines() if l.strip()]
short = lines[0] if lines else textwrap.shorten(text, width=120)
detailed = "\n".join(lines[1:]) if len(lines) > 1 else text
mermaid = code_to_mermaid(code)
unit_tests = self._make_unit_test_template(code)
return {"short": short, "detailed": detailed, "mermaid": mermaid, "unit_tests": unit_tests}
def _make_unit_test_template(self, code: str) -> str:
import re
m = re.search(r"def\s+([A-Za-z0-9_]+)\s*\((.*?)\):", code)
fn = m.group(1) if m else "function_under_test"
params = m.group(2) if m else ""
param_count = len([p for p in params.split(',') if p.strip()]) if params.strip() else 0
args = ", ".join(["0"] * param_count)
template = f"""import pytest
from your_module import {fn}
def test_{fn}_basic():
# TODO: replace with real inputs and expected outputs
assert {fn}({args}) == ...
def test_{fn}_edge_cases():
# Example edge-case tests
with pytest.raises(Exception):
{fn}(...)"""
return template