Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import subprocess | |
| import torch | |
| import gradio as gr | |
| from gradio_client.client import DEFAULT_TEMP_DIR | |
| from playwright.sync_api import sync_playwright | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension | |
| from typing import List | |
| from PIL import Image | |
| from transformers.image_transforms import resize, to_channel_dimension_format | |
| API_TOKEN = os.getenv("HF_AUTH_TOKEN") | |
| DEVICE = torch.device("cuda") | |
| PROCESSOR = AutoProcessor.from_pretrained( | |
| "HuggingFaceM4/VLM_WebSight_finetuned", | |
| token=API_TOKEN, | |
| ) | |
| MODEL = AutoModelForCausalLM.from_pretrained( | |
| "HuggingFaceM4/VLM_WebSight_finetuned", | |
| token=API_TOKEN, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| ).to(DEVICE) | |
| if MODEL.config.use_resampler: | |
| image_seq_len = MODEL.config.perceiver_config.resampler_n_latents | |
| else: | |
| image_seq_len = ( | |
| MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size | |
| ) ** 2 | |
| BOS_TOKEN = PROCESSOR.tokenizer.bos_token | |
| BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids | |
| ## Utils | |
| def convert_to_rgb(image): | |
| # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background | |
| # for transparent images. The call to `alpha_composite` handles this case | |
| if image.mode == "RGB": | |
| return image | |
| image_rgba = image.convert("RGBA") | |
| background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) | |
| alpha_composite = Image.alpha_composite(background, image_rgba) | |
| alpha_composite = alpha_composite.convert("RGB") | |
| return alpha_composite | |
| # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip, | |
| # so this is a hack in order to redefine ONLY the transform method | |
| def custom_transform(x): | |
| x = convert_to_rgb(x) | |
| x = to_numpy_array(x) | |
| x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) | |
| x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) | |
| x = PROCESSOR.image_processor.normalize( | |
| x, | |
| mean=PROCESSOR.image_processor.image_mean, | |
| std=PROCESSOR.image_processor.image_std | |
| ) | |
| x = to_channel_dimension_format(x, ChannelDimension.FIRST) | |
| x = torch.tensor(x) | |
| return x | |
| ## End of Utils | |
| IMAGE_GALLERY_PATHS = [ | |
| f"example_images/{ex_image}" | |
| for ex_image in os.listdir(f"example_images") | |
| ] | |
| def install_playwright(): | |
| try: | |
| subprocess.run(["playwright", "install"], check=True) | |
| print("Playwright installation successful.") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error during Playwright installation: {e}") | |
| install_playwright() | |
| def add_file_gallery( | |
| selected_state: gr.SelectData, | |
| gallery_list: List[str] | |
| ): | |
| return Image.open(gallery_list.root[selected_state.index].image.path) | |
| def render_webpage( | |
| html_css_code, | |
| ): | |
| with sync_playwright() as p: | |
| browser = p.chromium.launch(headless=True) | |
| context = browser.new_context( | |
| user_agent=( | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0" | |
| " Safari/537.36" | |
| ) | |
| ) | |
| page = context.new_page() | |
| page.set_content(html_css_code) | |
| page.wait_for_load_state("networkidle") | |
| output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png" | |
| _ = page.screenshot(path=output_path_screenshot, full_page=True) | |
| context.close() | |
| browser.close() | |
| return Image.open(output_path_screenshot) | |
| def model_inference( | |
| image, | |
| ): | |
| if image is None: | |
| raise ValueError("`image` is None. It should be a PIL image.") | |
| inputs = PROCESSOR.tokenizer( | |
| f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>", | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| ) | |
| inputs["pixel_values"] = PROCESSOR.image_processor( | |
| [image], | |
| transform=custom_transform | |
| ) | |
| inputs = { | |
| k: v.to(DEVICE) | |
| for k, v in inputs.items() | |
| } | |
| generated_ids = MODEL.generate( | |
| **inputs, | |
| bad_words_ids=BAD_WORDS_IDS, | |
| max_length=4096 | |
| ) | |
| generated_text = PROCESSOR.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| rendered_page = render_webpage(generated_text) | |
| return generated_text, rendered_page | |
| generated_html = gr.Code( | |
| label="Extracted HTML", | |
| elem_id="generated_html", | |
| ) | |
| rendered_html = gr.Image( | |
| label="Rendered HTML" | |
| ) | |
| # rendered_html = gr.HTML( | |
| # label="Rendered HTML" | |
| # ) | |
| css = """ | |
| .gradio-container{max-width: 1000px!important} | |
| h1{display: flex;align-items: center;justify-content: center;gap: .25em} | |
| *{transition: width 0.5s ease, flex-grow 0.5s ease} | |
| """ | |
| with gr.Blocks(title="Img2html", theme=gr.themes.Base(), css=css) as demo: | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=4, min_width=250) as upload_area: | |
| imagebox = gr.Image( | |
| type="pil", | |
| label="Screenshot to extract", | |
| visible=True, | |
| sources=["upload", "clipboard"], | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| submit_btn = gr.Button( | |
| value="▶️ Submit", visible=True, min_width=120 | |
| ) | |
| clear_btn = gr.ClearButton( | |
| [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120 | |
| ) | |
| regenerate_btn = gr.Button( | |
| value="🔄 Regenerate", visible=True, min_width=120 | |
| ) | |
| with gr.Column(scale=4) as result_area: | |
| rendered_html.render() | |
| with gr.Row(): | |
| generated_html.render() | |
| with gr.Row(): | |
| template_gallery = gr.Gallery( | |
| value=IMAGE_GALLERY_PATHS, | |
| label="Templates Gallery", | |
| allow_preview=False, | |
| columns=4, | |
| elem_id="gallery", | |
| show_share_button=False, | |
| height=400, | |
| ) | |
| gr.on( | |
| triggers=[ | |
| imagebox.upload, | |
| submit_btn.click, | |
| regenerate_btn.click, | |
| ], | |
| fn=model_inference, | |
| inputs=[imagebox], | |
| outputs=[generated_html, rendered_html], | |
| queue=False, | |
| ) | |
| regenerate_btn.click( | |
| fn=model_inference, | |
| inputs=[imagebox], | |
| outputs=[generated_html, rendered_html], | |
| queue=False, | |
| ) | |
| template_gallery.select( | |
| fn=add_file_gallery, | |
| inputs=[template_gallery], | |
| outputs=[imagebox], | |
| queue=False, | |
| ).success( | |
| fn=model_inference, | |
| inputs=[imagebox], | |
| outputs=[generated_html, rendered_html], | |
| queue=False, | |
| ) | |
| demo.load(queue=False) | |
| demo.queue(max_size=40, api_open=False) | |
| demo.launch(max_threads=400) | |