Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,11 +9,11 @@ import pandas as pd
|
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
| 13 |
from diffusers import StableDiffusionPipeline
|
| 14 |
from torch.utils.data import Dataset, DataLoader
|
| 15 |
import csv
|
| 16 |
-
import fitz
|
| 17 |
import requests
|
| 18 |
from PIL import Image
|
| 19 |
import cv2
|
|
@@ -28,10 +28,7 @@ import zipfile
|
|
| 28 |
import math
|
| 29 |
import random
|
| 30 |
import re
|
| 31 |
-
from datetime import datetime
|
| 32 |
-
import pytz
|
| 33 |
|
| 34 |
-
# Logging setup with custom buffer
|
| 35 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 36 |
logger = logging.getLogger(__name__)
|
| 37 |
log_records = []
|
|
@@ -42,7 +39,6 @@ class LogCaptureHandler(logging.Handler):
|
|
| 42 |
|
| 43 |
logger.addHandler(LogCaptureHandler())
|
| 44 |
|
| 45 |
-
# Page Configuration
|
| 46 |
st.set_page_config(
|
| 47 |
page_title="AI Vision & SFT Titans 🚀",
|
| 48 |
page_icon="🤖",
|
|
@@ -55,9 +51,8 @@ st.set_page_config(
|
|
| 55 |
}
|
| 56 |
)
|
| 57 |
|
| 58 |
-
# Initialize st.session_state
|
| 59 |
if 'history' not in st.session_state:
|
| 60 |
-
st.session_state['history'] = []
|
| 61 |
if 'builder' not in st.session_state:
|
| 62 |
st.session_state['builder'] = None
|
| 63 |
if 'model_loaded' not in st.session_state:
|
|
@@ -68,10 +63,7 @@ if 'pdf_checkboxes' not in st.session_state:
|
|
| 68 |
st.session_state['pdf_checkboxes'] = {}
|
| 69 |
if 'downloaded_pdfs' not in st.session_state:
|
| 70 |
st.session_state['downloaded_pdfs'] = {}
|
| 71 |
-
if 'captured_images' not in st.session_state:
|
| 72 |
-
st.session_state['captured_images'] = []
|
| 73 |
|
| 74 |
-
# Model Configuration Classes
|
| 75 |
@dataclass
|
| 76 |
class ModelConfig:
|
| 77 |
name: str
|
|
@@ -88,12 +80,11 @@ class DiffusionConfig:
|
|
| 88 |
name: str
|
| 89 |
base_model: str
|
| 90 |
size: str
|
| 91 |
-
domain: Optional[str] = None
|
| 92 |
@property
|
| 93 |
def model_path(self):
|
| 94 |
return f"diffusion_models/{self.name}"
|
| 95 |
|
| 96 |
-
# Datasets
|
| 97 |
class SFTDataset(Dataset):
|
| 98 |
def __init__(self, data, tokenizer, max_length=128):
|
| 99 |
self.data = data
|
|
@@ -132,7 +123,6 @@ class TinyDiffusionDataset(Dataset):
|
|
| 132 |
def __getitem__(self, idx):
|
| 133 |
return self.images[idx]
|
| 134 |
|
| 135 |
-
# Custom Tiny Diffusion Model
|
| 136 |
class TinyUNet(nn.Module):
|
| 137 |
def __init__(self, in_channels=3, out_channels=3):
|
| 138 |
super(TinyUNet, self).__init__()
|
|
@@ -205,7 +195,6 @@ class TinyDiffusion:
|
|
| 205 |
upscaled = torch.clamp(upscaled * 255, 0, 255).byte()
|
| 206 |
return Image.fromarray(upscaled.squeeze(0).permute(1, 2, 0).cpu().numpy())
|
| 207 |
|
| 208 |
-
# Model Builders
|
| 209 |
class ModelBuilder:
|
| 210 |
def __init__(self):
|
| 211 |
self.config = None
|
|
@@ -316,10 +305,8 @@ class DiffusionBuilder:
|
|
| 316 |
def generate(self, prompt: str):
|
| 317 |
return self.pipeline(prompt, num_inference_steps=20).images[0]
|
| 318 |
|
| 319 |
-
# Utility Functions
|
| 320 |
def generate_filename(sequence, ext="png"):
|
| 321 |
-
|
| 322 |
-
timestamp = datetime.now(central).strftime("%d%m%Y%H%M%S%p")
|
| 323 |
return f"{sequence}_{timestamp}.{ext}"
|
| 324 |
|
| 325 |
def pdf_url_to_filename(url):
|
|
@@ -342,7 +329,7 @@ def get_model_files(model_type="causal_lm"):
|
|
| 342 |
path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
|
| 343 |
return [d for d in glob.glob(path) if os.path.isdir(d)]
|
| 344 |
|
| 345 |
-
def get_gallery_files(file_types=["png"
|
| 346 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 347 |
|
| 348 |
def get_pdf_files():
|
|
@@ -360,33 +347,6 @@ def download_pdf(url, output_path):
|
|
| 360 |
logger.error(f"Failed to download {url}: {e}")
|
| 361 |
return False
|
| 362 |
|
| 363 |
-
# Model Loaders for New App Features
|
| 364 |
-
def load_ocr_qwen2vl():
|
| 365 |
-
model_id = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
|
| 366 |
-
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
| 367 |
-
model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
|
| 368 |
-
return processor, model
|
| 369 |
-
|
| 370 |
-
def load_ocr_trocr():
|
| 371 |
-
model_id = "microsoft/trocr-small-handwritten"
|
| 372 |
-
processor = TrOCRProcessor.from_pretrained(model_id)
|
| 373 |
-
model = VisionEncoderDecoderModel.from_pretrained(model_id, torch_dtype=torch.float32).to("cpu").eval()
|
| 374 |
-
return processor, model
|
| 375 |
-
|
| 376 |
-
def load_image_gen():
|
| 377 |
-
model_id = "OFA-Sys/small-stable-diffusion-v0"
|
| 378 |
-
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32).to("cpu")
|
| 379 |
-
return pipeline
|
| 380 |
-
|
| 381 |
-
def load_line_drawer():
|
| 382 |
-
def edge_detection(image):
|
| 383 |
-
img_np = np.array(image.convert("RGB"))
|
| 384 |
-
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
| 385 |
-
edges = cv2.Canny(gray, 100, 200)
|
| 386 |
-
return Image.fromarray(edges)
|
| 387 |
-
return edge_detection
|
| 388 |
-
|
| 389 |
-
# Async Processing Functions
|
| 390 |
async def process_pdf_snapshot(pdf_path, mode="single"):
|
| 391 |
start_time = time.time()
|
| 392 |
status = st.empty()
|
|
@@ -423,31 +383,17 @@ async def process_pdf_snapshot(pdf_path, mode="single"):
|
|
| 423 |
status.error(f"Failed to process PDF: {str(e)}")
|
| 424 |
return []
|
| 425 |
|
| 426 |
-
async def process_ocr(image,
|
| 427 |
start_time = time.time()
|
| 428 |
status = st.empty()
|
| 429 |
-
status.text(
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 434 |
-
inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True).to("cpu")
|
| 435 |
-
outputs = model.generate(**inputs, max_new_tokens=1024)
|
| 436 |
-
result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 437 |
-
elif model_name == "TrOCR-Small":
|
| 438 |
-
processor, model = load_ocr_trocr()
|
| 439 |
-
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
|
| 440 |
-
outputs = model.generate(pixel_values)
|
| 441 |
-
result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 442 |
-
else: # GOT-OCR2_0 (original from Backup 6)
|
| 443 |
-
tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
|
| 444 |
-
model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
|
| 445 |
-
result = model.chat(tokenizer, image, ocr_type='ocr')
|
| 446 |
elapsed = int(time.time() - start_time)
|
| 447 |
-
status.text(f"
|
| 448 |
async with aiofiles.open(output_file, "w") as f:
|
| 449 |
await f.write(result)
|
| 450 |
-
st.session_state['captured_images'].append(output_file)
|
| 451 |
update_gallery()
|
| 452 |
return result
|
| 453 |
|
|
@@ -455,29 +401,29 @@ async def process_image_gen(prompt, output_file):
|
|
| 455 |
start_time = time.time()
|
| 456 |
status = st.empty()
|
| 457 |
status.text("Processing Image Gen... (0s)")
|
| 458 |
-
pipeline =
|
| 459 |
gen_image = pipeline(prompt, num_inference_steps=20).images[0]
|
| 460 |
elapsed = int(time.time() - start_time)
|
| 461 |
status.text(f"Image Gen completed in {elapsed}s!")
|
| 462 |
gen_image.save(output_file)
|
| 463 |
-
st.session_state['captured_images'].append(output_file)
|
| 464 |
update_gallery()
|
| 465 |
return gen_image
|
| 466 |
|
| 467 |
-
async def
|
| 468 |
start_time = time.time()
|
| 469 |
status = st.empty()
|
| 470 |
-
status.text("
|
| 471 |
-
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
| 473 |
elapsed = int(time.time() - start_time)
|
| 474 |
-
status.text(f"
|
| 475 |
-
|
| 476 |
-
st.session_state['captured_images'].append(output_file)
|
| 477 |
update_gallery()
|
| 478 |
-
return
|
| 479 |
|
| 480 |
-
# Mock Search Tool for RAG
|
| 481 |
def mock_search(query: str) -> str:
|
| 482 |
if "superhero" in query.lower():
|
| 483 |
return "Latest trends: Gold-plated Batman statues, VR superhero battles."
|
|
@@ -493,7 +439,6 @@ def mock_duckduckgo_search(query: str) -> str:
|
|
| 493 |
"""
|
| 494 |
return "No relevant results found."
|
| 495 |
|
| 496 |
-
# Agent Classes
|
| 497 |
class PartyPlannerAgent:
|
| 498 |
def __init__(self, model, tokenizer):
|
| 499 |
self.model = model
|
|
@@ -558,26 +503,19 @@ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_
|
|
| 558 |
flight_time = (actual_distance / cruising_speed_kmh) + 1.0
|
| 559 |
return round(flight_time, 2)
|
| 560 |
|
| 561 |
-
# Main App
|
| 562 |
st.title("AI Vision & SFT Titans 🚀")
|
| 563 |
|
| 564 |
-
# Sidebar
|
| 565 |
st.sidebar.header("Captured Files 📜")
|
| 566 |
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2)
|
| 567 |
def update_gallery():
|
| 568 |
-
media_files = get_gallery_files(["png"
|
| 569 |
pdf_files = get_pdf_files()
|
| 570 |
if media_files or pdf_files:
|
| 571 |
-
st.sidebar.subheader("Images
|
| 572 |
cols = st.sidebar.columns(2)
|
| 573 |
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
| 574 |
with cols[idx % 2]:
|
| 575 |
-
|
| 576 |
-
st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
|
| 577 |
-
elif file.endswith(".txt"):
|
| 578 |
-
with open(file, "r") as f:
|
| 579 |
-
content = f.read()
|
| 580 |
-
st.text(content[:50] + "..." if len(content) > 50 else content, help=file)
|
| 581 |
st.sidebar.subheader("PDF Downloads 📖")
|
| 582 |
for pdf_file in pdf_files[:gallery_size * 2]:
|
| 583 |
st.markdown(get_download_link(pdf_file, "application/pdf", f"📥 Grab {os.path.basename(pdf_file)}"), unsafe_allow_html=True)
|
|
@@ -607,11 +545,9 @@ with history_container:
|
|
| 607 |
for entry in st.session_state['history'][-gallery_size * 2:]:
|
| 608 |
st.write(entry)
|
| 609 |
|
| 610 |
-
|
| 611 |
-
tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8, tab9, tab10 = st.tabs([
|
| 612 |
"Camera Snap 📷", "Download PDFs 📥", "Build Titan 🌱", "Fine-Tune Titan 🔧",
|
| 613 |
-
"Test Titan 🧪", "Agentic RAG Party 🌐", "Test OCR 🔍", "Test Image Gen 🎨",
|
| 614 |
-
"Test Line Drawings ✏️", "Custom Diffusion 🎨🤓"
|
| 615 |
])
|
| 616 |
|
| 617 |
with tab1:
|
|
@@ -622,55 +558,43 @@ with tab1:
|
|
| 622 |
cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
|
| 623 |
if cam0_img:
|
| 624 |
filename = generate_filename("cam0")
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
|
|
|
| 632 |
with cols[1]:
|
| 633 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
| 634 |
if cam1_img:
|
| 635 |
filename = generate_filename("cam1")
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
st.subheader("Burst Capture")
|
| 645 |
-
slice_count = st.number_input("Number of Frames", min_value=1, max_value=20, value=10, key="burst_count")
|
| 646 |
-
if st.button("Start Burst Capture 📸"):
|
| 647 |
-
st.session_state['burst_frames'] = []
|
| 648 |
-
placeholder = st.empty()
|
| 649 |
-
for i in range(slice_count):
|
| 650 |
-
with placeholder.container():
|
| 651 |
-
st.write(f"Capturing frame {i+1}/{slice_count}...")
|
| 652 |
-
img = st.camera_input(f"Frame {i}", key=f"burst_{i}_{time.time()}")
|
| 653 |
-
if img:
|
| 654 |
-
filename = generate_filename(f"burst_{i}")
|
| 655 |
-
if filename not in st.session_state['captured_images']:
|
| 656 |
-
with open(filename, "wb") as f:
|
| 657 |
-
f.write(img.getvalue())
|
| 658 |
-
st.session_state['burst_frames'].append(filename)
|
| 659 |
-
logger.info(f"Saved burst frame {i}: {filename}")
|
| 660 |
-
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
| 661 |
-
time.sleep(0.5)
|
| 662 |
-
st.session_state['captured_images'].extend([f for f in st.session_state['burst_frames'] if f not in st.session_state['captured_images']])
|
| 663 |
-
update_gallery()
|
| 664 |
-
placeholder.success(f"Captured {len(st.session_state['burst_frames'])} frames!")
|
| 665 |
|
| 666 |
with tab2:
|
| 667 |
st.header("Download PDFs 📥")
|
| 668 |
if st.button("Examples 📚"):
|
| 669 |
example_urls = [
|
| 670 |
-
"https://arxiv.org/pdf/2308.03892",
|
| 671 |
-
"https://arxiv.org/pdf/
|
| 672 |
-
"https://arxiv.org/pdf/
|
| 673 |
-
"https://arxiv.org/pdf/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
]
|
| 675 |
st.session_state['pdf_urls'] = "\n".join(example_urls)
|
| 676 |
|
|
@@ -716,7 +640,9 @@ with tab2:
|
|
| 716 |
st.image(img, caption=os.path.basename(pdf_path), use_container_width=True)
|
| 717 |
checkbox_key = f"pdf_{pdf_path}"
|
| 718 |
st.session_state['pdf_checkboxes'][checkbox_key] = st.checkbox(
|
| 719 |
-
"Use for SFT/Input",
|
|
|
|
|
|
|
| 720 |
)
|
| 721 |
st.markdown(get_download_link(pdf_path, "application/pdf", "Snag It! 📥"), unsafe_allow_html=True)
|
| 722 |
if st.button("Zap It! 🗑️", key=f"delete_{pdf_path}"):
|
|
@@ -916,12 +842,13 @@ with tab7:
|
|
| 916 |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 917 |
doc.close()
|
| 918 |
st.image(image, caption="Input Image", use_container_width=True)
|
| 919 |
-
ocr_model = st.selectbox("Select OCR Model", ["Qwen2-VL-OCR-2B", "TrOCR-Small", "GOT-OCR2_0"], key="ocr_model_select")
|
| 920 |
-
prompt = st.text_area("Prompt", "Extract text from the image", key="ocr_prompt")
|
| 921 |
if st.button("Run OCR 🚀", key="ocr_run"):
|
| 922 |
output_file = generate_filename("ocr_output", "txt")
|
| 923 |
st.session_state['processing']['ocr'] = True
|
| 924 |
-
result = asyncio.run(process_ocr(image,
|
|
|
|
|
|
|
|
|
|
| 925 |
st.text_area("OCR Result", result, height=200, key="ocr_result")
|
| 926 |
st.success(f"OCR output saved to {output_file}")
|
| 927 |
st.session_state['processing']['ocr'] = False
|
|
@@ -949,6 +876,9 @@ with tab8:
|
|
| 949 |
output_file = generate_filename("gen_output", "png")
|
| 950 |
st.session_state['processing']['gen'] = True
|
| 951 |
result = asyncio.run(process_image_gen(prompt, output_file))
|
|
|
|
|
|
|
|
|
|
| 952 |
st.image(result, caption="Generated Image", use_container_width=True)
|
| 953 |
st.success(f"Image saved to {output_file}")
|
| 954 |
st.session_state['processing']['gen'] = False
|
|
@@ -956,32 +886,6 @@ with tab8:
|
|
| 956 |
st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
|
| 957 |
|
| 958 |
with tab9:
|
| 959 |
-
st.header("Test Line Drawings ✏️")
|
| 960 |
-
captured_files = get_gallery_files(["png"])
|
| 961 |
-
selected_pdfs = [path for key, path in st.session_state['downloaded_pdfs'].items() if st.session_state['pdf_checkboxes'].get(f"pdf_{path}", False)]
|
| 962 |
-
all_files = captured_files + selected_pdfs
|
| 963 |
-
if all_files:
|
| 964 |
-
selected_file = st.selectbox("Select Image or PDF", all_files, key="line_select")
|
| 965 |
-
if selected_file:
|
| 966 |
-
if selected_file.endswith('.png'):
|
| 967 |
-
image = Image.open(selected_file)
|
| 968 |
-
else:
|
| 969 |
-
doc = fitz.open(selected_file)
|
| 970 |
-
pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
|
| 971 |
-
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 972 |
-
doc.close()
|
| 973 |
-
st.image(image, caption="Input Image", use_container_width=True)
|
| 974 |
-
if st.button("Run Line Drawing 🚀", key="line_run"):
|
| 975 |
-
output_file = generate_filename("line_output", "png")
|
| 976 |
-
st.session_state['processing']['line'] = True
|
| 977 |
-
result = asyncio.run(process_line_drawing(image, output_file))
|
| 978 |
-
st.image(result, caption="Line Drawing", use_container_width=True)
|
| 979 |
-
st.success(f"Line drawing saved to {output_file}")
|
| 980 |
-
st.session_state['processing']['line'] = False
|
| 981 |
-
else:
|
| 982 |
-
st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
|
| 983 |
-
|
| 984 |
-
with tab10:
|
| 985 |
st.header("Custom Diffusion 🎨🤓")
|
| 986 |
st.write("Unleash your inner artist with our tiny diffusion models!")
|
| 987 |
captured_files = get_gallery_files(["png"])
|
|
@@ -1027,5 +931,4 @@ with tab10:
|
|
| 1027 |
else:
|
| 1028 |
st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
|
| 1029 |
|
| 1030 |
-
# Initial Gallery Update
|
| 1031 |
update_gallery()
|
|
|
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
| 13 |
from diffusers import StableDiffusionPipeline
|
| 14 |
from torch.utils.data import Dataset, DataLoader
|
| 15 |
import csv
|
| 16 |
+
import fitz
|
| 17 |
import requests
|
| 18 |
from PIL import Image
|
| 19 |
import cv2
|
|
|
|
| 28 |
import math
|
| 29 |
import random
|
| 30 |
import re
|
|
|
|
|
|
|
| 31 |
|
|
|
|
| 32 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 33 |
logger = logging.getLogger(__name__)
|
| 34 |
log_records = []
|
|
|
|
| 39 |
|
| 40 |
logger.addHandler(LogCaptureHandler())
|
| 41 |
|
|
|
|
| 42 |
st.set_page_config(
|
| 43 |
page_title="AI Vision & SFT Titans 🚀",
|
| 44 |
page_icon="🤖",
|
|
|
|
| 51 |
}
|
| 52 |
)
|
| 53 |
|
|
|
|
| 54 |
if 'history' not in st.session_state:
|
| 55 |
+
st.session_state['history'] = []
|
| 56 |
if 'builder' not in st.session_state:
|
| 57 |
st.session_state['builder'] = None
|
| 58 |
if 'model_loaded' not in st.session_state:
|
|
|
|
| 63 |
st.session_state['pdf_checkboxes'] = {}
|
| 64 |
if 'downloaded_pdfs' not in st.session_state:
|
| 65 |
st.session_state['downloaded_pdfs'] = {}
|
|
|
|
|
|
|
| 66 |
|
|
|
|
| 67 |
@dataclass
|
| 68 |
class ModelConfig:
|
| 69 |
name: str
|
|
|
|
| 80 |
name: str
|
| 81 |
base_model: str
|
| 82 |
size: str
|
| 83 |
+
domain: Optional[str] = None
|
| 84 |
@property
|
| 85 |
def model_path(self):
|
| 86 |
return f"diffusion_models/{self.name}"
|
| 87 |
|
|
|
|
| 88 |
class SFTDataset(Dataset):
|
| 89 |
def __init__(self, data, tokenizer, max_length=128):
|
| 90 |
self.data = data
|
|
|
|
| 123 |
def __getitem__(self, idx):
|
| 124 |
return self.images[idx]
|
| 125 |
|
|
|
|
| 126 |
class TinyUNet(nn.Module):
|
| 127 |
def __init__(self, in_channels=3, out_channels=3):
|
| 128 |
super(TinyUNet, self).__init__()
|
|
|
|
| 195 |
upscaled = torch.clamp(upscaled * 255, 0, 255).byte()
|
| 196 |
return Image.fromarray(upscaled.squeeze(0).permute(1, 2, 0).cpu().numpy())
|
| 197 |
|
|
|
|
| 198 |
class ModelBuilder:
|
| 199 |
def __init__(self):
|
| 200 |
self.config = None
|
|
|
|
| 305 |
def generate(self, prompt: str):
|
| 306 |
return self.pipeline(prompt, num_inference_steps=20).images[0]
|
| 307 |
|
|
|
|
| 308 |
def generate_filename(sequence, ext="png"):
|
| 309 |
+
timestamp = time.strftime("%d%m%Y%H%M%S")
|
|
|
|
| 310 |
return f"{sequence}_{timestamp}.{ext}"
|
| 311 |
|
| 312 |
def pdf_url_to_filename(url):
|
|
|
|
| 329 |
path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
|
| 330 |
return [d for d in glob.glob(path) if os.path.isdir(d)]
|
| 331 |
|
| 332 |
+
def get_gallery_files(file_types=["png"]):
|
| 333 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 334 |
|
| 335 |
def get_pdf_files():
|
|
|
|
| 347 |
logger.error(f"Failed to download {url}: {e}")
|
| 348 |
return False
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
async def process_pdf_snapshot(pdf_path, mode="single"):
|
| 351 |
start_time = time.time()
|
| 352 |
status = st.empty()
|
|
|
|
| 383 |
status.error(f"Failed to process PDF: {str(e)}")
|
| 384 |
return []
|
| 385 |
|
| 386 |
+
async def process_ocr(image, output_file):
|
| 387 |
start_time = time.time()
|
| 388 |
status = st.empty()
|
| 389 |
+
status.text("Processing GOT-OCR2_0... (0s)")
|
| 390 |
+
tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
|
| 391 |
+
model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
|
| 392 |
+
result = model.chat(tokenizer, image, ocr_type='ocr')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
elapsed = int(time.time() - start_time)
|
| 394 |
+
status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
|
| 395 |
async with aiofiles.open(output_file, "w") as f:
|
| 396 |
await f.write(result)
|
|
|
|
| 397 |
update_gallery()
|
| 398 |
return result
|
| 399 |
|
|
|
|
| 401 |
start_time = time.time()
|
| 402 |
status = st.empty()
|
| 403 |
status.text("Processing Image Gen... (0s)")
|
| 404 |
+
pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu")
|
| 405 |
gen_image = pipeline(prompt, num_inference_steps=20).images[0]
|
| 406 |
elapsed = int(time.time() - start_time)
|
| 407 |
status.text(f"Image Gen completed in {elapsed}s!")
|
| 408 |
gen_image.save(output_file)
|
|
|
|
| 409 |
update_gallery()
|
| 410 |
return gen_image
|
| 411 |
|
| 412 |
+
async def process_custom_diffusion(images, output_file, model_name):
|
| 413 |
start_time = time.time()
|
| 414 |
status = st.empty()
|
| 415 |
+
status.text(f"Training {model_name}... (0s)")
|
| 416 |
+
unet = TinyUNet()
|
| 417 |
+
diffusion = TinyDiffusion(unet)
|
| 418 |
+
diffusion.train(images)
|
| 419 |
+
gen_image = diffusion.generate()
|
| 420 |
+
upscaled_image = diffusion.upscale(gen_image, scale_factor=2)
|
| 421 |
elapsed = int(time.time() - start_time)
|
| 422 |
+
status.text(f"{model_name} completed in {elapsed}s!")
|
| 423 |
+
upscaled_image.save(output_file)
|
|
|
|
| 424 |
update_gallery()
|
| 425 |
+
return upscaled_image
|
| 426 |
|
|
|
|
| 427 |
def mock_search(query: str) -> str:
|
| 428 |
if "superhero" in query.lower():
|
| 429 |
return "Latest trends: Gold-plated Batman statues, VR superhero battles."
|
|
|
|
| 439 |
"""
|
| 440 |
return "No relevant results found."
|
| 441 |
|
|
|
|
| 442 |
class PartyPlannerAgent:
|
| 443 |
def __init__(self, model, tokenizer):
|
| 444 |
self.model = model
|
|
|
|
| 503 |
flight_time = (actual_distance / cruising_speed_kmh) + 1.0
|
| 504 |
return round(flight_time, 2)
|
| 505 |
|
|
|
|
| 506 |
st.title("AI Vision & SFT Titans 🚀")
|
| 507 |
|
|
|
|
| 508 |
st.sidebar.header("Captured Files 📜")
|
| 509 |
gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2)
|
| 510 |
def update_gallery():
|
| 511 |
+
media_files = get_gallery_files(["png"])
|
| 512 |
pdf_files = get_pdf_files()
|
| 513 |
if media_files or pdf_files:
|
| 514 |
+
st.sidebar.subheader("Images 📸")
|
| 515 |
cols = st.sidebar.columns(2)
|
| 516 |
for idx, file in enumerate(media_files[:gallery_size * 2]):
|
| 517 |
with cols[idx % 2]:
|
| 518 |
+
st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
st.sidebar.subheader("PDF Downloads 📖")
|
| 520 |
for pdf_file in pdf_files[:gallery_size * 2]:
|
| 521 |
st.markdown(get_download_link(pdf_file, "application/pdf", f"📥 Grab {os.path.basename(pdf_file)}"), unsafe_allow_html=True)
|
|
|
|
| 545 |
for entry in st.session_state['history'][-gallery_size * 2:]:
|
| 546 |
st.write(entry)
|
| 547 |
|
| 548 |
+
tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8, tab9 = st.tabs([
|
|
|
|
| 549 |
"Camera Snap 📷", "Download PDFs 📥", "Build Titan 🌱", "Fine-Tune Titan 🔧",
|
| 550 |
+
"Test Titan 🧪", "Agentic RAG Party 🌐", "Test OCR 🔍", "Test Image Gen 🎨", "Custom Diffusion 🎨🤓"
|
|
|
|
| 551 |
])
|
| 552 |
|
| 553 |
with tab1:
|
|
|
|
| 558 |
cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
|
| 559 |
if cam0_img:
|
| 560 |
filename = generate_filename("cam0")
|
| 561 |
+
with open(filename, "wb") as f:
|
| 562 |
+
f.write(cam0_img.getvalue())
|
| 563 |
+
entry = f"Snapshot from Cam 0: {filename}"
|
| 564 |
+
if entry not in st.session_state['history']:
|
| 565 |
+
st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry]
|
| 566 |
+
st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
|
| 567 |
+
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
| 568 |
+
update_gallery()
|
| 569 |
with cols[1]:
|
| 570 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
| 571 |
if cam1_img:
|
| 572 |
filename = generate_filename("cam1")
|
| 573 |
+
with open(filename, "wb") as f:
|
| 574 |
+
f.write(cam1_img.getvalue())
|
| 575 |
+
entry = f"Snapshot from Cam 1: {filename}"
|
| 576 |
+
if entry not in st.session_state['history']:
|
| 577 |
+
st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry]
|
| 578 |
+
st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
|
| 579 |
+
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
| 580 |
+
update_gallery()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
|
| 582 |
with tab2:
|
| 583 |
st.header("Download PDFs 📥")
|
| 584 |
if st.button("Examples 📚"):
|
| 585 |
example_urls = [
|
| 586 |
+
"https://arxiv.org/pdf/2308.03892",
|
| 587 |
+
"https://arxiv.org/pdf/1912.01703",
|
| 588 |
+
"https://arxiv.org/pdf/2408.11039",
|
| 589 |
+
"https://arxiv.org/pdf/2109.10282",
|
| 590 |
+
"https://arxiv.org/pdf/2112.10752",
|
| 591 |
+
"https://arxiv.org/pdf/2308.11236",
|
| 592 |
+
"https://arxiv.org/pdf/1706.03762",
|
| 593 |
+
"https://arxiv.org/pdf/2006.11239",
|
| 594 |
+
"https://arxiv.org/pdf/2305.11207",
|
| 595 |
+
"https://arxiv.org/pdf/2106.09685",
|
| 596 |
+
"https://arxiv.org/pdf/2005.11401",
|
| 597 |
+
"https://arxiv.org/pdf/2106.10504"
|
| 598 |
]
|
| 599 |
st.session_state['pdf_urls'] = "\n".join(example_urls)
|
| 600 |
|
|
|
|
| 640 |
st.image(img, caption=os.path.basename(pdf_path), use_container_width=True)
|
| 641 |
checkbox_key = f"pdf_{pdf_path}"
|
| 642 |
st.session_state['pdf_checkboxes'][checkbox_key] = st.checkbox(
|
| 643 |
+
"Use for SFT/Input",
|
| 644 |
+
value=st.session_state['pdf_checkboxes'].get(checkbox_key, False),
|
| 645 |
+
key=checkbox_key
|
| 646 |
)
|
| 647 |
st.markdown(get_download_link(pdf_path, "application/pdf", "Snag It! 📥"), unsafe_allow_html=True)
|
| 648 |
if st.button("Zap It! 🗑️", key=f"delete_{pdf_path}"):
|
|
|
|
| 842 |
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 843 |
doc.close()
|
| 844 |
st.image(image, caption="Input Image", use_container_width=True)
|
|
|
|
|
|
|
| 845 |
if st.button("Run OCR 🚀", key="ocr_run"):
|
| 846 |
output_file = generate_filename("ocr_output", "txt")
|
| 847 |
st.session_state['processing']['ocr'] = True
|
| 848 |
+
result = asyncio.run(process_ocr(image, output_file))
|
| 849 |
+
entry = f"OCR Test: {selected_file} -> {output_file}"
|
| 850 |
+
if entry not in st.session_state['history']:
|
| 851 |
+
st.session_state['history'].append(entry)
|
| 852 |
st.text_area("OCR Result", result, height=200, key="ocr_result")
|
| 853 |
st.success(f"OCR output saved to {output_file}")
|
| 854 |
st.session_state['processing']['ocr'] = False
|
|
|
|
| 876 |
output_file = generate_filename("gen_output", "png")
|
| 877 |
st.session_state['processing']['gen'] = True
|
| 878 |
result = asyncio.run(process_image_gen(prompt, output_file))
|
| 879 |
+
entry = f"Image Gen Test: {prompt} -> {output_file}"
|
| 880 |
+
if entry not in st.session_state['history']:
|
| 881 |
+
st.session_state['history'].append(entry)
|
| 882 |
st.image(result, caption="Generated Image", use_container_width=True)
|
| 883 |
st.success(f"Image saved to {output_file}")
|
| 884 |
st.session_state['processing']['gen'] = False
|
|
|
|
| 886 |
st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
|
| 887 |
|
| 888 |
with tab9:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 889 |
st.header("Custom Diffusion 🎨🤓")
|
| 890 |
st.write("Unleash your inner artist with our tiny diffusion models!")
|
| 891 |
captured_files = get_gallery_files(["png"])
|
|
|
|
| 931 |
else:
|
| 932 |
st.warning("No images or PDFs captured yet. Use Camera Snap or Download PDFs first!")
|
| 933 |
|
|
|
|
| 934 |
update_gallery()
|