dhruv2842 commited on
Commit
02f47a2
·
verified ·
1 Parent(s): c57b611

Update utils/specialist_predictor.py

Browse files
Files changed (1) hide show
  1. utils/specialist_predictor.py +18 -17
utils/specialist_predictor.py CHANGED
@@ -1,17 +1,18 @@
1
- from sentence_transformers import SentenceTransformer, util
2
- import torch
3
- import joblib
4
-
5
- # Load model components once
6
- bundle = joblib.load("semantic_specialist_model.pkl")
7
- model = SentenceTransformer(bundle["model_name"])
8
- known_embeddings = bundle["known_embeddings"]
9
- symptom_specialist_pairs = bundle["symptom_specialist_pairs"]
10
-
11
- def predict_specialist(symptom_text: str):
12
- input_embedding = model.encode(symptom_text, convert_to_tensor=True)
13
- similarities = util.pytorch_cos_sim(input_embedding, known_embeddings)[0]
14
- top_idx = similarities.argmax().item()
15
- specialist = symptom_specialist_pairs[top_idx][1]
16
- score = similarities[top_idx].item()
17
- return specialist, score
 
 
1
+ from sentence_transformers import SentenceTransformer, util
2
+ import torch
3
+ import joblib
4
+
5
+ # Load model components once
6
+ bundle = joblib.load("semantic_specialist_model.pkl")
7
+ model = SentenceTransformer("./models/all-MiniLM-L6-v2")
8
+
9
+ known_embeddings = bundle["known_embeddings"]
10
+ symptom_specialist_pairs = bundle["symptom_specialist_pairs"]
11
+
12
+ def predict_specialist(symptom_text: str):
13
+ input_embedding = model.encode(symptom_text, convert_to_tensor=True)
14
+ similarities = util.pytorch_cos_sim(input_embedding, known_embeddings)[0]
15
+ top_idx = similarities.argmax().item()
16
+ specialist = symptom_specialist_pairs[top_idx][1]
17
+ score = similarities[top_idx].item()
18
+ return specialist, score