|
|
import gradio as gr |
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
from utils import compute_features |
|
|
from scipy.stats import nbinom |
|
|
|
|
|
|
|
|
class NegBinomialModel(nn.Module): |
|
|
def __init__(self, in_features): |
|
|
super().__init__() |
|
|
self.linear = nn.Linear(in_features, 1) |
|
|
self.alpha = nn.Parameter(torch.tensor(0.5)) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
mu = torch.exp(torch.clamp(self.linear(x), min=-5, max=5)) |
|
|
alpha = torch.clamp(self.alpha, min=1e-3, max=10) |
|
|
return mu.squeeze(), alpha |
|
|
|
|
|
|
|
|
model = NegBinomialModel(16) |
|
|
model.load_state_dict(torch.load("model_weights.pt", map_location='cpu')) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_score(lat, lon): |
|
|
|
|
|
|
|
|
inputs = compute_features((lat,lon)) |
|
|
print("[INPUTS]", inputs) |
|
|
num_banks = inputs.pop("num_banks_in_radius", 0) |
|
|
|
|
|
inputs = torch.tensor([lat,lon] + list(inputs.values()), dtype=torch.float32) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
mu_pred, alpha = model(inputs) |
|
|
|
|
|
|
|
|
mu_pred = mu_pred.numpy().flatten() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
diff = mu_pred - num_banks |
|
|
score = 100 / (1 + np.exp(-alpha * diff)) |
|
|
|
|
|
score = np.abs(1 + np.tanh(diff)) / 2 * 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
|
round(float(score), 3), |
|
|
num_banks, |
|
|
round(float(mu_pred), 3), |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=predict_score, |
|
|
inputs=[ |
|
|
gr.Number(label="Latitude"), |
|
|
gr.Number(label="Longitude"), |
|
|
], |
|
|
outputs=[ |
|
|
gr.Number(label="Score (0 - 100)"), |
|
|
gr.Number(label="Number of Current Banks"), |
|
|
gr.Number(label="Number of Ideal Banks"), |
|
|
|
|
|
], |
|
|
title="Bank Location Scoring Model", |
|
|
description="Enter latitude and longitude to get the predicted score, number of banks, and normalized score.", |
|
|
) |
|
|
|
|
|
|
|
|
interface.launch() |
|
|
|