MagicDash commited on
Commit
bf7032a
·
verified ·
1 Parent(s): 4a43403

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +669 -686
app.py CHANGED
@@ -1,686 +1,669 @@
1
- import pandas as pd
2
- import seaborn as sns
3
- import matplotlib
4
- import matplotlib.pyplot as plt
5
- matplotlib.use('Agg')
6
- import numpy as np
7
- import google.generativeai as genai
8
- from PIL import Image
9
- from werkzeug.utils import secure_filename
10
- import os
11
- import json
12
- from fpdf import FPDF
13
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException
14
- from fastapi.responses import HTMLResponse, FileResponse
15
- from fastapi.staticfiles import StaticFiles
16
- from fastapi.templating import Jinja2Templates
17
- from starlette.requests import Request
18
- from typing import List
19
- import textwrap
20
- from IPython.display import display, Markdown
21
- from PIL import Image
22
- import shutil
23
- from werkzeug.utils import secure_filename
24
- import urllib.parse
25
- import re
26
- from langchain_google_genai import ChatGoogleGenerativeAI
27
- from langchain_community.document_loaders import PyPDFLoader, UnstructuredCSVLoader, UnstructuredExcelLoader, Docx2txtLoader, UnstructuredPowerPointLoader
28
- from langchain.chains import StuffDocumentsChain
29
- from langchain.chains.llm import LLMChain
30
- from langchain.prompts import PromptTemplate
31
- from langchain.vectorstores import FAISS
32
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
33
- from langchain.text_splitter import CharacterTextSplitter
34
-
35
- app = FastAPI()
36
- app.mount("/static", StaticFiles(directory="static"), name="static")
37
- templates = Jinja2Templates(directory="templates")
38
-
39
- sns.set_theme(color_codes=True)
40
- uploaded_df = None
41
- document_analyzed = False
42
- question_responses = []
43
-
44
-
45
- def format_text(text):
46
- # Replace **text** with <b>text</b>
47
- text = re.sub(r'\*\*(.*?)\*\*', r'<b>\1</b>', text)
48
- # Replace any remaining * with <br>
49
- text = text.replace('*', '<br>')
50
- return text
51
-
52
- def clean_data(df):
53
- # Step 1: Clean currency-related columns
54
- for col in df.columns:
55
- if any(x in col.lower() for x in ['value', 'price', 'cost', 'amount']):
56
- if df[col].dtype == 'object':
57
- df[col] = df[col].str.replace('$', '').str.replace('£', '').str.replace('€', '').replace('[^\d.-]', '', regex=True).astype(float)
58
-
59
- # Step 2: Drop columns with more than 25% missing values
60
- null_percentage = df.isnull().sum() / len(df)
61
- columns_to_drop = null_percentage[null_percentage > 0.25].index
62
- df.drop(columns=columns_to_drop, inplace=True)
63
-
64
- # Step 3: Fill missing values for remaining columns
65
- for col in df.columns:
66
- if df[col].isnull().sum() > 0:
67
- if null_percentage[col] <= 0.25:
68
- if df[col].dtype in ['float64', 'int64']:
69
- median_value = df[col].median()
70
- df[col].fillna(median_value, inplace=True)
71
-
72
- # Step 4: Convert object-type columns to lowercase
73
- for col in df.columns:
74
- if df[col].dtype == 'object':
75
- df[col] = df[col].str.lower()
76
-
77
- # Step 5: Drop columns with only one unique value
78
- unique_value_columns = [col for col in df.columns if df[col].nunique() == 1]
79
- df.drop(columns=unique_value_columns, inplace=True)
80
-
81
- return df
82
-
83
-
84
-
85
-
86
- def clean_data2(df):
87
- for col in df.columns:
88
- if 'value' in col or 'price' in col or 'cost' in col or 'amount' in col or 'Value' in col or 'Price' in col or 'Cost' in col or 'Amount' in col:
89
- if df[col].dtype == 'object':
90
- df[col] = df[col].str.replace('$', '')
91
- df[col] = df[col].str.replace('£', '')
92
- df[col] = df[col].str.replace('€', '')
93
- df[col] = df[col].replace('[^\d.-]', '', regex=True).astype(float)
94
-
95
- null_percentage = df.isnull().sum() / len(df)
96
-
97
- for col in df.columns:
98
- if df[col].isnull().sum() > 0:
99
- if null_percentage[col] <= 0.25:
100
- if df[col].dtype in ['float64', 'int64']:
101
- median_value = df[col].median()
102
- df[col].fillna(median_value, inplace=True)
103
-
104
- for col in df.columns:
105
- if df[col].dtype == 'object':
106
- df[col] = df[col].str.lower()
107
-
108
- return df
109
-
110
-
111
-
112
- def generate_plot(df, plot_path, plot_type):
113
- df = clean_data(df)
114
- excluded_words = ["name", "postal", "date", "phone", "address", "code", "id"]
115
-
116
- if plot_type == 'countplot':
117
- cat_vars = [col for col in df.select_dtypes(include='object').columns
118
- if all(word not in col.lower() for word in excluded_words) and df[col].nunique() > 1]
119
-
120
- for col in cat_vars:
121
- if df[col].nunique() > 10:
122
- top_categories = df[col].value_counts().index[:10]
123
- df[col] = df[col].apply(lambda x: x if x in top_categories else 'Other')
124
-
125
- num_cols = len(cat_vars)
126
- num_rows = (num_cols + 1) // 2
127
- fig, axs = plt.subplots(nrows=num_rows, ncols=2, figsize=(15, 5*num_rows))
128
- axs = axs.flatten()
129
-
130
- for i, var in enumerate(cat_vars):
131
- category_counts = df[var].value_counts()
132
- top_values = category_counts.index[:10][::-1]
133
- filtered_df = df.copy()
134
- filtered_df[var] = pd.Categorical(filtered_df[var], categories=top_values, ordered=True)
135
- sns.countplot(x=var, data=filtered_df, order=top_values, ax=axs[i])
136
- axs[i].set_title(var)
137
- axs[i].tick_params(axis='x', rotation=30)
138
-
139
- total = len(filtered_df[var])
140
- for p in axs[i].patches:
141
- height = p.get_height()
142
- axs[i].annotate(f'{height/total:.1%}', (p.get_x() + p.get_width() / 2., height), ha='center', va='bottom')
143
-
144
- sample_size = filtered_df.shape[0]
145
- axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
146
-
147
- for i in range(num_cols, len(axs)):
148
- fig.delaxes(axs[i])
149
-
150
- elif plot_type == 'histplot':
151
- num_vars = [col for col in df.select_dtypes(include=['int', 'float']).columns
152
- if all(word not in col.lower() for word in excluded_words)]
153
- num_cols = len(num_vars)
154
- num_rows = (num_cols + 2) // 3
155
- fig, axs = plt.subplots(nrows=num_rows, ncols=min(3, num_cols), figsize=(15, 5*num_rows))
156
- axs = axs.flatten()
157
-
158
- plot_index = 0
159
-
160
- for i, var in enumerate(num_vars):
161
- if len(df[var].unique()) == len(df):
162
- fig.delaxes(axs[plot_index])
163
- else:
164
- sns.histplot(df[var], ax=axs[plot_index], kde=True, stat="percent")
165
- axs[plot_index].set_title(var)
166
- axs[plot_index].set_xlabel('')
167
-
168
- sample_size = df.shape[0]
169
- axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
170
-
171
- plot_index += 1
172
-
173
- for i in range(plot_index, len(axs)):
174
- fig.delaxes(axs[i])
175
-
176
- fig.tight_layout()
177
- fig.savefig(plot_path)
178
- plt.close(fig)
179
- return plot_path
180
-
181
- @app.get("/", response_class=HTMLResponse)
182
- async def upload_file(request: Request):
183
- return templates.TemplateResponse("upload.html", {"request": request})
184
-
185
- @app.post("/result")
186
- async def result(request: Request,
187
- api_key: str = Form(...),
188
- file: UploadFile = File(...),
189
- custom_question: str = Form(...)):
190
- global uploaded_df, uploaded_filename, plot1_path, plot2_path, response1, response2, api, question, uploaded_file
191
-
192
- api = api_key
193
- uploaded_file = file
194
-
195
- if file.filename == '':
196
- raise HTTPException(status_code=400, detail="No file selected")
197
-
198
- # Secure and validate the file name
199
- uploaded_filename = secure_filename(file.filename)
200
-
201
- # Determine file path based on file type
202
- if uploaded_filename.endswith('.csv'):
203
- file_path = 'dataset.csv'
204
- # Save the file
205
- with open(file_path, 'wb') as buffer:
206
- shutil.copyfileobj(file.file, buffer)
207
- # Read the file into a DataFrame
208
- df = pd.read_csv(file_path, encoding='utf-8')
209
-
210
- elif uploaded_filename.endswith('.xlsx'):
211
- file_path = 'dataset.xlsx'
212
- # Save the file
213
- with open(file_path, 'wb') as buffer:
214
- shutil.copyfileobj(file.file, buffer)
215
- # Read the file into a DataFrame
216
- df = pd.read_excel(file_path)
217
-
218
- else:
219
- raise HTTPException(status_code=400, detail="Unsupported file format")
220
-
221
- columns = df.columns.tolist()
222
-
223
- def generate_gemini_response(plot_path):
224
- global question
225
- question = custom_question
226
- genai.configure(api_key=api)
227
- img = Image.open(plot_path)
228
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
229
- response = model.generate_content([
230
- question + " As a marketing consultant, I want to understand consumer insights based on the chart and the market context so I can use the key findings to formulate actionable insights",
231
- img
232
- ])
233
- response.resolve()
234
- return response.text
235
-
236
- plot1_path = generate_plot(df, 'static/plot1.png', 'countplot')
237
- plot2_path = generate_plot(df, 'static/plot2.png', 'histplot')
238
-
239
- response1 = (generate_gemini_response(plot1_path))
240
- response2 = (generate_gemini_response(plot2_path))
241
-
242
- uploaded_df = df
243
-
244
- outputs = {
245
- "barchart_visualization": plot1_path,
246
- "gemini_response1": response1,
247
- "histoplot_visualization": plot2_path,
248
- "gemini_response2": response2
249
- }
250
-
251
- with open("output.json", "w") as outfile:
252
- json.dump(outputs, outfile)
253
-
254
- def safe_encode(text):
255
- try:
256
- return text.encode('latin1', errors='replace').decode('latin1')
257
- except Exception as e:
258
- return f"Error encoding text: {str(e)}"
259
-
260
- pdf = FPDF()
261
- pdf.set_font("Arial", size=12)
262
-
263
- # Single Countplot Barchart and response
264
- pdf.add_page()
265
- pdf.cell(200, 10, txt="Single Countplot Barchart", ln=True, align='C')
266
- pdf.image(plot1_path, x=10, y=30, w=190)
267
- pdf.add_page()
268
- pdf.cell(200, 10, txt="Single Countplot Barchart Google Gemini Response", ln=True, align='C')
269
- pdf.ln(10)
270
- pdf.multi_cell(0, 10, safe_encode(response1))
271
-
272
- # Single Histplot and response
273
- pdf.add_page()
274
- pdf.cell(200, 10, txt="Single Histplot", ln=True, align='C')
275
- pdf.image(plot2_path, x=10, y=30, w=190)
276
- pdf.add_page()
277
- pdf.cell(200, 10, txt="Single Histplot Google Gemini Response", ln=True, align='C')
278
- pdf.ln(10)
279
- pdf.multi_cell(0, 10, safe_encode(response2))
280
-
281
- pdf_output_path = 'static/analysis_report.pdf'
282
- pdf.output(pdf_output_path)
283
-
284
- return templates.TemplateResponse("upload.html", {
285
- "request": request,
286
- "response1": response1,
287
- "response2": response2,
288
- "plot1_path": plot1_path,
289
- "plot2_path": plot2_path,
290
- "columns": columns})
291
-
292
- @app.get("/download_pdf")
293
- async def download_pdf():
294
- pdf_output_path = 'static/analysis_report.pdf'
295
- return FileResponse(pdf_output_path, media_type='application/pdf', filename=os.path.basename(pdf_output_path))
296
-
297
-
298
-
299
-
300
-
301
- @app.post("/streamlit")
302
- async def streamlit(request: Request,
303
- target_variable: str = Form(...),
304
- columns_for_analysis: List[str] = Form(...)):
305
- global uploaded_df, uploaded_filename, plot1_path, plot2_path, response1, response2, api, question, document_analyzed, plot3_path, plot4_path, response3, response4
306
- target_variable_html = None
307
- columns_for_analysis_html = None
308
- response3 = None
309
- response4 = None
310
- plot3_path = None
311
- plot4_path = None
312
-
313
-
314
- if uploaded_df is None:
315
- raise HTTPException(status_code=400, detail="No CSV file uploaded")
316
-
317
-
318
- df = uploaded_df
319
-
320
- # Process the uploaded file
321
- if uploaded_filename.endswith('.csv'):
322
- df = pd.read_csv('dataset.csv', encoding='utf-8')
323
- elif uploaded_filename.endswith('.xlsx'):
324
- df = pd.read_excel('dataset.xlsx')
325
-
326
- # Select the target variable and columns for analysis from the original DataFrame
327
- target_variable_data = df[target_variable]
328
- columns_for_analysis_data = df[columns_for_analysis]
329
-
330
- # Concatenate target variable and columns for analysis into a single DataFrame
331
- df = pd.concat([target_variable_data, columns_for_analysis_data], axis=1)
332
-
333
- # Clean the data (if needed)
334
- df = clean_data2(df)
335
-
336
- # Generate visualizations
337
-
338
- # Multiclass Barplot
339
- excluded_words = ["name", "postal", "date", "phone", "address", "id"]
340
-
341
- # Get the names of all columns with data type 'object' (categorical variables)
342
- cat_vars = [col for col in df.select_dtypes(include=['object']).columns
343
- if all(word not in col.lower() for word in excluded_words)]
344
-
345
- # Exclude the target variable from the list if it exists in cat_vars
346
- if target_variable in cat_vars:
347
- cat_vars.remove(target_variable)
348
-
349
- # Create a figure with subplots, but only include the required number of subplots
350
- num_cols = len(cat_vars)
351
- num_rows = (num_cols + 2) // 3 # To make sure there are enough rows for the subplots
352
- fig, axs = plt.subplots(nrows=num_rows, ncols=3, figsize=(15, 5*num_rows))
353
- axs = axs.flatten()
354
-
355
- # Create a count plot for each categorical variable
356
- for i, var in enumerate(cat_vars):
357
- top_categories = df[var].value_counts().nlargest(5).index
358
- filtered_df = df[df[var].notnull() & df[var].isin(top_categories)] # Exclude rows with NaN values in the variable
359
-
360
- # Replace less frequent categories with "Other" if there are more than 5 unique values
361
- if df[var].nunique() > 5:
362
- other_categories = df[var].value_counts().index[5:]
363
- filtered_df[var] = filtered_df[var].apply(lambda x: x if x in top_categories else 'Other')
364
-
365
- sns.countplot(x=var, hue=target_variable, stat="percent", data=filtered_df, ax=axs[i])
366
- axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=45)
367
-
368
- # Change y-axis label to represent percentage
369
- axs[i].set_ylabel('Percentage')
370
-
371
- # Annotate the subplot with sample size
372
- sample_size = df.shape[0]
373
- axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
374
-
375
- # Remove any remaining blank subplots
376
- for i in range(num_cols, len(axs)):
377
- fig.delaxes(axs[i])
378
-
379
- plt.xticks(rotation=45)
380
- plt.tight_layout()
381
- plot3_path = "static/multiclass_barplot.png"
382
- plt.savefig(plot3_path)
383
- plt.close(fig)
384
-
385
-
386
- # Multiclass Histplot
387
- # Get the names of all columns with data type 'object' (categorical columns)
388
- cat_cols = df.columns.tolist()
389
-
390
- # Get the names of all columns with data type 'int'
391
- int_vars = df.select_dtypes(include=['int', 'float']).columns.tolist()
392
- int_vars = [col for col in int_vars if col != target_variable]
393
-
394
- # Create a figure with subplots
395
- num_cols = len(int_vars)
396
- num_rows = (num_cols + 2) // 3 # To make sure there are enough rows for the subplots
397
- fig, axs = plt.subplots(nrows=num_rows, ncols=3, figsize=(15, 5*num_rows))
398
- axs = axs.flatten()
399
-
400
- # Create a histogram for each integer variable with hue='Attrition'
401
- for i, var in enumerate(int_vars):
402
- top_categories = df[var].value_counts().nlargest(10).index
403
- filtered_df = df[df[var].notnull() & df[var].isin(top_categories)]
404
- sns.histplot(data=df, x=var, hue=target_variable, kde=True, ax=axs[i], stat="percent")
405
- axs[i].set_title(var)
406
-
407
- # Annotate the subplot with sample size
408
- sample_size = df.shape[0]
409
- axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
410
-
411
- # Remove any extra empty subplots if needed
412
- if num_cols < len(axs):
413
- for i in range(num_cols, len(axs)):
414
- fig.delaxes(axs[i])
415
-
416
- # Adjust spacing between subplots
417
- fig.tight_layout()
418
- plt.xticks(rotation=45)
419
- plot4_path = "static/multiclass_histplot.png"
420
- plt.savefig(plot4_path)
421
- plt.close(fig)
422
-
423
- #response 3
424
- def to_markdown(text):
425
- text = text.replace('•', ' *')
426
- return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
427
-
428
- genai.configure(api_key=api)
429
-
430
- import PIL.Image
431
-
432
- img = PIL.Image.open("static/multiclass_barplot.png")
433
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
434
- response = model.generate_content(img)
435
- response = model.generate_content([question + "As a marketing consulant, I want to understand consumer insighst based on the chart and the market context so I can use the key findings to formulate actionable insights", img])
436
- response.resolve()
437
-
438
-
439
-
440
- #response 4
441
- def to_markdown(text):
442
- text = text.replace('•', ' *')
443
- return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
444
-
445
- genai.configure(api_key=api)
446
-
447
- import PIL.Image
448
-
449
- img = PIL.Image.open("static/multiclass_histplot.png")
450
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
451
- response5 = model.generate_content(img)
452
- response5 = model.generate_content([question + "As a marketing consulant, I want to understand consumer insighst based on the chart and the market context so I can use the key findings to formulate actionable insights", img])
453
- response5.resolve()
454
-
455
- # Generate Google Gemini responses
456
- response3 = response.text
457
- response4 = response5.text
458
- document_analyzed = True
459
-
460
- # Create a dictionary to store the outputs
461
- outputs = {
462
- "barchart_visualization": plot1_path,
463
- "gemini_response1": response1,
464
- "histoplot_visualization": plot2_path,
465
- "gemini_response2": response2,
466
- "multiBarchart_visualization": plot3_path,
467
- "gemini_response3": response3,
468
- "multiHistoplot_visualization": plot4_path,
469
- "gemini_response4": response4
470
- }
471
-
472
- # Save the dictionary as a JSON file
473
- with open("output1.json", "w") as outfile:
474
- json.dump(outputs, outfile)
475
-
476
-
477
-
478
- # Function to handle encoding to latin1
479
- def safe_encode(text):
480
- try:
481
- return text.encode('latin1', errors='replace').decode('latin1') # Replace invalid characters
482
- except Exception as e:
483
- return f"Error encoding text: {str(e)}"
484
-
485
-
486
-
487
- # Generate PDF with the results
488
- pdf = FPDF()
489
- pdf.set_font("Arial", size=12)
490
-
491
- # Single Countplot Barchart and response
492
- pdf.add_page()
493
- pdf.cell(200, 10, txt="Single Countplot Barchart", ln=True, align='C')
494
- pdf.image(plot1_path, x=10, y=30, w=190)
495
- pdf.add_page()
496
- pdf.cell(200, 10, txt="Single Countplot Barchart Google Gemini Response", ln=True, align='C')
497
- pdf.ln(10)
498
- pdf.multi_cell(0, 10, safe_encode(response1))
499
-
500
- # Single Histplot and response
501
- pdf.add_page()
502
- pdf.cell(200, 10, txt="Single Histplot", ln=True, align='C')
503
- pdf.image(plot2_path, x=10, y=30, w=190)
504
- pdf.add_page()
505
- pdf.cell(200, 10, txt="Single Histplot Google Gemini Response", ln=True, align='C')
506
- pdf.ln(10)
507
- pdf.multi_cell(0, 10, safe_encode(response2))
508
-
509
- # Multiclass Countplot Barchart and response
510
- pdf.add_page()
511
- pdf.cell(200, 10, txt="Multiclass Countplot Barchart", ln=True, align='C')
512
- pdf.image(plot3_path, x=10, y=30, w=190)
513
- pdf.add_page()
514
- pdf.cell(200, 10, txt="Multiclass Countplot Barchart Google Gemini Response", ln=True, align='C')
515
- pdf.ln(10)
516
- pdf.multi_cell(0, 10, safe_encode(response3))
517
-
518
- # Multiclass Histplot and response
519
- pdf.add_page()
520
- pdf.cell(200, 10, txt="Multiclass Histplot", ln=True, align='C')
521
- pdf.image(plot4_path, x=10, y=30, w=190)
522
- pdf.add_page()
523
- pdf.cell(200, 10, txt="Multiclass Histplot Google Gemini Response", ln=True, align='C')
524
- pdf.ln(10)
525
- pdf.multi_cell(0, 10, safe_encode(response4))
526
-
527
-
528
- pdf_output_path = 'static/analysis_report_complete.pdf'
529
- pdf.output(pdf_output_path)
530
-
531
-
532
-
533
- return templates.TemplateResponse("upload.html", {
534
- "request": request,
535
- "plot1_path": plot1_path,
536
- "response1": response1,
537
- "plot2_path": plot2_path,
538
- "response2": response2,
539
- "plot3_path": plot3_path,
540
- "response3": response3,
541
- "plot4_path": plot4_path,
542
- "response4": response4,
543
- "show_conversation": document_analyzed,
544
- "question_responses": question_responses
545
- })
546
-
547
-
548
-
549
-
550
- @app.get('/download_pdf2')
551
- async def download_pdf2():
552
- pdf_output_path2 = 'static/analysis_report_complete.pdf'
553
- return FileResponse(pdf_output_path2, media_type='application/pdf', filename='analysis_report_complete.pdf')
554
-
555
-
556
- # Route for asking questions
557
- @app.post("/ask", response_class=HTMLResponse)
558
- async def ask_question(request: Request, question: str = Form(...)):
559
- global uploaded_filename, question_responses, api
560
- global plot1_path, plot2_path, plot3_path, plot4_path
561
- global response1, response2, response3, response4
562
- global document_analyzed
563
-
564
- # Check if a file has been uploaded
565
- if not uploaded_filename:
566
- raise HTTPException(status_code=400, detail="No file has been uploaded yet.")
567
-
568
- # Initialize the LLM model
569
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=api)
570
-
571
- # Determine the file extension and select the appropriate loader
572
- file_path = ''
573
- loader = None
574
-
575
- if uploaded_filename.endswith('.csv'):
576
- file_path = 'dataset.csv'
577
- loader = UnstructuredCSVLoader(file_path, mode="elements")
578
- elif uploaded_filename.endswith('.xlsx'):
579
- file_path = 'dataset.xlsx'
580
- loader = UnstructuredExcelLoader(file_path, mode="elements")
581
- else:
582
- raise HTTPException(status_code=400, detail="Unsupported file format")
583
-
584
- # Load and process the document
585
- try:
586
- docs = loader.load()
587
- except Exception as e:
588
- raise HTTPException(status_code=500, detail=f"Error loading document: {str(e)}")
589
-
590
- # Combine document text
591
- text = "\n".join([doc.page_content for doc in docs])
592
- os.environ["GOOGLE_API_KEY"] = api
593
-
594
- # Initialize embeddings and create FAISS vector store
595
- embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
596
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
597
- chunks = text_splitter.split_text(text)
598
- document_search = FAISS.from_texts(chunks, embeddings)
599
-
600
- # Generate query embedding and perform similarity search
601
- query_embedding = embeddings.embed_query(question)
602
- results = document_search.similarity_search_by_vector(query_embedding, k=3)
603
-
604
- if results:
605
- retrieved_texts = " ".join([result.page_content for result in results])
606
-
607
- # Define the Summarize Chain for the question
608
- latest_conversation = request.cookies.get("latest_question_response", "")
609
- template1 = (
610
- f"{question} Answer the question based on the following:\n\"{text}\"\n:" +
611
- (f" Answer the Question with only 3 sentences. Latest conversation: {latest_conversation}" if latest_conversation else "")
612
- )
613
- prompt1 = PromptTemplate.from_template(template1)
614
-
615
- # Initialize the LLMChain with the prompt
616
- llm_chain1 = LLMChain(llm=llm, prompt=prompt1)
617
-
618
- # Invoke the chain to get the summary
619
- try:
620
- response_chain = llm_chain1.invoke({"text": text})
621
- summary1 = response_chain["text"]
622
- except Exception as e:
623
- raise HTTPException(status_code=500, detail=f"Error invoking LLMChain: {str(e)}")
624
-
625
- # Generate embeddings for the summary
626
- try:
627
- summary_embedding = embeddings.embed_query(summary1)
628
- document_search = FAISS.from_texts([summary1], embeddings)
629
- except Exception as e:
630
- raise HTTPException(status_code=500, detail=f"Error generating embeddings: {str(e)}")
631
-
632
- # Perform a search on the FAISS vector database
633
- try:
634
- if document_search:
635
- query_embedding = embeddings.embed_query(question)
636
- results = document_search.similarity_search_by_vector(query_embedding, k=1)
637
-
638
- if results:
639
- current_response = format_text(results[0].page_content)
640
- else:
641
- current_response = "No matching document found in the database."
642
- else:
643
- current_response = "Vector database not initialized."
644
- except Exception as e:
645
- raise HTTPException(status_code=500, detail=f"Error during similarity search: {str(e)}")
646
- else:
647
- current_response = "No relevant results found."
648
-
649
- # Append the question and response from FAISS search
650
- current_question = f"You asked: {question}"
651
- question_responses.append((current_question, current_response))
652
-
653
- # Save all results to output_summary.json
654
- save_to_json(question_responses)
655
-
656
- # Prepare the response to render the HTML template
657
- response = templates.TemplateResponse("upload.html", {
658
- "request": request,
659
- "plot1_path": plot1_path,
660
- "response1": response1,
661
- "plot2_path": plot2_path,
662
- "response2": response2,
663
- "plot3_path": plot3_path,
664
- "response3": response3,
665
- "plot4_path": plot4_path,
666
- "response4": response4,
667
- "show_conversation": document_analyzed,
668
- "question_responses": question_responses,
669
- })
670
- response.set_cookie(key="latest_question_response", value=current_response)
671
- return response
672
-
673
-
674
-
675
- def save_to_json(question_responses):
676
- outputs = {
677
- "question_responses": question_responses
678
- }
679
- with open("output_summary.json", "w") as outfile:
680
- json.dump(outputs, outfile)
681
-
682
-
683
-
684
- if __name__ == "__main__":
685
- import uvicorn
686
- uvicorn.run(app, host="127.0.0.1", port=8000)
 
