devjas1 commited on
Commit
51acb3f
·
1 Parent(s): 9125d98

(refactor): clean up unused inference utilities and FastAPI setup

Browse files

- Removed obsolete inference utilities and FastAPI setup files that were remnants of the Docker environment dismantling.

backend/.gitignore DELETED
@@ -1 +0,0 @@
1
- __pycache__/
 
 
backend/inference_utils.py DELETED
@@ -1,79 +0,0 @@
1
- def load_model(name):
2
- return "mock_model"
3
-
4
- def run_inference(model, spectrum):
5
- return {
6
- "prediction": "Stubbed Output",
7
- "class_index": 0,
8
- "logits": [0.0, 1.0],
9
- "class_labels": ["Stub", "Output"]
10
- }
11
-
12
-
13
- # ---------- ACTUAL MODEL LOADING/INFERENCE CODE ---------------------|
14
- # import torch
15
- # import numpy as np
16
- # from pathlib import Path
17
- # from scripts.preprocess_dataset import resample_spectrum
18
- # from models.figure2_cnn import Figure2CNN
19
- # from models.resnet_cnn import ResNet1D
20
-
21
- # # -- Label Map --
22
- # LABELS = ["Stable (Unweathered)", "Weathered (Degraded)"]
23
-
24
- # # -- Model Paths --
25
- # MODEL_CONFIG = {
26
- # "figure2": {
27
- # "class": Figure2CNN,
28
- # "path": "outputs/figure2_model.pth"
29
- # },
30
- # "resnet": {
31
- # "class": ResNet1D,
32
- # "path": "outputs/resnet_model.pth"
33
- # }
34
- # }
35
-
36
- # def load_model(model_name: str):
37
- # if model_name not in MODEL_CONFIG:
38
- # raise ValueError(f"Unknown model '{model_name}'. Valid options: {list(MODEL_CONFIG.keys())}")
39
-
40
- # config = MODEL_CONFIG[model_name]
41
- # model = config["class"]()
42
- # state_dict = torch.load(config["path"], map_location=torch.device("cpu"), weights_only=True)
43
- # model.load_state_dict(state_dict)
44
- # model.eval()
45
- # return model
46
-
47
- # def run_inference(model, spectrum: list):
48
- # # -- Validate Input --
49
- # if not isinstance(spectrum, list) or len(spectrum) < 10:
50
- # raise ValueError("Spectrum must be a list of floats with reasonable length")
51
-
52
- # # -- Convert to Numpy --
53
- # spectrum = np.array(spectrum, dtype=np.float32)
54
-
55
- # # -- Resample --
56
- # x_vals = np.arange(len(spectrum))
57
- # spectrum = resample_spectrum(x_vals, spectrum, target_len=500)
58
-
59
- # # -- Normalize --
60
- # mean = np.mean(spectrum)
61
- # std = np.std(spectrum)
62
- # if std == 0:
63
- # raise ValueError("Standard deviation of spectrum is zero; normalization will fail.")
64
- # spectrum = (spectrum - mean) / std
65
-
66
- # # -- To Tensor --
67
- # x = torch.tensor(spectrum, dtype=torch.float32).unsqueeze(0).unsqueeze(0) # Shape (1, 1, 500)
68
-
69
- # with torch.no_grad():
70
- # logits = model(x)
71
- # pred_index = torch.argmax(logits, dim=1).item()
72
-
73
- # return {
74
- # "prediction": LABELS[pred_index],
75
- # "class_index": pred_index,
76
- # "logits": logits.squeeze().tolist(),
77
- # "class_labels": LABELS
78
- # }
79
- # ---------- ACTUAL MODEL LOADING/INFERENCE CODE ---------------------|
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/main.py DELETED
@@ -1,34 +0,0 @@
1
- # from fastapi import FastAPI, HTTPException
2
- from fastapi import FastAPI
3
- from pydantic import BaseModel
4
- # import torch
5
-
6
- # from backend.inference_utils import load_model, run_inference
7
-
8
- # -- FastAPI app --
9
- app = FastAPI()
10
-
11
- # -- Input Schema --
12
- class InferenceRequest(BaseModel):
13
- model_name: str
14
- spectrum: list[float]
15
-
16
- @app.get("/")
17
- def root():
18
- return {"message": "Polymer Aging Inference API is online"}
19
-
20
- @app.post("/infer")
21
- def infer(request: InferenceRequest):
22
- return{
23
- "prediction": "Stubbed Output",
24
- "class_index": 0,
25
- "logits": [0.0, 1.0],
26
- "class_labels": ["Stub", "Output"],
27
- }
28
- # def infer(request: InferenceRequest):
29
- # try:
30
- # model = load_model(request.model_name)
31
- # result = run_inference(model, request.spectrum)
32
- # return result
33
- # except Exception as e:
34
- # raise HTTPException(status_code=500, detail=str(e)) from e