Spaces:
Sleeping
Sleeping
Upload 22 files
Browse files- LICENSE +1 -0
- README.md +2 -12
- app.py +104 -0
- car_advisor/__init__.py +1 -0
- car_advisor/config.py +34 -0
- car_advisor/cost_estimator.py +45 -0
- car_advisor/fusion.py +12 -0
- car_advisor/nlp_model.py +55 -0
- car_advisor/reporter.py +69 -0
- car_advisor/scheduler.py +15 -0
- car_advisor/suggestions.py +33 -0
- car_advisor/utils.py +11 -0
- car_advisor/vision_model.py +74 -0
- configs/issues.yaml +11 -0
- configs/parts_costs.yaml +46 -0
- data/sample_data/annotations.csv +2 -0
- data/sample_data/images/example.jpg +0 -0
- requirements.txt +15 -0
- training/dataset.py +37 -0
- training/train_fusion.py +42 -0
- training/train_nlp.py +38 -0
- training/train_vision.py +64 -0
LICENSE
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
MIT License - 2025
|
README.md
CHANGED
|
@@ -1,13 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
title: Serviceadvisor
|
| 3 |
-
emoji: 📉
|
| 4 |
-
colorFrom: gray
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.44.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
---
|
| 12 |
|
| 13 |
-
|
|
|
|
| 1 |
+
# Workshop Car Service Advisor (Hugging Face)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
See instructions inside.
|
app.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, io, base64
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from car_advisor.vision_model import VisionInference
|
| 7 |
+
from car_advisor.nlp_model import NLPInference
|
| 8 |
+
from car_advisor.fusion import fuse
|
| 9 |
+
from car_advisor.cost_estimator import estimate_costs
|
| 10 |
+
from car_advisor.suggestions import predictive_maintenance, advanced_suggestions
|
| 11 |
+
from car_advisor.reporter import export_pdf, export_json
|
| 12 |
+
from car_advisor.scheduler import create_service_ics
|
| 13 |
+
|
| 14 |
+
vision = VisionInference()
|
| 15 |
+
nlp = NLPInference()
|
| 16 |
+
|
| 17 |
+
def _to_image(obj):
|
| 18 |
+
if isinstance(obj, dict) and "image" in obj:
|
| 19 |
+
return Image.open(io.BytesIO(base64.b64decode(obj["image"].split(",")[-1])))
|
| 20 |
+
if isinstance(obj, str):
|
| 21 |
+
return Image.open(obj)
|
| 22 |
+
return obj
|
| 23 |
+
|
| 24 |
+
def analyze(images: list, customer_text: str, make: str, model: str, year: int, mileage_km: int, vin: str, name: str, phone: str):
|
| 25 |
+
# Vision aggregation
|
| 26 |
+
agg = None
|
| 27 |
+
valid = 0
|
| 28 |
+
for it in images or []:
|
| 29 |
+
try:
|
| 30 |
+
img = _to_image(it)
|
| 31 |
+
vp = vision.predict(img)
|
| 32 |
+
valid += 1
|
| 33 |
+
if agg is None:
|
| 34 |
+
agg = {k: v for k,v in vp.items()}
|
| 35 |
+
else:
|
| 36 |
+
for k in agg:
|
| 37 |
+
agg[k] += vp.get(k, 0.0)
|
| 38 |
+
except Exception:
|
| 39 |
+
pass
|
| 40 |
+
if agg is None:
|
| 41 |
+
agg = {k: 0.0 for k in vision.labels}
|
| 42 |
+
else:
|
| 43 |
+
for k in agg:
|
| 44 |
+
agg[k] /= max(1, valid)
|
| 45 |
+
|
| 46 |
+
tp = nlp.predict(customer_text or "")
|
| 47 |
+
|
| 48 |
+
fused = fuse(agg, tp)
|
| 49 |
+
top = dict(list(fused.items())[:4])
|
| 50 |
+
estimate = estimate_costs(top, "configs/parts_costs.yaml", top_k=4)
|
| 51 |
+
pm = predictive_maintenance(car_year=int(year) if year else None, mileage_km=int(mileage_km) if mileage_km else None)
|
| 52 |
+
adv = advanced_suggestions(top_issues=top)
|
| 53 |
+
|
| 54 |
+
payload = {
|
| 55 |
+
"customer": {"name": name, "phone": phone},
|
| 56 |
+
"vehicle": {"make": make, "model": model, "year": year, "mileage_km": mileage_km, "vin": vin},
|
| 57 |
+
"complaint_text": customer_text,
|
| 58 |
+
"issues_ranked": fused,
|
| 59 |
+
"estimate": estimate,
|
| 60 |
+
"predictive_maintenance": pm,
|
| 61 |
+
"advanced_suggestions": adv
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
os.makedirs("exports", exist_ok=True)
|
| 65 |
+
pdf_path = "exports/service_report.pdf"
|
| 66 |
+
json_path = "exports/service_report.json"
|
| 67 |
+
ics_path = "exports/service_appointment.ics"
|
| 68 |
+
export_pdf(payload, pdf_path)
|
| 69 |
+
export_json(payload, json_path)
|
| 70 |
+
create_service_ics(ics_path, hours_from_now=48, duration_minutes=60)
|
| 71 |
+
|
| 72 |
+
def to_dl(path):
|
| 73 |
+
with open(path, "rb") as f:
|
| 74 |
+
return (os.path.basename(path), f.read())
|
| 75 |
+
|
| 76 |
+
return payload, to_dl(pdf_path), to_dl(json_path), to_dl(ics_path)
|
| 77 |
+
|
| 78 |
+
with gr.Blocks(fill_height=True) as demo:
|
| 79 |
+
gr.Markdown("## 🚗 Workshop Car Service Advisor")
|
| 80 |
+
with gr.Row():
|
| 81 |
+
with gr.Column(scale=1):
|
| 82 |
+
imgs = gr.File(label="Upload car image(s)", file_count="multiple", file_types=["image"])
|
| 83 |
+
cust = gr.Textbox(label="Customer reported issue", placeholder="Describe the problem...")
|
| 84 |
+
with gr.Row():
|
| 85 |
+
make = gr.Textbox(label="Make", value="Toyota")
|
| 86 |
+
model = gr.Textbox(label="Model", value="Corolla")
|
| 87 |
+
year = gr.Number(label="Year", value=2017, precision=0)
|
| 88 |
+
with gr.Row():
|
| 89 |
+
mileage = gr.Number(label="Mileage (km)", value=60000, precision=0)
|
| 90 |
+
vin = gr.Textbox(label="VIN", placeholder="Optional")
|
| 91 |
+
with gr.Row():
|
| 92 |
+
name = gr.Textbox(label="Customer Name", value="")
|
| 93 |
+
phone = gr.Textbox(label="Phone", value="")
|
| 94 |
+
run = gr.Button("Analyze", variant="primary")
|
| 95 |
+
with gr.Column(scale=1):
|
| 96 |
+
out_json = gr.JSON(label="Structured output")
|
| 97 |
+
pdf_file = gr.File(label="Download PDF report")
|
| 98 |
+
json_file = gr.File(label="Download JSON")
|
| 99 |
+
ics_file = gr.File(label="Download .ics (appointment)")
|
| 100 |
+
run.click(analyze, inputs=[imgs, cust, make, model, year, mileage, vin, name, phone],
|
| 101 |
+
outputs=[out_json, pdf_file, json_file, ics_file])
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
demo.launch()
|
car_advisor/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = '0.1.0'
|
car_advisor/config.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
|
| 4 |
+
DEFAULT_LABELS: List[str] = [
|
| 5 |
+
"scratch_dent","paint_damage","cracked_windshield","flat_tire","engine_leak",
|
| 6 |
+
"brake_wear","headlight_fault","battery_corrosion","rust","bumper_damage"
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
SEVERITY_DEFAULTS: Dict[str, int] = {
|
| 10 |
+
"scratch_dent": 2,
|
| 11 |
+
"paint_damage": 2,
|
| 12 |
+
"cracked_windshield": 4,
|
| 13 |
+
"flat_tire": 3,
|
| 14 |
+
"engine_leak": 5,
|
| 15 |
+
"brake_wear": 4,
|
| 16 |
+
"headlight_fault": 3,
|
| 17 |
+
"battery_corrosion": 2,
|
| 18 |
+
"rust": 2,
|
| 19 |
+
"bumper_damage": 3,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
PM_THRESHOLDS = {
|
| 23 |
+
"engine_oil": 10000,
|
| 24 |
+
"brake_pads": 30000,
|
| 25 |
+
"coolant": 40000,
|
| 26 |
+
"battery_check": 25000,
|
| 27 |
+
"tire_rotation": 8000,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class AppSettings:
|
| 32 |
+
labels: List[str] = DEFAULT_LABELS
|
| 33 |
+
labor_rate_per_hour: float = 1200.0
|
| 34 |
+
diagnostic_fee: float = 500.0
|
car_advisor/cost_estimator.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
def estimate_costs(fused: Dict[str, Any], parts_yaml_path: str, top_k: int = 4) -> Dict[str, Any]:
|
| 5 |
+
cfg = yaml.safe_load(open(parts_yaml_path, "r", encoding="utf-8"))
|
| 6 |
+
labor_rate = float(cfg.get("labor_rate_per_hour", 1200.0))
|
| 7 |
+
diagnostic_fee = float(cfg.get("diagnostic_fee", 500.0))
|
| 8 |
+
parts_cfg = cfg.get("parts", {})
|
| 9 |
+
items = []
|
| 10 |
+
total = 0.0
|
| 11 |
+
for i, (label, rec) in enumerate(fused.items()):
|
| 12 |
+
if i >= top_k:
|
| 13 |
+
break
|
| 14 |
+
part_info = parts_cfg.get(label, {})
|
| 15 |
+
hours = float(part_info.get("hours", 1.0))
|
| 16 |
+
parts_list = part_info.get("parts_list", [])
|
| 17 |
+
parts_cost = sum(float(p.get("cost", 0.0)) for p in parts_list)
|
| 18 |
+
labor_cost = hours * labor_rate
|
| 19 |
+
line_total = parts_cost + labor_cost
|
| 20 |
+
items.append({
|
| 21 |
+
"issue": label,
|
| 22 |
+
"probability": round(float(rec["prob"]), 3),
|
| 23 |
+
"severity": int(rec.get("severity", 3)),
|
| 24 |
+
"labor_hours": hours,
|
| 25 |
+
"labor_cost": round(labor_cost, 2),
|
| 26 |
+
"parts": parts_list,
|
| 27 |
+
"parts_cost": round(parts_cost, 2),
|
| 28 |
+
"line_total": round(line_total, 2)
|
| 29 |
+
})
|
| 30 |
+
total += line_total
|
| 31 |
+
if not items:
|
| 32 |
+
items.append({
|
| 33 |
+
"issue": "diagnostic_only",
|
| 34 |
+
"probability": 0.3,
|
| 35 |
+
"severity": 1,
|
| 36 |
+
"labor_hours": 0.0,
|
| 37 |
+
"labor_cost": 0.0,
|
| 38 |
+
"parts": [],
|
| 39 |
+
"parts_cost": 0.0,
|
| 40 |
+
"line_total": diagnostic_fee
|
| 41 |
+
})
|
| 42 |
+
total += diagnostic_fee
|
| 43 |
+
tax = round(0.18 * total, 2)
|
| 44 |
+
grand = round(total + tax, 2)
|
| 45 |
+
return {"items": items, "subtotal": round(total, 2), "tax": tax, "grand_total": grand}
|
car_advisor/fusion.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any
|
| 2 |
+
from .config import SEVERITY_DEFAULTS
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
def fuse(vision_probs: Dict[str, float], text_probs: Dict[str, float]) -> Dict[str, Any]:
|
| 6 |
+
fused = {}
|
| 7 |
+
for label in sorted(set(list(vision_probs.keys()) + list(text_probs.keys()))):
|
| 8 |
+
pv = max(1e-6, vision_probs.get(label, 0.0))
|
| 9 |
+
pt = max(1e-6, text_probs.get(label, 0.0))
|
| 10 |
+
p = math.sqrt(pv * pt) * 1.2 + 0.1 * pt + 0.05 * pv
|
| 11 |
+
fused[label] = {"prob": float(min(1.0, p)), "severity": SEVERITY_DEFAULTS.get(label, 3)}
|
| 12 |
+
return dict(sorted(fused.items(), key=lambda kv: kv[1]["prob"], reverse=True))
|
car_advisor/nlp_model.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from .config import DEFAULT_LABELS
|
| 3 |
+
|
| 4 |
+
KEYWORD_MAP = {
|
| 5 |
+
"brake_wear": ["brake", "squeal", "screech", "stopping", "pads"],
|
| 6 |
+
"flat_tire": ["flat", "puncture", "tyre", "tire", "pressure", "air"],
|
| 7 |
+
"engine_leak": ["leak", "oil", "puddle", "drip", "smell burning"],
|
| 8 |
+
"cracked_windshield": ["crack", "windshield", "glass"],
|
| 9 |
+
"paint_damage": ["scratch", "scrape", "paint", "scuff"],
|
| 10 |
+
"scratch_dent": ["dent", "dented", "bent"],
|
| 11 |
+
"headlight_fault": ["headlight", "bulb", "beam", "lamp"],
|
| 12 |
+
"battery_corrosion": ["battery", "corrosion", "terminal", "start"],
|
| 13 |
+
"rust": ["rust", "oxid"],
|
| 14 |
+
"bumper_damage": ["bumper", "fender"]
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
def _contains_any(text: str, keywords: List[str]) -> bool:
|
| 18 |
+
t = (text or "").lower()
|
| 19 |
+
return any(kw in t for kw in keywords)
|
| 20 |
+
|
| 21 |
+
class NLPInference:
|
| 22 |
+
def __init__(self, labels: List[str] = None, ckpt_dir: str = "checkpoints/nlp"):
|
| 23 |
+
self.labels = labels or DEFAULT_LABELS
|
| 24 |
+
self.ckpt_dir = ckpt_dir
|
| 25 |
+
self.trained = False
|
| 26 |
+
try:
|
| 27 |
+
import joblib, os
|
| 28 |
+
clf_p = os.path.join(ckpt_dir, "best.joblib")
|
| 29 |
+
mlb_p = os.path.join(ckpt_dir, "mlb.joblib")
|
| 30 |
+
if os.path.exists(clf_p) and os.path.exists(mlb_p):
|
| 31 |
+
self.clf = joblib.load(clf_p)
|
| 32 |
+
self.mlb = joblib.load(mlb_p)
|
| 33 |
+
self.trained = True
|
| 34 |
+
except Exception:
|
| 35 |
+
self.trained = False
|
| 36 |
+
|
| 37 |
+
def predict(self, text: str) -> Dict[str, float]:
|
| 38 |
+
if not text:
|
| 39 |
+
return {l: 0.0 for l in self.labels}
|
| 40 |
+
if self.trained:
|
| 41 |
+
probs = self.clf.predict_proba([text])[0]
|
| 42 |
+
out = {}
|
| 43 |
+
for i, lbl in enumerate(self.mlb.classes_):
|
| 44 |
+
val = probs[i] if isinstance(probs[i], (float,int)) else probs[i][1]
|
| 45 |
+
out[lbl] = float(val)
|
| 46 |
+
for l in self.labels:
|
| 47 |
+
out.setdefault(l, 0.0)
|
| 48 |
+
return out
|
| 49 |
+
else:
|
| 50 |
+
scores = {l: 0.01 for l in self.labels}
|
| 51 |
+
for label, kws in KEYWORD_MAP.items():
|
| 52 |
+
if _contains_any(text, kws):
|
| 53 |
+
scores[label] += 0.5
|
| 54 |
+
s = sum(scores.values())
|
| 55 |
+
return {k: v/s for k,v in scores.items()}
|
car_advisor/reporter.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
from reportlab.lib.pagesizes import A4
|
| 4 |
+
from reportlab.lib import colors
|
| 5 |
+
from reportlab.lib.styles import getSampleStyleSheet
|
| 6 |
+
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
def export_json(payload: Dict[str, Any], out_path: str) -> str:
|
| 10 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 11 |
+
json.dump(payload, f, ensure_ascii=False, indent=2)
|
| 12 |
+
return out_path
|
| 13 |
+
|
| 14 |
+
def export_pdf(payload: Dict[str, Any], out_path: str) -> str:
|
| 15 |
+
doc = SimpleDocTemplate(out_path, pagesize=A4)
|
| 16 |
+
styles = getSampleStyleSheet()
|
| 17 |
+
story = []
|
| 18 |
+
story.append(Paragraph("<b>Workshop Car Service Advisor Report</b>", styles["Title"]))
|
| 19 |
+
story.append(Spacer(1, 10))
|
| 20 |
+
|
| 21 |
+
meta = payload.get("vehicle", {})
|
| 22 |
+
cust = payload.get("customer", {})
|
| 23 |
+
details = f"""
|
| 24 |
+
<b>Customer:</b> {cust.get('name','N/A')} | <b>Phone:</b> {cust.get('phone','N/A')}<br/>
|
| 25 |
+
<b>Vehicle:</b> {meta.get('make','N/A')} {meta.get('model','')} {meta.get('year','')} | <b>VIN:</b> {meta.get('vin','N/A')}<br/>
|
| 26 |
+
<b>Mileage:</b> {meta.get('mileage_km','N/A')} km | <b>Date:</b> {datetime.now().strftime('%Y-%m-%d %H:%M')}
|
| 27 |
+
"""
|
| 28 |
+
story.append(Paragraph(details, styles["Normal"]))
|
| 29 |
+
story.append(Spacer(1, 10))
|
| 30 |
+
|
| 31 |
+
story.append(Paragraph("<b>Detected Issues</b>", styles["Heading2"]))
|
| 32 |
+
data = [["Issue", "Probability", "Severity", "Labor (hrs)", "Labor Cost", "Parts Cost", "Line Total"]]
|
| 33 |
+
for item in payload["estimate"]["items"]:
|
| 34 |
+
data.append([
|
| 35 |
+
item["issue"],
|
| 36 |
+
f"{item['probability']:.2f}",
|
| 37 |
+
str(item["severity"]),
|
| 38 |
+
f"{item['labor_hours']:.2f}",
|
| 39 |
+
f"₹{item['labor_cost']:.2f}",
|
| 40 |
+
f"₹{item['parts_cost']:.2f}",
|
| 41 |
+
f"₹{item['line_total']:.2f}",
|
| 42 |
+
])
|
| 43 |
+
table = Table(data, hAlign="LEFT")
|
| 44 |
+
table.setStyle(TableStyle([
|
| 45 |
+
('BACKGROUND',(0,0),(-1,0),colors.lightblue),
|
| 46 |
+
('TEXTCOLOR',(0,0),(-1,0),colors.whitesmoke),
|
| 47 |
+
('ALIGN',(0,0),(-1,-1),'CENTER'),
|
| 48 |
+
('GRID',(0,0),(-1,-1),0.25,colors.grey),
|
| 49 |
+
('FONTNAME',(0,0),(-1,0),'Helvetica-Bold'),
|
| 50 |
+
]))
|
| 51 |
+
story.append(table)
|
| 52 |
+
story.append(Spacer(1, 8))
|
| 53 |
+
|
| 54 |
+
story.append(Paragraph(f"<b>Subtotal:</b> ₹{payload['estimate']['subtotal']:.2f}", styles["Normal"]))
|
| 55 |
+
story.append(Paragraph(f"<b>Tax:</b> ₹{payload['estimate']['tax']:.2f}", styles["Normal"]))
|
| 56 |
+
story.append(Paragraph(f"<b>Grand Total:</b> ₹{payload['estimate']['grand_total']:.2f}", styles["Heading3"]))
|
| 57 |
+
|
| 58 |
+
story.append(Spacer(1, 10))
|
| 59 |
+
story.append(Paragraph("<b>Predictive Maintenance</b>", styles["Heading2"]))
|
| 60 |
+
for tip in payload.get("predictive_maintenance", []):
|
| 61 |
+
story.append(Paragraph(f"• {tip}", styles["Normal"]))
|
| 62 |
+
|
| 63 |
+
story.append(Spacer(1, 10))
|
| 64 |
+
story.append(Paragraph("<b>Advanced Suggestions</b>", styles["Heading2"]))
|
| 65 |
+
for tip in payload.get("advanced_suggestions", []):
|
| 66 |
+
story.append(Paragraph(f"• {tip}", styles["Normal"]))
|
| 67 |
+
|
| 68 |
+
doc.build(story)
|
| 69 |
+
return out_path
|
car_advisor/scheduler.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ics import Calendar, Event
|
| 2 |
+
from datetime import datetime, timedelta
|
| 3 |
+
|
| 4 |
+
def create_service_ics(out_path: str, summary="Car Service Appointment", hours_from_now: int = 48, duration_minutes: int = 60):
|
| 5 |
+
cal = Calendar()
|
| 6 |
+
e = Event()
|
| 7 |
+
start = datetime.now() + timedelta(hours=hours_from_now)
|
| 8 |
+
e.name = summary
|
| 9 |
+
e.begin = start
|
| 10 |
+
e.duration = timedelta(minutes=duration_minutes)
|
| 11 |
+
e.description = "Auto-suggested appointment from Workshop Car Service Advisor."
|
| 12 |
+
cal.events.add(e)
|
| 13 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 14 |
+
f.writelines(cal.serialize_iter())
|
| 15 |
+
return out_path
|
car_advisor/suggestions.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List
|
| 2 |
+
from .config import PM_THRESHOLDS
|
| 3 |
+
|
| 4 |
+
def predictive_maintenance(car_year: int = None, mileage_km: int = None) -> List[str]:
|
| 5 |
+
tips = []
|
| 6 |
+
if mileage_km is not None:
|
| 7 |
+
if mileage_km % PM_THRESHOLDS["engine_oil"] > PM_THRESHOLDS["engine_oil"] - 1000:
|
| 8 |
+
tips.append("Engine oil service due soon based on mileage.")
|
| 9 |
+
if mileage_km % PM_THRESHOLDS["tire_rotation"] > PM_THRESHOLDS["tire_rotation"] - 500:
|
| 10 |
+
tips.append("Consider tire rotation and balancing.")
|
| 11 |
+
if mileage_km > 50000:
|
| 12 |
+
tips.append("Inspect suspension components (shocks/struts) for wear.")
|
| 13 |
+
if mileage_km > 80000:
|
| 14 |
+
tips.append("Check timing belt/chain and water pump as per manufacturer schedule.")
|
| 15 |
+
if car_year is not None and car_year < 2015:
|
| 16 |
+
tips.append("Vehicle age suggests comprehensive electrical & rubber parts inspection.")
|
| 17 |
+
if not tips:
|
| 18 |
+
tips.append("No immediate predictive maintenance items flagged.")
|
| 19 |
+
return tips
|
| 20 |
+
|
| 21 |
+
def advanced_suggestions(top_issues: Dict[str, Any]) -> List[str]:
|
| 22 |
+
tips = []
|
| 23 |
+
if "engine_leak" in top_issues:
|
| 24 |
+
tips.append("After fixing leak, clean engine bay and monitor oil level weekly for 1 month.")
|
| 25 |
+
if "brake_wear" in top_issues:
|
| 26 |
+
tips.append("Bed-in new pads and avoid hard braking for first 200 km.")
|
| 27 |
+
if "flat_tire" in top_issues:
|
| 28 |
+
tips.append("Check alignment and inspect other tires for embedded nails/screws.")
|
| 29 |
+
if "rust" in top_issues:
|
| 30 |
+
tips.append("Apply rust protection and inspect underbody after monsoon season.")
|
| 31 |
+
if "cracked_windshield" in top_issues:
|
| 32 |
+
tips.append("Avoid potholes and sudden temperature changes until replacement.")
|
| 33 |
+
return tips
|
car_advisor/utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
|
| 3 |
+
def load_yaml(path):
|
| 4 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 5 |
+
return yaml.safe_load(f)
|
| 6 |
+
|
| 7 |
+
def softmax(xs):
|
| 8 |
+
import numpy as np
|
| 9 |
+
x = np.array(xs, dtype=float)
|
| 10 |
+
e = np.exp(x - x.max())
|
| 11 |
+
return (e / e.sum()).tolist()
|
car_advisor/vision_model.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch, os
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision.transforms as T
|
| 6 |
+
from .config import DEFAULT_LABELS
|
| 7 |
+
from .utils import softmax
|
| 8 |
+
|
| 9 |
+
class SimpleVisionModel(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Wrapper around a lightweight classifier. For training, use training/train_vision.py.
|
| 12 |
+
At inference, if checkpoint absent or downloads fail, we return rule-based scores.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, num_classes: int):
|
| 15 |
+
super().__init__()
|
| 16 |
+
try:
|
| 17 |
+
import timm
|
| 18 |
+
self.net = timm.create_model("mobilenetv3_small_100", pretrained=True, num_classes=num_classes)
|
| 19 |
+
except Exception:
|
| 20 |
+
self.net = nn.Sequential(
|
| 21 |
+
nn.AdaptiveAvgPool2d((8,8)),
|
| 22 |
+
nn.Flatten(),
|
| 23 |
+
nn.Linear(8*8*3, 128),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
nn.Linear(128, num_classes)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
return self.net(x)
|
| 30 |
+
|
| 31 |
+
class VisionInference:
|
| 32 |
+
def __init__(self, labels: List[str] = None, ckpt_path: str = "checkpoints/vision/best.pt"):
|
| 33 |
+
self.labels = labels or DEFAULT_LABELS
|
| 34 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
self.model = SimpleVisionModel(num_classes=len(self.labels)).to(self.device)
|
| 36 |
+
self.transform = T.Compose([T.Resize((224,224)), T.ToTensor()])
|
| 37 |
+
self.ready = False
|
| 38 |
+
if os.path.exists(ckpt_path):
|
| 39 |
+
try:
|
| 40 |
+
state = torch.load(ckpt_path, map_location=self.device)
|
| 41 |
+
self.model.load_state_dict(state["model"] if "model" in state else state)
|
| 42 |
+
self.ready = True
|
| 43 |
+
except Exception:
|
| 44 |
+
self.ready = False
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def predict(self, image: Image.Image) -> Dict[str, float]:
|
| 48 |
+
if image is None:
|
| 49 |
+
return {l: 0.0 for l in self.labels}
|
| 50 |
+
try:
|
| 51 |
+
x = self.transform(image.convert("RGB")).unsqueeze(0).to(self.device)
|
| 52 |
+
logits = self.model(x)[0].detach().cpu().tolist()
|
| 53 |
+
probs = softmax(logits)
|
| 54 |
+
return {lbl: float(p) for lbl, p in zip(self.labels, probs)}
|
| 55 |
+
except Exception:
|
| 56 |
+
import numpy as np
|
| 57 |
+
img = image.convert("RGB").resize((64,64))
|
| 58 |
+
arr = np.array(img).astype("float32")/255.0
|
| 59 |
+
gray = arr.mean(axis=2)
|
| 60 |
+
contrast = float(gray.std())
|
| 61 |
+
red_mean = float(arr[:,:,0].mean())
|
| 62 |
+
green_mean = float(arr[:,:,1].mean())
|
| 63 |
+
blue_mean = float(arr[:,:,2].mean())
|
| 64 |
+
scores = {l: 0.01 for l in self.labels}
|
| 65 |
+
if contrast > 0.22:
|
| 66 |
+
scores["scratch_dent"] += 0.2
|
| 67 |
+
scores["paint_damage"] += 0.15
|
| 68 |
+
scores["bumper_damage"] += 0.1
|
| 69 |
+
if blue_mean < 0.35 and green_mean < 0.35:
|
| 70 |
+
scores["rust"] += 0.2
|
| 71 |
+
if red_mean > 0.55:
|
| 72 |
+
scores["engine_leak"] += 0.15
|
| 73 |
+
s = sum(scores.values())
|
| 74 |
+
return {k: v/s for k,v in scores.items()}
|
configs/issues.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
labels:
|
| 2 |
+
- scratch_dent
|
| 3 |
+
- paint_damage
|
| 4 |
+
- cracked_windshield
|
| 5 |
+
- flat_tire
|
| 6 |
+
- engine_leak
|
| 7 |
+
- brake_wear
|
| 8 |
+
- headlight_fault
|
| 9 |
+
- battery_corrosion
|
| 10 |
+
- rust
|
| 11 |
+
- bumper_damage
|
configs/parts_costs.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
labor_rate_per_hour: 1200
|
| 2 |
+
diagnostic_fee: 500
|
| 3 |
+
parts:
|
| 4 |
+
scratch_dent:
|
| 5 |
+
parts_list:
|
| 6 |
+
- {name: "Body filler/paint kit", cost: 1800}
|
| 7 |
+
hours: 1.5
|
| 8 |
+
paint_damage:
|
| 9 |
+
parts_list:
|
| 10 |
+
- {name: "Paint & materials", cost: 2500}
|
| 11 |
+
hours: 2.0
|
| 12 |
+
cracked_windshield:
|
| 13 |
+
parts_list:
|
| 14 |
+
- {name: "Windshield glass", cost: 8000}
|
| 15 |
+
- {name: "Sealant kit", cost: 900}
|
| 16 |
+
hours: 2.5
|
| 17 |
+
flat_tire:
|
| 18 |
+
parts_list:
|
| 19 |
+
- {name: "New tire", cost: 4500}
|
| 20 |
+
hours: 0.6
|
| 21 |
+
engine_leak:
|
| 22 |
+
parts_list:
|
| 23 |
+
- {name: "Gasket/seal kit", cost: 3200}
|
| 24 |
+
- {name: "Engine oil", cost: 1800}
|
| 25 |
+
hours: 3.0
|
| 26 |
+
brake_wear:
|
| 27 |
+
parts_list:
|
| 28 |
+
- {name: "Brake pads (pair)", cost: 3500}
|
| 29 |
+
hours: 1.4
|
| 30 |
+
headlight_fault:
|
| 31 |
+
parts_list:
|
| 32 |
+
- {name: "Headlight bulb/assembly", cost: 2200}
|
| 33 |
+
hours: 0.8
|
| 34 |
+
battery_corrosion:
|
| 35 |
+
parts_list:
|
| 36 |
+
- {name: "Battery terminals/cleaner", cost: 600}
|
| 37 |
+
hours: 0.5
|
| 38 |
+
rust:
|
| 39 |
+
parts_list:
|
| 40 |
+
- {name: "Rust converter & primer", cost: 1000}
|
| 41 |
+
hours: 2.0
|
| 42 |
+
bumper_damage:
|
| 43 |
+
parts_list:
|
| 44 |
+
- {name: "Bumper cover", cost: 7000}
|
| 45 |
+
- {name: "Clips/fasteners", cost: 500}
|
| 46 |
+
hours: 2.2
|
data/sample_data/annotations.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_path,issue_label,customer_text,car_make,car_model,car_year,mileage_km
|
| 2 |
+
images/example.jpg,paint_damage,"Scratches on left door, visible scuff marks",Maruti,Swift,2017,65000
|
data/sample_data/images/example.jpg
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
torchvision>=0.16.0
|
| 3 |
+
torchaudio>=2.1.0
|
| 4 |
+
timm>=1.0.3
|
| 5 |
+
transformers>=4.42.0
|
| 6 |
+
tokenizers>=0.15.2
|
| 7 |
+
gradio>=4.36.1
|
| 8 |
+
pydantic>=2.7.0
|
| 9 |
+
pillow>=10.3.0
|
| 10 |
+
numpy>=1.26.4
|
| 11 |
+
pandas>=2.2.2
|
| 12 |
+
scikit-learn>=1.5.0
|
| 13 |
+
pyyaml>=6.0.1
|
| 14 |
+
reportlab>=4.1.0
|
| 15 |
+
ics>=0.7.2
|
training/dataset.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
|
| 6 |
+
class CarIssuesDataset(Dataset):
|
| 7 |
+
def __init__(self, csv_path, img_root, labels, transform=None, text_col="customer_text"):
|
| 8 |
+
self.df = pd.read_csv(csv_path)
|
| 9 |
+
self.img_root = img_root
|
| 10 |
+
self.labels = labels
|
| 11 |
+
self.transform = transform or T.Compose([T.Resize((224,224)), T.ToTensor()])
|
| 12 |
+
self.text_col = text_col
|
| 13 |
+
|
| 14 |
+
def __len__(self):
|
| 15 |
+
return len(self.df)
|
| 16 |
+
|
| 17 |
+
def __getitem__(self, idx):
|
| 18 |
+
row = self.df.iloc[idx]
|
| 19 |
+
img_path = row['image_path']
|
| 20 |
+
if not str(img_path).startswith(self.img_root):
|
| 21 |
+
import os
|
| 22 |
+
img_path = os.path.join(self.img_root, img_path)
|
| 23 |
+
try:
|
| 24 |
+
img = Image.open(img_path).convert("RGB")
|
| 25 |
+
except Exception:
|
| 26 |
+
import numpy as np
|
| 27 |
+
img = Image.fromarray((np.zeros((224,224,3))+255).astype("uint8"))
|
| 28 |
+
x = self.transform(img)
|
| 29 |
+
y = [1 if row["issue_label"] == l else 0 for l in self.labels]
|
| 30 |
+
text = str(row.get(self.text_col, ""))
|
| 31 |
+
meta = {
|
| 32 |
+
"car_make": row.get("car_make", ""),
|
| 33 |
+
"car_model": row.get("car_model", ""),
|
| 34 |
+
"car_year": row.get("car_year", ""),
|
| 35 |
+
"mileage_km": row.get("mileage_km", ""),
|
| 36 |
+
}
|
| 37 |
+
return x, y, text, meta
|
training/train_fusion.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, os, pandas as pd, numpy as np
|
| 2 |
+
from sklearn.linear_model import LogisticRegression
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from sklearn.metrics import f1_score
|
| 5 |
+
import joblib
|
| 6 |
+
|
| 7 |
+
def main(args):
|
| 8 |
+
df = pd.read_csv(args.annotations)
|
| 9 |
+
labels = sorted(df["issue_label"].unique().tolist())
|
| 10 |
+
label_to_idx = {l:i for i,l in enumerate(labels)}
|
| 11 |
+
|
| 12 |
+
X = []
|
| 13 |
+
y = []
|
| 14 |
+
for _, row in df.iterrows():
|
| 15 |
+
text = str(row.get("customer_text","")).lower()
|
| 16 |
+
features = [
|
| 17 |
+
len(text),
|
| 18 |
+
int("brake" in text),
|
| 19 |
+
int("leak" in text),
|
| 20 |
+
int("tire" in text or "tyre" in text),
|
| 21 |
+
int("scratch" in text or "dent" in text),
|
| 22 |
+
]
|
| 23 |
+
X.append(features)
|
| 24 |
+
y.append(label_to_idx[row["issue_label"]])
|
| 25 |
+
X = np.array(X); y = np.array(y)
|
| 26 |
+
|
| 27 |
+
Xtr, Xv, ytr, yv = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 28 |
+
clf = LogisticRegression(max_iter=200).fit(Xtr, ytr)
|
| 29 |
+
yp = clf.predict(Xv)
|
| 30 |
+
print("fusion macro F1:", f1_score(yv, yp, average="macro"))
|
| 31 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 32 |
+
joblib.dump({"clf": clf, "labels": labels}, os.path.join(args.out_dir, "best.joblib"))
|
| 33 |
+
print("Saved", args.out_dir)
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
ap = argparse.ArgumentParser()
|
| 37 |
+
ap.add_argument("--annotations", required=True)
|
| 38 |
+
ap.add_argument("--vision_ckpt", required=False)
|
| 39 |
+
ap.add_argument("--nlp_ckpt", required=False)
|
| 40 |
+
ap.add_argument("--out_dir", default="checkpoints/fusion")
|
| 41 |
+
args = ap.parse_args()
|
| 42 |
+
main(args)
|
training/train_nlp.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, os, pandas as pd
|
| 2 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 3 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 4 |
+
from sklearn.multiclass import OneVsRestClassifier
|
| 5 |
+
from sklearn.linear_model import LogisticRegression
|
| 6 |
+
from sklearn.pipeline import Pipeline
|
| 7 |
+
from sklearn.model_selection import train_test_split
|
| 8 |
+
from sklearn.metrics import f1_score
|
| 9 |
+
import joblib
|
| 10 |
+
|
| 11 |
+
def main(args):
|
| 12 |
+
df = pd.read_csv(args.annotations)
|
| 13 |
+
X = df["customer_text"].fillna("")
|
| 14 |
+
y = df["issue_label"].fillna("").apply(lambda x: [x])
|
| 15 |
+
mlb = MultiLabelBinarizer()
|
| 16 |
+
Y = mlb.fit_transform(y)
|
| 17 |
+
|
| 18 |
+
X_tr, X_v, Y_tr, Y_v = train_test_split(X, Y, test_size=0.2, random_state=42)
|
| 19 |
+
|
| 20 |
+
pipe = Pipeline([
|
| 21 |
+
("tfidf", TfidfVectorizer(ngram_range=(1,2), max_features=40000)),
|
| 22 |
+
("clf", OneVsRestClassifier(LogisticRegression(max_iter=200)))
|
| 23 |
+
])
|
| 24 |
+
pipe.fit(X_tr, Y_tr)
|
| 25 |
+
Yp = pipe.predict(X_v)
|
| 26 |
+
print("macro F1:", f1_score(Y_v, Yp, average="macro"))
|
| 27 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 28 |
+
joblib.dump(pipe, os.path.join(args.out_dir, "best.joblib"))
|
| 29 |
+
joblib.dump(mlb, os.path.join(args.out_dir, "mlb.joblib"))
|
| 30 |
+
print("Saved to", args.out_dir)
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
ap = argparse.ArgumentParser()
|
| 34 |
+
ap.add_argument("--annotations", required=True)
|
| 35 |
+
ap.add_argument("--out_dir", default="checkpoints/nlp")
|
| 36 |
+
ap.add_argument("--epochs", type=int, default=3)
|
| 37 |
+
args = ap.parse_args()
|
| 38 |
+
main(args)
|
training/train_vision.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, os
|
| 2 |
+
import torch, torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader, random_split
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
import timm
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from training.dataset import CarIssuesDataset
|
| 8 |
+
|
| 9 |
+
def main(args):
|
| 10 |
+
labels = pd.read_csv(args.annotations)["issue_label"].unique().tolist()
|
| 11 |
+
transform = T.Compose([T.Resize((224,224)), T.ToTensor()])
|
| 12 |
+
ds = CarIssuesDataset(args.annotations, os.path.dirname(args.annotations), labels, transform=transform)
|
| 13 |
+
n = len(ds)
|
| 14 |
+
n_val = max(1, int(0.2 * n))
|
| 15 |
+
tr, val = random_split(ds, [n - n_val, n_val])
|
| 16 |
+
tl = DataLoader(tr, batch_size=16, shuffle=True)
|
| 17 |
+
vl = DataLoader(val, batch_size=16)
|
| 18 |
+
|
| 19 |
+
model = timm.create_model("mobilenetv3_small_100", pretrained=True, num_classes=len(labels))
|
| 20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
model.to(device)
|
| 22 |
+
|
| 23 |
+
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
| 24 |
+
crit = nn.CrossEntropyLoss()
|
| 25 |
+
|
| 26 |
+
best = 0.0
|
| 27 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
for epoch in range(args.epochs):
|
| 30 |
+
model.train()
|
| 31 |
+
for xb, yb, _, _ in tl:
|
| 32 |
+
xb = xb.to(device)
|
| 33 |
+
yb = yb.argmax(dim=1).to(device)
|
| 34 |
+
opt.zero_grad()
|
| 35 |
+
out = model(xb)
|
| 36 |
+
loss = crit(out, yb)
|
| 37 |
+
loss.backward()
|
| 38 |
+
opt.step()
|
| 39 |
+
|
| 40 |
+
model.eval()
|
| 41 |
+
correct = 0; total = 0
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
for xb, yb, _, _ in vl:
|
| 44 |
+
xb = xb.to(device)
|
| 45 |
+
y_true = yb.argmax(dim=1).to(device)
|
| 46 |
+
logits = model(xb)
|
| 47 |
+
preds = logits.argmax(dim=1)
|
| 48 |
+
correct += (preds == y_true).sum().item()
|
| 49 |
+
total += y_true.numel()
|
| 50 |
+
acc = correct/total if total else 0
|
| 51 |
+
print(f"Epoch {epoch+1}: val_acc={acc:.3f}")
|
| 52 |
+
if acc > best:
|
| 53 |
+
best = acc
|
| 54 |
+
torch.save({"model": model.state_dict(), "labels": labels}, os.path.join(args.out_dir, "best.pt"))
|
| 55 |
+
print("Done. Best acc:", best)
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
ap = argparse.ArgumentParser()
|
| 59 |
+
ap.add_argument("--data_root", required=True)
|
| 60 |
+
ap.add_argument("--annotations", required=True)
|
| 61 |
+
ap.add_argument("--out_dir", default="checkpoints/vision")
|
| 62 |
+
ap.add_argument("--epochs", type=int, default=5)
|
| 63 |
+
args = ap.parse_args()
|
| 64 |
+
main(args)
|