wenbemi commited on
Commit
40e0b51
·
verified ·
1 Parent(s): 6a7a44d

Update chat_a.py

Browse files
Files changed (1) hide show
  1. chat_a.py +13 -9
chat_a.py CHANGED
@@ -15,7 +15,7 @@ from css import log_and_render
15
  HOME = pathlib.Path.home()
16
 
17
  # ✅ ENV가 있으면 따르고, 없으면 홈 밑 .cache/hf-cache 사용
18
- CACHE_DIR = os.getenv("TRANSFORMERS_CACHE") or str(HOME / ".cache" / "hf-cache")
19
  os.makedirs(CACHE_DIR, exist_ok=True)
20
 
21
  HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "emisdfde/moai-travel-data")
@@ -1038,16 +1038,20 @@ def override_emotion_if_needed(text):
1038
  return [(emotion_label, 50.0)], [emotion_group]
1039
  return None
1040
 
1041
- def analyze_emotion(user_input):
1042
- sentiment_model = load_sentiment_model()
1043
  tokenizer = load_tokenizer()
1044
- override = override_emotion_if_needed(user_input)
1045
- if override:
1046
- return override
1047
- inputs = tokenizer(user_input, return_tensors="pt", truncation=True)
1048
-
 
 
 
 
1049
  with torch.no_grad():
1050
- probs = F.softmax(sentiment_model(**inputs).logits, dim=1)[0]
 
1051
  top_indices = torch.topk(probs, k=5).indices.tolist()
1052
  top_emotions = [(klue_emotions[i], float(probs[i]) * 100) for i in top_indices]
1053
  top_emotion_groups = list(dict.fromkeys([klue_to_general[i] for i in top_indices if probs[i] > 0.05]))
 
15
  HOME = pathlib.Path.home()
16
 
17
  # ✅ ENV가 있으면 따르고, 없으면 홈 밑 .cache/hf-cache 사용
18
+ CACHE_DIR = os.getenv("TRANSFORMERS_CACHE") or os.path.expanduser("~/.cache/hf-cache")
19
  os.makedirs(CACHE_DIR, exist_ok=True)
20
 
21
  HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "emisdfde/moai-travel-data")
 
1038
  return [(emotion_label, 50.0)], [emotion_group]
1039
  return None
1040
 
1041
+ def analyze_emotion(text: str):
 
1042
  tokenizer = load_tokenizer()
1043
+ model = load_sentiment_model()
1044
+
1045
+ inputs = tokenizer(
1046
+ text,
1047
+ return_tensors="pt",
1048
+ truncation=True,
1049
+ padding=False,
1050
+ max_length=256,
1051
+ )
1052
  with torch.no_grad():
1053
+ logits = model(**inputs).logits
1054
+ probs = F.softmax(logits, dim=1)[0].tolist()
1055
  top_indices = torch.topk(probs, k=5).indices.tolist()
1056
  top_emotions = [(klue_emotions[i], float(probs[i]) * 100) for i in top_indices]
1057
  top_emotion_groups = list(dict.fromkeys([klue_to_general[i] for i in top_indices if probs[i] > 0.05]))