Spaces:
Runtime error
Runtime error
eliphatfs
commited on
Commit
·
22e326e
1
Parent(s):
7c9515a
Updates.
Browse files
app.py
CHANGED
|
@@ -69,11 +69,10 @@ def sq(kc, vc):
|
|
| 69 |
|
| 70 |
|
| 71 |
def reset_3d_shape_input(key):
|
| 72 |
-
|
| 73 |
model_key = key + "_model"
|
| 74 |
npy_key = key + "_npy"
|
| 75 |
swap_key = key + "_swap"
|
| 76 |
-
sq(objaid_key, "")
|
| 77 |
sq(model_key, None)
|
| 78 |
sq(npy_key, None)
|
| 79 |
sq(swap_key, "Y is up (for most Objaverse shapes)")
|
|
@@ -121,43 +120,40 @@ def image_examples(samples, ncols, return_key=None):
|
|
| 121 |
return trigger
|
| 122 |
|
| 123 |
|
| 124 |
-
def text_examples(samples):
|
| 125 |
-
return st.selectbox("Or pick an example", samples)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
def demo_classification():
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
| 161 |
if image_examples(samples_index.classification, 3):
|
| 162 |
queue_auto_submit("clsauto")
|
| 163 |
|
|
@@ -226,18 +222,25 @@ def demo_retrieval():
|
|
| 226 |
with tab_text:
|
| 227 |
with st.form("rtextform"):
|
| 228 |
k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rtext')
|
| 229 |
-
text = st.text_input("Input Text")
|
| 230 |
-
|
| 231 |
-
if st.form_submit_button("Run with Text"):
|
| 232 |
prog.progress(0.49, "Computing Embeddings")
|
| 233 |
device = clip_model.device
|
| 234 |
tn = clip_prep(
|
| 235 |
-
text=[text
|
| 236 |
).to(device)
|
| 237 |
enc = clip_model.get_text_features(**tn).float().cpu()
|
| 238 |
prog.progress(0.7, "Running Retrieval")
|
| 239 |
retrieval_results(retrieval.retrieve(enc, k))
|
| 240 |
prog.progress(1.0, "Idle")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
with tab_img:
|
| 243 |
submit = False
|
|
@@ -246,19 +249,21 @@ def demo_retrieval():
|
|
| 246 |
pic = st.file_uploader("Upload an Image", key='rimageinput')
|
| 247 |
if st.form_submit_button("Run with Image"):
|
| 248 |
submit = True
|
|
|
|
| 249 |
sample_got = image_examples(samples_index.iret, 4, 'rimageinput')
|
| 250 |
if sample_got:
|
| 251 |
pic = sample_got
|
| 252 |
if sample_got or submit:
|
| 253 |
img = Image.open(pic)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
| 262 |
|
| 263 |
with tab_pc:
|
| 264 |
with st.form("rpcform"):
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
def reset_3d_shape_input(key):
|
| 72 |
+
# this is not working due to streamlit problems, don't use it
|
| 73 |
model_key = key + "_model"
|
| 74 |
npy_key = key + "_npy"
|
| 75 |
swap_key = key + "_swap"
|
|
|
|
| 76 |
sq(model_key, None)
|
| 77 |
sq(npy_key, None)
|
| 78 |
sq(swap_key, "Y is up (for most Objaverse shapes)")
|
|
|
|
| 120 |
return trigger
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
def demo_classification():
|
| 124 |
+
with st.form("clsform"):
|
| 125 |
+
load_data = misc_utils.input_3d_shape('cls')
|
| 126 |
+
cats = st.text_input("Custom Categories (64 max, separated with comma)")
|
| 127 |
+
cats = [a.strip() for a in cats.split(',')]
|
| 128 |
+
if len(cats) > 64:
|
| 129 |
+
st.error('Maximum 64 custom categories supported in the demo')
|
| 130 |
+
return
|
| 131 |
+
lvis_run = st.form_submit_button("Run Classification on LVIS Categories")
|
| 132 |
+
custom_run = st.form_submit_button("Run Classification on Custom Categories")
|
| 133 |
+
if lvis_run or auto_submit("clsauto"):
|
| 134 |
+
pc = load_data(prog)
|
| 135 |
+
col2 = misc_utils.render_pc(pc)
|
| 136 |
+
prog.progress(0.5, "Running Classification")
|
| 137 |
+
pred = classification.pred_lvis_sims(model_g14, pc)
|
| 138 |
+
with col2:
|
| 139 |
+
for i, (cat, sim) in zip(range(5), pred.items()):
|
| 140 |
+
st.text(cat)
|
| 141 |
+
st.caption("Similarity %.4f" % sim)
|
| 142 |
+
prog.progress(1.0, "Idle")
|
| 143 |
+
if custom_run:
|
| 144 |
+
pc = load_data(prog)
|
| 145 |
+
col2 = misc_utils.render_pc(pc)
|
| 146 |
+
prog.progress(0.5, "Computing Category Embeddings")
|
| 147 |
+
device = clip_model.device
|
| 148 |
+
tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76).to(device)
|
| 149 |
+
feats = clip_model.get_text_features(**tn).float().cpu()
|
| 150 |
+
prog.progress(0.5, "Running Classification")
|
| 151 |
+
pred = classification.pred_custom_sims(model_g14, pc, cats, feats)
|
| 152 |
+
with col2:
|
| 153 |
+
for i, (cat, sim) in zip(range(5), pred.items()):
|
| 154 |
+
st.text(cat)
|
| 155 |
+
st.caption("Similarity %.4f" % sim)
|
| 156 |
+
prog.progress(1.0, "Idle")
|
| 157 |
if image_examples(samples_index.classification, 3):
|
| 158 |
queue_auto_submit("clsauto")
|
| 159 |
|
|
|
|
| 222 |
with tab_text:
|
| 223 |
with st.form("rtextform"):
|
| 224 |
k = st.slider("Shapes to Retrieve", 1, 100, 16, key='rtext')
|
| 225 |
+
text = st.text_input("Input Text", key="inputrtext")
|
| 226 |
+
if st.form_submit_button("Run with Text") or auto_submit("rtextauto"):
|
|
|
|
| 227 |
prog.progress(0.49, "Computing Embeddings")
|
| 228 |
device = clip_model.device
|
| 229 |
tn = clip_prep(
|
| 230 |
+
text=[text], return_tensors='pt', truncation=True, max_length=76
|
| 231 |
).to(device)
|
| 232 |
enc = clip_model.get_text_features(**tn).float().cpu()
|
| 233 |
prog.progress(0.7, "Running Retrieval")
|
| 234 |
retrieval_results(retrieval.retrieve(enc, k))
|
| 235 |
prog.progress(1.0, "Idle")
|
| 236 |
+
picked_sample = st.selectbox("Examples", ["Select..."] + samples_index.retrieval_texts)
|
| 237 |
+
text_last_example = st.session_state.get('text_last_example', None)
|
| 238 |
+
if text_last_example is None:
|
| 239 |
+
st.session_state.text_last_example = picked_sample
|
| 240 |
+
elif text_last_example != picked_sample and picked_sample != "Select...":
|
| 241 |
+
st.session_state.text_last_example = picked_sample
|
| 242 |
+
sq("inputrtext", picked_sample)
|
| 243 |
+
queue_auto_submit("rtextauto")
|
| 244 |
|
| 245 |
with tab_img:
|
| 246 |
submit = False
|
|
|
|
| 249 |
pic = st.file_uploader("Upload an Image", key='rimageinput')
|
| 250 |
if st.form_submit_button("Run with Image"):
|
| 251 |
submit = True
|
| 252 |
+
results_container = st.container()
|
| 253 |
sample_got = image_examples(samples_index.iret, 4, 'rimageinput')
|
| 254 |
if sample_got:
|
| 255 |
pic = sample_got
|
| 256 |
if sample_got or submit:
|
| 257 |
img = Image.open(pic)
|
| 258 |
+
with results_container:
|
| 259 |
+
st.image(img)
|
| 260 |
+
prog.progress(0.49, "Computing Embeddings")
|
| 261 |
+
device = clip_model.device
|
| 262 |
+
tn = clip_prep(images=[img], return_tensors="pt").to(device)
|
| 263 |
+
enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
|
| 264 |
+
prog.progress(0.7, "Running Retrieval")
|
| 265 |
+
retrieval_results(retrieval.retrieve(enc, k))
|
| 266 |
+
prog.progress(1.0, "Idle")
|
| 267 |
|
| 268 |
with tab_pc:
|
| 269 |
with st.form("rpcform"):
|