Spaces:
Sleeping
Sleeping
File size: 8,208 Bytes
1350413 277340e 70d9c39 d9a6cb8 277340e 3cbfd11 277340e d9a6cb8 277340e d9a6cb8 277340e d9a6cb8 efdbf4b 1350413 efdbf4b 1350413 efdbf4b 1350413 efdbf4b 1350413 efdbf4b 1350413 efdbf4b 1350413 efdbf4b d9a6cb8 816bccd d9a6cb8 7a5fda6 816bccd d9a6cb8 816bccd d9a6cb8 816bccd d9a6cb8 868397b 816bccd d9a6cb8 f64a78f efdbf4b d9a6cb8 efdbf4b 7a5fda6 d9a6cb8 65d90d0 d9a6cb8 7a5fda6 d9a6cb8 277340e d9a6cb8 70d9c39 efdbf4b 70d9c39 efdbf4b d9a6cb8 efdbf4b 7a5fda6 d9a6cb8 7a5fda6 d9a6cb8 7a5fda6 70d9c39 d9a6cb8 70d9c39 7a5fda6 d9a6cb8 1350413 277340e efdbf4b 70d9c39 277340e 1350413 d9a6cb8 7a5fda6 65d90d0 ea4b0ea 7a5fda6 ea4b0ea d9a6cb8 ea4b0ea c914526 70d9c39 65d90d0 efdbf4b 70d9c39 277340e 816bccd 277340e 1350413 816bccd 65d90d0 efdbf4b d9a6cb8 70d9c39 d9a6cb8 7a5fda6 f64a78f 7a5fda6 f64a78f ca56987 f64a78f ca56987 70d9c39 7a5fda6 efdbf4b ca56987 7a5fda6 ca56987 d9a6cb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import os
import onnx
import onnxruntime as ort
import numpy as np
# === Model Path ===
MODEL_PATH = "bangla_disaster_model.pth"
ONNX_PATH = "bangla_disaster_model.onnx"
# === Class Labels ===
classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD']
# === Model Architecture (used only for export) ===
class MultimodalBanglaClassifier(nn.Module):
def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5):
super(MultimodalBanglaClassifier, self).__init__()
self.text_model = AutoModel.from_pretrained(text_model_name)
for param in self.text_model.encoder.layer[:6].parameters():
param.requires_grad = False
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
self.image_model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
self.image_model.classifier = nn.Identity()
self.proj = nn.Linear(768 + 1536, 512)
self.transformer_fusion = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=4, batch_first=True),
num_layers=2
)
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, input_ids, attention_mask, image):
text_feat = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
image_feat = self.image_model(image)
fused = self.proj(torch.cat((text_feat, image_feat), dim=1)).unsqueeze(1)
fused = self.transformer_fusion(fused).squeeze(1)
return self.classifier(fused)
# === ONNX Export ===
def export_to_onnx_if_needed(model):
if os.path.exists(ONNX_PATH):
return
dummy_input_ids = torch.randint(0, 30522, (1, 128), dtype=torch.long)
dummy_attention_mask = torch.ones((1, 128), dtype=torch.long)
dummy_image = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
(dummy_input_ids, dummy_attention_mask, dummy_image),
ONNX_PATH,
input_names=["input_ids", "attention_mask", "image"],
output_names=["output"],
dynamic_axes={"input_ids": {0: "batch"}, "attention_mask": {0: "batch"}, "image": {0: "batch"}, "output": {0: "batch"}},
opset_version=14,
do_constant_folding=True
)
# === Load Model and Tokenizer for Exporting Only ===
@st.cache_resource
def load_model_and_tokenizer():
model = MultimodalBanglaClassifier()
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
export_to_onnx_if_needed(model)
return tokenizer
# === Load ONNX Session ===
@st.cache_resource
def load_onnx_session():
return ort.InferenceSession(ONNX_PATH)
# === ONNX Prediction ===
def predict_with_onnx(session, tokenizer, image, caption):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).numpy()
encoded = tokenizer(caption, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
input_ids = encoded["input_ids"].numpy()
attention_mask = encoded["attention_mask"].numpy()
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"image": image_tensor
}
outputs = session.run(None, inputs)
logits = outputs[0]
# Softmax to get probabilities
exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
pred_class = np.argmax(probs, axis=1)[0]
return classes[pred_class], probs[0].tolist()
# === Bangla Labels ===
def get_bangla_response(class_name):
responses = {
'HYD': "🌊 এটি একটি জলসম্পর্কিত দুর্যোগ (Hydrological Disaster)। সতর্ক থাকুন!",
'MET': "🌪️ এটি একটি আবহাওয়া সংক্রান্ত দুর্যোগ (Meteorological Disaster)। সাবধানে থাকুন!",
'FD': "🔥 আগুন লেগেছে! এটি একটি অগ্নিদুর্ঘটনা (Fire Disaster)। দ্রুত ব্যবস্থা নিন!",
'EQ': "🌍 ভূমিকম্প শনাক্ত হয়েছে (Earthquake)! নিরাপদ স্থানে যান!",
'OTHD': "😌 এটা কোনো দুর্যোগ নয়। চিন্তার কিছু নেই!"
}
return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি।")
# === Streamlit UI ===
st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered")
st.title("🌪️🇧🇩 Bangla Disaster Classifier")
st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।")
tokenizer = load_model_and_tokenizer()
onnx_session = load_onnx_session()
uploaded_file = st.file_uploader(
"🖼️ একটি দুর্যোগের ছবি আপলোড করুন",
type=['jpg', 'png', 'jpeg'],
key="disaster_image_uploader",
help="ছবি আপলোড করতে এখানে ক্লিক করুন অথবা drag & drop করুন"
)
if uploaded_file:
st.success(f"✅ ছবি আপলোড সফল: {uploaded_file.name}")
else:
st.info("📁 অনুগ্রহ করে একটি ছবি আপলোড করুন")
caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "")
col1, col2 = st.columns([1, 1])
submit = col1.button("🔍 পূর্বাভাস দিন")
clear = col2.button("🧹 রিসেট করুন")
if clear:
st.rerun()
if submit and uploaded_file and caption:
img = Image.open(uploaded_file).convert("RGB")
st.image(img, caption="আপলোড করা ছবি", use_container_width=True)
with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"):
progress_bar = st.progress(0, text="প্রক্রিয়াকরণ শুরু হচ্ছে...")
progress_bar.progress(50, text="বিশ্লেষণ চলছে...")
prediction, probs = predict_with_onnx(onnx_session, tokenizer, img, caption)
progress_bar.progress(100, text="✅ বিশ্লেষণ সম্পন্ন!")
progress_bar.empty()
st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}")
col1, col2 = st.columns([2, 1])
with col1:
st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**")
with col2:
st.caption("🎯 উচ্চ নির্ভুলতা মোড (ONNX)")
with st.expander("📈 বিস্তারিত সম্ভাব্যতা (Detailed Probabilities)"):
class_names = {
'HYD': 'জলসম্পর্কিত দুর্যোগ',
'MET': 'আবহাওয়া দুর্যোগ',
'FD': 'অগ্নিদুর্ঘটনা',
'EQ': 'ভূমিকম্প',
'OTHD': 'কোনো দুর্যোগ নয়'
}
for i, class_code in enumerate(classes):
percentage = probs[i]
st.write(f"**{class_names[class_code]}**: {percentage:.1%}")
st.progress(min(max(percentage, 0.0), 1.0)) # ensure range [0.0, 1.0]
|