bangla-disaster / app.py
pr0ximaCent's picture
Update app.py
d9a6cb8 verified
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import os
import onnx
import onnxruntime as ort
import numpy as np
# === Model Path ===
MODEL_PATH = "bangla_disaster_model.pth"
ONNX_PATH = "bangla_disaster_model.onnx"
# === Class Labels ===
classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD']
# === Model Architecture (used only for export) ===
class MultimodalBanglaClassifier(nn.Module):
def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5):
super(MultimodalBanglaClassifier, self).__init__()
self.text_model = AutoModel.from_pretrained(text_model_name)
for param in self.text_model.encoder.layer[:6].parameters():
param.requires_grad = False
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
self.image_model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
self.image_model.classifier = nn.Identity()
self.proj = nn.Linear(768 + 1536, 512)
self.transformer_fusion = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=4, batch_first=True),
num_layers=2
)
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, input_ids, attention_mask, image):
text_feat = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
image_feat = self.image_model(image)
fused = self.proj(torch.cat((text_feat, image_feat), dim=1)).unsqueeze(1)
fused = self.transformer_fusion(fused).squeeze(1)
return self.classifier(fused)
# === ONNX Export ===
def export_to_onnx_if_needed(model):
if os.path.exists(ONNX_PATH):
return
dummy_input_ids = torch.randint(0, 30522, (1, 128), dtype=torch.long)
dummy_attention_mask = torch.ones((1, 128), dtype=torch.long)
dummy_image = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask, dummy_image),
ONNX_PATH,
input_names=["input_ids", "attention_mask", "image"],
output_names=["output"],
dynamic_axes={"input_ids": {0: "batch"}, "attention_mask": {0: "batch"}, "image": {0: "batch"}, "output": {0: "batch"}},
opset_version=14,
do_constant_folding=True
)
# === Load Model and Tokenizer for Exporting Only ===
@st.cache_resource
def load_model_and_tokenizer():
model = MultimodalBanglaClassifier()
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
export_to_onnx_if_needed(model)
return tokenizer
# === Load ONNX Session ===
@st.cache_resource
def load_onnx_session():
return ort.InferenceSession(ONNX_PATH)
# === ONNX Prediction ===
def predict_with_onnx(session, tokenizer, image, caption):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).numpy()
encoded = tokenizer(caption, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
input_ids = encoded["input_ids"].numpy()
attention_mask = encoded["attention_mask"].numpy()
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"image": image_tensor
}
outputs = session.run(None, inputs)
logits = outputs[0]
# Softmax to get probabilities
exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
pred_class = np.argmax(probs, axis=1)[0]
return classes[pred_class], probs[0].tolist()
# === Bangla Labels ===
def get_bangla_response(class_name):
responses = {
'HYD': "🌊 এটি একটি জলসম্পর্কিত দুর্যোগ (Hydrological Disaster)। সতর্ক থাকুন!",
'MET': "🌪️ এটি একটি আবহাওয়া সংক্রান্ত দুর্যোগ (Meteorological Disaster)। সাবধানে থাকুন!",
'FD': "🔥 আগুন লেগেছে! এটি একটি অগ্নিদুর্ঘটনা (Fire Disaster)। দ্রুত ব্যবস্থা নিন!",
'EQ': "🌍 ভূমিকম্প শনাক্ত হয়েছে (Earthquake)! নিরাপদ স্থানে যান!",
'OTHD': "😌 এটা কোনো দুর্যোগ নয়। চিন্তার কিছু নেই!"
}
return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি।")
# === Streamlit UI ===
st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered")
st.title("🌪️🇧🇩 Bangla Disaster Classifier")
st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।")
tokenizer = load_model_and_tokenizer()
onnx_session = load_onnx_session()
uploaded_file = st.file_uploader(
"🖼️ একটি দুর্যোগের ছবি আপলোড করুন",
type=['jpg', 'png', 'jpeg'],
key="disaster_image_uploader",
help="ছবি আপলোড করতে এখানে ক্লিক করুন অথবা drag & drop করুন"
)
if uploaded_file:
st.success(f"✅ ছবি আপলোড সফল: {uploaded_file.name}")
else:
st.info("📁 অনুগ্রহ করে একটি ছবি আপলোড করুন")
caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "")
col1, col2 = st.columns([1, 1])
submit = col1.button("🔍 পূর্বাভাস দিন")
clear = col2.button("🧹 রিসেট করুন")
if clear:
st.rerun()
if submit and uploaded_file and caption:
img = Image.open(uploaded_file).convert("RGB")
st.image(img, caption="আপলোড করা ছবি", use_container_width=True)
with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"):
progress_bar = st.progress(0, text="প্রক্রিয়াকরণ শুরু হচ্ছে...")
progress_bar.progress(50, text="বিশ্লেষণ চলছে...")
prediction, probs = predict_with_onnx(onnx_session, tokenizer, img, caption)
progress_bar.progress(100, text="✅ বিশ্লেষণ সম্পন্ন!")
progress_bar.empty()
st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}")
col1, col2 = st.columns([2, 1])
with col1:
st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**")
with col2:
st.caption("🎯 উচ্চ নির্ভুলতা মোড (ONNX)")
with st.expander("📈 বিস্তারিত সম্ভাব্যতা (Detailed Probabilities)"):
class_names = {
'HYD': 'জলসম্পর্কিত দুর্যোগ',
'MET': 'আবহাওয়া দুর্যোগ',
'FD': 'অগ্নিদুর্ঘটনা',
'EQ': 'ভূমিকম্প',
'OTHD': 'কোনো দুর্যোগ নয়'
}
for i, class_code in enumerate(classes):
percentage = probs[i]
st.write(f"**{class_names[class_code]}**: {percentage:.1%}")
st.progress(min(max(percentage, 0.0), 1.0)) # ensure range [0.0, 1.0]