Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch.nn as nn | |
| import os | |
| import onnx | |
| import onnxruntime as ort | |
| import numpy as np | |
| # === Model Path === | |
| MODEL_PATH = "bangla_disaster_model.pth" | |
| ONNX_PATH = "bangla_disaster_model.onnx" | |
| # === Class Labels === | |
| classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD'] | |
| # === Model Architecture (used only for export) === | |
| class MultimodalBanglaClassifier(nn.Module): | |
| def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5): | |
| super(MultimodalBanglaClassifier, self).__init__() | |
| self.text_model = AutoModel.from_pretrained(text_model_name) | |
| for param in self.text_model.encoder.layer[:6].parameters(): | |
| param.requires_grad = False | |
| from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights | |
| self.image_model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1) | |
| self.image_model.classifier = nn.Identity() | |
| self.proj = nn.Linear(768 + 1536, 512) | |
| self.transformer_fusion = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer(d_model=512, nhead=4, batch_first=True), | |
| num_layers=2 | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, input_ids, attention_mask, image): | |
| text_feat = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] | |
| image_feat = self.image_model(image) | |
| fused = self.proj(torch.cat((text_feat, image_feat), dim=1)).unsqueeze(1) | |
| fused = self.transformer_fusion(fused).squeeze(1) | |
| return self.classifier(fused) | |
| # === ONNX Export === | |
| def export_to_onnx_if_needed(model): | |
| if os.path.exists(ONNX_PATH): | |
| return | |
| dummy_input_ids = torch.randint(0, 30522, (1, 128), dtype=torch.long) | |
| dummy_attention_mask = torch.ones((1, 128), dtype=torch.long) | |
| dummy_image = torch.randn(1, 3, 224, 224) | |
| torch.onnx.export( | |
| model, | |
| (dummy_input_ids, dummy_attention_mask, dummy_image), | |
| ONNX_PATH, | |
| input_names=["input_ids", "attention_mask", "image"], | |
| output_names=["output"], | |
| dynamic_axes={"input_ids": {0: "batch"}, "attention_mask": {0: "batch"}, "image": {0: "batch"}, "output": {0: "batch"}}, | |
| opset_version=14, | |
| do_constant_folding=True | |
| ) | |
| # === Load Model and Tokenizer for Exporting Only === | |
| def load_model_and_tokenizer(): | |
| model = MultimodalBanglaClassifier() | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base") | |
| export_to_onnx_if_needed(model) | |
| return tokenizer | |
| # === Load ONNX Session === | |
| def load_onnx_session(): | |
| return ort.InferenceSession(ONNX_PATH) | |
| # === ONNX Prediction === | |
| def predict_with_onnx(session, tokenizer, image, caption): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0).numpy() | |
| encoded = tokenizer(caption, padding="max_length", truncation=True, max_length=128, return_tensors="pt") | |
| input_ids = encoded["input_ids"].numpy() | |
| attention_mask = encoded["attention_mask"].numpy() | |
| inputs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "image": image_tensor | |
| } | |
| outputs = session.run(None, inputs) | |
| logits = outputs[0] | |
| # Softmax to get probabilities | |
| exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True)) | |
| probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True) | |
| pred_class = np.argmax(probs, axis=1)[0] | |
| return classes[pred_class], probs[0].tolist() | |
| # === Bangla Labels === | |
| def get_bangla_response(class_name): | |
| responses = { | |
| 'HYD': "🌊 এটি একটি জলসম্পর্কিত দুর্যোগ (Hydrological Disaster)। সতর্ক থাকুন!", | |
| 'MET': "🌪️ এটি একটি আবহাওয়া সংক্রান্ত দুর্যোগ (Meteorological Disaster)। সাবধানে থাকুন!", | |
| 'FD': "🔥 আগুন লেগেছে! এটি একটি অগ্নিদুর্ঘটনা (Fire Disaster)। দ্রুত ব্যবস্থা নিন!", | |
| 'EQ': "🌍 ভূমিকম্প শনাক্ত হয়েছে (Earthquake)! নিরাপদ স্থানে যান!", | |
| 'OTHD': "😌 এটা কোনো দুর্যোগ নয়। চিন্তার কিছু নেই!" | |
| } | |
| return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি।") | |
| # === Streamlit UI === | |
| st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered") | |
| st.title("🌪️🇧🇩 Bangla Disaster Classifier") | |
| st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।") | |
| tokenizer = load_model_and_tokenizer() | |
| onnx_session = load_onnx_session() | |
| uploaded_file = st.file_uploader( | |
| "🖼️ একটি দুর্যোগের ছবি আপলোড করুন", | |
| type=['jpg', 'png', 'jpeg'], | |
| key="disaster_image_uploader", | |
| help="ছবি আপলোড করতে এখানে ক্লিক করুন অথবা drag & drop করুন" | |
| ) | |
| if uploaded_file: | |
| st.success(f"✅ ছবি আপলোড সফল: {uploaded_file.name}") | |
| else: | |
| st.info("📁 অনুগ্রহ করে একটি ছবি আপলোড করুন") | |
| caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "") | |
| col1, col2 = st.columns([1, 1]) | |
| submit = col1.button("🔍 পূর্বাভাস দিন") | |
| clear = col2.button("🧹 রিসেট করুন") | |
| if clear: | |
| st.rerun() | |
| if submit and uploaded_file and caption: | |
| img = Image.open(uploaded_file).convert("RGB") | |
| st.image(img, caption="আপলোড করা ছবি", use_container_width=True) | |
| with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"): | |
| progress_bar = st.progress(0, text="প্রক্রিয়াকরণ শুরু হচ্ছে...") | |
| progress_bar.progress(50, text="বিশ্লেষণ চলছে...") | |
| prediction, probs = predict_with_onnx(onnx_session, tokenizer, img, caption) | |
| progress_bar.progress(100, text="✅ বিশ্লেষণ সম্পন্ন!") | |
| progress_bar.empty() | |
| st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**") | |
| with col2: | |
| st.caption("🎯 উচ্চ নির্ভুলতা মোড (ONNX)") | |
| with st.expander("📈 বিস্তারিত সম্ভাব্যতা (Detailed Probabilities)"): | |
| class_names = { | |
| 'HYD': 'জলসম্পর্কিত দুর্যোগ', | |
| 'MET': 'আবহাওয়া দুর্যোগ', | |
| 'FD': 'অগ্নিদুর্ঘটনা', | |
| 'EQ': 'ভূমিকম্প', | |
| 'OTHD': 'কোনো দুর্যোগ নয়' | |
| } | |
| for i, class_code in enumerate(classes): | |
| percentage = probs[i] | |
| st.write(f"**{class_names[class_code]}**: {percentage:.1%}") | |
| st.progress(min(max(percentage, 0.0), 1.0)) # ensure range [0.0, 1.0] | |