AlanXian commited on
Commit
76eb9fc
·
1 Parent(s): ae24f1a

update: nougat gpu

Browse files
Files changed (1) hide show
  1. app.py +112 -87
app.py CHANGED
@@ -87,43 +87,107 @@ except:
87
  if not terminators:
88
  terminators = [2] # 使用常见的</s>标记ID作为默认值
89
 
90
- # 优化后的GPU-based Nougat PDF处理
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @spaces.GPU(stateless=True)
92
- def process_pdf_with_nougat(pdf_path):
93
- """使用Nougat处理PDF文件 (在GPU环境中运行)"""
94
  try:
95
- # 方法1: 使用Nougat Python API
96
- try:
97
- from nougat import NougatModel
98
- from nougat.utils.checkpoint import get_checkpoint
99
- from nougat.dataset.rasterize import rasterize_paper
100
- import torch
101
-
102
- # 初始化模型并移至GPU
103
- ckpt = get_checkpoint()
104
- model = NougatModel.from_pretrained(ckpt)
105
- device = torch.device("cuda")
106
- model = model.to(device)
107
-
108
- # 处理PDF
109
- markdown_content = ""
110
- for page_idx, page in enumerate(rasterize_paper(pdf_path)):
111
- page = page.to(device)
112
- markdown = model.inference(page)
113
- markdown_content += f"--- Page {page_idx+1} ---\n{markdown}\n\n"
114
-
115
- print("成功使用Nougat Python API处理PDF")
116
- return markdown_content
117
- except Exception as api_error:
118
- print(f"Nougat Python API处理失败: {str(api_error)}")
119
- raise api_error
120
-
 
 
 
 
 
 
 
121
  except Exception as e:
122
  import traceback
123
- print(f"GPU PDF处理失败: {str(e)}\n{traceback.format_exc()}")
124
- raise e
 
125
 
126
- # 添加优化后的PDF转换为Markdown函数
127
  def convert_pdf_to_markdown(pdf_file):
128
  """使用Nougat将PDF转换为Markdown (GPU优化版)"""
129
  if pdf_file is None:
@@ -141,73 +205,34 @@ def convert_pdf_to_markdown(pdf_file):
141
  with open(temp_pdf_path, "wb") as f:
142
  f.write(pdf_file)
143
 
144
- # 首先尝试使用GPU命令行方式处理PDF
145
- output_dir = temp_dir
146
- print(f"执行: nougat {temp_pdf_path} -o {output_dir}")
147
-
148
- # 设置GPU环境变量
149
- env = os.environ.copy()
150
- env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
151
 
152
- try:
153
- result = subprocess.run(
154
- ["nougat", temp_pdf_path, "-o", output_dir],
155
- stdout=subprocess.PIPE,
156
- stderr=subprocess.PIPE,
157
- text=True,
158
- timeout=180,
159
- env=env
160
- )
161
-
162
- # 检查命令行转换是否成功
163
- if result.returncode == 0:
164
- # 读取生成的Markdown文件
165
- markdown_path = os.path.join(output_dir, "temp.mmd")
166
- if os.path.exists(markdown_path):
167
- with open(markdown_path, "r", encoding="utf-8") as f:
168
- markdown_content = f.read()
169
-
170
- # 限制文本长度
171
- if len(markdown_content) > 20000:
172
- markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
173
-
174
- status = f"PDF已成功转换为Markdown (GPU命令行): 生成了{len(markdown_content)}个字符"
175
- return markdown_content, status
176
-
177
- # 如果命令行方式失败,尝试空间GPU API
178
- print("命令行转换失败,正在尝试使用GPU API方式处理PDF...")
179
- markdown_content = process_pdf_with_nougat(temp_pdf_path)
180
-
181
- # 限制文本长度
182
- if len(markdown_content) > 20000:
183
- markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
184
-
185
- status = f"PDF已成功转换为Markdown (GPU API): 生成了{len(markdown_content)}个字符"
186
- return markdown_content, status
187
-
188
- except subprocess.TimeoutExpired:
189
- print("命令行处理超时,尝试使用GPU API...")
190
- # 尝试使用GPU API
191
- markdown_content = process_pdf_with_nougat(temp_pdf_path)
192
-
193
  # 限制文本长度
194
  if len(markdown_content) > 20000:
195
  markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
196
 
197
- status = f"PDF已成功转换为Markdown (GPU API): 生成了{len(markdown_content)}个字符"
198
  return markdown_content, status
199
-
200
- except Exception as cmd_error:
201
- print(f"命令行处理失败: {str(cmd_error)}")
202
- # 尝试使用GPU API
203
- markdown_content = process_pdf_with_nougat(temp_pdf_path)
204
-
 
 
205
  # 限制文本长度
206
  if len(markdown_content) > 20000:
207
  markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
208
 
209
  status = f"PDF已成功转换为Markdown (GPU API): 生成了{len(markdown_content)}个字符"
210
  return markdown_content, status
 
 
 
211
 
212
  except Exception as e:
213
  import traceback
 
87
  if not terminators:
88
  terminators = [2] # 使用常见的</s>标记ID作为默认值
89
 
