Spaces:
Sleeping
Sleeping
wenbemi
commited on
Update chat_a.py
Browse files
chat_a.py
CHANGED
|
@@ -15,7 +15,7 @@ from css import log_and_render
|
|
| 15 |
HOME = pathlib.Path.home()
|
| 16 |
|
| 17 |
# ✅ ENV가 있으면 따르고, 없으면 홈 밑 .cache/hf-cache 사용
|
| 18 |
-
CACHE_DIR = os.getenv("TRANSFORMERS_CACHE") or
|
| 19 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 20 |
|
| 21 |
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "emisdfde/moai-travel-data")
|
|
@@ -1038,16 +1038,20 @@ def override_emotion_if_needed(text):
|
|
| 1038 |
return [(emotion_label, 50.0)], [emotion_group]
|
| 1039 |
return None
|
| 1040 |
|
| 1041 |
-
def analyze_emotion(
|
| 1042 |
-
sentiment_model = load_sentiment_model()
|
| 1043 |
tokenizer = load_tokenizer()
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1049 |
with torch.no_grad():
|
| 1050 |
-
|
|
|
|
| 1051 |
top_indices = torch.topk(probs, k=5).indices.tolist()
|
| 1052 |
top_emotions = [(klue_emotions[i], float(probs[i]) * 100) for i in top_indices]
|
| 1053 |
top_emotion_groups = list(dict.fromkeys([klue_to_general[i] for i in top_indices if probs[i] > 0.05]))
|
|
|
|
| 15 |
HOME = pathlib.Path.home()
|
| 16 |
|
| 17 |
# ✅ ENV가 있으면 따르고, 없으면 홈 밑 .cache/hf-cache 사용
|
| 18 |
+
CACHE_DIR = os.getenv("TRANSFORMERS_CACHE") or os.path.expanduser("~/.cache/hf-cache")
|
| 19 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 20 |
|
| 21 |
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "emisdfde/moai-travel-data")
|
|
|
|
| 1038 |
return [(emotion_label, 50.0)], [emotion_group]
|
| 1039 |
return None
|
| 1040 |
|
| 1041 |
+
def analyze_emotion(text: str):
|
|
|
|
| 1042 |
tokenizer = load_tokenizer()
|
| 1043 |
+
model = load_sentiment_model()
|
| 1044 |
+
|
| 1045 |
+
inputs = tokenizer(
|
| 1046 |
+
text,
|
| 1047 |
+
return_tensors="pt",
|
| 1048 |
+
truncation=True,
|
| 1049 |
+
padding=False,
|
| 1050 |
+
max_length=256,
|
| 1051 |
+
)
|
| 1052 |
with torch.no_grad():
|
| 1053 |
+
logits = model(**inputs).logits
|
| 1054 |
+
probs = F.softmax(logits, dim=1)[0].tolist()
|
| 1055 |
top_indices = torch.topk(probs, k=5).indices.tolist()
|
| 1056 |
top_emotions = [(klue_emotions[i], float(probs[i]) * 100) for i in top_indices]
|
| 1057 |
top_emotion_groups = list(dict.fromkeys([klue_to_general[i] for i in top_indices if probs[i] > 0.05]))
|