Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,11 +33,10 @@ def load_openshape(name, to_cpu=False):
|
|
| 33 |
|
| 34 |
|
| 35 |
def load_tripletmix(name, to_cpu=False):
|
| 36 |
-
pce
|
| 37 |
if to_cpu:
|
| 38 |
pce = pce.cpu()
|
| 39 |
-
|
| 40 |
-
return pce, pca
|
| 41 |
|
| 42 |
|
| 43 |
|
|
@@ -81,10 +80,8 @@ def classification_lvis(load_data):
|
|
| 81 |
pc = load_data(prog)
|
| 82 |
col2 = utils.render_pc(pc)
|
| 83 |
prog.progress(0.5, "Running Classification")
|
| 84 |
-
ref_dev = next(
|
| 85 |
-
enc =
|
| 86 |
-
if model_name == "pb-sn-M":
|
| 87 |
-
enc = pc_adapter(enc)
|
| 88 |
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
| 89 |
argsort = torch.argsort(sim, descending=True)
|
| 90 |
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
|
@@ -104,10 +101,8 @@ def classification_custom(load_data, cats):
|
|
| 104 |
feats = clip_model.get_text_features(**tn).float().cpu()
|
| 105 |
|
| 106 |
prog.progress(0.5, "Running Classification")
|
| 107 |
-
ref_dev = next(
|
| 108 |
-
enc =
|
| 109 |
-
if model_name == "pb-sn-M":
|
| 110 |
-
enc = pc_adapter(enc)
|
| 111 |
sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
| 112 |
argsort = torch.argsort(sim, descending=True)
|
| 113 |
pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
|
|
@@ -207,14 +202,7 @@ try:
|
|
| 207 |
|
| 208 |
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
|
| 209 |
st.sidebar.title("TripletMix Demo Configuration Panel")
|
| 210 |
-
|
| 211 |
-
'Model Selection',
|
| 212 |
-
("pb-sn-M", "pb-sn")
|
| 213 |
-
)
|
| 214 |
-
if model_name == "pb-sn-M":
|
| 215 |
-
model_g14, pc_adapter = load_tripletmix('tripletmix-pointbert-shapenet')
|
| 216 |
-
elif model_name == "pb-sn":
|
| 217 |
-
model_g14 = load_openshape('openshape-pointbert-shapenet')
|
| 218 |
task = st.sidebar.selectbox(
|
| 219 |
'Task Selection',
|
| 220 |
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
|
|
@@ -225,6 +213,14 @@ try:
|
|
| 225 |
'Choose the source of categories',
|
| 226 |
("LVIS Categories", "Custom Categories")
|
| 227 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
load_data = utils.input_3d_shape('rpcinput')
|
| 229 |
if cls_mode == "LVIS Categories":
|
| 230 |
st.title("Classification with LVIS Categories")
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def load_tripletmix(name, to_cpu=False):
|
| 36 |
+
pce = openshape.load_pc_encoder_mix(name)
|
| 37 |
if to_cpu:
|
| 38 |
pce = pce.cpu()
|
| 39 |
+
return pce
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
|
|
|
|
| 80 |
pc = load_data(prog)
|
| 81 |
col2 = utils.render_pc(pc)
|
| 82 |
prog.progress(0.5, "Running Classification")
|
| 83 |
+
ref_dev = next(model_classification.parameters()).device
|
| 84 |
+
enc = model_classification(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
|
|
|
|
|
|
|
| 85 |
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
| 86 |
argsort = torch.argsort(sim, descending=True)
|
| 87 |
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
|
|
|
| 101 |
feats = clip_model.get_text_features(**tn).float().cpu()
|
| 102 |
|
| 103 |
prog.progress(0.5, "Running Classification")
|
| 104 |
+
ref_dev = next(model_classification.parameters()).device
|
| 105 |
+
enc = model_classification(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
|
|
|
|
|
|
|
| 106 |
sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
|
| 107 |
argsort = torch.argsort(sim, descending=True)
|
| 108 |
pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
|
|
|
|
| 202 |
|
| 203 |
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
|
| 204 |
st.sidebar.title("TripletMix Demo Configuration Panel")
|
| 205 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
task = st.sidebar.selectbox(
|
| 207 |
'Task Selection',
|
| 208 |
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
|
|
|
|
| 213 |
'Choose the source of categories',
|
| 214 |
("LVIS Categories", "Custom Categories")
|
| 215 |
)
|
| 216 |
+
model_name = st.sidebar.selectbox(
|
| 217 |
+
'Model Selection',
|
| 218 |
+
("pb-Mix", "pb")
|
| 219 |
+
)
|
| 220 |
+
if model_name == "pb-Mix":
|
| 221 |
+
model_classification = load_tripletmix('tripletmix-pointbert-all-modelnet40')
|
| 222 |
+
elif model_name == "pb":
|
| 223 |
+
model_classification = load_openshape('openshape-pointbert-vitg14-rgb')
|
| 224 |
load_data = utils.input_3d_shape('rpcinput')
|
| 225 |
if cls_mode == "LVIS Categories":
|
| 226 |
st.title("Classification with LVIS Categories")
|