Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
import os
|
| 3 |
-
import shutil
|
| 4 |
import glob
|
| 5 |
import base64
|
| 6 |
import streamlit as st
|
|
@@ -31,7 +30,7 @@ st.set_page_config(
|
|
| 31 |
initial_sidebar_state="expanded",
|
| 32 |
menu_items={
|
| 33 |
'Get Help': 'https://huggingface.co/awacke1',
|
| 34 |
-
'Report a
|
| 35 |
'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
|
| 36 |
}
|
| 37 |
)
|
|
@@ -177,9 +176,9 @@ class DiffusionBuilder:
|
|
| 177 |
total_loss = 0
|
| 178 |
for batch in dataloader:
|
| 179 |
optimizer.zero_grad()
|
| 180 |
-
image = batch["image"].to(self.pipeline.device)
|
| 181 |
-
text = batch["text"]
|
| 182 |
-
latents = self.pipeline.vae.encode(image).latent_dist.sample()
|
| 183 |
noise = torch.randn_like(latents)
|
| 184 |
timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
|
| 185 |
noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
|
|
@@ -220,9 +219,10 @@ def get_model_files(model_type="causal_lm"):
|
|
| 220 |
def get_gallery_files(file_types):
|
| 221 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 222 |
|
|
|
|
| 223 |
def mock_search(query: str) -> str:
|
| 224 |
if "superhero" in query.lower():
|
| 225 |
-
return "Latest trends
|
| 226 |
return "No relevant results found."
|
| 227 |
|
| 228 |
class PartyPlannerAgent:
|
|
@@ -291,6 +291,7 @@ if media_files:
|
|
| 291 |
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
| 292 |
with cols[idx % 2]:
|
| 293 |
st.image(Image.open(file), caption=file, use_column_width=True)
|
|
|
|
| 294 |
|
| 295 |
st.sidebar.subheader("Model Management 🗂️")
|
| 296 |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
|
|
@@ -351,7 +352,7 @@ with tab2:
|
|
| 351 |
f.write(img.getvalue())
|
| 352 |
st.session_state['cam0_frames'].append(filename)
|
| 353 |
logger.info(f"Saved frame {i} from Camera 0: {filename}")
|
| 354 |
-
time.sleep(1.0 / slice_count)
|
| 355 |
st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
|
| 356 |
update_gallery()
|
| 357 |
for frame in st.session_state['cam0_frames']:
|
|
@@ -379,7 +380,7 @@ with tab2:
|
|
| 379 |
f.write(img.getvalue())
|
| 380 |
st.session_state['cam1_frames'].append(filename)
|
| 381 |
logger.info(f"Saved frame {i} from Camera 1: {filename}")
|
| 382 |
-
time.sleep(1.0 / slice_count)
|
| 383 |
st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
|
| 384 |
update_gallery()
|
| 385 |
for frame in st.session_state['cam1_frames']:
|
|
@@ -420,6 +421,13 @@ with tab3:
|
|
| 420 |
zip_path = f"{new_config.model_path}.zip"
|
| 421 |
zip_directory(new_config.model_path, zip_path)
|
| 422 |
st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
with tab4:
|
| 425 |
st.header("Test Titan 🧪")
|
|
@@ -456,4 +464,11 @@ with tab5:
|
|
| 456 |
st.dataframe(plan_df)
|
| 457 |
for _, row in plan_df.iterrows():
|
| 458 |
image = agent.generate(row["Image Idea"])
|
| 459 |
-
st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
import os
|
|
|
|
| 3 |
import glob
|
| 4 |
import base64
|
| 5 |
import streamlit as st
|
|
|
|
| 30 |
initial_sidebar_state="expanded",
|
| 31 |
menu_items={
|
| 32 |
'Get Help': 'https://huggingface.co/awacke1',
|
| 33 |
+
'Report a Bug': 'https://huggingface.co/spaces/awacke1',
|
| 34 |
'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
|
| 35 |
}
|
| 36 |
)
|
|
|
|
| 176 |
total_loss = 0
|
| 177 |
for batch in dataloader:
|
| 178 |
optimizer.zero_grad()
|
| 179 |
+
image = batch["image"][0].to(self.pipeline.device)
|
| 180 |
+
text = batch["text"][0]
|
| 181 |
+
latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
|
| 182 |
noise = torch.randn_like(latents)
|
| 183 |
timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
|
| 184 |
noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
| 219 |
def get_gallery_files(file_types):
|
| 220 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 221 |
|
| 222 |
+
# Mock Search Tool for RAG
|
| 223 |
def mock_search(query: str) -> str:
|
| 224 |
if "superhero" in query.lower():
|
| 225 |
+
return "Latest trends: Gold-plated Batman statues, VR superhero battles."
|
| 226 |
return "No relevant results found."
|
| 227 |
|
| 228 |
class PartyPlannerAgent:
|
|
|
|
| 291 |
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
| 292 |
with cols[idx % 2]:
|
| 293 |
st.image(Image.open(file), caption=file, use_column_width=True)
|
| 294 |
+
st.markdown(get_download_link(file, "image/png", "Download Image"), unsafe_allow_html=True)
|
| 295 |
|
| 296 |
st.sidebar.subheader("Model Management 🗂️")
|
| 297 |
model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
|
|
|
|
| 352 |
f.write(img.getvalue())
|
| 353 |
st.session_state['cam0_frames'].append(filename)
|
| 354 |
logger.info(f"Saved frame {i} from Camera 0: {filename}")
|
| 355 |
+
time.sleep(1.0 / slice_count)
|
| 356 |
st.session_state['captured_images'].extend(st.session_state['cam0_frames'])
|
| 357 |
update_gallery()
|
| 358 |
for frame in st.session_state['cam0_frames']:
|
|
|
|
| 380 |
f.write(img.getvalue())
|
| 381 |
st.session_state['cam1_frames'].append(filename)
|
| 382 |
logger.info(f"Saved frame {i} from Camera 1: {filename}")
|
| 383 |
+
time.sleep(1.0 / slice_count)
|
| 384 |
st.session_state['captured_images'].extend(st.session_state['cam1_frames'])
|
| 385 |
update_gallery()
|
| 386 |
for frame in st.session_state['cam1_frames']:
|
|
|
|
| 421 |
zip_path = f"{new_config.model_path}.zip"
|
| 422 |
zip_directory(new_config.model_path, zip_path)
|
| 423 |
st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
|
| 424 |
+
csv_path = f"sft_dataset_{int(time.time())}.csv"
|
| 425 |
+
with open(csv_path, "w", newline="") as f:
|
| 426 |
+
writer = csv.writer(f)
|
| 427 |
+
writer.writerow(["image", "text"])
|
| 428 |
+
for _, row in edited_data.iterrows():
|
| 429 |
+
writer.writerow([row["image"], row["text"]])
|
| 430 |
+
st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
|
| 431 |
|
| 432 |
with tab4:
|
| 433 |
st.header("Test Titan 🧪")
|
|
|
|
| 464 |
st.dataframe(plan_df)
|
| 465 |
for _, row in plan_df.iterrows():
|
| 466 |
image = agent.generate(row["Image Idea"])
|
| 467 |
+
st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
|
| 468 |
+
|
| 469 |
+
# Main App
|
| 470 |
+
st.sidebar.subheader("Action Logs 📜")
|
| 471 |
+
log_container = st.sidebar.empty()
|
| 472 |
+
with log_container:
|
| 473 |
+
for record in logger.handlers[0].buffer:
|
| 474 |
+
st.write(f"{record.asctime} - {record.levelname} - {record.message}")
|