File size: 13,644 Bytes
201e72b
 
3d98931
24b6480
4d95fe3
3d98931
 
201e72b
2db64c1
c159b12
5cc5eaf
47ebcc2
349da56
f5bde8f
f9cf2d2
6e8596f
de25c1c
f5bde8f
c159b12
ac9b41a
3d98931
349da56
 
 
3d98931
 
ac9b41a
d357a83
e6b3e57
f5bde8f
e6b3e57
de25c1c
f5bde8f
3d98931
6e8596f
 
335202a
f2fa19f
f9cf2d2
3d98931
f5bde8f
335202a
99fffc4
de25c1c
335202a
 
5e79e19
99fffc4
 
335202a
 
 
f9cf2d2
335202a
 
6e8596f
3d98931
6e8596f
 
3d98931
f9cf2d2
335202a
3d98931
6e8596f
3153536
f5bde8f
 
 
 
2ab0dac
f47ba34
2ab0dac
5e93e49
 
 
 
 
 
 
3d9bef4
5e93e49
3d9bef4
c8f5440
 
 
 
 
 
 
 
 
 
 
 
e750aba
 
c8f5440
 
 
 
 
 
 
 
 
 
 
f8f0c90
2ab0dac
 
 
aad0b76
f47ba34
f5bde8f
2ab0dac
 
 
f47ba34
3d9bef4
2ab0dac
 
 
 
39d5293
2ab0dac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39d5293
2ab0dac
69fa845
2ab0dac
 
3d9bef4
2ab0dac
3d98931
88a1665
c159b12
f47ba34
c159b12
5e93e49
 
 
 
 
 
 
 
 
 
 
 
f47ba34
5e93e49
 
 
 
 
 
 
 
 
 
 
 
 
 
f47ba34
5e93e49
f47ba34
 
c8f5440
 
f47ba34
349da56
c159b12
 
 
 
5e93e49
 
1808f78
c159b12
 
 
 
 
 
 
 
 
 
5554150
 
f5bde8f
f47ba34
1808f78
 
c159b12
 
 
 
f47ba34
 
c159b12
 
 
 
 
 
 
 
5554150
 
 
c159b12
5554150
c159b12
 
 
79ce211
c159b12
79ce211
c159b12
 
 
 
 
 
 
 
 
 
 
 
 
f5bde8f
9e5cd7a
f5bde8f
e583f25
f5bde8f
e583f25
aedcb69
 
3d98931
aedcb69
 
 
 
3d98931
aedcb69
e583f25
f5bde8f
 
 
aedcb69
 
 
5952c14
f5bde8f
9e5cd7a
f5bde8f
 
 
1c9b687
 
 
 
 
 
 
 
 
c245449
201e72b
c159b12
 
 
 
 
 
 
 
 
 
f5bde8f
 
9e5cd7a
 
f5bde8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de25c1c
f5bde8f
 
 
 
 
 
 
 
 
 
 
 
 
e6aebd3
d13d866
 
3d98931
2f6a01e
3d98931
 
f5bde8f
 
 
3d98931
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import os
import logging
import gradio as gr
from fastapi import FastAPI
from langserve import add_routes
from langgraph.graph import StateGraph, START, END
from langchain_core.runnables import RunnableLambda
import uvicorn
import asyncio
import base64
from typing import Optional, List

from utils import getconfig, build_conversation_context
from nodes import (
    detect_file_type_node, ingest_node, direct_output_node, 
    retrieve_node, generate_node_streaming, route_workflow, 
    process_query_streaming
)
from models import GraphState, ChatUIInput, ChatUIFileInput

config = getconfig("params.cfg")
MAX_TURNS = config.getint("conversation_history", "MAX_TURNS")
MAX_CHARS = config.getint("conversation_history", "MAX_CHARS")

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


#----------------------------------------
# LANGGRAPH WORKFLOW SETUP
#----------------------------------------
# Workflow handles direct output caching automatically in nodes

workflow = StateGraph(GraphState)

# Add all nodes
workflow.add_node("detect_file_type", detect_file_type_node)
workflow.add_node("ingest", ingest_node)
workflow.add_node("direct_output", direct_output_node)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node_streaming)

# Simple linear path - node logic handles routing
workflow.add_edge(START, "detect_file_type")
workflow.add_edge("detect_file_type", "ingest")

# Route after ingestion based on direct output mode
# Direct output mode routes to direct_output (return ingestor results)
# Standard mode routes to retrieve + generate (full RAG pipeline)
workflow.add_conditional_edges(
    "ingest",
    route_workflow,
    {"direct_output": "direct_output", "standard": "retrieve"}
)

