Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4a9f0a0
1
Parent(s):
521b81b
some cleaning and on the path to having token streamign
Browse files
app.py
CHANGED
|
@@ -6,7 +6,8 @@ import gradio as gr
|
|
| 6 |
|
| 7 |
from gradio_client.client import DEFAULT_TEMP_DIR
|
| 8 |
from playwright.sync_api import sync_playwright
|
| 9 |
-
from
|
|
|
|
| 10 |
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
|
| 11 |
from typing import List
|
| 12 |
from PIL import Image
|
|
@@ -14,15 +15,12 @@ from PIL import Image
|
|
| 14 |
from transformers.image_transforms import resize, to_channel_dimension_format
|
| 15 |
|
| 16 |
|
| 17 |
-
API_TOKEN = os.getenv("HF_AUTH_TOKEN")
|
| 18 |
DEVICE = torch.device("cuda")
|
| 19 |
PROCESSOR = AutoProcessor.from_pretrained(
|
| 20 |
"HuggingFaceM4/VLM_WebSight_finetuned",
|
| 21 |
-
token=API_TOKEN,
|
| 22 |
)
|
| 23 |
MODEL = AutoModelForCausalLM.from_pretrained(
|
| 24 |
"HuggingFaceM4/VLM_WebSight_finetuned",
|
| 25 |
-
token=API_TOKEN,
|
| 26 |
trust_remote_code=True,
|
| 27 |
torch_dtype=torch.bfloat16,
|
| 28 |
).to(DEVICE)
|
|
@@ -134,20 +132,35 @@ def model_inference(
|
|
| 134 |
k: v.to(DEVICE)
|
| 135 |
for k, v in inputs.items()
|
| 136 |
}
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
bad_words_ids=BAD_WORDS_IDS,
|
| 140 |
-
max_length=4096
|
|
|
|
| 141 |
)
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
rendered_page = render_webpage(generated_text)
|
| 148 |
return generated_text, rendered_page
|
| 149 |
|
| 150 |
-
|
| 151 |
generated_html = gr.Code(
|
| 152 |
label="Extracted HTML",
|
| 153 |
elem_id="generated_html",
|
|
@@ -189,7 +202,7 @@ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as d
|
|
| 189 |
regenerate_btn = gr.Button(
|
| 190 |
value="🔄 Regenerate", visible=True, min_width=120
|
| 191 |
)
|
| 192 |
-
with gr.Column(scale=4)
|
| 193 |
rendered_html.render()
|
| 194 |
|
| 195 |
with gr.Row():
|
|
|
|
| 6 |
|
| 7 |
from gradio_client.client import DEFAULT_TEMP_DIR
|
| 8 |
from playwright.sync_api import sync_playwright
|
| 9 |
+
from threading import Thread
|
| 10 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer
|
| 11 |
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
|
| 12 |
from typing import List
|
| 13 |
from PIL import Image
|
|
|
|
| 15 |
from transformers.image_transforms import resize, to_channel_dimension_format
|
| 16 |
|
| 17 |
|
|
|
|
| 18 |
DEVICE = torch.device("cuda")
|
| 19 |
PROCESSOR = AutoProcessor.from_pretrained(
|
| 20 |
"HuggingFaceM4/VLM_WebSight_finetuned",
|
|
|
|
| 21 |
)
|
| 22 |
MODEL = AutoModelForCausalLM.from_pretrained(
|
| 23 |
"HuggingFaceM4/VLM_WebSight_finetuned",
|
|
|
|
| 24 |
trust_remote_code=True,
|
| 25 |
torch_dtype=torch.bfloat16,
|
| 26 |
).to(DEVICE)
|
|
|
|
| 132 |
k: v.to(DEVICE)
|
| 133 |
for k, v in inputs.items()
|
| 134 |
}
|
| 135 |
+
|
| 136 |
+
streamer = TextIteratorStreamer(
|
| 137 |
+
PROCESSOR.tokenizer,
|
| 138 |
+
decode_kwargs=dict(
|
| 139 |
+
skip_special_tokens=True
|
| 140 |
+
),
|
| 141 |
+
skip_prompt=True,
|
| 142 |
+
)
|
| 143 |
+
generation_kwargs = dict(
|
| 144 |
+
inputs,
|
| 145 |
bad_words_ids=BAD_WORDS_IDS,
|
| 146 |
+
max_length=4096,
|
| 147 |
+
streamer=streamer,
|
| 148 |
)
|
| 149 |
+
thread = Thread(
|
| 150 |
+
target=MODEL.generate,
|
| 151 |
+
kwargs=generation_kwargs,
|
| 152 |
+
)
|
| 153 |
+
thread.start()
|
| 154 |
+
generated_text = ""
|
| 155 |
+
for new_text in streamer:
|
| 156 |
+
generated_text += new_text
|
| 157 |
+
print("before yield")
|
| 158 |
+
# yield generated_text, image
|
| 159 |
+
print("after yield")
|
| 160 |
|
| 161 |
rendered_page = render_webpage(generated_text)
|
| 162 |
return generated_text, rendered_page
|
| 163 |
|
|
|
|
| 164 |
generated_html = gr.Code(
|
| 165 |
label="Extracted HTML",
|
| 166 |
elem_id="generated_html",
|
|
|
|
| 202 |
regenerate_btn = gr.Button(
|
| 203 |
value="🔄 Regenerate", visible=True, min_width=120
|
| 204 |
)
|
| 205 |
+
with gr.Column(scale=4):
|
| 206 |
rendered_html.render()
|
| 207 |
|
| 208 |
with gr.Row():
|