Spaces:
Sleeping
Sleeping
add edit function
Browse files- Home.py +1 -0
- pages/Gallery.py +50 -45
Home.py
CHANGED
|
@@ -43,6 +43,7 @@ def logout():
|
|
| 43 |
st.session_state.pop('selected_dict', None)
|
| 44 |
st.session_state.pop('score_weights', None)
|
| 45 |
st.session_state.pop('gallery_state', None)
|
|
|
|
| 46 |
st.session_state.pop('progress', None)
|
| 47 |
st.session_state.pop('gallery_focus', None)
|
| 48 |
st.session_state.pop('assigned_rank_mode', None)
|
|
|
|
| 43 |
st.session_state.pop('selected_dict', None)
|
| 44 |
st.session_state.pop('score_weights', None)
|
| 45 |
st.session_state.pop('gallery_state', None)
|
| 46 |
+
st.session_state.pop('edit_state', None)
|
| 47 |
st.session_state.pop('progress', None)
|
| 48 |
st.session_state.pop('gallery_focus', None)
|
| 49 |
st.session_state.pop('assigned_rank_mode', None)
|
pages/Gallery.py
CHANGED
|
@@ -300,7 +300,7 @@ class GalleryApp:
|
|
| 300 |
# save tag to session state on change
|
| 301 |
tag = st.radio('Select a tag', prompt_tags, index=tag_focus_idx, horizontal=True, key='tag', label_visibility='collapsed')
|
| 302 |
|
| 303 |
-
print('current state: ', st.session_state.gallery_state)
|
| 304 |
|
| 305 |
if st.session_state.gallery_state == 'graph':
|
| 306 |
|
|
@@ -364,13 +364,12 @@ class GalleryApp:
|
|
| 364 |
if has_selection:
|
| 365 |
checkout = st.button('Check out selections', use_container_width=True, type='primary')
|
| 366 |
if checkout:
|
| 367 |
-
print('checkout')
|
| 368 |
# add focus to session state
|
| 369 |
st.session_state.gallery_focus['tag'] = tag
|
| 370 |
st.session_state.gallery_focus['prompt'] = selected_prompt
|
| 371 |
|
| 372 |
st.session_state.gallery_state = 'check out'
|
| 373 |
-
print(st.session_state.gallery_state)
|
| 374 |
st.experimental_rerun()
|
| 375 |
else:
|
| 376 |
st.write(':orange[👇 **Select images you like below**]')
|
|
@@ -389,25 +388,6 @@ class GalleryApp:
|
|
| 389 |
|
| 390 |
self.checkout_mode(tag, items)
|
| 391 |
|
| 392 |
-
|
| 393 |
-
# items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
|
| 394 |
-
# drop=True)
|
| 395 |
-
# self.gallery_mode(prompt_id, items)
|
| 396 |
-
#
|
| 397 |
-
# with subset_selector[-1]:
|
| 398 |
-
# state_operations = st.columns([1, 1])
|
| 399 |
-
# with state_operations[0]:
|
| 400 |
-
# back = st.button('Back to 🖼️', use_container_width=True)
|
| 401 |
-
# if back:
|
| 402 |
-
# st.session_state.gallery_state[prompt_id] = 'graph'
|
| 403 |
-
# st.experimental_rerun()
|
| 404 |
-
#
|
| 405 |
-
# with state_operations[1]:
|
| 406 |
-
# forward = st.button('Check out', use_container_width=True, type='primary', on_click=self.submit_actions, args=('Continue', prompt_id))
|
| 407 |
-
# if forward:
|
| 408 |
-
# switch_page('ranking')
|
| 409 |
-
|
| 410 |
-
|
| 411 |
def graph_mode(self, prompt_id, items):
|
| 412 |
graph_cols = st.columns([3, 1])
|
| 413 |
# prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
|
|
@@ -466,10 +446,6 @@ class GalleryApp:
|
|
| 466 |
infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
|
| 467 |
st.table(infos_df)
|
| 468 |
|
| 469 |
-
# for info in infos:
|
| 470 |
-
# st.write(f"**{info}**:")
|
| 471 |
-
# st.write(item[info])
|
| 472 |
-
|
| 473 |
else:
|
| 474 |
st.info('Please click on an image to show')
|
| 475 |
|
|
@@ -525,31 +501,60 @@ class GalleryApp:
|
|
| 525 |
default_expand = True if st.session_state.gallery_focus['prompt'] == prompt else False
|
| 526 |
with st.expander(f'**{prompt}**', expanded=default_expand):
|
| 527 |
# st.caption('select info to show')
|
| 528 |
-
checkout_panel = st.columns([
|
| 529 |
with checkout_panel[0]:
|
| 530 |
pass
|
| 531 |
info = st.multiselect('Show Info',
|
| 532 |
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
|
| 533 |
'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
|
| 534 |
'nsfw_score', 'norm_nsfw'],
|
| 535 |
-
label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
st.
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
else:
|
| 554 |
# with st.form(key=f'checkout_{tag}'):
|
| 555 |
st.info('No selection under this tag')
|
|
|
|
| 300 |
# save tag to session state on change
|
| 301 |
tag = st.radio('Select a tag', prompt_tags, index=tag_focus_idx, horizontal=True, key='tag', label_visibility='collapsed')
|
| 302 |
|
| 303 |
+
# print('current state: ', st.session_state.gallery_state)
|
| 304 |
|
| 305 |
if st.session_state.gallery_state == 'graph':
|
| 306 |
|
|
|
|
| 364 |
if has_selection:
|
| 365 |
checkout = st.button('Check out selections', use_container_width=True, type='primary')
|
| 366 |
if checkout:
|
|
|
|
| 367 |
# add focus to session state
|
| 368 |
st.session_state.gallery_focus['tag'] = tag
|
| 369 |
st.session_state.gallery_focus['prompt'] = selected_prompt
|
| 370 |
|
| 371 |
st.session_state.gallery_state = 'check out'
|
| 372 |
+
# print(st.session_state.gallery_state)
|
| 373 |
st.experimental_rerun()
|
| 374 |
else:
|
| 375 |
st.write(':orange[👇 **Select images you like below**]')
|
|
|
|
| 388 |
|
| 389 |
self.checkout_mode(tag, items)
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
def graph_mode(self, prompt_id, items):
|
| 392 |
graph_cols = st.columns([3, 1])
|
| 393 |
# prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
|
|
|
|
| 446 |
infos_df = infos_df.rename(index={'model_name': 'Model', 'modelVersion_name': 'Version', 'model_download_count': 'Downloads', 'clip_score': 'Clip Score', 'mcos_score': 'mcos Score', 'nsfw_score': 'NSFW Score'})
|
| 447 |
st.table(infos_df)
|
| 448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
else:
|
| 450 |
st.info('Please click on an image to show')
|
| 451 |
|
|
|
|
| 501 |
default_expand = True if st.session_state.gallery_focus['prompt'] == prompt else False
|
| 502 |
with st.expander(f'**{prompt}**', expanded=default_expand):
|
| 503 |
# st.caption('select info to show')
|
| 504 |
+
checkout_panel = st.columns([5, 3])
|
| 505 |
with checkout_panel[0]:
|
| 506 |
pass
|
| 507 |
info = st.multiselect('Show Info',
|
| 508 |
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id',
|
| 509 |
'weighted_score_sum', 'model_download_count', 'clip_score', 'mcos_score',
|
| 510 |
'nsfw_score', 'norm_nsfw'],
|
| 511 |
+
label_visibility='collapsed', key=f'info_{prompt_id}', placeholder='Select what infos to show')
|
| 512 |
+
|
| 513 |
+
with checkout_panel[-1]:
|
| 514 |
+
checkout_buttons = st.columns([1, 1, 1])
|
| 515 |
+
with checkout_buttons[0]:
|
| 516 |
+
back = st.button('Back to 🖼️', key=f'checkout_back_{prompt_id}', use_container_width=True)
|
| 517 |
+
if back:
|
| 518 |
+
st.session_state.gallery_focus['tag'] = tag
|
| 519 |
+
st.session_state.gallery_focus['prompt'] = prompt
|
| 520 |
+
print(st.session_state.gallery_focus)
|
| 521 |
+
st.session_state.gallery_state = 'graph'
|
| 522 |
+
st.experimental_rerun()
|
| 523 |
+
|
| 524 |
+
with checkout_buttons[1]:
|
| 525 |
+
# init edit state
|
| 526 |
+
if 'edit_state' not in st.session_state:
|
| 527 |
+
st.session_state.edit_state = False
|
| 528 |
+
|
| 529 |
+
if not st.session_state.edit_state:
|
| 530 |
+
edit = st.button('Edit', key=f'checkout_edit_{prompt_id}', use_container_width=True)
|
| 531 |
+
if edit:
|
| 532 |
+
st.session_state.edit_state = True
|
| 533 |
+
st.experimental_rerun()
|
| 534 |
+
else:
|
| 535 |
+
done = st.button('Done', key=f'checkout_done_{prompt_id}', use_container_width=True)
|
| 536 |
+
if done:
|
| 537 |
+
st.session_state.selected_dict[prompt_id] = []
|
| 538 |
+
for key in st.session_state:
|
| 539 |
+
|
| 540 |
+
# update selected_dict with edited selection
|
| 541 |
+
keys = key.split('_')
|
| 542 |
+
if keys[0] == 'select' and keys[1] == str(prompt_id):
|
| 543 |
+
if st.session_state[key]:
|
| 544 |
+
st.session_state.selected_dict[prompt_id].append(int(keys[2]))
|
| 545 |
+
|
| 546 |
+
st.session_state.edit_state = False
|
| 547 |
+
st.experimental_rerun()
|
| 548 |
+
|
| 549 |
+
with checkout_buttons[-1]:
|
| 550 |
+
proceed = st.button('Proceed ➡️', key=f'checkout_proceed_{prompt_id}', use_container_width=True,
|
| 551 |
+
type='primary')
|
| 552 |
+
if proceed:
|
| 553 |
+
st.session_state.gallery_focus['tag'] = tag
|
| 554 |
+
st.session_state.gallery_focus['prompt'] = prompt
|
| 555 |
+
switch_page('ranking')
|
| 556 |
+
|
| 557 |
+
self.gallery_standard(items[items['prompt_id'] == prompt_id].reset_index(drop=True), 4, info, show_checkbox=st.session_state.edit_state)
|
| 558 |
else:
|
| 559 |
# with st.form(key=f'checkout_{tag}'):
|
| 560 |
st.info('No selection under this tag')
|