added necessary files for app to run on HF
Browse files- app.py +256 -0
- classification_model/model.py +93 -0
- classification_model/predict.py +52 -0
- graph_logic.py +82 -0
- requirements.txt +44 -0
app.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- START OF FILE app_main.py ---
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
# Ensure classification_model is importable (e.g., it's in the same directory orPYTHONPATH)
|
| 8 |
+
try:
|
| 9 |
+
from classification_model.model import VGG16Model
|
| 10 |
+
except ImportError:
|
| 11 |
+
st.error("Could not import VGG16Model. Make sure 'classification_model/model.py' exists and is importable.")
|
| 12 |
+
st.stop()
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import os
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
from langchain.schema import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
| 18 |
+
import io
|
| 19 |
+
|
| 20 |
+
# Import graph logic from the other file
|
| 21 |
+
try:
|
| 22 |
+
from graph_logic import initialize_llm, create_chat_graph, GraphState
|
| 23 |
+
except ImportError:
|
| 24 |
+
st.error("Could not import from graph_logic.py. Make sure the file exists in the same directory.")
|
| 25 |
+
st.stop()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Load environment variables from .env file
|
| 29 |
+
load_dotenv()
|
| 30 |
+
|
| 31 |
+
# --- Helper Functions (Specific to App/UI) ---
|
| 32 |
+
|
| 33 |
+
@st.cache_resource
|
| 34 |
+
def load_vgg16(path='./classification_model/best_model.pth', device='cpu'):
|
| 35 |
+
"""Loads the pre-trained VGG16 model."""
|
| 36 |
+
model = VGG16Model()
|
| 37 |
+
if not os.path.exists(path):
|
| 38 |
+
st.error(f"Model file not found at {path}. Please ensure the model is present.")
|
| 39 |
+
st.stop()
|
| 40 |
+
try:
|
| 41 |
+
model.load_state_dict(torch.load(path, map_location=device))
|
| 42 |
+
model.eval()
|
| 43 |
+
except Exception as e:
|
| 44 |
+
st.error(f"Error loading model state dict: {e}")
|
| 45 |
+
st.stop()
|
| 46 |
+
return model
|
| 47 |
+
|
| 48 |
+
@st.cache_resource
|
| 49 |
+
def load_class_labels() -> dict:
|
| 50 |
+
"""Loads the class labels."""
|
| 51 |
+
return {0: 'good_images', 1: 'Imaging Artifact', 2: 'Not Tracking', 3: 'Tip Contamination'}
|
| 52 |
+
|
| 53 |
+
def preprocess_image(img: Image.Image):
|
| 54 |
+
"""Preprocesses the image for the VGG16 model."""
|
| 55 |
+
preprocess = transforms.Compose([
|
| 56 |
+
transforms.Resize((224, 224)),
|
| 57 |
+
transforms.ToTensor(),
|
| 58 |
+
transforms.Normalize(mean=[0.3718, 0.1738, 0.0571], std=[0.2095, 0.2124, 0.1321]),
|
| 59 |
+
])
|
| 60 |
+
img_tensor = preprocess(img).unsqueeze(0)
|
| 61 |
+
return img_tensor
|
| 62 |
+
|
| 63 |
+
def predict_image_class(img: Image.Image, model, class_names: dict) -> str:
|
| 64 |
+
"""Predicts the class label for a given image."""
|
| 65 |
+
img_tensor = preprocess_image(img)
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
outputs = model(img_tensor)
|
| 68 |
+
probs = F.softmax(outputs, dim=1)
|
| 69 |
+
top_prob, top_idx = torch.topk(probs, 1)
|
| 70 |
+
class_label = class_names.get(top_idx.item(), "Unknown Class")
|
| 71 |
+
return class_label
|
| 72 |
+
|
| 73 |
+
def download_chat_history() -> str:
|
| 74 |
+
"""Generates the chat history text for download."""
|
| 75 |
+
if "messages" not in st.session_state or not st.session_state.messages:
|
| 76 |
+
return ""
|
| 77 |
+
output = io.StringIO()
|
| 78 |
+
start_index = 1 if isinstance(st.session_state.messages[0], SystemMessage) else 0
|
| 79 |
+
for msg in st.session_state.messages[start_index:]:
|
| 80 |
+
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
|
| 81 |
+
output.write(f"{role}: {msg.content}\n")
|
| 82 |
+
return output.getvalue()
|
| 83 |
+
|
| 84 |
+
# --- Streamlit UI ---
|
| 85 |
+
|
| 86 |
+
st.set_page_config(page_title="AFM Defect Assistant (LangGraph)", page_icon="🔬")
|
| 87 |
+
|
| 88 |
+
st.title("🔬 AFM Image Defect Classification + LLM-based AFM Assistant")
|
| 89 |
+
st.write("Upload an AFM image, get a classification, and chat with an AI assistant about the result.")
|
| 90 |
+
|
| 91 |
+
# --- Sidebar Controls ---
|
| 92 |
+
st.sidebar.header("Settings")
|
| 93 |
+
|
| 94 |
+
# Model Selection
|
| 95 |
+
provider = st.sidebar.selectbox("LLM Provider", ["OpenAI", "Anthropic"], key="provider_select")
|
| 96 |
+
|
| 97 |
+
api_key = None
|
| 98 |
+
api_key_name = ""
|
| 99 |
+
if provider == "OpenAI":
|
| 100 |
+
default_model = "gpt-4o"
|
| 101 |
+
available_models = ["gpt-4o", "o3-mini"]
|
| 102 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 103 |
+
api_key_name = "OPENAI_API_KEY"
|
| 104 |
+
elif provider == "Anthropic":
|
| 105 |
+
default_model = "claude-3-5-sonnet-latest"
|
| 106 |
+
available_models = ["claude-3-5-sonnet-latest", "claude-3-7-sonnet-latest"]
|
| 107 |
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
| 108 |
+
api_key_name = "ANTHROPIC_API_KEY"
|
| 109 |
+
else:
|
| 110 |
+
st.sidebar.error("Invalid provider selected.")
|
| 111 |
+
st.stop()
|
| 112 |
+
|
| 113 |
+
# Display warning if API key is missing
|
| 114 |
+
if not api_key:
|
| 115 |
+
st.sidebar.warning(f"{provider} API key not found. Please set the {api_key_name} environment variable.")
|
| 116 |
+
|
| 117 |
+
model_name = st.sidebar.selectbox(f"Choose {provider} Model", available_models, index=available_models.index(default_model), key="model_name_select")
|
| 118 |
+
temperature = st.sidebar.slider("LLM Temperature", 0.0, 1.0, 0.3, 0.05, key="temp_slider")
|
| 119 |
+
|
| 120 |
+
# Clear Chat Button
|
| 121 |
+
if st.sidebar.button("Start New Session", key="clear_chat_button"):
|
| 122 |
+
keys_to_clear = ["messages", "current_label", "uploaded_file_state", "llm", "graph"]
|
| 123 |
+
for key in keys_to_clear:
|
| 124 |
+
if key in st.session_state:
|
| 125 |
+
del st.session_state[key]
|
| 126 |
+
st.rerun()
|
| 127 |
+
|
| 128 |
+
# --- Main Page Logic ---
|
| 129 |
+
|
| 130 |
+
# File Uploader
|
| 131 |
+
uploaded_file = st.file_uploader("Upload an AFM image", type=["jpg", "jpeg", "png"], key="file_uploader")
|
| 132 |
+
|
| 133 |
+
# Manage state based on uploaded file
|
| 134 |
+
if uploaded_file is not None:
|
| 135 |
+
new_file_id = uploaded_file.file_id
|
| 136 |
+
# Check if it's a new file or the same one to avoid re-processing
|
| 137 |
+
if "uploaded_file_state" not in st.session_state or st.session_state.uploaded_file_state["id"] != new_file_id:
|
| 138 |
+
st.session_state.uploaded_file_state = {"id": new_file_id, "name": uploaded_file.name}
|
| 139 |
+
# Clear previous chat/state if a new file is uploaded
|
| 140 |
+
keys_to_reset = ["messages", "current_label", "llm", "graph"]
|
| 141 |
+
for key in keys_to_reset:
|
| 142 |
+
if key in st.session_state:
|
| 143 |
+
del st.session_state[key]
|
| 144 |
+
|
| 145 |
+
# --- Image Processing and Classification ---
|
| 146 |
+
try:
|
| 147 |
+
img = Image.open(uploaded_file).convert("RGB")
|
| 148 |
+
st.image(img, caption=f"Uploaded: {st.session_state.uploaded_file_state['name']}", width=200)
|
| 149 |
+
|
| 150 |
+
model = load_vgg16()
|
| 151 |
+
class_names = load_class_labels()
|
| 152 |
+
|
| 153 |
+
with st.spinner("Classifying image..."):
|
| 154 |
+
class_label = predict_image_class(img, model, class_names)
|
| 155 |
+
st.success(f"**Predicted Class Label:** {class_label}")
|
| 156 |
+
|
| 157 |
+
# --- LLM and Graph Initialization ---
|
| 158 |
+
label_changed = ("current_label" not in st.session_state or
|
| 159 |
+
st.session_state.current_label != class_label)
|
| 160 |
+
|
| 161 |
+
# Initialize LLM and Graph if not present or if label changed
|
| 162 |
+
if "llm" not in st.session_state or "graph" not in st.session_state or label_changed:
|
| 163 |
+
if not api_key:
|
| 164 |
+
st.error(f"Cannot proceed without {api_key_name}. Please set it in your environment variables.")
|
| 165 |
+
st.stop()
|
| 166 |
+
try:
|
| 167 |
+
# Initialize LLM using the function from graph_logic.py
|
| 168 |
+
st.session_state.llm = initialize_llm(provider, model_name, temperature, api_key)
|
| 169 |
+
# Create the graph using the function from graph_logic.py
|
| 170 |
+
st.session_state.graph = create_chat_graph(st.session_state.llm)
|
| 171 |
+
st.session_state.current_label = class_label
|
| 172 |
+
|
| 173 |
+
# Define the system prompt and initial state for the graph
|
| 174 |
+
system_prompt_content = (
|
| 175 |
+
f"You are an expert in atomic force microscopy (AFM). "
|
| 176 |
+
f"The user has uploaded an image and it has '{class_label}' defect. "
|
| 177 |
+
"Your role is to help the user understand this defect, potential causes, "
|
| 178 |
+
"and how to potentially avoid or address the issue represented by this defect. "
|
| 179 |
+
"Provide concise, technically accurate, and helpful answers. Avoid speculation if unsure."
|
| 180 |
+
)
|
| 181 |
+
system_message = SystemMessage(content=system_prompt_content)
|
| 182 |
+
st.session_state.messages = [system_message] # Initialize message history
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
st.error(f"Failed to initialize LLM or Graph: {e}")
|
| 186 |
+
st.stop()
|
| 187 |
+
|
| 188 |
+
# --- Chat Interface ---
|
| 189 |
+
st.divider()
|
| 190 |
+
st.header(f"Chat about '{st.session_state.current_label}'")
|
| 191 |
+
|
| 192 |
+
# Display existing messages (skip system message)
|
| 193 |
+
if "messages" in st.session_state:
|
| 194 |
+
for i, msg in enumerate(st.session_state.messages):
|
| 195 |
+
if i == 0 and isinstance(msg, SystemMessage):
|
| 196 |
+
continue
|
| 197 |
+
with st.chat_message("user" if isinstance(msg, HumanMessage) else "assistant"):
|
| 198 |
+
st.markdown(msg.content)
|
| 199 |
+
|
| 200 |
+
# Chat input
|
| 201 |
+
if prompt := st.chat_input("Ask a question about the detected defect..."):
|
| 202 |
+
# Add user message to state and display it
|
| 203 |
+
st.session_state.messages.append(HumanMessage(content=prompt))
|
| 204 |
+
with st.chat_message("user"):
|
| 205 |
+
st.markdown(prompt)
|
| 206 |
+
|
| 207 |
+
# Prepare the input state for the graph
|
| 208 |
+
current_graph_state: GraphState = {"messages": st.session_state.messages}
|
| 209 |
+
|
| 210 |
+
# Invoke the graph
|
| 211 |
+
with st.chat_message("assistant"):
|
| 212 |
+
with st.spinner("Thinking..."):
|
| 213 |
+
try:
|
| 214 |
+
# Invoke the graph with the current state
|
| 215 |
+
response_state = st.session_state.graph.invoke(current_graph_state)
|
| 216 |
+
|
| 217 |
+
# Update session state with the full response history from the graph
|
| 218 |
+
st.session_state.messages = response_state['messages']
|
| 219 |
+
ai_response_content = st.session_state.messages[-1].content
|
| 220 |
+
st.markdown(ai_response_content)
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
st.error(f"Error during chat generation: {e}")
|
| 224 |
+
# Roll back user message if AI fails
|
| 225 |
+
if st.session_state.messages and isinstance(st.session_state.messages[-1], HumanMessage):
|
| 226 |
+
st.session_state.messages.pop()
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# --- Download Chat Button ---
|
| 230 |
+
if len(st.session_state.get("messages", [])) > 1: # Show only if conversation started
|
| 231 |
+
st.divider()
|
| 232 |
+
chat_text = download_chat_history()
|
| 233 |
+
st.download_button(
|
| 234 |
+
label="Download Chat History",
|
| 235 |
+
data=chat_text,
|
| 236 |
+
file_name=f"afm_chat_{st.session_state.current_label.replace(' ', '_')}.txt",
|
| 237 |
+
mime="text/plain"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
st.error(f"An error occurred processing the image or during chat setup: {e}")
|
| 242 |
+
if "uploaded_file_state" in st.session_state:
|
| 243 |
+
del st.session_state.uploaded_file_state # Reset if critical error occurs
|
| 244 |
+
|
| 245 |
+
elif "uploaded_file_state" in st.session_state:
|
| 246 |
+
# If file uploader is cleared by the user after a file was processed
|
| 247 |
+
keys_to_clear = ["messages", "current_label", "uploaded_file_state", "llm", "graph"]
|
| 248 |
+
for key in keys_to_clear:
|
| 249 |
+
if key in st.session_state:
|
| 250 |
+
del st.session_state[key]
|
| 251 |
+
st.rerun()
|
| 252 |
+
|
| 253 |
+
else:
|
| 254 |
+
st.info("Please upload an image to start the analysis and chat.")
|
| 255 |
+
|
| 256 |
+
# --- END OF FILE app_main.py ---
|
classification_model/model.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models as models
|
| 4 |
+
from torchvision.models import VGG16_Weights
|
| 5 |
+
|
| 6 |
+
class VGG16Model(nn.Module):
|
| 7 |
+
def __init__(self, num_classes=4):
|
| 8 |
+
super(VGG16Model, self).__init__()
|
| 9 |
+
#self.vgg16 = models.vgg16(pretrained=True)
|
| 10 |
+
# Replace pretrained=True with weights=VGG16_Weights.DEFAULT
|
| 11 |
+
#self.vgg16 = models.vgg16(weights=VGG16_Weights.DEFAULT)
|
| 12 |
+
# Manually load the weights
|
| 13 |
+
self.vgg16 = models.vgg16()
|
| 14 |
+
# state_dict = torch.load("/work/mech-ai/angona3/Trial/vgg16_weights.pth", map_location=torch.device("cpu"))
|
| 15 |
+
# self.vgg16.load_state_dict(state_dict)
|
| 16 |
+
self.vgg16.classifier[6] = nn.Linear(4096, num_classes) ##In PyTorch CrossEntropyLoss Handles Softmax Internally
|
| 17 |
+
###Adding an explicit softmax layer would create numerical instability by double-applying softmax (once in your model, once in the loss),
|
| 18 |
+
# leading to incorrect gradient calculations during training.
|
| 19 |
+
|
| 20 |
+
self.vgg16.classifier[2] = nn.Dropout(p=0.6) # or 0.7 ##introducing dropout for regularization
|
| 21 |
+
self.vgg16.classifier[5] = nn.Dropout(p=0.6) # or 0.7
|
| 22 |
+
|
| 23 |
+
# Freeze all the convolutional layers (feature extractor part)
|
| 24 |
+
# The classifier layers (fully connected layers) remain trainable
|
| 25 |
+
for param in self.vgg16.features.parameters():
|
| 26 |
+
param.requires_grad = False
|
| 27 |
+
#param.requires_grad = True # Unfreeze the last two fully connected layers
|
| 28 |
+
|
| 29 |
+
# Unfreeze Conv Block 4 and Conv Block 5 (512 filters, 3x3 filters, same padding)
|
| 30 |
+
# best: conv_layers_to_unfreeze = [17, 19, 21, 24, 26, 28]
|
| 31 |
+
conv_layers_to_unfreeze = [17, 19, 21, 24, 26, 28]
|
| 32 |
+
for layer_idx in conv_layers_to_unfreeze:
|
| 33 |
+
for param in self.vgg16.features[layer_idx].parameters():
|
| 34 |
+
param.requires_grad = True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Unfreeze the last two fully connected layers
|
| 40 |
+
# Unfreeze all fully connected layers
|
| 41 |
+
for param in self.vgg16.classifier.parameters():
|
| 42 |
+
param.requires_grad = True
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
return self.vgg16(x)
|
| 46 |
+
|
| 47 |
+
#Fine-tuning the entire network can lead to better performance compared to freezing layers
|
| 48 |
+
##because the model can adjust both the feature extractor and classifier to your specific dataset.
|
| 49 |
+
##Yes, nn.CrossEntropyLoss in PyTorch is explicitly designed to handle raw logits (ranging from −∞to +∞) directly and efficiently.
|
| 50 |
+
##CrossEntropyLoss=Softmax(logits)+Log+NLLLoss
|
| 51 |
+
|
| 52 |
+
##### In newer versions of torchvision, the weights argument replaces pretrained.
|
| 53 |
+
# #The VGG16_Weights.DEFAULT is the new way to specify that anyone want the pretrained weights, and it is the preferred method.
|
| 54 |
+
##Deprecation Warning Fix: Replace pretrained=True with weights=VGG16_Weights.DEFAULT.
|
| 55 |
+
##Security Warning Fix: You don't need to use torch.load in this case because you're not loading a pre-trained model from a .pth file.
|
| 56 |
+
##from torchvision.models import VGG16_Weights
|
| 57 |
+
## self.vgg16 = models.vgg16(weights=VGG16_Weights.DEFAULT)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
## adding dropout to VGG16
|
| 61 |
+
#class VGG16Model(nn.Module):
|
| 62 |
+
# def __init__(self, num_classes=4):
|
| 63 |
+
# super(VGG16Model, self).__init__()
|
| 64 |
+
# self.vgg16 = models.vgg16(pretrained=True)
|
| 65 |
+
# self.vgg16.classifier[6] = nn.Sequential(
|
| 66 |
+
# nn.Linear(4096, 1024), # First layer reduces dimensions to 1024
|
| 67 |
+
# nn.ReLU(), # Adds non-linearity to increase learning capacity
|
| 68 |
+
# nn.Dropout(0.5), # Introduces dropout for regularization
|
| 69 |
+
# nn.Linear(1024, num_classes) # Second layer maps to the desired number of classes
|
| 70 |
+
# )
|
| 71 |
+
|
| 72 |
+
# def forward(self, x):
|
| 73 |
+
# return self.vgg16(x)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
#import torch.nn as nn
|
| 77 |
+
#from torchvision import models
|
| 78 |
+
|
| 79 |
+
#def get_resnet18_model(num_classes):
|
| 80 |
+
# model = models.resnet18(pretrained=True)
|
| 81 |
+
# model.fc = nn.Linear(model.fc.in_features, num_classes)
|
| 82 |
+
# return model
|
| 83 |
+
|
| 84 |
+
##VGG16 structure
|
| 85 |
+
#(vgg16.classifier): Sequential(
|
| 86 |
+
# (0): Linear(25088, 4096)
|
| 87 |
+
# (1): ReLU(inplace=True)
|
| 88 |
+
# (2): Dropout(p=0.5, inplace=False)
|
| 89 |
+
# (3): Linear(4096, 4096)
|
| 90 |
+
# (4): ReLU(inplace=True)
|
| 91 |
+
# (5): Dropout(p=0.5, inplace=False)
|
| 92 |
+
# (6): Linear(4096, num_classes)
|
| 93 |
+
#)
|
classification_model/predict.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from model import VGG16Model
|
| 6 |
+
# Load your PyTorch VGG16 model (pre-trained)
|
| 7 |
+
|
| 8 |
+
def load_vgg16():
|
| 9 |
+
model = VGG16Model()
|
| 10 |
+
model.load_state_dict(torch.load("/work/mech-ai/jrrade/AFM/AFM-LLM-Defect-Guidance/classification_model/best_model.pth", map_location=torch.device('cpu'))) # Or GPU if available
|
| 11 |
+
model.eval()
|
| 12 |
+
return model
|
| 13 |
+
|
| 14 |
+
# Load ImageNet class labels (you need a mapping file)
|
| 15 |
+
|
| 16 |
+
def load_class_labels():
|
| 17 |
+
class_names = {0:'good_images', 1:'Imaging Artifact', 2:'Not Tracking', 3:'Tip Contamination'}
|
| 18 |
+
return class_names
|
| 19 |
+
|
| 20 |
+
# Preprocess uploaded image
|
| 21 |
+
def preprocess_image(img):
|
| 22 |
+
preprocess = transforms.Compose([
|
| 23 |
+
transforms.Resize((224, 224)),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
transforms.Normalize(
|
| 26 |
+
mean=[0.3718, 0.1738, 0.0571],
|
| 27 |
+
std=[0.2095, 0.2124, 0.1321]
|
| 28 |
+
),
|
| 29 |
+
])
|
| 30 |
+
img_tensor = preprocess(img).unsqueeze(0) # Add batch dimension
|
| 31 |
+
return img_tensor
|
| 32 |
+
|
| 33 |
+
# Predict using PyTorch model
|
| 34 |
+
def predict_image_class(img, model, class_names):
|
| 35 |
+
img_tensor = preprocess_image(img)
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
outputs = model(img_tensor)
|
| 38 |
+
_, preds = torch.max(outputs, dim=1)
|
| 39 |
+
print(_, preds)
|
| 40 |
+
probs = F.softmax(outputs, dim=1)
|
| 41 |
+
top_prob, top_idx = torch.topk(probs, 1)
|
| 42 |
+
print(top_prob, top_idx)
|
| 43 |
+
class_label = class_names[top_idx.item()]
|
| 44 |
+
return class_label
|
| 45 |
+
|
| 46 |
+
img_path = '/work/mech-ai/angona3/Trial/image/Not_Tracking/Not_Tracking_21.jpg'
|
| 47 |
+
# img_path = '/work/mech-ai/angona3/Trial/image/Tip_Contamination/Tip_Contamination_17.jpg'
|
| 48 |
+
img = Image.open(img_path)
|
| 49 |
+
model = load_vgg16()
|
| 50 |
+
class_names = load_class_labels()
|
| 51 |
+
class_label = predict_image_class(img, model, class_names)
|
| 52 |
+
print(class_label)
|
graph_logic.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- START OF FILE graph_logic.py ---
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import TypedDict, Annotated, List
|
| 5 |
+
import operator
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from langchain_openai import ChatOpenAI
|
| 8 |
+
from langchain_anthropic import ChatAnthropic
|
| 9 |
+
from langchain.schema import BaseMessage, AIMessage
|
| 10 |
+
from langgraph.graph import StateGraph, END
|
| 11 |
+
from langgraph.checkpoint.memory import MemorySaver # Optional for checkpointing
|
| 12 |
+
|
| 13 |
+
# Load environment variables (optional here, but good practice if testing independently)
|
| 14 |
+
# load_dotenv() # Can be commented out if only app_main.py loads it
|
| 15 |
+
|
| 16 |
+
# --- LangGraph State Definition ---
|
| 17 |
+
|
| 18 |
+
class GraphState(TypedDict):
|
| 19 |
+
"""
|
| 20 |
+
Represents the state of our graph.
|
| 21 |
+
|
| 22 |
+
Attributes:
|
| 23 |
+
messages: The list of messages comprising the conversation.
|
| 24 |
+
operator.add indicates messages should be appended.
|
| 25 |
+
"""
|
| 26 |
+
messages: Annotated[List[BaseMessage], operator.add]
|
| 27 |
+
|
| 28 |
+
# --- LLM Initialization ---
|
| 29 |
+
|
| 30 |
+
def initialize_llm(provider: str, model_name: str, temperature: float, api_key: str):
|
| 31 |
+
"""Initializes the appropriate LangChain Chat Model."""
|
| 32 |
+
if provider == "OpenAI":
|
| 33 |
+
if not api_key:
|
| 34 |
+
raise ValueError("OpenAI API key is missing. Please set OPENAI_API_KEY.")
|
| 35 |
+
return ChatOpenAI(api_key=api_key, model_name=model_name, temperature=temperature)
|
| 36 |
+
elif provider == "Anthropic":
|
| 37 |
+
if not api_key:
|
| 38 |
+
raise ValueError("Anthropic API key is missing. Please set ANTHROPIC_API_KEY.")
|
| 39 |
+
return ChatAnthropic(api_key=api_key, model_name=model_name, temperature=temperature)
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f"Unsupported LLM provider: {provider}")
|
| 42 |
+
|
| 43 |
+
# --- LangGraph Node and Graph Building ---
|
| 44 |
+
|
| 45 |
+
def create_chat_graph(llm):
|
| 46 |
+
"""
|
| 47 |
+
Builds and compiles the LangGraph conversational graph.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
llm: An initialized LangChain Chat Model instance.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
A compiled LangGraph application.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# Define the function that calls the LLM - it closes over the 'llm' variable
|
| 57 |
+
def call_model(state: GraphState) -> dict:
|
| 58 |
+
"""Invokes the provided LLM with the current conversation state."""
|
| 59 |
+
messages = state['messages']
|
| 60 |
+
response = llm.invoke(messages)
|
| 61 |
+
# Return the AIMessage list to be added to the state
|
| 62 |
+
return {"messages": [response]}
|
| 63 |
+
|
| 64 |
+
# Build the graph workflow
|
| 65 |
+
workflow = StateGraph(GraphState)
|
| 66 |
+
|
| 67 |
+
# Add the single node that runs the LLM
|
| 68 |
+
workflow.add_node("llm_node", call_model)
|
| 69 |
+
|
| 70 |
+
# Set the entry point and the only edge
|
| 71 |
+
workflow.set_entry_point("llm_node")
|
| 72 |
+
workflow.add_edge("llm_node", END) # Conversation ends after one LLM call per turn
|
| 73 |
+
|
| 74 |
+
# Compile the graph
|
| 75 |
+
# Optional: Add memory for checkpointing if needed
|
| 76 |
+
# memory = MemorySaver()
|
| 77 |
+
# graph = workflow.compile(checkpointer=memory)
|
| 78 |
+
graph = workflow.compile()
|
| 79 |
+
|
| 80 |
+
return graph
|
| 81 |
+
|
| 82 |
+
# --- END OF FILE graph_logic.py ---
|
requirements.txt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core LangChain packages
|
| 2 |
+
langchain
|
| 3 |
+
langchain-core
|
| 4 |
+
|
| 5 |
+
# LangGraph
|
| 6 |
+
langgraph
|
| 7 |
+
|
| 8 |
+
# Pydantic
|
| 9 |
+
pydantic
|
| 10 |
+
|
| 11 |
+
# OpenAI integration (GPT-4, GPT-4o)
|
| 12 |
+
langchain-openai
|
| 13 |
+
openai
|
| 14 |
+
|
| 15 |
+
# Anthropic integration (Claude 3)
|
| 16 |
+
langchain-anthropic
|
| 17 |
+
anthropic
|
| 18 |
+
|
| 19 |
+
# Google Gemini integration (Gemini Pro / Flash)
|
| 20 |
+
langchain-google-genai
|
| 21 |
+
google-generativeai
|
| 22 |
+
|
| 23 |
+
# Groq integration (Groq)
|
| 24 |
+
langchain-groq
|
| 25 |
+
groq
|
| 26 |
+
|
| 27 |
+
# Data processing
|
| 28 |
+
pandas
|
| 29 |
+
PyPDF2
|
| 30 |
+
|
| 31 |
+
# Optional: for improved dev type checking
|
| 32 |
+
typing-extensions
|
| 33 |
+
|
| 34 |
+
# Environment variables
|
| 35 |
+
python-dotenv
|
| 36 |
+
|
| 37 |
+
# for streamlit
|
| 38 |
+
streamlit
|
| 39 |
+
|
| 40 |
+
# for classification model
|
| 41 |
+
torch
|
| 42 |
+
torchvision
|
| 43 |
+
numpy
|
| 44 |
+
PIL
|