Eliot0110 commited on
Commit
c21e548
·
1 Parent(s): 8ddee1f

fix:处理key

Browse files
modules/config_loader.py CHANGED
@@ -1,33 +0,0 @@
1
- # modules/config_loader.py
2
- import json
3
- from pathlib import Path
4
- from utils.logger import log
5
-
6
- class ConfigLoader:
7
- def __init__(self, config_dir: Path = Path("./config")):
8
- self.config_dir = config_dir
9
- self.cities = {}
10
- self.personas = {}
11
- self.interests = {}
12
- try:
13
- self._load_all()
14
- log.info("✅ 所有配置文件加载成功")
15
- except Exception as e:
16
- log.error(f"❌ 配置文件加载失败: {e}", exc_info=True)
17
- raise
18
-
19
- def _load_all(self):
20
- # 加载城市
21
- with open(self.config_dir / "cities.json", 'r', encoding='utf-8') as f:
22
- cities_data = json.load(f)
23
- for city in cities_data['cities']:
24
- for alias in [city['name']] + city.get('aliases', []):
25
- self.cities[alias.lower()] = city
26
-
27
- # 加载 personas
28
- with open(self.config_dir / "personas.json", 'r', encoding='utf-8') as f:
29
- self.personas = json.load(f)['personas']
30
-
31
- # 加载兴趣
32
- with open(self.config_dir / "interests.json", 'r', encoding='utf-8') as f:
33
- self.interests = json.load(f)['interests']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/info_extractor.py CHANGED
@@ -1,31 +0,0 @@
1
- # modules/info_extractor.py
2
- import re
3
- from .config_loader import ConfigLoader
4
-
5
- class InfoExtractor:
6
- def __init__(self, config_loader: ConfigLoader):
7
- self.configs = config_loader
8
-
9
- def extract(self, user_input: str) -> dict:
10
- """从用户输入中提取目的地、天数和旅行风格"""
11
- extracted_info = {}
12
- user_lower = user_input.lower()
13
-
14
- # 提取目的地
15
- for alias, city_info in self.configs.cities.items():
16
- if alias in user_lower:
17
- extracted_info["destination"] = city_info
18
- break
19
-
20
- # 提取天数
21
- match = re.search(r'(\d+)\s*天', user_input)
22
- if match:
23
- extracted_info["duration"] = {"days": int(match.group(1))}
24
-
25
- # 提取旅行风格 (persona)
26
- for p_name, p_info in self.configs.personas.items():
27
- if p_info['name'] in user_input or p_name in user_input:
28
- extracted_info["persona"] = p_info
29
- break
30
-
31
- return extracted_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/knowledge_base.py CHANGED
@@ -1,29 +0,0 @@
1
- # modules/knowledge_base.py
2
- import json
3
- from pathlib import Path
4
- from utils.logger import log
5
-
6
- class KnowledgeBase:
7
- def __init__(self, file_path: Path = Path("./config/general_travelplan.json")):
8
- self.knowledge = []
9
- try:
10
- with open(file_path, 'r', encoding='utf-8') as f:
11
- self.knowledge = json.load(f).get('clean_knowledge', [])
12
- log.info(f"✅ 知识库加载完成")
13
- except Exception as e:
14
- log.error(f"❌ 知识库加载失败: {e}", exc_info=True)
15
- raise
16
-
17
- def search(self, query: str) -> list:
18
- relevant_knowledge = []
19
- query_lower = query.lower()
20
-
21
- for item in self.knowledge:
22
- # 简单实现:如果查询的城市在知识库的目的地中,则返回该知识
23
- destinations = item.get('knowledge', {}).get('travel_knowledge', {}).get('destination_info', {}).get('primary_destinations', [])
24
- for dest in destinations:
25
- if dest.lower() in query_lower:
26
- if item not in relevant_knowledge:
27
- relevant_knowledge.append(item)
28
- break
29
- return relevant_knowledge
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/response_generator.py CHANGED
@@ -1,45 +0,0 @@
1
- # modules/response_generator.py
2
- from .ai_model import AIModel
3
- from .knowledge_base import KnowledgeBase
4
-
5
- class ResponseGenerator:
6
- def __init__(self, ai_model: AIModel, knowledge_base: KnowledgeBase):
7
- self.ai_model = ai_model
8
- self.kb = knowledge_base
9
-
10
- def generate(self, user_message: str, session_state: dict) -> str:
11
- # 1. 优先使用 RAG (检索增强生成)
12
- # 我们用目的地名称来强化检索查询
13
- search_query = user_message
14
- if session_state.get("destination"):
15
- search_query += f" {session_state['destination']['name']}"
16
-
17
- relevant_knowledge = self.kb.search(search_query)
18
- if relevant_knowledge:
19
- context = self._format_knowledge_context(relevant_knowledge)
20
- return self.ai_model.generate(user_message, context)
21
-
22
- # 2. 如果没有知识库匹配,则使用基于规则的引导式对话
23
- if not session_state.get("destination"):
24
- return "听起来很棒!你想去欧洲的哪个城市呢?比如巴黎, 罗马, 巴塞罗那?"
25
- if not session_state.get("duration"):
26
- return f"好的,{session_state['destination']['name']}是个很棒的选择!你计划玩几天呢?"
27
- if not session_state.get("persona"):
28
- return "最后一个问题,这次旅行对你来说什么最重要呢?(例如:美食、艺术、购物、历史)"
29
-
30
- # 3. 如果信息都收集全了,但没触发RAG,让Gemma生成一个通用计划
31
- plan_prompt = (
32
- f"请为用户生成一个在 {session_state['destination']['name']} 的 "
33
- f"{session_state['duration']['days']} 天旅行计划。"
34
- f"旅行风格侧重于: {session_state['persona']['description']}。"
35
- )
36
- return self.ai_model.generate(plan_prompt, context="用户需要一个详细的旅行计划。")
37
-
38
- def _format_knowledge_context(self, knowledge_items: list) -> str:
39
- if not knowledge_items: return "没有特定的背景知识。"
40
- # 简化处理,只用最相关的一条知识
41
- item = knowledge_items[0]['knowledge']['travel_knowledge']
42
- context = f"相关知识:\n- 目的地: {item['destination_info']['primary_destinations']}\n"
43
- context += f"- 推荐天数: {item['destination_info']['recommended_duration']}天\n"
44
- context += f"- 专业见解: {item['professional_insights']['key_takeaways']}\n"
45
- return context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/session_manager.py CHANGED
@@ -23,10 +23,22 @@ class SessionManager:
23
  self.sessions[session_id].update(updates)
