Spaces:
Sleeping
Sleeping
| import uuid | |
| import threading | |
| import asyncio | |
| import json | |
| import re | |
| from datetime import datetime | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from langgraph.graph import StateGraph, START, END | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| import pickle | |
| import numpy as np | |
| from tools import extract_json_from_response, apply_filters_partial, rule_based_extract, format_property_data, estateKeywords | |
| import random | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.tools import tool | |
| from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager | |
| from langchain_core.callbacks.base import BaseCallbackHandler | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer | |
| class CallbackTextStreamer(TextStreamer): | |
| def __init__(self, tokenizer, callbacks, skip_prompt=True, skip_special_tokens=True): | |
| super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens) | |
| self.callbacks = callbacks | |
| def on_new_token(self, token: str): | |
| for callback in self.callbacks: | |
| callback.on_llm_new_token(token) | |
| class ChatQwen: | |
| def __init__(self, temperature=0.3, streaming=False, max_new_tokens=512, callbacks=None): | |
| self.temperature = temperature | |
| self.streaming = streaming | |
| self.max_new_tokens = max_new_tokens | |
| self.callbacks = callbacks | |
| self.model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| def generate_text(self, messages: list) -> str: | |
| """ | |
| Given a list of messages, create a prompt and generate text using the Qwen model. | |
| In streaming mode, uses a TextIteratorStreamer and iterates over tokens to call callbacks. | |
| """ | |
| # Create prompt from messages using the tokenizer's chat template. | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device) | |
| if self.streaming: | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| # Create the streamer that collects tokens as they are generated. | |
| streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| **model_inputs, | |
| max_new_tokens=self.max_new_tokens, | |
| streamer=streamer, | |
| temperature=self.temperature, | |
| do_sample=True | |
| ) | |
| # Run generation in a separate thread so that we can iterate over tokens. | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| # Iterate over tokens as they arrive. | |
| for token in streamer: | |
| generated_text += token | |
| # Call each callback with the new token. | |
| if self.callbacks: | |
| for callback in self.callbacks: | |
| callback.on_llm_new_token(token) | |
| # In streaming mode you may want to return empty string, | |
| # but here we return the full text if needed. | |
| return generated_text | |
| else: | |
| outputs = self.model.generate( | |
| **model_inputs, | |
| max_new_tokens=self.max_new_tokens, | |
| temperature=self.temperature, | |
| do_sample=True | |
| ) | |
| # Remove the prompt tokens from the output. | |
| prompt_length = model_inputs.input_ids.shape[-1] | |
| generated_ids = outputs[0][prompt_length:] | |
| text_output = self.tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| return text_output | |
| def invoke(self, messages: list, config: dict = None) -> AIMessage: | |
| config = config or {} | |
| # Use provided callbacks if any, otherwise default to the callbacks in the instance. | |
| callbacks = config.get("callbacks", self.callbacks) | |
| original_callbacks = self.callbacks | |
| self.callbacks = callbacks | |
| output_text = self.generate_text(messages) | |
| self.callbacks = original_callbacks | |
| if self.streaming: | |
| return AIMessage(content="") | |
| else: | |
| return AIMessage(content=output_text) | |
| def __call__(self, messages: list) -> AIMessage: | |
| return self.invoke(messages) | |
| class WebSocketStreamingCallbackHandler(BaseCallbackHandler): | |
| def __init__(self, connection_id: str, loop): | |
| self.connection_id = connection_id | |
| self.loop = loop | |
| def on_llm_new_token(self, token: str, **kwargs): | |
| asyncio.run_coroutine_threadsafe( | |
| manager_socket.send_message(self.connection_id, token), | |
| self.loop | |
| ) | |
| llm = ChatQwen(temperature=0.3, streaming=True, max_new_tokens=512) | |
| index = faiss.read_index("./faiss.index") | |
| with open("./metadata.pkl", "rb") as f: | |
| docs = pickle.load(f) | |
| st_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| def make_system_prompt(suffix: str) -> str: | |
| return ( | |
| "You are EstateGuru, a real estate expert created by Abhishek Pathak from SwavishTek. " | |
| "Your role is to help customers buy properties using the available data. " | |
| "Only use the provided data—do not make up any information. " | |
| "The default currency is AED. If a query uses a different currency, convert the amount to AED " | |
| "(for example, $10k becomes 36726.50 AED and $1 becomes 3.67 AED). " | |
| "If a customer is interested in a property, wants to buy, or needs to contact an agent or customer care, " | |
| "instruct them to call +91 8766268285." | |
| f"\n{suffix}" | |
| ) | |
| general_query_prompt = make_system_prompt( | |
| "You are EstateGuru, a helpful real estate assistant. Answer the user's query accurately using the available data. " | |
| "Do not invent any details or go beyond the real estate domain. " | |
| "If the user shows interest in a property or contacting an agent, ask them to call +91 8766268285." | |
| ) | |
| # ------------------------ Tool Definitions ------------------------ | |
| def extract_filters(query: str) -> dict: | |
| """For extracting filters""" | |
| # Use a non-streaming ChatQwen for tool use. | |
| llm_local = ChatQwen(temperature=0.3, streaming=False) | |
| system = ( | |
| "You are an expert in extracting filters from property-related queries. Your task is to extract and return only the keys explicitly mentioned in the query as a valid JSON object (starting with '{' and ending with '}'). Include only those keys that are directly present in the query.\n\n" | |
| "The possible keys are:\n" | |
| " - 'projectName': The name of the project.\n" | |
| " - 'developerName': The developer's name.\n" | |
| " - 'relationshipManager': The relationship manager.\n" | |
| " - 'propertyAddress': The property address.\n" | |
| " - 'surroundingArea': The area or nearby landmarks.\n" | |
| " - 'propertyType': The type or configuration of the property.\n" | |
| " - 'amenities': Any amenities mentioned.\n" | |
| " - 'coveredParking': Parking availability.\n" | |
| " - 'petRules': Pet policies.\n" | |
| " - 'security': Security details.\n" | |
| " - 'occupancyRate': Occupancy information.\n" | |
| " - 'constructionImpact': Construction or its impact.\n" | |
| " - 'propertySize': Size of the property.\n" | |
| " - 'propertyView': View details.\n" | |
| " - 'propertyCondition': Condition of the property.\n" | |
| " - 'serviceCharges': Service or maintenance charges.\n" | |
| " - 'ownershipType': Ownership type.\n" | |
| " - 'totalCosts': A cost threshold or cost amount.\n" | |
| " - 'paymentPlans': Payment or financing plans.\n" | |
| " - 'expectedRentalYield': Expected rental yield.\n" | |
| " - 'rentalHistory': Rental history.\n" | |
| " - 'shortTermRentals': Short-term rental information.\n" | |
| " - 'resalePotential': Resale potential.\n" | |
| " - 'uniqueId': A unique identifier.\n\n" | |
| "Important instructions regarding cost thresholds:\n" | |
| " - If the query contains phrases like 'under 10k', 'below 2m', or 'less than 5k', interpret these as cost thresholds.\n" | |
| " - Convert any shorthand cost values to pure numbers (for example, '10k' becomes 10000, '2m' becomes 2000000) and assign them to the key 'totalCosts'.\n" | |
| " - Do not use 'propertySize' for cost thresholds.\n\n" | |
| " - Default currency is AED, if user query have different currency symbol then convert to equivalent AED amount (eg. $10k becomes 36726.50, $1 becomes 3.67).\n\n" | |
| "Example:\n" | |
| " For the query: \"properties near dubai mall under 43k\"\n" | |
| " The expected output should be:\n" | |
| " { \"surroundingArea\": \"dubai mall\", \"totalCosts\": 43000 }\n\n" | |
| "Return ONLY a valid JSON object with the extracted keys and their corresponding values, with no additional text." | |
| ) | |
| human_str = f"Here is the query:\n{query}" | |
| filter_prompt = [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": human_str}, | |
| ] | |
| response = llm_local.invoke(messages=filter_prompt) | |
| response_text = response.content if isinstance(response, AIMessage) else str(response) | |
| try: | |
| model_filters = extract_json_from_response(response_text) | |
| except Exception as e: | |
| print(f"JSON parsing error: {e}") | |
| model_filters = {} | |
| rule_filters = rule_based_extract(query) | |
| print("Rule-based extraction:", rule_filters) | |
| final_filters = {**model_filters, **rule_filters} | |
| print("Final extraction:", final_filters) | |
| return {"filters": final_filters} | |
| def determine_route(query: str) -> dict: | |
| """For determining route using enhanced prompt and fallback logic.""" | |
| # Define a set of keywords that are strong indicators of a real estate query. | |
| real_estate_keywords = estateKeywords | |
| # Check if the query includes any of the positive signals. | |
| pattern = re.compile("|".join(re.escape(keyword) for keyword in real_estate_keywords), re.IGNORECASE) | |
| positive_signal = bool(pattern.search(query)) | |
| # Proceed with LLM classification regardless, but use the positive signal in fallback. | |
| llm_local = ChatQwen(temperature=0.3, streaming=False) | |
| transform_suggest_to_list = query.lower().replace("suggest ", "list ", -1) | |
| system = """ | |
| Classify the user query as: | |
| - **"search"**: if it requests property listings with specific filters (e.g., location, price, property type like "2bhk", service charges, pet policies, etc.). | |
| - **"suggest"**: if it asks for property suggestions without filters. | |
| - **"detail"**: if it is asking for more information about a previously provided property (for example, "tell me more about property 5" or "I want more information regarding 4BHK"). | |
| - **"general"**: for all other real estate-related questions. | |
| - **"out_of_domain"**: if the query is not related to real estate (for example, tourist attractions, restaurants, etc.). | |
| Keep in mind that queries mentioning terms like "service charge", "allow pets", "pet rules", etc., are considered real estate queries. | |
| Return only the keyword: search, suggest, detail, general, or out_of_domain. | |
| """ | |
| human_str = f"Here is the query:\n{transform_suggest_to_list}" | |
| router_prompt = [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": human_str}, | |
| ] | |
| response = llm_local.invoke(messages=router_prompt) | |
| response_text = response.content if isinstance(response, AIMessage) else str(response) | |
| route_value = str(response_text).strip().lower() | |
| # Fallback: if the query seems like a detailed request, override. | |
| detail_phrases = [ | |
| "more information", | |
| "tell me more", | |
| "more details", | |
| "give me more details", | |
| "i need more details", | |
| "can you provide more details", | |
| "additional details", | |
| "further information", | |
| "expand on that", | |
| "explain further", | |
| "elaborate more", | |
| "more specifics", | |
| "i want to know more", | |
| "could you elaborate", | |
| "need more info", | |
| "provide more details", | |
| "detail it further", | |
| "in-depth information", | |
| "break it down further", | |
| "further explanation" | |
| ] | |
| if any(phrase in query.lower() for phrase in detail_phrases): | |
| route_value = "detail" | |
| if route_value not in {"search", "suggest", "detail", "general", "out_of_domain"}: | |
| route_value = "general" | |
| if route_value == "out_of_domain" and positive_signal: | |
| route_value = "general" | |
| if route_value == "out_of_domain": | |
| route_value = "general" if positive_signal else "out_of_domain" | |
| return {"route": route_value} | |
| # ------------------------ Workflow Setup ------------------------ | |
| workflow = StateGraph(state_schema=dict) | |
| def route_query(state: dict) -> dict: | |
| new_state = state.copy() | |
| try: | |
| new_state["route"] = determine_route.invoke(new_state.get("query", "")).get("route", "general") | |
| print(new_state["route"]) | |
| except Exception as e: | |
| print(f"Routing error: {e}") | |
| new_state["route"] = "general" | |
| return new_state | |
| def hybrid_extract(state: dict) -> dict: | |
| new_state = state.copy() | |
| new_state["filters"] = extract_filters.invoke(new_state.get("query", "")).get("filters", {}) | |
| return new_state | |
| def search_faiss(state: dict) -> dict: | |
| new_state = state.copy() | |
| query_embedding = st_model.encode([state["query"]]) | |
| _, indices = index.search(query_embedding.astype(np.float32), 5) | |
| new_state["faiss_results"] = [docs[idx] for idx in indices[0] if idx < len(docs)] | |
| return new_state | |
| def apply_filters(state: dict) -> dict: | |
| new_state = state.copy() | |
| new_state["final_results"] = apply_filters_partial(state["faiss_results"], state.get("filters", {})) | |
| return new_state | |
| def suggest_properties(state: dict) -> dict: | |
| new_state = state.copy() | |
| new_state["suggestions"] = random.sample(docs, 5) | |
| return new_state | |
| def handle_out_of_domain(state: dict) -> dict: | |
| new_state = state.copy() | |
| new_state["response"] = "I only handle real estate inquiries. Please ask a question related to properties." | |
| return new_state | |
| def generate_response(state: dict) -> dict: | |
| new_state = state.copy() | |
| messages = [] | |
| # Add the general query prompt. | |
| messages.append({"role": "system", "content": general_query_prompt}) | |
| # If this is a detail query, add a system message that forces a detailed answer. | |
| if new_state.get("route", "general") == "detail": | |
| messages.append({ | |
| "role": "system", | |
| "content": ( | |
| "This is a detail query. Please provide detailed information about the property below. " | |
| "Do not generate a new list of properties; only use the provided property details to answer the query. " | |
| "Focus on answering the specific question (for example, whether pets are allowed)." | |
| ) | |
| }) | |
| # If property details are available, add them without clearing context. | |
| if new_state.get("current_properties"): | |
| property_context = format_property_data(new_state["current_properties"]) | |
| messages.append({"role": "system", "content": "Available Property:\n" + property_context}) | |
| # Do NOT clear current_properties here. | |
| messages.append({"role": "system", "content": "When responding, use only the provided property details to answer the user's specific question about the property."}) | |
| # Add the conversation history. | |
| for msg in state.get("messages", []): | |
| if msg["role"] == "user": | |
| messages.append({"role": "user", "content": msg["content"]}) | |
| else: | |
| messages.append({"role": "assistant", "content": msg["content"]}) | |
| # Invoke the LLM with the constructed messages. | |
| connection_id = state.get("connection_id") | |
| loop = state.get("loop") | |
| if connection_id and loop: | |
| print("Yes") | |
| callback_manager = [WebSocketStreamingCallbackHandler(connection_id, loop)] | |
| _ = llm.invoke( | |
| messages, | |
| config={"callbacks": callback_manager} | |
| ) | |
| new_state["response"] = "" | |
| else: | |
| callback_manager = [StreamingStdOutCallbackHandler()] | |
| response = llm.invoke( | |
| messages, | |
| config={"callbacks": callback_manager} | |
| ) | |
| new_state["response"] = response.content if isinstance(response, AIMessage) else str(response) | |
| return new_state | |
| def format_final_response(state: dict) -> dict: | |
| new_state = state.copy() | |
| # Only override the current_properties if this is NOT a detail query. | |
| if not state.get("route", "general") == "detail": | |
| if state.get("route") in ["search", "suggest"]: | |
| if "final_results" in state: | |
| new_state["current_properties"] = state["final_results"] | |
| elif "suggestions" in state: | |
| new_state["current_properties"] = state["suggestions"] | |
| # Then format the response based on the (possibly filtered) current_properties. | |
| if new_state.get("current_properties"): | |
| formatted = [] | |
| for idx, prop in enumerate(new_state["current_properties"], 1): | |
| cost = prop.get("totalCosts", "N/A") | |
| cost_str = f"{cost:,}" if isinstance(cost, (int, float)) else cost | |
| formatted.append( | |
| f"{idx}. Type: {prop['propertyType']}, Cost: AED {cost_str}, " | |
| f"Size: {prop.get('propertySize', 'N/A')}, Amenities: {', '.join(map(str, prop.get('amenities', []))) if prop.get('amenities') else 'N/A'}, " | |
| f"Rental Yield: {prop.get('expectedRentalYield', 'N/A')}, " | |
| f"Ownership: {prop.get('ownershipType', 'N/A')}\n" | |
| ) | |
| aggregated_response = "Here are the property details:\n" + "\n".join(formatted) | |
| connection_id = state.get("connection_id") | |
| loop = state.get("loop") | |
| if connection_id and loop: | |
| import time | |
| tokens = aggregated_response.split(" ") | |
| for token in tokens: | |
| asyncio.run_coroutine_threadsafe( | |
| manager_socket.send_message(connection_id, token + " "), | |
| loop | |
| ) | |
| time.sleep(0.05) | |
| new_state["response"] = "" | |
| else: | |
| new_state["response"] = aggregated_response | |
| elif "response" in new_state: | |
| new_state["response"] = str(new_state["response"]) | |
| return new_state | |
| nodes = [ | |
| ("route_query", route_query), | |
| ("hybrid_extract", hybrid_extract), | |
| ("faiss_search", search_faiss), | |
| ("apply_filters", apply_filters), | |
| ("suggest_properties", suggest_properties), | |
| ("handle_out_of_domain", handle_out_of_domain), | |
| ("generate_response", generate_response), | |
| ("format_response", format_final_response) | |
| ] | |
| for name, node in nodes: | |
| workflow.add_node(name, node) | |
| workflow.add_edge(START, "route_query") | |
| workflow.add_conditional_edges( | |
| "route_query", | |
| lambda state: state.get("route", "general"), | |
| { | |
| "search": "hybrid_extract", | |
| "suggest": "suggest_properties", | |
| "detail": "generate_response", | |
| "general": "generate_response", | |
| "out_of_domain": "handle_out_of_domain" | |
| } | |
| ) | |
| workflow.add_edge("hybrid_extract", "faiss_search") | |
| workflow.add_edge("faiss_search", "apply_filters") | |
| workflow.add_edge("apply_filters", "format_response") | |
| workflow.add_edge("suggest_properties", "format_response") | |
| workflow.add_edge("generate_response", "format_response") | |
| workflow.add_edge("handle_out_of_domain", "format_response") | |
| workflow.add_edge("format_response", END) | |
| workflow_app = workflow.compile() | |
| # ------------------------ Conversation Manager ------------------------ | |
| class ConversationManager: | |
| def __init__(self): | |
| self.conversation_history = [] | |
| self.current_properties = [] | |
| def _add_message(self, role: str, content: str): | |
| self.conversation_history.append({ | |
| "role": role, | |
| "content": content, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| def process_query(self, query: str) -> str: | |
| # Reset context on greetings to avoid using off-domain history | |
| if query.strip().lower() in {"hi", "hello", "hey"}: | |
| self.conversation_history = [] | |
| self.current_properties = [] | |
| greeting_response = "Hello! How can I assist you today with your real estate inquiries?" | |
| self._add_message("assistant", greeting_response) | |
| return greeting_response | |
| try: | |
| self._add_message("user", query) | |
| initial_state = { | |
| "messages": self.conversation_history.copy(), | |
| "query": query, | |
| "route": "general", | |
| "filters": {}, | |
| "current_properties": self.current_properties | |
| } | |
| for event in workflow_app.stream(initial_state, stream_mode="values"): | |
| final_state = event | |
| if 'final_results' in final_state: | |
| self.current_properties = final_state['final_results'] | |
| elif 'suggestions' in final_state: | |
| self.current_properties = final_state['suggestions'] | |
| if final_state.get("route") == "general": | |
| response_text = final_state.get("response", "") | |
| self._add_message("assistant", response_text) | |
| return response_text | |
| else: | |
| response = final_state.get("response", "I couldn't process that request.") | |
| self._add_message("assistant", response) | |
| return response | |
| except Exception as e: | |
| print(f"Processing error: {e}") | |
| return "Sorry, I encountered an error processing your request." | |
| conversation_managers = {} | |
| # ------------------------ FastAPI Backend with WebSockets ------------------------ | |
| app = FastAPI() | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections = {} | |
| async def connect(self, websocket: WebSocket): | |
| await websocket.accept() | |
| connection_id = str(uuid.uuid4()) | |
| self.active_connections[connection_id] = websocket | |
| print(f"New connection: {connection_id}") | |
| return connection_id | |
| def disconnect(self, connection_id: str): | |
| if connection_id in self.active_connections: | |
| del self.active_connections[connection_id] | |
| print(f"Disconnected: {connection_id}") | |
| async def send_message(self, connection_id: str, message: str): | |
| websocket = self.active_connections.get(connection_id) | |
| if websocket: | |
| await websocket.send_text(message) | |
| manager_socket = ConnectionManager() | |
| def stream_query(query: str, connection_id: str, loop): | |
| conv_manager = conversation_managers.get(connection_id) | |
| if conv_manager is None: | |
| print(f"No conversation manager found for connection {connection_id}") | |
| return | |
| # Check for greetings and handle them immediately | |
| if query.strip().lower() in {"hi", "hello", "hey"}: | |
| conv_manager.conversation_history = [] | |
| conv_manager.current_properties = [] | |
| greeting_response = "Hello! How can I assist you today with your real estate inquiries?" | |
| conv_manager._add_message("assistant", greeting_response) | |
| asyncio.run_coroutine_threadsafe( | |
| manager_socket.send_message(connection_id, greeting_response), | |
| loop | |
| ) | |
| return | |
| conv_manager._add_message("user", query) | |
| initial_state = { | |
| "messages": conv_manager.conversation_history.copy(), | |
| "query": query, | |
| "route": "general", | |
| "filters": {}, | |
| "current_properties": conv_manager.current_properties, | |
| "connection_id": connection_id, | |
| "loop": loop | |
| } | |
| try: | |
| workflow_app.invoke(initial_state) | |
| except Exception as e: | |
| error_msg = f"Error processing query: {str(e)}" | |
| asyncio.run_coroutine_threadsafe( | |
| manager_socket.send_message(connection_id, error_msg), | |
| loop | |
| ) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| connection_id = await manager_socket.connect(websocket) | |
| conversation_managers[connection_id] = ConversationManager() | |
| try: | |
| while True: | |
| query = await websocket.receive_text() | |
| loop = asyncio.get_event_loop() | |
| # loop = asyncio.get_running_loop() | |
| threading.Thread( | |
| target=stream_query, | |
| args=(query, connection_id, loop), | |
| daemon=True | |
| ).start() | |
| except WebSocketDisconnect: | |
| conv_manager = conversation_managers.get(connection_id) | |
| if conv_manager: | |
| filename = f"conversations/conversation_{connection_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
| with open(filename, "w") as f: | |
| json.dump(conv_manager.conversation_history, f, indent=4) | |
| del conversation_managers[connection_id] | |
| manager_socket.disconnect(connection_id) | |
| async def post_query(query: str): | |
| conv_manager = ConversationManager() | |
| response = conv_manager.process_query(query) | |
| return {"response": response} | |