jrrade commited on
Commit
74c02ff
·
1 Parent(s): db189f7

added necessary files for app to run on HF

Browse files
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- START OF FILE app_main.py ---
2
+
3
+ import streamlit as st
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torchvision import transforms
7
+ # Ensure classification_model is importable (e.g., it's in the same directory orPYTHONPATH)
8
+ try:
9
+ from classification_model.model import VGG16Model
10
+ except ImportError:
11
+ st.error("Could not import VGG16Model. Make sure 'classification_model/model.py' exists and is importable.")
12
+ st.stop()
13
+
14
+ from PIL import Image
15
+ import os
16
+ from dotenv import load_dotenv
17
+ from langchain.schema import SystemMessage, HumanMessage, AIMessage, BaseMessage
18
+ import io
19
+
20
+ # Import graph logic from the other file
21
+ try:
22
+ from graph_logic import initialize_llm, create_chat_graph, GraphState
23
+ except ImportError:
24
+ st.error("Could not import from graph_logic.py. Make sure the file exists in the same directory.")
25
+ st.stop()
26
+
27
+
28
+ # Load environment variables from .env file
29
+ load_dotenv()
30
+
31
+ # --- Helper Functions (Specific to App/UI) ---
32
+
33
+ @st.cache_resource
34
+ def load_vgg16(path='./classification_model/best_model.pth', device='cpu'):
35
+ """Loads the pre-trained VGG16 model."""
36
+ model = VGG16Model()
37
+ if not os.path.exists(path):
38
+ st.error(f"Model file not found at {path}. Please ensure the model is present.")
39
+ st.stop()
40
+ try:
41
+ model.load_state_dict(torch.load(path, map_location=device))
42
+ model.eval()
43
+ except Exception as e:
44
+ st.error(f"Error loading model state dict: {e}")
45
+ st.stop()
46
+ return model
47
+
48
+ @st.cache_resource
49
+ def load_class_labels() -> dict:
50
+ """Loads the class labels."""
51
+ return {0: 'good_images', 1: 'Imaging Artifact', 2: 'Not Tracking', 3: 'Tip Contamination'}
52
+
53
+ def preprocess_image(img: Image.Image):
54
+ """Preprocesses the image for the VGG16 model."""
55
+ preprocess = transforms.Compose([
56
+ transforms.Resize((224, 224)),
57
+ transforms.ToTensor(),
58
+ transforms.Normalize(mean=[0.3718, 0.1738, 0.0571], std=[0.2095, 0.2124, 0.1321]),
59
+ ])
60
+ img_tensor = preprocess(img).unsqueeze(0)
61
+ return img_tensor
62
+
63
+ def predict_image_class(img: Image.Image, model, class_names: dict) -> str:
64
+ """Predicts the class label for a given image."""
65
+ img_tensor = preprocess_image(img)
66
+ with torch.no_grad():
67
+ outputs = model(img_tensor)
68
+ probs = F.softmax(outputs, dim=1)
69
+ top_prob, top_idx = torch.topk(probs, 1)
70
+ class_label = class_names.get(top_idx.item(), "Unknown Class")
71
+ return class_label
72
+
73
+ def download_chat_history() -> str:
74
+ """Generates the chat history text for download."""
75
+ if "messages" not in st.session_state or not st.session_state.messages:
76
+ return ""
77
+ output = io.StringIO()
78
+ start_index = 1 if isinstance(st.session_state.messages[0], SystemMessage) else 0
79
+ for msg in st.session_state.messages[start_index:]:
80
+ role = "User" if isinstance(msg, HumanMessage) else "Assistant"
81
+ output.write(f"{role}: {msg.content}\n")
82
+ return output.getvalue()
83
+
84
+ # --- Streamlit UI ---
85
+
86
+ st.set_page_config(page_title="AFM Defect Assistant (LangGraph)", page_icon="🔬")
87
+
88
+ st.title("🔬 AFM Image Defect Classification + LLM-based AFM Assistant")
89
+ st.write("Upload an AFM image, get a classification, and chat with an AI assistant about the result.")
90
+
91
+ # --- Sidebar Controls ---
92
+ st.sidebar.header("Settings")
93
+
94
+ # Model Selection
95
+ provider = st.sidebar.selectbox("LLM Provider", ["OpenAI", "Anthropic"], key="provider_select")
96
+
97
+ api_key = None
98
+ api_key_name = ""
99
+ if provider == "OpenAI":
100
+ default_model = "gpt-4o"
101
+ available_models = ["gpt-4o", "o3-mini"]
102
+ api_key = os.getenv("OPENAI_API_KEY")
103
+ api_key_name = "OPENAI_API_KEY"
104
+ elif provider == "Anthropic":
105
+ default_model = "claude-3-5-sonnet-latest"
106
+ available_models = ["claude-3-5-sonnet-latest", "claude-3-7-sonnet-latest"]
107
+ api_key = os.getenv("ANTHROPIC_API_KEY")
108
+ api_key_name = "ANTHROPIC_API_KEY"
109
+ else:
110
+ st.sidebar.error("Invalid provider selected.")
111
+ st.stop()
112
+
113
+ # Display warning if API key is missing
114
+ if not api_key:
115
+ st.sidebar.warning(f"{provider} API key not found. Please set the {api_key_name} environment variable.")
116
+
117
+ model_name = st.sidebar.selectbox(f"Choose {provider} Model", available_models, index=available_models.index(default_model), key="model_name_select")
118
+ temperature = st.sidebar.slider("LLM Temperature", 0.0, 1.0, 0.3, 0.05, key="temp_slider")
119
+
120
+ # Clear Chat Button
121
+ if st.sidebar.button("Start New Session", key="clear_chat_button"):
122
+ keys_to_clear = ["messages", "current_label", "uploaded_file_state", "llm", "graph"]
123
+ for key in keys_to_clear:
124
+ if key in st.session_state:
125
+ del st.session_state[key]
126
+ st.rerun()
127
+
128
+ # --- Main Page Logic ---
129
+
130
+ # File Uploader
131
+ uploaded_file = st.file_uploader("Upload an AFM image", type=["jpg", "jpeg", "png"], key="file_uploader")
132
+
133
+ # Manage state based on uploaded file
134
+ if uploaded_file is not None:
135
+ new_file_id = uploaded_file.file_id
136
+ # Check if it's a new file or the same one to avoid re-processing
137
+ if "uploaded_file_state" not in st.session_state or st.session_state.uploaded_file_state["id"] != new_file_id:
138
+ st.session_state.uploaded_file_state = {"id": new_file_id, "name": uploaded_file.name}
139
+ # Clear previous chat/state if a new file is uploaded
140
+ keys_to_reset = ["messages", "current_label", "llm", "graph"]
141
+ for key in keys_to_reset:
142
+ if key in st.session_state:
143
+ del st.session_state[key]
144
+
145
+ # --- Image Processing and Classification ---
146
+ try:
147
+ img = Image.open(uploaded_file).convert("RGB")
148
+ st.image(img, caption=f"Uploaded: {st.session_state.uploaded_file_state['name']}", width=200)
149
+
150
+ model = load_vgg16()
151
+ class_names = load_class_labels()
152
+
153
+ with st.spinner("Classifying image..."):
154
+ class_label = predict_image_class(img, model, class_names)
155
+ st.success(f"**Predicted Class Label:** {class_label}")
156
+
157
+ # --- LLM and Graph Initialization ---
158
+ label_changed = ("current_label" not in st.session_state or
159
+ st.session_state.current_label != class_label)
160
+
161
+ # Initialize LLM and Graph if not present or if label changed
162
+ if "llm" not in st.session_state or "graph" not in st.session_state or label_changed:
163
+ if not api_key:
164
+ st.error(f"Cannot proceed without {api_key_name}. Please set it in your environment variables.")
165
+ st.stop()
166
+ try:
167
+ # Initialize LLM using the function from graph_logic.py
168
+ st.session_state.llm = initialize_llm(provider, model_name, temperature, api_key)
169
+ # Create the graph using the function from graph_logic.py
170
+ st.session_state.graph = create_chat_graph(st.session_state.llm)
171
+ st.session_state.current_label = class_label
172
+
173
+ # Define the system prompt and initial state for the graph
174
+ system_prompt_content = (
175
+ f"You are an expert in atomic force microscopy (AFM). "
176
+ f"The user has uploaded an image and it has '{class_label}' defect. "
177
+ "Your role is to help the user understand this defect, potential causes, "
178
+ "and how to potentially avoid or address the issue represented by this defect. "
179
+ "Provide concise, technically accurate, and helpful answers. Avoid speculation if unsure."
180
+ )
181
+ system_message = SystemMessage(content=system_prompt_content)
182
+ st.session_state.messages = [system_message] # Initialize message history
183
+
184
+ except Exception as e:
185
+ st.error(f"Failed to initialize LLM or Graph: {e}")
186
+ st.stop()
187
+
188
+ # --- Chat Interface ---
189
+ st.divider()
190
+ st.header(f"Chat about '{st.session_state.current_label}'")
191
+
192
+ # Display existing messages (skip system message)
193
+ if "messages" in st.session_state:
194
+ for i, msg in enumerate(st.session_state.messages):
195
+ if i == 0 and isinstance(msg, SystemMessage):
196
+ continue
197
+ with st.chat_message("user" if isinstance(msg, HumanMessage) else "assistant"):
198
+ st.markdown(msg.content)
199
+
200
+ # Chat input
201
+ if prompt := st.chat_input("Ask a question about the detected defect..."):
202
+ # Add user message to state and display it
203
+ st.session_state.messages.append(HumanMessage(content=prompt))
204
+ with st.chat_message("user"):
205
+ st.markdown(prompt)
206
+
207
+ # Prepare the input state for the graph
208
+ current_graph_state: GraphState = {"messages": st.session_state.messages}
209
+
210
+ # Invoke the graph
211
+ with st.chat_message("assistant"):
212
+ with st.spinner("Thinking..."):
213
+ try:
214
+ # Invoke the graph with the current state
215
+ response_state = st.session_state.graph.invoke(current_graph_state)
216
+
217
+ # Update session state with the full response history from the graph
218
+ st.session_state.messages = response_state['messages']
219
+ ai_response_content = st.session_state.messages[-1].content
220
+ st.markdown(ai_response_content)
221
+
222
+ except Exception as e:
223
+ st.error(f"Error during chat generation: {e}")
224
+ # Roll back user message if AI fails
225
+ if st.session_state.messages and isinstance(st.session_state.messages[-1], HumanMessage):
226
+ st.session_state.messages.pop()
227
+
228
+
229
+ # --- Download Chat Button ---
230
+ if len(st.session_state.get("messages", [])) > 1: # Show only if conversation started
231
+ st.divider()
232
+ chat_text = download_chat_history()
233
+ st.download_button(
234
+ label="Download Chat History",
235
+ data=chat_text,
236
+ file_name=f"afm_chat_{st.session_state.current_label.replace(' ', '_')}.txt",
237
+ mime="text/plain"
238
+ )
239
+
240
+ except Exception as e:
241
+ st.error(f"An error occurred processing the image or during chat setup: {e}")
242
+ if "uploaded_file_state" in st.session_state:
243
+ del st.session_state.uploaded_file_state # Reset if critical error occurs
244
+
245
+ elif "uploaded_file_state" in st.session_state:
246
+ # If file uploader is cleared by the user after a file was processed
247
+ keys_to_clear = ["messages", "current_label", "uploaded_file_state", "llm", "graph"]
248
+ for key in keys_to_clear:
249
+ if key in st.session_state:
250
+ del st.session_state[key]
251
+ st.rerun()
252
+
253
+ else:
254
+ st.info("Please upload an image to start the analysis and chat.")
255
+
256
+ # --- END OF FILE app_main.py ---
classification_model/model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ from torchvision.models import VGG16_Weights
5
+
6
+ class VGG16Model(nn.Module):
7
+ def __init__(self, num_classes=4):
8
+ super(VGG16Model, self).__init__()
9
+ #self.vgg16 = models.vgg16(pretrained=True)
10
+ # Replace pretrained=True with weights=VGG16_Weights.DEFAULT
11
+ #self.vgg16 = models.vgg16(weights=VGG16_Weights.DEFAULT)
12
+ # Manually load the weights
13
+ self.vgg16 = models.vgg16()
14
+ # state_dict = torch.load("/work/mech-ai/angona3/Trial/vgg16_weights.pth", map_location=torch.device("cpu"))
15
+ # self.vgg16.load_state_dict(state_dict)
16
+ self.vgg16.classifier[6] = nn.Linear(4096, num_classes) ##In PyTorch CrossEntropyLoss Handles Softmax Internally
17
+ ###Adding an explicit softmax layer would create numerical instability by double-applying softmax (once in your model, once in the loss),
18
+ # leading to incorrect gradient calculations during training.
19
+
20
+ self.vgg16.classifier[2] = nn.Dropout(p=0.6) # or 0.7 ##introducing dropout for regularization
21
+ self.vgg16.classifier[5] = nn.Dropout(p=0.6) # or 0.7
22
+
23
+ # Freeze all the convolutional layers (feature extractor part)
24
+ # The classifier layers (fully connected layers) remain trainable
25
+ for param in self.vgg16.features.parameters():
26
+ param.requires_grad = False
27
+ #param.requires_grad = True # Unfreeze the last two fully connected layers
28
+
29
+ # Unfreeze Conv Block 4 and Conv Block 5 (512 filters, 3x3 filters, same padding)
30
+ # best: conv_layers_to_unfreeze = [17, 19, 21, 24, 26, 28]
31
+ conv_layers_to_unfreeze = [17, 19, 21, 24, 26, 28]
32
+ for layer_idx in conv_layers_to_unfreeze:
33
+ for param in self.vgg16.features[layer_idx].parameters():
34
+ param.requires_grad = True
35
+
36
+
37
+
38
+
39
+ # Unfreeze the last two fully connected layers
40
+ # Unfreeze all fully connected layers
41
+ for param in self.vgg16.classifier.parameters():
42
+ param.requires_grad = True
43
+
44
+ def forward(self, x):
45
+ return self.vgg16(x)
46
+
47
+ #Fine-tuning the entire network can lead to better performance compared to freezing layers
48
+ ##because the model can adjust both the feature extractor and classifier to your specific dataset.
49
+ ##Yes, nn.CrossEntropyLoss in PyTorch is explicitly designed to handle raw logits (ranging from −∞to +∞) directly and efficiently.
50
+ ##CrossEntropyLoss=Softmax(logits)+Log+NLLLoss
51
+
52
+ ##### In newer versions of torchvision, the weights argument replaces pretrained.
53
+ # #The VGG16_Weights.DEFAULT is the new way to specify that anyone want the pretrained weights, and it is the preferred method.
54
+ ##Deprecation Warning Fix: Replace pretrained=True with weights=VGG16_Weights.DEFAULT.
55
+ ##Security Warning Fix: You don't need to use torch.load in this case because you're not loading a pre-trained model from a .pth file.
56
+ ##from torchvision.models import VGG16_Weights
57
+ ## self.vgg16 = models.vgg16(weights=VGG16_Weights.DEFAULT)
58
+
59
+
60
+ ## adding dropout to VGG16
61
+ #class VGG16Model(nn.Module):
62
+ # def __init__(self, num_classes=4):
63
+ # super(VGG16Model, self).__init__()
64
+ # self.vgg16 = models.vgg16(pretrained=True)
65
+ # self.vgg16.classifier[6] = nn.Sequential(
66
+ # nn.Linear(4096, 1024), # First layer reduces dimensions to 1024
67
+ # nn.ReLU(), # Adds non-linearity to increase learning capacity
68
+ # nn.Dropout(0.5), # Introduces dropout for regularization
69
+ # nn.Linear(1024, num_classes) # Second layer maps to the desired number of classes
70
+ # )
71
+
72
+ # def forward(self, x):
73
+ # return self.vgg16(x)
74
+
75
+
76
+ #import torch.nn as nn
77
+ #from torchvision import models
78
+
79
+ #def get_resnet18_model(num_classes):
80
+ # model = models.resnet18(pretrained=True)
81
+ # model.fc = nn.Linear(model.fc.in_features, num_classes)
82
+ # return model
83
+
84
+ ##VGG16 structure
85
+ #(vgg16.classifier): Sequential(
86
+ # (0): Linear(25088, 4096)
87
+ # (1): ReLU(inplace=True)
88
+ # (2): Dropout(p=0.5, inplace=False)
89
+ # (3): Linear(4096, 4096)
90
+ # (4): ReLU(inplace=True)
91
+ # (5): Dropout(p=0.5, inplace=False)
92
+ # (6): Linear(4096, num_classes)
93
+ #)
classification_model/predict.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from model import VGG16Model
6
+ # Load your PyTorch VGG16 model (pre-trained)
7
+
8
+ def load_vgg16():
9
+ model = VGG16Model()
10
+ model.load_state_dict(torch.load("/work/mech-ai/jrrade/AFM/AFM-LLM-Defect-Guidance/classification_model/best_model.pth", map_location=torch.device('cpu'))) # Or GPU if available
11
+ model.eval()
12
+ return model
13
+
14
+ # Load ImageNet class labels (you need a mapping file)
15
+
16
+ def load_class_labels():
17
+ class_names = {0:'good_images', 1:'Imaging Artifact', 2:'Not Tracking', 3:'Tip Contamination'}
18
+ return class_names
19
+
20
+ # Preprocess uploaded image
21
+ def preprocess_image(img):
22
+ preprocess = transforms.Compose([
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(
26
+ mean=[0.3718, 0.1738, 0.0571],
27
+ std=[0.2095, 0.2124, 0.1321]
28
+ ),
29
+ ])
30
+ img_tensor = preprocess(img).unsqueeze(0) # Add batch dimension
31
+ return img_tensor
32
+
33
+ # Predict using PyTorch model
34
+ def predict_image_class(img, model, class_names):
35
+ img_tensor = preprocess_image(img)
36
+ with torch.no_grad():
37
+ outputs = model(img_tensor)
38
+ _, preds = torch.max(outputs, dim=1)
39
+ print(_, preds)
40
+ probs = F.softmax(outputs, dim=1)
41
+ top_prob, top_idx = torch.topk(probs, 1)
42
+ print(top_prob, top_idx)
43
+ class_label = class_names[top_idx.item()]
44
+ return class_label
45
+
46
+ img_path = '/work/mech-ai/angona3/Trial/image/Not_Tracking/Not_Tracking_21.jpg'
47
+ # img_path = '/work/mech-ai/angona3/Trial/image/Tip_Contamination/Tip_Contamination_17.jpg'
48
+ img = Image.open(img_path)
49
+ model = load_vgg16()
50
+ class_names = load_class_labels()
51
+ class_label = predict_image_class(img, model, class_names)
52
+ print(class_label)
graph_logic.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- START OF FILE graph_logic.py ---
2
+
3
+ import os
4
+ from typing import TypedDict, Annotated, List
5
+ import operator
6
+ from dotenv import load_dotenv
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_anthropic import ChatAnthropic
9
+ from langchain.schema import BaseMessage, AIMessage
10
+ from langgraph.graph import StateGraph, END
11
+ from langgraph.checkpoint.memory import MemorySaver # Optional for checkpointing
12
+
13
+ # Load environment variables (optional here, but good practice if testing independently)
14
+ # load_dotenv() # Can be commented out if only app_main.py loads it
15
+
16
+ # --- LangGraph State Definition ---
17
+
18
+ class GraphState(TypedDict):
19
+ """
20
+ Represents the state of our graph.
21
+
22
+ Attributes:
23
+ messages: The list of messages comprising the conversation.
24
+ operator.add indicates messages should be appended.
25
+ """
26
+ messages: Annotated[List[BaseMessage], operator.add]
27
+
28
+ # --- LLM Initialization ---
29
+
30
+ def initialize_llm(provider: str, model_name: str, temperature: float, api_key: str):
31
+ """Initializes the appropriate LangChain Chat Model."""
32
+ if provider == "OpenAI":
33
+ if not api_key:
34
+ raise ValueError("OpenAI API key is missing. Please set OPENAI_API_KEY.")
35
+ return ChatOpenAI(api_key=api_key, model_name=model_name, temperature=temperature)
36
+ elif provider == "Anthropic":
37
+ if not api_key:
38
+ raise ValueError("Anthropic API key is missing. Please set ANTHROPIC_API_KEY.")
39
+ return ChatAnthropic(api_key=api_key, model_name=model_name, temperature=temperature)
40
+ else:
41
+ raise ValueError(f"Unsupported LLM provider: {provider}")
42
+
43
+ # --- LangGraph Node and Graph Building ---
44
+
45
+ def create_chat_graph(llm):
46
+ """
47
+ Builds and compiles the LangGraph conversational graph.
48
+
49
+ Args:
50
+ llm: An initialized LangChain Chat Model instance.
51
+
52
+ Returns:
53
+ A compiled LangGraph application.
54
+ """
55
+
56
+ # Define the function that calls the LLM - it closes over the 'llm' variable
57
+ def call_model(state: GraphState) -> dict:
58
+ """Invokes the provided LLM with the current conversation state."""
59
+ messages = state['messages']
60
+ response = llm.invoke(messages)
61
+ # Return the AIMessage list to be added to the state
62
+ return {"messages": [response]}
63
+
64
+ # Build the graph workflow
65
+ workflow = StateGraph(GraphState)
66
+
67
+ # Add the single node that runs the LLM
68
+ workflow.add_node("llm_node", call_model)
69
+
70
+ # Set the entry point and the only edge
71
+ workflow.set_entry_point("llm_node")
72
+ workflow.add_edge("llm_node", END) # Conversation ends after one LLM call per turn
73
+
74
+ # Compile the graph
75
+ # Optional: Add memory for checkpointing if needed
76
+ # memory = MemorySaver()
77
+ # graph = workflow.compile(checkpointer=memory)
78
+ graph = workflow.compile()
79
+
80
+ return graph
81
+
82
+ # --- END OF FILE graph_logic.py ---
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core LangChain packages
2
+ langchain
3
+ langchain-core
4
+
5
+ # LangGraph
6
+ langgraph
7
+
8
+ # Pydantic
9
+ pydantic
10
+
11
+ # OpenAI integration (GPT-4, GPT-4o)
12
+ langchain-openai
13
+ openai
14
+
15
+ # Anthropic integration (Claude 3)
16
+ langchain-anthropic
17
+ anthropic
18
+
19
+ # Google Gemini integration (Gemini Pro / Flash)
20
+ langchain-google-genai
21
+ google-generativeai
22
+
23
+ # Groq integration (Groq)
24
+ langchain-groq
25
+ groq
26
+
27
+ # Data processing
28
+ pandas
29
+ PyPDF2
30
+
31
+ # Optional: for improved dev type checking
32
+ typing-extensions
33
+
34
+ # Environment variables
35
+ python-dotenv
36
+
37
+ # for streamlit
38
+ streamlit
39
+
40
+ # for classification model
41
+ torch
42
+ torchvision
43
+ numpy
44
+ PIL