Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| import tempfile | |
| import torch | |
| import subprocess | |
| import sys | |
| import importlib.util | |
| from tqdm import tqdm | |
| # Updated imports for transformers-based Nougat | |
| TRANSFORMERS_NOUGAT_AVAILABLE = importlib.util.find_spec("transformers") is not None | |
| try: | |
| from transformers import VisionEncoderDecoderModel, NougatProcessor, NougatImageProcessor | |
| from PIL import Image | |
| import pdf2image | |
| TRANSFORMERS_NOUGAT_AVAILABLE = True | |
| except ImportError: | |
| TRANSFORMERS_NOUGAT_AVAILABLE = False | |
| print("Warning: transformers with Nougat support is not installed. PDF to Markdown conversion will not be available.") | |
| print("To install required packages, run: pip install transformers pdf2image Pillow") | |
| # Set an environment variable | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # Set CUDA environment variables for better GPU performance with Nougat | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| DESCRIPTION = ''' | |
| <div> | |
| <h1 style="text-align: center;">Meta Llama3 8B with Nougat PDF Processing</h1> | |
| <p>This Space demonstrates the instruction-tuned model <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama3 8b Chat</b></a>. Meta Llama3 is the new open LLM and comes in two sizes: 8b and 70b. Feel free to play with it, or duplicate to run privately!</p> | |
| <p>🔎 For more details about the Llama3 release and how to use the model with <code>transformers</code>, take a look <a href="https://huggingface.co/blog/llama3">at our blog post</a>.</p> | |
| <p>🦕 Looking for an even more powerful model? Check out the <a href="https://huggingface.co/chat/"><b>Hugging Chat</b></a> integration for Meta Llama 3 70b</p> | |
| <p>📝 <b>PDF处理功能:</b> 本应用使用<a href="https://huggingface.co/docs/transformers/model_doc/nougat">Transformers Nougat</a>进行高质量PDF到Markdown的转换。该工具能够很好地保留原始布局、数学公式和表格,提供最佳的PDF文档处理体验。</p> | |
| </div> | |
| ''' | |
| LICENSE = """ | |
| <p/> | |
| --- | |
| Built with Meta Llama 3 | |
| """ | |
| PLACEHOLDER = """ | |
| <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
| <img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/8e75e61cc9bab22b7ce3dec85ab0e6db1da5d107/Meta_lockup_positive%20primary_RGB.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; "> | |
| <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Meta llama3</h1> | |
| <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p> | |
| </div> | |
| """ | |
| css = """ | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| #duplicate-button { | |
| margin: auto; | |
| color: white; | |
| background: #1565c0; | |
| border-radius: 100vh; | |
| } | |
| """ | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("nuojohnchen/shuishanllm") | |
| # Load model weights but don't initialize CUDA in main process | |
| model = AutoModelForCausalLM.from_pretrained("nuojohnchen/shuishanllm", device_map=None) | |
| # 确保eos_token_id不为None | |
| if tokenizer.eos_token_id is None: | |
| tokenizer.eos_token_id = 2 # 通常2是</s>标记的ID,这是一个常见的默认值 | |
| # 定义终止标记 | |
| terminators = [] | |
| if tokenizer.eos_token_id is not None: | |
| terminators.append(tokenizer.eos_token_id) | |
| # 尝试添加特殊的终止标记,如果存在的话 | |
| try: | |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| if eot_id != tokenizer.unk_token_id: # 确保不是未知标记 | |
| terminators.append(eot_id) | |
| except: | |
| pass | |
| # 如果terminators为空,添加一个默认值 | |
| if not terminators: | |
| terminators = [2] # 使用常见的</s>标记ID作为默认值 | |
| # 使用transformers库中的Nougat模型处理PDF | |
| def process_pdf_with_transformers_nougat(pdf_path): | |
| """使用transformers库中的Nougat模型将PDF转换为Markdown""" | |
| try: | |
| # 确保GPU可用 | |
| if not torch.cuda.is_available(): | |
| return None, "GPU不可用,无法使用Nougat处理PDF" | |
| # 显示GPU信息 | |
| device_count = torch.cuda.device_count() | |
| device_name = torch.cuda.get_device_name(0) if device_count > 0 else "Unknown" | |
| print(f"使用GPU: {device_name}, 可用GPU数量: {device_count}") | |
| # 加载Nougat模型和处理器 | |
| processor = NougatProcessor.from_pretrained("facebook/nougat-base") | |
| image_processor = NougatImageProcessor.from_pretrained("facebook/nougat-base") | |
| model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base") | |
| # 将模型移到GPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| # 将PDF转换为图像 | |
| print(f"将PDF转换为图像: {pdf_path}") | |
| images = pdf2image.convert_from_path(pdf_path) | |
| # 处理每一页并生成Markdown | |
| markdown_content = "" | |
| for page_idx, image in enumerate(tqdm(images, desc="处理PDF页面")): | |
| # 处理图像 | |
| pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device) | |
| # 生成文本 | |
| outputs = model.generate( | |
| pixel_values, | |
| max_length=1024, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| # 解码输出 | |
| page_markdown = processor.decode(outputs[0], skip_special_tokens=True) | |
| markdown_content += f"--- Page {page_idx+1} ---\n{page_markdown}\n\n" | |
| return markdown_content, None | |
| except Exception as e: | |
| import traceback | |
| error = f"Transformers Nougat处理异常: {str(e)}\n{traceback.format_exc()}" | |
| print(error) | |
| return None, error | |
| # 添加PDF转换为Markdown函数 | |
| def convert_pdf_to_markdown(pdf_file): | |
| """使用Transformers Nougat将PDF转换为Markdown""" | |
| if pdf_file is None: | |
| return "", "未上传PDF" | |
| # 检查Transformers Nougat是否可用 | |
| if not TRANSFORMERS_NOUGAT_AVAILABLE: | |
| return "", "错误: Transformers Nougat未安装。请执行 'pip install transformers pdf2image Pillow' 安装后重试。" | |
| try: | |
| # 创建临时目录用于存储PDF和输出文件 | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # 将二进制PDF数据保存到临时文件 | |
| temp_pdf_path = os.path.join(temp_dir, "temp.pdf") | |
| with open(temp_pdf_path, "wb") as f: | |
| f.write(pdf_file) | |
| # 使用Transformers Nougat处理PDF | |
| print("使用Transformers Nougat处理PDF...") | |
| markdown_content, error = process_pdf_with_transformers_nougat(temp_pdf_path) | |
| if markdown_content is not None: | |
| # 限制文本长度 | |
| if len(markdown_content) > 20000: | |
| markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)" | |
| status = f"PDF已成功转换为Markdown (Transformers Nougat): 生成了{len(markdown_content)}个字符" | |
| return markdown_content, status | |
| # 处理失败 | |
| return "", f"PDF转换失败: Transformers Nougat处理失败\n错误: {error}" | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Transformers Nougat转换错误: {str(e)}\n{error_details}") | |
| return "", f"Markdown转换错误: {str(e)}" | |
| def chat_llama3_8b(message, history, temperature, max_new_tokens, markdown_content=""): | |
| """ | |
| Generate a streaming response using the llama3-8b model. | |
| Args: | |
| message (str): The input message. | |
| history (list): The conversation history used by ChatInterface. | |
| temperature (float): The temperature for generating the response. | |
| max_new_tokens (int): The maximum number of new tokens to generate. | |
| markdown_content (str): Optional Markdown content converted by Nougat to include in the context. | |
| Returns: | |
| str: The generated response. | |
| """ | |
| try: | |
| conversation = [] | |
| for user, assistant in history: | |
| # 确保所有内容都是字符串类型 | |
| user_msg = str(user) if user is not None else "" | |
| assistant_msg = str(assistant) if assistant is not None else "" | |
| conversation.extend([ | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": assistant_msg} | |
| ]) | |
| # 确保message是字符串 | |
| message = str(message) if message is not None else "" | |
| # 如果有Markdown内容,将其添加到用户消息中 | |
| if markdown_content and isinstance(markdown_content, str) and markdown_content.strip(): | |
| message = f""" | |
| Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**. | |
| The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements. | |
| Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper. | |
| <PAPER_CONTENT> | |
| {markdown_content} | |
| </PAPER_CONTENT> | |
| <QUESTION> | |
| {message} | |
| </QUESTION> | |
| """ | |
| print(f"加入Markdown的message", message) | |
| conversation.append({"role": "user", "content": message}) | |
| # 使用简单的文本拼接方式构建提示 | |
| prompt = "" | |
| for item in conversation: | |
| role = item["role"] | |
| content = item["content"] | |
| if role == "user": | |
| prompt += f"用户: {content}\n" | |
| else: | |
| prompt += f"助手: {content}\n" | |
| prompt += "助手: " # 添加最后的提示符 | |
| # 编码提示 | |
| # 在stateless GPU环境中将模型移到CUDA设备 | |
| global model | |
| device = torch.device("cuda") | |
| model = model.to(device) | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| ) | |
| # 只有当terminators非空时才添加eos_token_id | |
| if terminators: | |
| generate_kwargs['eos_token_id'] = terminators | |
| # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash. | |
| if temperature == 0: | |
| generate_kwargs['do_sample'] = False | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"生成错误: {str(e)}\n{error_details}") | |
| yield f"生成文本时出错: {str(e)}\n\n请尝试使用不同的参数或输入。" | |
| # Gradio block | |
| with gr.Blocks(fill_height=True, css=css) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
| # 创建Markdown内容状态 | |
| markdown_content_state = gr.State("") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # PDF上传区域 | |
| pdf_file = gr.File( | |
| label="上传PDF文档(可选)", | |
| file_types=[".pdf"], | |
| type="binary" | |
| ) | |
| pdf_status = gr.Textbox( | |
| label="PDF状态", | |
| value="未上传PDF", | |
| interactive=False | |
| ) | |
| clear_pdf_btn = gr.Button("清除PDF") | |
| if TRANSFORMERS_NOUGAT_AVAILABLE: | |
| nougat_info = """ | |
| <div style="margin-top: 10px; margin-bottom: 10px;"> | |
| <p><b>Transformers Nougat PDF处理:</b> 系统将使用Transformers库中的Nougat模型将上传的PDF转换为高质量Markdown。Nougat能够很好地保留原始布局、数学公式和表格,远优于传统的PDF文本提取。</p> | |
| </div> | |
| """ | |
| else: | |
| nougat_info = """ | |
| <div style="margin-top: 10px; margin-bottom: 10px; color: #d32f2f;"> | |
| <p><b>Transformers Nougat未安装:</b> PDF处理功能需要Transformers Nougat。请执行 <code>pip install transformers pdf2image Pillow</code> 安装后重试。</p> | |
| </div> | |
| """ | |
| gr.Markdown(nougat_info) | |
| # 添加Markdown内容查看器(可折叠) | |
| with gr.Accordion("查看Markdown内容", open=False): | |
| markdown_content_display = gr.Textbox( | |
| label="Nougat转换的Markdown内容", | |
| lines=10, | |
| interactive=False | |
| ) | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface') | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| msg = gr.Textbox( | |
| show_label=False, | |
| placeholder="输入您的问题...", | |
| container=False | |
| ) | |
| with gr.Column(scale=1, min_width=50): | |
| submit_btn = gr.Button("发送") | |
| with gr.Accordion("⚙️ 参数设置", open=False): | |
| temperature = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.95, | |
| label="Temperature" | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=4096, | |
| step=1, | |
| value=512, | |
| label="Max new tokens" | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ['How to setup a human base on Mars? Give short answer.'], | |
| ['What is 9,000 * 9,000?'], | |
| ['Write a pun-filled happy birthday message to my friend Alex.'], | |
| ['Justify why a penguin might make a good king of the jungle.'] | |
| ], | |
| inputs=msg | |
| ) | |
| # 处理PDF上传 - 直接使用Nougat转换 | |
| pdf_file.change( | |
| fn=convert_pdf_to_markdown, | |
| inputs=[pdf_file], | |
| outputs=[markdown_content_state, pdf_status] | |
| ) | |
| # 更新Markdown内容显示 | |
| pdf_file.change( | |
| fn=lambda content: content, | |
| inputs=[markdown_content_state], | |
| outputs=[markdown_content_display] | |
| ) | |
| # 清除PDF内容 | |
| clear_pdf_btn.click( | |
| fn=lambda: ("", "PDF已清除"), | |
| inputs=[], | |
| outputs=[markdown_content_state, pdf_status] | |
| ) | |
| # 清除Markdown内容显示 | |
| clear_pdf_btn.click( | |
| fn=lambda: "", | |
| inputs=[], | |
| outputs=[markdown_content_display] | |
| ) | |
| # 聊天功能 | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_llama3_8b, | |
| chatbot=chatbot, | |
| textbox=msg, | |
| submit_btn=submit_btn, | |
| additional_inputs=[temperature, max_new_tokens, markdown_content_state], | |
| additional_inputs_accordion=None, | |
| ) | |
| gr.Markdown(LICENSE) | |
| if __name__ == "__main__": | |
| demo.launch() | |