24
 
25
  def format_session_info(self, session_state: dict) -> str:
 
26
  parts = [f"ID: {session_state.get('session_id', 'N/A')}"]
27
- if session_state.get('destination'): parts.append(f"目的地: {session_state['destination']['name']}")
28
- if session_state.get('duration'): parts.append(f"天数: {session_state['duration']['days']}")
29
- if session_state.get('persona'): parts.append(f"风格: {session_state['persona']['name']}")
 
 
 
 
 
 
 
 
 
 
 
30
  return " | ".join(parts)
31
 
32
  def reset(self, session_id: str):
 
23
  self.sessions[session_id].update(updates)
24
 
25
  def format_session_info(self, session_state: dict) -> str:
26
+ """格式化会话状态信息 - 修复版本"""
27
  parts = [f"ID: {session_state.get('session_id', 'N/A')}"]
28
+
29
+ if session_state.get('destination'):
30
+ parts.append(f"目的地: {session_state['destination']['name']}")
31
+
32
+ if session_state.get('duration'):
33
+ parts.append(f"天数: {session_state['duration']['days']}")
34
+
35
+ if session_state.get('persona'):
36
+ # 使用name字段而不是假设存在的其他字段
37
+ persona_name = session_state['persona'].get('name', '未知风格')
38
+ # 清理emoji用于显示
39
+ clean_name = ''.join(char for char in persona_name if not char.startswith(('🗓️', '🤝', '🎨')))
40
+ parts.append(f"风格: {clean_name.strip()}")
41
+
42
  return " | ".join(parts)
43
 
44
  def reset(self, session_id: str):
modules/travel_assistant.py CHANGED
@@ -1,45 +0,0 @@
1
- # modules/travel_assistant.py
2
- from .config_loader import ConfigLoader
3
- from .ai_model import AIModel
4
- from .knowledge_base import KnowledgeBase
5
- from .info_extractor import InfoExtractor
6
- from .session_manager import SessionManager
7
- from .response_generator import ResponseGenerator
8
- from utils.logger import log
9
-
10
- class TravelAssistant:
11
- def __init__(self):
12
- # 依赖注入:在这里实例化所有需要的模块
13
- log.info("开始初始化 Travel Assistant 核心模块...")
14
- self.config = ConfigLoader()
15
- self.kb = KnowledgeBase()
16
- self.ai_model = AIModel()
17
- self.session_manager = SessionManager()
18
- self.info_extractor = InfoExtractor(self.config)
19
- self.response_generator = ResponseGenerator(self.ai_model, self.kb)
20
- log.info("✅ Travel Assistant 核心模块全部初始化完成!")
21
-
22
- def chat(self, message: str, session_id: str, history: list):
23
- # 1. 获取或创建会话
24
- session_state = self.session_manager.get_or_create_session(session_id)
25
- current_session_id = session_state['session_id']
26
-
27
- # 2. 从用户输入中提取信息
28
- extracted_info = self.info_extractor.extract(message)
29
-
30
- # 3. 更新会话状态
31
- if extracted_info:
32
- self.session_manager.update_session(current_session_id, extracted_info)
33
- # 重新获取更新后的状态
34
- session_state = self.session_manager.get_or_create_session(current_session_id)
35
-
36
- # 4. 生成回复
37
- bot_response = self.response_generator.generate(message, session_state)
38
-
39
- # 5. 格式化状态信息用于前端显示
40
- status_info = self.session_manager.format_session_info(session_state)
41
-
42
- # 6. 更新对话历史
43
- new_history = history + [[message, bot_response]]
44
-
45
- return bot_response, current_session_id, status_info, new_history