90
+ # 使用CUDA运行NougatPDF处理函数
91
+ def process_pdf_with_nougat_gpu(pdf_path, output_dir=None):
92
+ """使用GPU运行Nougat处理PDF文件"""
93
+ try:
94
+ # 如果未指定输出目录,使用PDF所在目录
95
+ if output_dir is None:
96
+ output_dir = os.path.dirname(pdf_path)
97
+
98
+ # 设置CUDA环境变量
99
+ env = os.environ.copy()
100
+ env["CUDA_VISIBLE_DEVICES"] = "0" # 使用第一个GPU
101
+ env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
102
+
103
+ # 执行带有GPU支持的Nougat命令
104
+ print(f"使用GPU运行Nougat: {pdf_path}")
105
+ cmd = ["nougat", pdf_path, "-o", output_dir, "--device", "cuda"]
106
+
107
+ # 执行命令并捕获输出
108
+ result = subprocess.run(
109
+ cmd,
110
+ stdout=subprocess.PIPE,
111
+ stderr=subprocess.PIPE,
112
+ text=True,
113
+ env=env,
114
+ timeout=300 # 5分钟超时
115
+ )
116
+
117
+ # 检查命令执行结果
118
+ if result.returncode != 0:
119
+ print(f"Nougat GPU处理失败: {result.stderr}")
120
+ return None, result.stderr
121
+
122
+ # 获取生成的markdown文件路径
123
+ base_name = os.path.basename(pdf_path)
124
+ name_without_ext = os.path.splitext(base_name)[0]
125
+ markdown_path = os.path.join(output_dir, f"{name_without_ext}.mmd")
126
+
127
+ # 检查markdown文件是否生成
128
+ if not os.path.exists(markdown_path):
129
+ return None, "Nougat处理完成,但未找到生成的Markdown文件"
130
+
131
+ # 读取markdown内容
132
+ with open(markdown_path, "r", encoding="utf-8") as f:
133
+ markdown_content = f.read()
134
+
135
+ return markdown_content, None
136
+
137
+ except subprocess.TimeoutExpired:
138
+ return None, "Nougat处理超时"
139
+
140
+ except Exception as e:
141
+ import traceback
142
+ error = f"Nougat处理异常: {str(e)}\n{traceback.format_exc()}"
143
+ print(error)
144
+ return None, error
145
+
146
+ # 使用Python API的GPU处理方式
147
  @spaces.GPU(stateless=True)
148
+ def process_pdf_with_nougat_api(pdf_path):
149
+ """使用Nougat Python API与GPU处理PDF文件"""
150
  try:
151
+ # 导入必要的库
152
+ from nougat import NougatModel
153
+ from nougat.utils.checkpoint import get_checkpoint
154
+ from nougat.dataset.rasterize import rasterize_paper
155
+ import torch
156
+
157
+ # 确保GPU可用
158
+ if not torch.cuda.is_available():
159
+ return None, "GPU不可用,无法使用Nougat API处理PDF"
160
+
161
+ # 显示GPU信息
162
+ device_count = torch.cuda.device_count()
163
+ device_name = torch.cuda.get_device_name(0) if device_count > 0 else "Unknown"
164
+ print(f"使用GPU: {device_name}, 可用GPU数量: {device_count}")
165
+
166
+ # 初始化模型并移至GPU
167
+ ckpt = get_checkpoint()
168
+ model = NougatModel.from_pretrained(ckpt)
169
+ device = torch.device("cuda")
170
+ model = model.to(device)
171
+
172
+ # 处理PDF
173
+ markdown_content = ""
174
+ pages = list(rasterize_paper(pdf_path))
175
+
176
+ # 使用tqdm显示进度
177
+ for page_idx, page in enumerate(tqdm(pages, desc="处理PDF页面")):
178
+ page = page.to(device)
179
+ markdown = model.inference(page)
180
+ markdown_content += f"--- Page {page_idx+1} ---\n{markdown}\n\n"
181
+
182
+ return markdown_content, None
183
+
184
  except Exception as e:
185
  import traceback
186
+ error = f"Nougat API处理异常: {str(e)}\n{traceback.format_exc()}"
187
+ print(error)
188
+ return None, error
189
 
190
+ # 添加PDF转换为Markdown函数
191
  def convert_pdf_to_markdown(pdf_file):
192
  """使用Nougat将PDF转换为Markdown (GPU优化版)"""
193
  if pdf_file is None:
 
205
  with open(temp_pdf_path, "wb") as f:
206
  f.write(pdf_file)
207
 
208
+ # 方法1: 首先尝试使用命令行GPU方式
209
+ print("方法1: 尝试使用命令行GPU方式处理PDF...")
210
+ markdown_content, error = process_pdf_with_nougat_gpu(temp_pdf_path, temp_dir)
 
 
 
 
211
 
212
+ if markdown_content is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  # 限制文本长度
214
  if len(markdown_content) > 20000:
215
  markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
216
 
217
+ status = f"PDF已成功转换为Markdown (GPU命令行): 生成了{len(markdown_content)}个字符"
218
  return markdown_content, status
219
+
220
+ # 方法2: 如果命令行方式失败,尝试使用Python API方式
221
+ print(f"方法1失败: {error}")
222
+ print("方法2: 尝试使用Python API GPU方式处理PDF...")
223
+
224
+ markdown_content, api_error = process_pdf_with_nougat_api(temp_pdf_path)
225
+
226
+ if markdown_content is not None:
227
  # 限制文本长度
228
  if len(markdown_content) > 20000:
229
  markdown_content = markdown_content[:20000] + "\n\n...(Markdown内容已截断)"
230
 
231
  status = f"PDF已成功转换为Markdown (GPU API): 生成了{len(markdown_content)}个字符"
232
  return markdown_content, status
233
+
234
+ # 所有方法都失败
235
+ return "", f"PDF转换失败: 所有GPU方法都失败了\n命令行错误: {error}\nAPI错误: {api_error}"
236
 
237
  except Exception as e:
238
  import traceback