# Standard RAG path
workflow.add_edge("retrieve", "generate")

# Terminal nodes
workflow.add_edge("generate", END)
workflow.add_edge("direct_output", END)

compiled_graph = workflow.compile()


#----------------------------------------
# CHATUI ADAPTERS
#----------------------------------------

async def chatui_adapter(data):
    """Text-only adapter for ChatUI with structured message support"""
    try:
        # Handle both dict and object access patterns
        if isinstance(data, dict):
            text_value = data.get('text', '')
            messages_value = data.get('messages', None)
            preprompt_value = data.get('preprompt', None)
        else:
            text_value = getattr(data, 'text', '')
            messages_value = getattr(data, 'messages', None)
            preprompt_value = getattr(data, 'preprompt', None)
        
        # Extract query
        
        # Convert dict messages to objects if needed
        messages = []
        for msg in messages_value:
            if isinstance(msg, dict):
                messages.append(type('Message', (), {
                    'role': msg.get('role', 'unknown'),
                    'content': msg.get('content', '')
                })())
            else:
                messages.append(msg)

        logger.info(f"Context: {messages}")
        # Extract latest user query
        user_messages = [msg for msg in messages if msg.role == 'user']
        query = user_messages[-1].content if user_messages else text_value
        
        # Log conversation context
        logger.info(f"Processing query: {query}")
        logger.info(f"Total messages in conversation: {len(messages)}")
        logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}")
        
        # Build conversation context for generation (last N turns)
        conversation_context = build_conversation_context(messages, max_turns=MAX_TURNS, max_chars=MAX_CHARS)
        logger.info(f"Context: {conversation_context}")
        full_response = ""
        sources_collected = None
        
        async for result in process_query_streaming(
            query=query,
            file_upload=None,
            reports_filter="",
            sources_filter="",
            subtype_filter="",
            year_filter="",
            conversation_context=conversation_context
        ):
            if isinstance(result, dict):
                result_type = result.get("type", "data")
                content = result.get("content", "")
                
                if result_type == "data":
                    full_response += content
                    yield content
                elif result_type == "sources":
                    sources_collected = content
                elif result_type == "end":
                    if sources_collected:
                        sources_text = "\n\n**Sources:**\n"
                        for i, source in enumerate(sources_collected, 1):
                            sources_text += f"{i}. [{source.get('title', 'Unknown')}]({source.get('link', '#')})\n"
                        yield sources_text
                elif result_type == "error":
                    yield f"Error: {content}"
            else:
                yield str(result)
            
            await asyncio.sleep(0)
        
    except Exception as e:
        logger.error(f"ChatUI error: {str(e)}")
        logger.error(f"Full traceback:", exc_info=True)
        yield f"Error: {str(e)}"


async def chatui_file_adapter(data):
    """File upload adapter for ChatUI with structured message support"""
    try:
        # Handle both dict and object access patterns
        if isinstance(data, dict):
            text_value = data.get('text', '')
            messages_value = data.get('messages', None)
            files_value = data.get('files', None)
            preprompt_value = data.get('preprompt', None)
        else:
            text_value = getattr(data, 'text', '')
            messages_value = getattr(data, 'messages', None)
            files_value = getattr(data, 'files', None)
            preprompt_value = getattr(data, 'preprompt', None)
        
        # Extract query - prefer structured messages
        if messages_value and len(messages_value) > 0:
            logger.info("✓ Using structured messages")
            
            # Convert dict messages to objects
            messages = []
            for msg in messages_value:
                if isinstance(msg, dict):
                    messages.append(type('Message', (), {
                        'role': msg.get('role', 'unknown'),
                        'content': msg.get('content', '')
                    })())
                else:
                    messages.append(msg)
            
            user_messages = [msg for msg in messages if msg.role == 'user']
            query = user_messages[-1].content if user_messages else text_value
            
            logger.info(f"Processing query: {query}")
            logger.info(f"Total messages in conversation: {len(messages)}")
            logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}")
            
            conversation_context = build_conversation_context(messages, max_turns=MAX_TURNS, max_chars=MAX_CHARS)
        
        file_content = None
        filename = None
        
        if files_value and len(files_value) > 0:
            file_info = files_value[0]
            logger.info(f"Processing file: {file_info.get('name', 'unknown')}")
            
            if file_info.get('type') == 'base64' and file_info.get('content'):
                try:
                    file_content = base64.b64decode(file_info['content'])
                    filename = file_info.get('name', 'uploaded_file')
                except Exception as e:
                    logger.error(f"Error decoding base64 file: {str(e)}")
                    yield f"Error: Failed to decode uploaded file - {str(e)}"
                    return
        
        sources_collected = None
        
        async for result in process_query_streaming(
            query=query,
            file_content=file_content,
            filename=filename,
            reports_filter="",
            sources_filter="",
            subtype_filter="",
            year_filter="",
            output_format="structured",
            conversation_context=conversation_context
        ):
            if isinstance(result, dict):
                result_type = result.get("type", "data")
                content = result.get("content", "")
                
                if result_type == "data":
                    yield content
                elif result_type == "sources":
                    sources_collected = content
                elif result_type == "end":
                    if sources_collected:
                        sources_text = "\n\n**Sources:**\n"
                        for i, source in enumerate(sources_collected, 1):
                            if isinstance(source, dict):
                                title = source.get('title', 'Unknown')
                                link = source.get('link', '#')
                                sources_text += f"{i}. [{title}]({link})\n"
                            else:
                                sources_text += f"{i}. {source}\n"
                        yield sources_text
                elif result_type == "error":
                    yield f"Error: {content}"
            else:
                yield str(result)
            
            await asyncio.sleep(0)
        
    except Exception as e:
        logger.error(f"ChatUI file adapter error: {str(e)}")
        yield f"Error: {str(e)}"


