File size: 8,208 Bytes
1350413
 
 
 
 
 
277340e
70d9c39
 
d9a6cb8
277340e
3cbfd11
277340e
d9a6cb8
277340e
d9a6cb8
277340e
 
d9a6cb8
efdbf4b
1350413
efdbf4b
 
 
1350413
 
efdbf4b
 
1350413
efdbf4b
 
 
 
 
 
1350413
efdbf4b
1350413
efdbf4b
 
1350413
 
 
efdbf4b
 
 
 
 
 
d9a6cb8
816bccd
d9a6cb8
7a5fda6
816bccd
 
d9a6cb8
816bccd
 
 
d9a6cb8
816bccd
 
d9a6cb8
868397b
816bccd
 
 
d9a6cb8
f64a78f
efdbf4b
 
d9a6cb8
efdbf4b
 
7a5fda6
d9a6cb8
65d90d0
d9a6cb8
7a5fda6
 
d9a6cb8
277340e
d9a6cb8
70d9c39
efdbf4b
70d9c39
efdbf4b
d9a6cb8
 
efdbf4b
7a5fda6
d9a6cb8
 
 
7a5fda6
 
 
 
 
 
 
 
 
d9a6cb8
 
 
 
 
 
 
7a5fda6
 
70d9c39
 
 
d9a6cb8
70d9c39
7a5fda6
d9a6cb8
1350413
277340e
efdbf4b
70d9c39
277340e
1350413
d9a6cb8
7a5fda6
65d90d0
ea4b0ea
7a5fda6
ea4b0ea
 
 
 
 
d9a6cb8
ea4b0ea
 
 
c914526
70d9c39
65d90d0
efdbf4b
 
70d9c39
277340e
 
816bccd
277340e
 
1350413
816bccd
65d90d0
efdbf4b
d9a6cb8
 
 
70d9c39
d9a6cb8
7a5fda6
f64a78f
7a5fda6
f64a78f
ca56987
 
f64a78f
ca56987
70d9c39
7a5fda6
efdbf4b
ca56987
 
7a5fda6
ca56987
 
 
 
 
d9a6cb8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import os
import onnx
import onnxruntime as ort
import numpy as np

# === Model Path ===
MODEL_PATH = "bangla_disaster_model.pth"
ONNX_PATH = "bangla_disaster_model.onnx"

# === Class Labels ===
classes = ['HYD', 'MET', 'FD', 'EQ', 'OTHD']

# === Model Architecture (used only for export) ===
class MultimodalBanglaClassifier(nn.Module):
    def __init__(self, text_model_name='sagorsarker/bangla-bert-base', num_classes=5):
        super(MultimodalBanglaClassifier, self).__init__()
        self.text_model = AutoModel.from_pretrained(text_model_name)
        for param in self.text_model.encoder.layer[:6].parameters():
            param.requires_grad = False

        from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
        self.image_model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
        self.image_model.classifier = nn.Identity()

        self.proj = nn.Linear(768 + 1536, 512)
        self.transformer_fusion = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=512, nhead=4, batch_first=True),
            num_layers=2
        )
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask, image):
        text_feat = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        image_feat = self.image_model(image)
        fused = self.proj(torch.cat((text_feat, image_feat), dim=1)).unsqueeze(1)
        fused = self.transformer_fusion(fused).squeeze(1)
        return self.classifier(fused)

# === ONNX Export ===
def export_to_onnx_if_needed(model):
    if os.path.exists(ONNX_PATH):
        return
    dummy_input_ids = torch.randint(0, 30522, (1, 128), dtype=torch.long)
    dummy_attention_mask = torch.ones((1, 128), dtype=torch.long)
    dummy_image = torch.randn(1, 3, 224, 224)
    torch.onnx.export(
        model,
        (dummy_input_ids, dummy_attention_mask, dummy_image),
        ONNX_PATH,
        input_names=["input_ids", "attention_mask", "image"],
        output_names=["output"],
        dynamic_axes={"input_ids": {0: "batch"}, "attention_mask": {0: "batch"}, "image": {0: "batch"}, "output": {0: "batch"}},
        opset_version=14,
        do_constant_folding=True
    )

# === Load Model and Tokenizer for Exporting Only ===
@st.cache_resource
def load_model_and_tokenizer():
    model = MultimodalBanglaClassifier()
    model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
    export_to_onnx_if_needed(model)
    return tokenizer

# === Load ONNX Session ===
@st.cache_resource
def load_onnx_session():
    return ort.InferenceSession(ONNX_PATH)