1
+ import pandas as pd
2
+ import seaborn as sns
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ matplotlib.use('Agg')
6
+ import numpy as np
7
+ import google.generativeai as genai
8
+ from PIL import Image
9
+ from werkzeug.utils import secure_filename
10
+ import os
11
+ import json
12
+ from fpdf import FPDF
13
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
14
+ from fastapi.responses import HTMLResponse, FileResponse
15
+ from fastapi.staticfiles import StaticFiles
16
+ from fastapi.templating import Jinja2Templates
17
+ from starlette.requests import Request
18
+ from typing import List
19
+ import textwrap
20
+ from IPython.display import display, Markdown
21
+ from PIL import Image
22
+ import shutil
23
+ from werkzeug.utils import secure_filename
24
+ import urllib.parse
25
+ import re
26
+ from langchain_google_genai import ChatGoogleGenerativeAI
27
+ from langchain_community.document_loaders import PyPDFLoader, UnstructuredCSVLoader, UnstructuredExcelLoader, Docx2txtLoader, UnstructuredPowerPointLoader
28
+ from langchain.chains import StuffDocumentsChain
29
+ from langchain.chains.llm import LLMChain
30
+ from langchain.prompts import PromptTemplate
31
+ from langchain.vectorstores import FAISS
32
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
33
+ from langchain.text_splitter import CharacterTextSplitter
34
+
35
+ app = FastAPI()
36
+ app.mount("/static", StaticFiles(directory="static"), name="static")
37
+ templates = Jinja2Templates(directory="templates")
38
+
39
+ sns.set_theme(color_codes=True)
40
+ uploaded_df = None
41
+ document_analyzed = False
42
+ question_responses = []
43
+
44
+
45
+ def format_text(text):
46
+ # Replace **text** with <b>text</b>
47
+ text = re.sub(r'\*\*(.*?)\*\*', r'<b>\1</b>', text)
48
+ # Replace any remaining * with <br>
49
+ text = text.replace('*', '<br>')
50
+ return text
51
+
52
+ def clean_data(df):
53
+ # Step 1: Clean currency-related columns
54
+ for col in df.columns:
55
+ if any(x in col.lower() for x in ['value', 'price', 'cost', 'amount']):
56
+ if df[col].dtype == 'object':
57
+ df[col] = df[col].str.replace('$', '').str.replace('£', '').str.replace('€', '').replace('[^\d.-]', '', regex=True).astype(float)
58
+
59
+ # Step 2: Drop columns with more than 25% missing values
60
+ null_percentage = df.isnull().sum() / len(df)
61
+ columns_to_drop = null_percentage[null_percentage > 0.25].index
62
+ df.drop(columns=columns_to_drop, inplace=True)
63
+
64
+ # Step 3: Fill missing values for remaining columns
65
+ for col in df.columns:
66
+ if df[col].isnull().sum() > 0:
67
+ if null_percentage[col] <= 0.25:
68
+ if df[col].dtype in ['float64', 'int64']:
69
+ median_value = df[col].median()
70
+ df[col].fillna(median_value, inplace=True)
71
+
72
+ # Step 4: Convert object-type columns to lowercase
73
+ for col in df.columns:
74
+ if df[col].dtype == 'object':
75
+ df[col] = df[col].str.lower()
76
+
77
+ # Step 5: Drop columns with only one unique value
78
+ unique_value_columns = [col for col in df.columns if df[col].nunique() == 1]
79
+ df.drop(columns=unique_value_columns, inplace=True)
80
+
81
+ return df
82
+
83
+
84
+
85
+
86
+ def clean_data2(df):
87
+ for col in df.columns:
88
+ if 'value' in col or 'price' in col or 'cost' in col or 'amount' in col or 'Value' in col or 'Price' in col or 'Cost' in col or 'Amount' in col:
89
+ if df[col].dtype == 'object':
90
+ df[col] = df[col].str.replace('$', '')
91
+ df[col] = df[col].str.replace('£', '')
92
+ df[col] = df[col].str.replace('€', '')
93
+ df[col] = df[col].replace('[^\d.-]', '', regex=True).astype(float)
94
+
95
+ null_percentage = df.isnull().sum() / len(df)
96
+
97
+ for col in df.columns:
98
+ if df[col].isnull().sum() > 0:
99
+ if null_percentage[col] <= 0.25:
100
+ if df[col].dtype in ['float64', 'int64']:
101
+ median_value = df[col].median()
102
+ df[col].fillna(median_value, inplace=True)
103
+
104
+ for col in df.columns:
105
+ if df[col].dtype == 'object':
106
+ df[col] = df[col].str.lower()
107
+
108
+ return df
109
+
110
+
111
+
112
+ def generate_plot(df, plot_path, plot_type):
113
+ df = clean_data(df)
114
+ excluded_words = ["name", "postal", "date", "phone", "address", "code", "id"]
115
+
116
+ if plot_type == 'countplot':
117
+ cat_vars = [col for col in df.select_dtypes(include='object').columns
118
+ if all(word not in col.lower() for word in excluded_words) and df[col].nunique() > 1]
119
+
120
+ for col in cat_vars:
121
+ if df[col].nunique() > 10:
122
+ top_categories = df[col].value_counts().index[:10]
123
+ df[col] = df[col].apply(lambda x: x if x in top_categories else 'Other')
124
+
125
+ num_cols = len(cat_vars)
126
+ num_rows = (num_cols + 1) // 2
127
+ fig, axs = plt.subplots(nrows=num_rows, ncols=2, figsize=(15, 5*num_rows))
128
+ axs = axs.flatten()
129
+
130
+ for i, var in enumerate(cat_vars):
131
+ category_counts = df[var].value_counts()
132
+ top_values = category_counts.index[:10][::-1]
133
+ filtered_df = df.copy()
134
+ filtered_df[var] = pd.Categorical(filtered_df[var], categories=top_values, ordered=True)
135
+ sns.countplot(x=var, data=filtered_df, order=top_values, ax=axs[i])
136
+ axs[i].set_title(var)
137
+ axs[i].tick_params(axis='x', rotation=30)
138
+
139
+ total = len(filtered_df[var])
140
+ for p in axs[i].patches:
141
+ height = p.get_height()
142
+ axs[i].annotate(f'{height/total:.1%}', (p.get_x() + p.get_width() / 2., height), ha='center', va='bottom')
143
+
144
+ sample_size = filtered_df.shape[0]
145
+ axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
146
+
147
+ for i in range(num_cols, len(axs)):
148
+ fig.delaxes(axs[i])
149
+
150
+ elif plot_type == 'histplot':
151
+ num_vars = [col for col in df.select_dtypes(include=['int', 'float']).columns
152
+ if all(word not in col.lower() for word in excluded_words)]
153
+ num_cols = len(num_vars)
154
+ num_rows = (num_cols + 2) // 3
155
+ fig, axs = plt.subplots(nrows=num_rows, ncols=min(3, num_cols), figsize=(15, 5*num_rows))
156
+ axs = axs.flatten()
157
+
158
+ plot_index = 0
159
+
160
+ for i, var in enumerate(num_vars):
161
+ if len(df[var].unique()) == len(df):
162
+ fig.delaxes(axs[plot_index])
163
+ else:
164
+ sns.histplot(df[var], ax=axs[plot_index], kde=True, stat="percent")
165
+ axs[plot_index].set_title(var)
166
+ axs[plot_index].set_xlabel('')
167
+
168
+ sample_size = df.shape[0]
169
+ axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
170
+
171
+ plot_index += 1
172
+
173
+ for i in range(plot_index, len(axs)):
174
+ fig.delaxes(axs[i])
175
+
176
+ fig.tight_layout()
177
+ fig.savefig(plot_path)
178
+ plt.close(fig)
179
+ return plot_path
180
+
181
+ @app.get("/", response_class=HTMLResponse)
182
+ async def upload_file(request: Request):
183
+ return templates.TemplateResponse("upload.html", {"request": request})
184
+
185
+ @app.post("/result")
186
+ async def result(request: Request,
187
+ api_key: str = Form(...),
188
+ file: UploadFile = File(...),
189
+ custom_question: str = Form(...)):
190
+ global uploaded_df, uploaded_filename, plot1_path, plot2_path, response1, response2, api, question, uploaded_file
191
+
192
+ api = api_key
193
+ uploaded_file = file
194
+
195
+ if file.filename == '':
196
+ raise HTTPException(status_code=400, detail="No file selected")
197
+
198
+ # Secure and validate the file name
199
+ uploaded_filename = secure_filename(file.filename)
200
+
201
+ # Determine file path based on file type
202
+ if uploaded_filename.endswith('.csv'):
203
+ file_path = 'dataset.csv'
204
+ # Save the file
205
+ with open(file_path, 'wb') as buffer:
206
+ shutil.copyfileobj(file.file, buffer)
207
+ # Read the file into a DataFrame
208
+ df = pd.read_csv(file_path, encoding='utf-8')
209
+
210
+ elif uploaded_filename.endswith('.xlsx'):
211
+ file_path = 'dataset.xlsx'
212
+ # Save the file
213
+ with open(file_path, 'wb') as buffer:
214
+ shutil.copyfileobj(file.file, buffer)
215
+ # Read the file into a DataFrame
216
+ df = pd.read_excel(file_path)
217
+
218
+ else:
219
+ raise HTTPException(status_code=400, detail="Unsupported file format")
220
+
221
+ columns = df.columns.tolist()
222
+
223
+ def generate_gemini_response(plot_path):
224
+ global question
225
+ question = custom_question
226
+ genai.configure(api_key=api)
227
+ img = Image.open(plot_path)
228
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
229
+ response = model.generate_content([
230
+ question + " As a marketing consultant, I want to understand consumer insights based on the chart and the market context so I can use the key findings to formulate actionable insights",
231
+ img
232
+ ])
233
+ response.resolve()
234
+ return response.text
235
+
236
+ plot1_path = generate_plot(df, 'static/plot1.png', 'countplot')
237
+ plot2_path = generate_plot(df, 'static/plot2.png', 'histplot')
238
+
239
+ response1 = (generate_gemini_response(plot1_path))
240
+ response2 = (generate_gemini_response(plot2_path))
241
+
242
+ uploaded_df = df
243
+
244
+ outputs = {
245
+ "barchart_visualization": plot1_path,
246
+ "gemini_response1": response1,
247
+ "histoplot_visualization": plot2_path,
248
+ "gemini_response2": response2
249
+ }
250
+
251
+ with open("output.json", "w") as outfile:
252
+ json.dump(outputs, outfile)
253
+
254
+ def safe_encode(text):
255
+ try:
256
+ return text.encode('latin1', errors='replace').decode('latin1')
257
+ except Exception as e:
258
+ return f"Error encoding text: {str(e)}"
259
+
260
+ pdf = FPDF()
261
+ pdf.set_font("Arial", size=12)
262
+
263
+ # Single Countplot Barchart and response
264
+ pdf.add_page()
265
+ pdf.cell(200, 10, txt="Single Countplot Barchart", ln=True, align='C')
266
+ pdf.image(plot1_path, x=10, y=30, w=190)
267
+ pdf.add_page()
268
+ pdf.cell(200, 10, txt="Single Countplot Barchart Google Gemini Response", ln=True, align='C')
269
+ pdf.ln(10)
270
+ pdf.multi_cell(0, 10, safe_encode(response1))
271
+
272
+ # Single Histplot and response
273
+ pdf.add_page()
274
+ pdf.cell(200, 10, txt="Single Histplot", ln=True, align='C')
275
+ pdf.image(plot2_path, x=10, y=30, w=190)
276
+ pdf.add_page()
277
+ pdf.cell(200, 10, txt="Single Histplot Google Gemini Response", ln=True, align='C')
278
+ pdf.ln(10)
279
+ pdf.multi_cell(0, 10, safe_encode(response2))
280
+
281
+ pdf_output_path = 'static/analysis_report.pdf'
282
+ pdf.output(pdf_output_path)
283
+
284
+ return templates.TemplateResponse("upload.html", {
285
+ "request": request,
286
+ "response1": response1,
287
+ "response2": response2,
288
+ "plot1_path": plot1_path,
289
+ "plot2_path": plot2_path,
290
+ "columns": columns})
291
+
292
+ @app.get("/download_pdf")
293
+ async def download_pdf():
294
+ pdf_output_path = 'static/analysis_report.pdf'
295
+ return FileResponse(pdf_output_path, media_type='application/pdf', filename=os.path.basename(pdf_output_path))
296
+
297
+
298
+
299
+
300
+
301
+ @app.post("/streamlit")
302
+ async def streamlit(request: Request,
303
+ target_variable: str = Form(...),
304
+ columns_for_analysis: List[str] = Form(...)):
305
+ global uploaded_df, uploaded_filename, plot1_path, plot2_path, response1, response2, api, question, document_analyzed, plot3_path, plot4_path, response3, response4
306
+ target_variable_html = None
307
+ columns_for_analysis_html = None
308
+ response3 = None
309
+ response4 = None
310
+ plot3_path = None
311
+ plot4_path = None
312
+
313
+
314
+ if uploaded_df is None:
315
+ raise HTTPException(status_code=400, detail="No CSV file uploaded")
316
+
317
+
318
+ df = uploaded_df
319
+
320
+ # Process the uploaded file
321
+ if uploaded_filename.endswith('.csv'):
322
+ df = pd.read_csv('dataset.csv', encoding='utf-8')
323
+ elif uploaded_filename.endswith('.xlsx'):
324
+ df = pd.read_excel('dataset.xlsx')
325
+
326
+ # Select the target variable and columns for analysis from the original DataFrame
327
+ target_variable_data = df[target_variable]
328
+ columns_for_analysis_data = df[columns_for_analysis]
329
+
330
+ # Concatenate target variable and columns for analysis into a single DataFrame
331
+ df = pd.concat([target_variable_data, columns_for_analysis_data], axis=1)
332
+
333
+ # Clean the data (if needed)
334
+ df = clean_data2(df)
335
+
336
+
337
+
338
+ def generate_gemini_response(plot_path):
339
+ global question
340
+ genai.configure(api_key=api)
341
+ img = Image.open(plot_path)
342
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
343
+ response = model.generate_content([
344
+ question + " As a marketing consultant, I want to understand consumer insights based on the chart and the market context so I can use the key findings to formulate actionable insights",
345
+ img
346
+ ])
347
+ response.resolve()
348
+ return response.text
349
+
350
+ # Generate visualizations
351
+
352
+ # Multiclass Barplot
353
+ excluded_words = ["name", "postal", "date", "phone", "address", "id"]
354
+
355
+ # Get the names of all columns with data type 'object' (categorical variables)
356
+ cat_vars = [col for col in df.select_dtypes(include=['object']).columns
357
+ if all(word not in col.lower() for word in excluded_words)]
358
+
359
+ # Exclude the target variable from the list if it exists in cat_vars
360
+ if target_variable in cat_vars:
361
+ cat_vars.remove(target_variable)
362
+
363
+ # Create a figure with subplots, but only include the required number of subplots
364
+ num_cols = len(cat_vars)
365
+ num_rows = (num_cols + 2) // 3 # To make sure there are enough rows for the subplots
366
+ fig, axs = plt.subplots(nrows=num_rows, ncols=3, figsize=(15, 5*num_rows))
367
+ axs = axs.flatten()
368
+
369
+ # Create a count plot for each categorical variable
370
+ for i, var in enumerate(cat_vars):
371
+ top_categories = df[var].value_counts().nlargest(5).index
372
+ filtered_df = df[df[var].notnull() & df[var].isin(top_categories)] # Exclude rows with NaN values in the variable
373
+
374
+ # Replace less frequent categories with "Other" if there are more than 5 unique values
375
+ if df[var].nunique() > 5:
376
+ other_categories = df[var].value_counts().index[5:]
377
+ filtered_df[var] = filtered_df[var].apply(lambda x: x if x in top_categories else 'Other')
378
+
379
+ sns.countplot(x=var, hue=target_variable, stat="percent", data=filtered_df, ax=axs[i])
380
+ axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=45)
381
+
382
+ # Change y-axis label to represent percentage
383
+ axs[i].set_ylabel('Percentage')
384
+
385
+ # Annotate the subplot with sample size
386
+ sample_size = df.shape[0]
387
+ axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
388
+
389
+ # Remove any remaining blank subplots
390
+ for i in range(num_cols, len(axs)):
391
+ fig.delaxes(axs[i])
392
+
393
+ plt.xticks(rotation=45)
394
+ plt.tight_layout()
395
+ plot3_path = "static/multiclass_barplot.png"
396
+ plt.savefig(plot3_path)
397
+ plt.close(fig)
398
+
399
+
400
+ # Multiclass Histplot
401
+ # Get the names of all columns with data type 'object' (categorical columns)
402
+ cat_cols = df.columns.tolist()
403
+
404
+ # Get the names of all columns with data type 'int'
405
+ int_vars = df.select_dtypes(include=['int', 'float']).columns.tolist()
406
+ int_vars = [col for col in int_vars if col != target_variable]
407
+
408
+ # Create a figure with subplots
409
+ num_cols = len(int_vars)
410
+ num_rows = (num_cols + 2) // 3 # To make sure there are enough rows for the subplots
411
+ fig, axs = plt.subplots(nrows=num_rows, ncols=3, figsize=(15, 5*num_rows))
412
+ axs = axs.flatten()
413
+
414
+ # Create a histogram for each integer variable with hue='Attrition'
415
+ for i, var in enumerate(int_vars):
416
+ top_categories = df[var].value_counts().nlargest(10).index
417
+ filtered_df = df[df[var].notnull() & df[var].isin(top_categories)]
418
+ sns.histplot(data=df, x=var, hue=target_variable, kde=True, ax=axs[i], stat="percent")
419
+ axs[i].set_title(var)
420
+
421
+ # Annotate the subplot with sample size
422
+ sample_size = df.shape[0]
423
+ axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
424
+
425
+ # Remove any extra empty subplots if needed
426
+ if num_cols < len(axs):
427
+ for i in range(num_cols, len(axs)):
428
+ fig.delaxes(axs[i])
429
+
430
+ # Adjust spacing between subplots
431
+ fig.tight_layout()
432
+ plt.xticks(rotation=45)
433
+ plot4_path = "static/multiclass_histplot.png"
434
+ plt.savefig(plot4_path)
435
+ plt.close(fig)
436
+
437
+ response3 = (generate_gemini_response(plot3_path))
438
+ response4 = (generate_gemini_response(plot4_path))
439
+
440
+
441
+ document_analyzed = True
442
+
443
+ # Create a dictionary to store the outputs
444
+ outputs = {
445
+ "barchart_visualization": plot1_path,
446
+ "gemini_response1": response1,
447
+ "histoplot_visualization": plot2_path,
448
+ "gemini_response2": response2,
449
+ "multiBarchart_visualization": plot3_path,
450
+ "gemini_response3": response3,
451
+ "multiHistoplot_visualization": plot4_path,
452
+ "gemini_response4": response4
453
+ }
454
+
455
+ # Save the dictionary as a JSON file
456
+ with open("output1.json", "w") as outfile:
457
+ json.dump(outputs, outfile)
458
+
459
+
460
+
461
+ # Function to handle encoding to latin1
462
+ def safe_encode(text):
463
+ try:
464
+ return text.encode('latin1', errors='replace').decode('latin1') # Replace invalid characters
465
+ except Exception as e:
466
+ return f"Error encoding text: {str(e)}"
467
+
468
+
469
+
470
+ # Generate PDF with the results
471
+ pdf = FPDF()
472
+ pdf.set_font("Arial", size=12)
473
+
474
+ # Single Countplot Barchart and response
475
+ pdf.add_page()
476
+ pdf.cell(200, 10, txt="Single Countplot Barchart", ln=True, align='C')
477
+ pdf.image(plot1_path, x=10, y=30, w=190)
478
+ pdf.add_page()
479
+ pdf.cell(200, 10, txt="Single Countplot Barchart Google Gemini Response", ln=True, align='C')
480
+ pdf.ln(10)
481
+ pdf.multi_cell(0, 10, safe_encode(response1))
482
+
483
+ # Single Histplot and response
484
+ pdf.add_page()
485
+ pdf.cell(200, 10, txt="Single Histplot", ln=True, align='C')
486
+ pdf.image(plot2_path, x=10, y=30, w=190)
487
+ pdf.add_page()
488
+ pdf.cell(200, 10, txt="Single Histplot Google Gemini Response", ln=True, align='C')
489
+ pdf.ln(10)
490
+ pdf.multi_cell(0, 10, safe_encode(response2))
491
+
492
+ # Multiclass Countplot Barchart and response
493
+ pdf.add_page()
494
+ pdf.cell(200, 10, txt="Multiclass Countplot Barchart", ln=True, align='C')
495
+ pdf.image(plot3_path, x=10, y=30, w=190)
496
+ pdf.add_page()
497
+ pdf.cell(200, 10, txt="Multiclass Countplot Barchart Google Gemini Response", ln=True, align='C')
498
+ pdf.ln(10)
499
+ pdf.multi_cell(0, 10, safe_encode(response3))
500
+
501
+ # Multiclass Histplot and response
502
+ pdf.add_page()
503
+ pdf.cell(200, 10, txt="Multiclass Histplot", ln=True, align='C')
504
+ pdf.image(plot4_path, x=10, y=30, w=190)
505
+ pdf.add_page()
506
+ pdf.cell(200, 10, txt="Multiclass Histplot Google Gemini Response", ln=True, align='C')
507
+ pdf.ln(10)
508
+ pdf.multi_cell(0, 10, safe_encode(response4))
509
+
510
+
511
+ pdf_output_path = 'static/analysis_report_complete.pdf'
512
+ pdf.output(pdf_output_path)
513
+
514
+
515
+
516
+ return templates.TemplateResponse("upload.html", {
517
+ "request": request,
518
+ "plot1_path": plot1_path,
519
+ "response1": response1,
520
+ "plot2_path": plot2_path,
521
+ "response2": response2,
522
+ "plot3_path": plot3_path,
523
+ "response3": response3,
524
+ "plot4_path": plot4_path,
525
+ "response4": response4,
526
+ "show_conversation": document_analyzed,
527
+ "question_responses": question_responses
528
+ })
529
+
530
+
531
+
532
+
533
+ @app.get('/download_pdf2')
534
+ async def download_pdf2():
535
+ pdf_output_path2 = 'static/analysis_report_complete.pdf'
536
+ return FileResponse(pdf_output_path2, media_type='application/pdf', filename='analysis_report_complete.pdf')
537
+
538
+
539
+ # Route for asking questions
540
+ @app.post("/ask", response_class=HTMLResponse)
541
+ async def ask_question(request: Request, question: str = Form(...)):
542
+ global uploaded_filename, question_responses, api
543
+ global plot1_path, plot2_path, plot3_path, plot4_path
544
+ global response1, response2, response3, response4
545
+ global document_analyzed
546
+
547
+ # Check if a file has been uploaded
548
+ if not uploaded_filename:
549
+ raise HTTPException(status_code=400, detail="No file has been uploaded yet.")
550
+
551
+ # Initialize the LLM model
552
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=api)
553
+
554
+ # Determine the file extension and select the appropriate loader
555
+ file_path = ''
556
+ loader = None
557
+
558
+ if uploaded_filename.endswith('.csv'):
559
+ file_path = 'dataset.csv'
560
+ loader = UnstructuredCSVLoader(file_path, mode="elements")
561
+ elif uploaded_filename.endswith('.xlsx'):
562
+ file_path = 'dataset.xlsx'
563
+ loader = UnstructuredExcelLoader(file_path, mode="elements")
564
+ else:
565
+ raise HTTPException(status_code=400, detail="Unsupported file format")
566
+
567
+ # Load and process the document
568
+ try:
569
+ docs = loader.load()
570
+ except Exception as e:
571
+ raise HTTPException(status_code=500, detail=f"Error loading document: {str(e)}")
572
+
573
+ # Combine document text
574
+ text = "\n".join([doc.page_content for doc in docs])
575
+ os.environ["GOOGLE_API_KEY"] = api
576
+
577
+ # Initialize embeddings and create FAISS vector store
578
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
579
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
580
+ chunks = text_splitter.split_text(text)
581
+ document_search = FAISS.from_texts(chunks, embeddings)
582
+
583
+ # Generate query embedding and perform similarity search
584
+ query_embedding = embeddings.embed_query(question)
585
+ results = document_search.similarity_search_by_vector(query_embedding, k=3)
586
+
587
+ if results:
588
+ retrieved_texts = " ".join([result.page_content for result in results])
589
+
590
+ # Define the Summarize Chain for the question
591
+ latest_conversation = request.cookies.get("latest_question_response", "")
592
+ template1 = (
593
+ f"{question} Answer the question based on the following:\n\"{text}\"\n:" +
594
+ (f" Answer the Question with only 3 sentences. Latest conversation: {latest_conversation}" if latest_conversation else "")
595
+ )
596
+ prompt1 = PromptTemplate.from_template(template1)
597
+
598
+ # Initialize the LLMChain with the prompt
599
+ llm_chain1 = LLMChain(llm=llm, prompt=prompt1)
600
+
601
+ # Invoke the chain to get the summary
602
+ try:
603
+ response_chain = llm_chain1.invoke({"text": text})
604
+ summary1 = response_chain["text"]
605
+ except Exception as e:
606
+ raise HTTPException(status_code=500, detail=f"Error invoking LLMChain: {str(e)}")
607
+
608
+ # Generate embeddings for the summary
609
+ try:
610
+ summary_embedding = embeddings.embed_query(summary1)
611
+ document_search = FAISS.from_texts([summary1], embeddings)
612
+ except Exception as e:
613
+ raise HTTPException(status_code=500, detail=f"Error generating embeddings: {str(e)}")
614
+
615
+ # Perform a search on the FAISS vector database
616
+ try:
617
+ if document_search:
618
+ query_embedding = embeddings.embed_query(question)
619
+ results = document_search.similarity_search_by_vector(query_embedding, k=1)
620
+
621
+ if results:
622
+ current_response = format_text(results[0].page_content)
623
+ else:
624
+ current_response = "No matching document found in the database."
625
+ else:
626
+ current_response = "Vector database not initialized."
627
+ except Exception as e:
628
+ raise HTTPException(status_code=500, detail=f"Error during similarity search: {str(e)}")
629
+ else:
630
+ current_response = "No relevant results found."
631
+
632
+ # Append the question and response from FAISS search
633
+ current_question = f"You asked: {question}"
634
+ question_responses.append((current_question, current_response))
635
+
636
+ # Save all results to output_summary.json
637
+ save_to_json(question_responses)
638
+
639
+ # Prepare the response to render the HTML template
640
+ response = templates.TemplateResponse("upload.html", {
641
+ "request": request,
642
+ "plot1_path": plot1_path,
643
+ "response1": response1,
644
+ "plot2_path": plot2_path,
645
+ "response2": response2,
646
+ "plot3_path": plot3_path,
647
+ "response3": response3,
648
+ "plot4_path": plot4_path,
649
+ "response4": response4,
650
+ "show_conversation": document_analyzed,
651
+ "question_responses": question_responses,
652
+ })
653
+ response.set_cookie(key="latest_question_response", value=current_response)
654
+ return response
655
+
656
+
657
+
658
+ def save_to_json(question_responses):
659
+ outputs = {
660
+ "question_responses": question_responses
661
+ }
662
+ with open("output_summary.json", "w") as outfile:
663
+ json.dump(outputs, outfile)
664
+
665
+
666
+
667
+ if __name__ == "__main__":
668
+ import uvicorn
669
+ uvicorn.run(app, host="127.0.0.1", port=8000)