#----------------------------------------
# FASTAPI SETUP - for future use
#----------------------------------------

app = FastAPI(title="ChatFed Orchestrator", version="1.0.0")

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.get("/")
async def root():
    return {
        "message": "ChatFed Orchestrator API",
        "endpoints": {
            "health": "/health",
            "chatfed-ui-stream": "/chatfed-ui-stream (LangServe)",  
            "chatfed-with-file-stream": "/chatfed-with-file-stream (LangServe)",
            "gradio": "/gradio"
        }
    }


#----------------------------------------
# LANGSERVE ROUTES - endpoints for ChatUI
#----------------------------------------

# Text-only endpoint
add_routes(
    app,
    RunnableLambda(chatui_adapter),
    path="/chatfed-ui-stream",
    input_type=ChatUIInput,
    output_type=str,
    enable_feedback_endpoint=True,
    enable_public_trace_link_endpoint=True,
)

# File upload endpoint
add_routes(
    app,
    RunnableLambda(chatui_file_adapter),
    path="/chatfed-with-file-stream",
    input_type=ChatUIFileInput,
    output_type=str,
    enable_feedback_endpoint=True,
    enable_public_trace_link_endpoint=True,
)


#----------------------------------------
# GRADIO INTERFACE - for local testing
# ACCESS: https://[ORG_NAME]-[SPACE_NAME].hf.space/gradio/
#----------------------------------------

def create_gradio_interface():
    with gr.Blocks(title="ChatFed Orchestrator") as demo:
        gr.Markdown("# ChatFed Orchestrator")
        gr.Markdown("Upload documents (PDF/DOCX/GeoJSON) alongside your queries for enhanced context.")
        
        with gr.Row():
            with gr.Column():
                query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
                file_input = gr.File(
                    label="Upload Document (PDF/DOCX/GeoJSON)", 
                    file_types=[".pdf", ".docx", ".geojson", ".json"]
                )
                
                with gr.Accordion("Filters (Optional)", open=False):
                    reports_filter = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
                    sources_filter = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
                    subtype_filter = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
                    year_filter = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
                
                submit_btn = gr.Button("Submit", variant="primary")
            
            with gr.Column():
                output = gr.Textbox(label="Response", lines=15, show_copy_button=True)

        async def gradio_handler(query, file, reports, sources, subtype, year):
            """Handler for Gradio interface"""
            result = ""
            async for chunk in process_query_streaming(
                query=query,
                file_upload=file,
                reports_filter=reports,
                sources_filter=sources,
                subtype_filter=subtype,
                year_filter=year,
                output_format="gradio"
            ):
                result = chunk
                yield result
        
        submit_btn.click(
            fn=gradio_handler,
            inputs=[query_input, file_input, reports_filter, sources_filter, subtype_filter, year_filter],
            outputs=output,
        )
    
    return demo


#----------------------------------------

if __name__ == "__main__":
    demo = create_gradio_interface()
    app = gr.mount_gradio_app(app, demo, path="/gradio")
    
    host = os.getenv("HOST", "0.0.0.0")  
    port = int(os.getenv("PORT", "7860"))
    
    logger.info(f"Starting ChatFed Orchestrator on {host}:{port}")
    logger.info(f"Gradio UI: http://{host}:{port}/gradio")
    logger.info(f"API Docs: http://{host}:{port}/docs")
    
    uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)