AlanXian's picture
update: use nougat-transformer
e8c1452
import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import tempfile
import torch
import subprocess
import sys
import importlib.util
from tqdm import tqdm
# Updated imports for transformers-based Nougat
TRANSFORMERS_NOUGAT_AVAILABLE = importlib.util.find_spec("transformers") is not None
try:
from transformers import VisionEncoderDecoderModel, NougatProcessor, NougatImageProcessor
from PIL import Image
import pdf2image
TRANSFORMERS_NOUGAT_AVAILABLE = True
except ImportError:
TRANSFORMERS_NOUGAT_AVAILABLE = False
print("Warning: transformers with Nougat support is not installed. PDF to Markdown conversion will not be available.")
print("To install required packages, run: pip install transformers pdf2image Pillow")
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Set CUDA environment variables for better GPU performance with Nougat
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Meta Llama3 8B with Nougat PDF Processing</h1>
<p>This Space demonstrates the instruction-tuned model <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama3 8b Chat</b></a>. Meta Llama3 is the new open LLM and comes in two sizes: 8b and 70b. Feel free to play with it, or duplicate to run privately!</p>
<p>🔎 For more details about the Llama3 release and how to use the model with <code>transformers</code>, take a look <a href="https://huggingface.co/blog/llama3">at our blog post</a>.</p>
<p>🦕 Looking for an even more powerful model? Check out the <a href="https://huggingface.co/chat/"><b>Hugging Chat</b></a> integration for Meta Llama 3 70b</p>
<p>📝 <b>PDF处理功能:</b> 本应用使用<a href="https://huggingface.co/docs/transformers/model_doc/nougat">Transformers Nougat</a>进行高质量PDF到Markdown的转换。该工具能够很好地保留原始布局、数学公式和表格,提供最佳的PDF文档处理体验。</p>
</div>
'''
LICENSE = """
<p/>
---
Built with Meta Llama 3
"""
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/8e75e61cc9bab22b7ce3dec85ab0e6db1da5d107/Meta_lockup_positive%20primary_RGB.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Meta llama3</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("nuojohnchen/shuishanllm")
# Load model weights but don't initialize CUDA in main process
model = AutoModelForCausalLM.from_pretrained("nuojohnchen/shuishanllm", device_map=None)
# 确保eos_token_id不为None
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = 2 # 通常2是</s>标记的ID,这是一个常见的默认值
# 定义终止标记
terminators = []
if tokenizer.eos_token_id is not None:
terminators.append(tokenizer.eos_token_id)
# 尝试添加特殊的终止标记,如果存在的话
try:
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
if eot_id != tokenizer.unk_token_id: # 确保不是未知标记
terminators.append(eot_id)
except:
pass
# 如果terminators为空,添加一个默认值
if not terminators:
terminators = [2] # 使用常见的</s>标记ID作为默认值
# 使用transformers库中的Nougat模型处理PDF
@spaces.GPU(stateless=True)
def process_pdf_with_transformers_nougat(pdf_path):
"""使用transformers库中的Nougat模型将PDF转换为Markdown"""
try:
# 确保GPU可用
if not torch.cuda.is_available():
return None, "GPU不可用,无法使用Nougat处理PDF"
# 显示GPU信息
device_count = torch.cuda.device_count()
device_name = torch.cuda.get_device_name(0) if device_count > 0 else "Unknown"
print(f"使用GPU: {device_name}, 可用GPU数量: {device_count}")
# 加载Nougat模型和处理器
processor = NougatProcessor.from_pretrained("facebook/nougat-base")
image_processor = NougatImageProcessor.from_pretrained("facebook/nougat-base")
model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base")
# 将模型移到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 将PDF转换为图像
print(f"将PDF转换为图像: {pdf_path}")
images = pdf2image.convert_from_path(pdf_path)
# 处理每一页并生成Markdown
markdown_content = ""
for page_idx, image in enumerate(tqdm(images, desc="处理PDF页面")):
# 处理图像
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
# 生成文本
outputs = model.generate(
pixel_values,
max_length=1024,
num_beams=4,
early_stopping=True
)
# 解码输出
page_markdown = processor.decode(outputs[0], skip_special_tokens=True)
markdown_content += f"--- Page {page_idx+1} ---\n{page_markdown}\n\n"
return markdown_content, None
except Exception as e:
import traceback
error = f"Transformers Nougat处理异常: {str(e)}\n{traceback.format_exc()}"
print(error)
return None, error
# 添加PDF转换为Markdown函数
def convert_pdf_to_markdown(pdf_file):
"""使用Transformers Nougat将PDF转换为Markdown"""
if pdf_file is None:
return "", "未上传PDF"
# 检查Transformers Nougat是否可用
if not TRANSFORMERS_NOUGAT_AVAILABLE:
return "", "错误: Transformers Nougat未安装。请执行 'pip install transformers pdf2image Pillow' 安装后重试。"
try:
# 创建临时目录用于存储PDF和输出文件
with tempfile.TemporaryDirectory() as temp_dir:
# 将二进制PDF数据保存到临时文件
temp_pdf_path = os.path.join(temp_dir, "temp.pdf")
with open(temp_pdf_path, "wb") as f:
f.write(pdf_file)
# 使用Transformers Nougat处理PDF
print("使用Transformers Nougat处理PDF...")
markdown_content, error = process_pdf_with_transformers_nougat(temp_pdf_path)
if markdown_content is not None:
# 限制文本长度
if len(markdown_content) > 20000:
markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
status = f"PDF已成功转换为Markdown (Transformers Nougat): 生成了{len(markdown_content)}个字符"
return markdown_content, status
# 处理失败
return "", f"PDF转换失败: Transformers Nougat处理失败\n错误: {error}"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Transformers Nougat转换错误: {str(e)}\n{error_details}")
return "", f"Markdown转换错误: {str(e)}"
@spaces.GPU(duration=120, stateless=True)
def chat_llama3_8b(message, history, temperature, max_new_tokens, markdown_content=""):
"""
Generate a streaming response using the llama3-8b model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
markdown_content (str): Optional Markdown content converted by Nougat to include in the context.
Returns:
str: The generated response.
"""
try:
conversation = []
for user, assistant in history:
# 确保所有内容都是字符串类型
user_msg = str(user) if user is not None else ""
assistant_msg = str(assistant) if assistant is not None else ""
conversation.extend([
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
])
# 确保message是字符串
message = str(message) if message is not None else ""
# 如果有Markdown内容,将其添加到用户消息中
if markdown_content and isinstance(markdown_content, str) and markdown_content.strip():
message = f"""
Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**.
The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements.
Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper.
<PAPER_CONTENT>
{markdown_content}
</PAPER_CONTENT>
<QUESTION>
{message}
</QUESTION>
"""
print(f"加入Markdown的message", message)
conversation.append({"role": "user", "content": message})
# 使用简单的文本拼接方式构建提示
prompt = ""
for item in conversation:
role = item["role"]
content = item["content"]
if role == "user":
prompt += f"用户: {content}\n"
else:
prompt += f"助手: {content}\n"
prompt += "助手: " # 添加最后的提示符
# 编码提示
# 在stateless GPU环境中将模型移到CUDA设备
global model
device = torch.device("cuda")
model = model.to(device)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
)
# 只有当terminators非空时才添加eos_token_id
if terminators:
generate_kwargs['eos_token_id'] = terminators
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"生成错误: {str(e)}\n{error_details}")
yield f"生成文本时出错: {str(e)}\n\n请尝试使用不同的参数或输入。"
# Gradio block
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
# 创建Markdown内容状态
markdown_content_state = gr.State("")
with gr.Row():
with gr.Column(scale=1):
# PDF上传区域
pdf_file = gr.File(
label="上传PDF文档(可选)",
file_types=[".pdf"],
type="binary"
)
pdf_status = gr.Textbox(
label="PDF状态",
value="未上传PDF",
interactive=False
)
clear_pdf_btn = gr.Button("清除PDF")
if TRANSFORMERS_NOUGAT_AVAILABLE:
nougat_info = """
<div style="margin-top: 10px; margin-bottom: 10px;">
<p><b>Transformers Nougat PDF处理:</b> 系统将使用Transformers库中的Nougat模型将上传的PDF转换为高质量Markdown。Nougat能够很好地保留原始布局、数学公式和表格,远优于传统的PDF文本提取。</p>
</div>
"""
else:
nougat_info = """
<div style="margin-top: 10px; margin-bottom: 10px; color: #d32f2f;">
<p><b>Transformers Nougat未安装:</b> PDF处理功能需要Transformers Nougat。请执行 <code>pip install transformers pdf2image Pillow</code> 安装后重试。</p>
</div>
"""
gr.Markdown(nougat_info)
# 添加Markdown内容查看器(可折叠)
with gr.Accordion("查看Markdown内容", open=False):
markdown_content_display = gr.Textbox(
label="Nougat转换的Markdown内容",
lines=10,
interactive=False
)
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Row():
with gr.Column(scale=8):
msg = gr.Textbox(
show_label=False,
placeholder="输入您的问题...",
container=False
)
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button("发送")
with gr.Accordion("⚙️ 参数设置", open=False):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.95,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens"
)
gr.Examples(
examples=[
['How to setup a human base on Mars? Give short answer.'],
['What is 9,000 * 9,000?'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
inputs=msg
)
# 处理PDF上传 - 直接使用Nougat转换
pdf_file.change(
fn=convert_pdf_to_markdown,
inputs=[pdf_file],
outputs=[markdown_content_state, pdf_status]
)
# 更新Markdown内容显示
pdf_file.change(
fn=lambda content: content,
inputs=[markdown_content_state],
outputs=[markdown_content_display]
)
# 清除PDF内容
clear_pdf_btn.click(
fn=lambda: ("", "PDF已清除"),
inputs=[],
outputs=[markdown_content_state, pdf_status]
)
# 清除Markdown内容显示
clear_pdf_btn.click(
fn=lambda: "",
inputs=[],
outputs=[markdown_content_display]
)
# 聊天功能
chat_interface = gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
textbox=msg,
submit_btn=submit_btn,
additional_inputs=[temperature, max_new_tokens, markdown_content_state],
additional_inputs_accordion=None,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()