# === ONNX Prediction ===
def predict_with_onnx(session, tokenizer, image, caption):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    image_tensor = transform(image).unsqueeze(0).numpy()
    encoded = tokenizer(caption, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    input_ids = encoded["input_ids"].numpy()
    attention_mask = encoded["attention_mask"].numpy()

    inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "image": image_tensor
    }
    outputs = session.run(None, inputs)
    logits = outputs[0]

    # Softmax to get probabilities
    exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
    probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
    pred_class = np.argmax(probs, axis=1)[0]
    return classes[pred_class], probs[0].tolist()

# === Bangla Labels ===
def get_bangla_response(class_name):
    responses = {
        'HYD': "🌊 এটি একটি জলসম্পর্কিত দুর্যোগ (Hydrological Disaster)। সতর্ক থাকুন!",
        'MET': "🌪️ এটি একটি আবহাওয়া সংক্রান্ত দুর্যোগ (Meteorological Disaster)। সাবধানে থাকুন!",
        'FD': "🔥 আগুন লেগেছে! এটি একটি অগ্নিদুর্ঘটনা (Fire Disaster)। দ্রুত ব্যবস্থা নিন!",
        'EQ': "🌍 ভূমিকম্প শনাক্ত হয়েছে (Earthquake)! নিরাপদ স্থানে যান!",
        'OTHD': "😌 এটা কোনো দুর্যোগ নয়। চিন্তার কিছু নেই!"
    }
    return responses.get(class_name, "🤔 শ্রেণিবিন্যাস করা যায়নি।")

# === Streamlit UI ===
st.set_page_config(page_title="Bangla Disaster Classifier", layout="centered")
st.title("🌪️🇧🇩 Bangla Disaster Classifier")
st.markdown("এই অ্যাপটি একটি multimodal deep learning মডেল ব্যবহার করে ছবির সাথে বাংলা ক্যাপশন বিশ্লেষণ করে দুর্যোগ শনাক্ত করে।")

tokenizer = load_model_and_tokenizer()
onnx_session = load_onnx_session()

uploaded_file = st.file_uploader(
    "🖼️ একটি দুর্যোগের ছবি আপলোড করুন",
    type=['jpg', 'png', 'jpeg'],
    key="disaster_image_uploader",
    help="ছবি আপলোড করতে এখানে ক্লিক করুন অথবা drag & drop করুন"
)

if uploaded_file:
    st.success(f"✅ ছবি আপলোড সফল: {uploaded_file.name}")
else:
    st.info("📁 অনুগ্রহ করে একটি ছবি আপলোড করুন")

caption = st.text_area("✍️ বাংলায় একটি ক্যাপশন লিখুন", "")

col1, col2 = st.columns([1, 1])
submit = col1.button("🔍 পূর্বাভাস দিন")
clear = col2.button("🧹 রিসেট করুন")

if clear:
    st.rerun()

if submit and uploaded_file and caption:
    img = Image.open(uploaded_file).convert("RGB")
    st.image(img, caption="আপলোড করা ছবি", use_container_width=True)

    with st.spinner("🧠 মডেল পূর্বাভাস দিচ্ছে... (Model processing...)"):
        progress_bar = st.progress(0, text="প্রক্রিয়াকরণ শুরু হচ্ছে...")

        progress_bar.progress(50, text="বিশ্লেষণ চলছে...")
        prediction, probs = predict_with_onnx(onnx_session, tokenizer, img, caption)
        progress_bar.progress(100, text="✅ বিশ্লেষণ সম্পন্ন!")

    progress_bar.empty()

    st.markdown(f"### ✅ পূর্বাভাস: {get_bangla_response(prediction)}")
    col1, col2 = st.columns([2, 1])
    with col1:
        st.markdown(f"#### 📊 সম্ভাব্যতা: **{probs[classes.index(prediction)]:.2%}**")
    with col2:
        st.caption("🎯 উচ্চ নির্ভুলতা মোড (ONNX)")

    with st.expander("📈 বিস্তারিত সম্ভাব্যতা (Detailed Probabilities)"):
        class_names = {
            'HYD': 'জলসম্পর্কিত দুর্যোগ',
            'MET': 'আবহাওয়া দুর্যোগ',
            'FD': 'অগ্নিদুর্ঘটনা',
            'EQ': 'ভূমিকম্প',
            'OTHD': 'কোনো দুর্যোগ নয়'
        }
        for i, class_code in enumerate(classes):
            percentage = probs[i]
            st.write(f"**{class_names[class_code]}**: {percentage:.1%}")
            st.progress(min(max(percentage, 0.0), 1.0))  # ensure range [0.0, 1.0]