import sys import uuid import os import re import tempfile import asyncio import threading import hashlib import json import uvicorn import click import zipfile import urllib.request import urllib.error from pathlib import Path import glob os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") from fastapi import Depends, FastAPI, HTTPException, UploadFile, File, Form, APIRouter, Header from pydantic import BaseModel from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles from starlette.background import BackgroundTask from typing import List, Optional from loguru import logger from base64 import b64encode # MinerU 内部导入 from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes, word_suffixes from mineru.utils.cli_parser import arg_parse from mineru.utils.config_reader import get_device from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path from mineru.utils.model_utils import clean_memory from mineru.version import __version__ # --- 日志配置 --- log_level = os.getenv("MINERU_LOG_LEVEL", "INFO").upper() logger.remove() logger.add(sys.stderr, level=log_level) # --- 全局变量与辅助类 --- _request_semaphore: Optional[asyncio.Semaphore] = None # --- 任务进度跟踪 --- _task_progress: dict = {} _task_progress_dir = Path( os.getenv("MINERU_TASK_PROGRESS_DIR", Path(tempfile.gettempdir()) / "mineru-task-progress") ) def _task_progress_path(task_id: str) -> Path: task_key = hashlib.sha256(task_id.encode("utf-8")).hexdigest() return _task_progress_dir / f"{task_key}.json" def _store_task_progress(task_id: str, state: dict) -> None: """Store progress in memory and on disk so API worker processes share it.""" state = dict(state) state["task_id"] = task_id _task_progress[task_id] = state try: _task_progress_dir.mkdir(parents=True, exist_ok=True) progress_path = _task_progress_path(task_id) temp_path = progress_path.with_name( f"{progress_path.name}.{os.getpid()}.{threading.get_ident()}.tmp" ) with open(temp_path, "w", encoding="utf-8") as fp: json.dump(state, fp, ensure_ascii=False) os.replace(temp_path, progress_path) except Exception as exc: logger.warning(f"Failed to persist task progress for {task_id}: {exc}") def _get_task_progress(task_id: str) -> Optional[dict]: try: with open(_task_progress_path(task_id), "r", encoding="utf-8") as fp: state = json.load(fp) _task_progress[task_id] = state return state except (FileNotFoundError, json.JSONDecodeError, OSError): return _task_progress.get(task_id) def _update_task_progress(task_id: Optional[str], progress: int, stage: str): """更新任务进度(安全调用,task_id 为 None 时静默跳过)""" if not task_id: return state = _get_task_progress(task_id) if state is not None: current = int(state.get("progress", 0) or 0) state["progress"] = min(max(progress, current), 100) state["stage"] = stage _store_task_progress(task_id, state) def _cleanup_runtime_memory(): """Release transient accelerator caches after each API parse request.""" try: clean_memory(get_device()) except Exception as exc: logger.warning(f"Failed to clean runtime memory: {exc}") def _format_parse_error(exc: Exception) -> str: message = str(exc) if "CUDA out of memory" in message or exc.__class__.__name__ in ("OutOfMemoryError", "RuntimeError"): if "out of memory" in message.lower(): return "GPU显存不足,已尝试清理缓存。请减少最大页数、关闭公式/表格识别,或稍后重试。" return f"Internal Error: {message}" class ProgressTracker: def __init__(self): self.progress = 0 self.status = "初始化" def update(self, progress: float, status: str): self.progress = progress self.status = status def get_progress(self): return {"progress": self.progress, "status": self.status} # --- 当前活跃任务ID(用于日志捕获) --- _current_task_id: Optional[str] = None def _log_progress_sink(message): """loguru 日志捕获:根据日志内容自动更新进度""" task_id = _current_task_id if not task_id or task_id not in _task_progress: return try: msg = str(message) except Exception: return # VLM 引擎加载阶段 (5% ~ 40%) if "Automatically detected platform" in msg: _update_task_progress(task_id, 7, "检测计算平台") elif "Resolved architecture" in msg: _update_task_progress(task_id, 10, "解析模型架构") elif "Starting to load model" in msg: _update_task_progress(task_id, 12, "开始加载模型权重") elif "Loading weights took" in msg: _update_task_progress(task_id, 18, "模型权重加载完成") elif "Model loading took" in msg: _update_task_progress(task_id, 22, "模型加载完成,初始化引擎") elif "Dynamo bytecode transform" in msg: _update_task_progress(task_id, 27, "编译优化计算图") elif "torch.compile takes" in msg: _update_task_progress(task_id, 30, "计算图编译完成") elif "Available KV cache memory" in msg: _update_task_progress(task_id, 32, "分配显存缓存") elif "Graph capturing finished" in msg: _update_task_progress(task_id, 36, "CUDA图捕获完成") elif "init engine" in msg and "took" in msg: _update_task_progress(task_id, 38, "VLM引擎初始化完成") elif "get vllm" in msg and "predictor cost" in msg: _update_task_progress(task_id, 40, "VLM预测器就绪") elif "hybrid batch ratio" in msg: _update_task_progress(task_id, 42, "开始文档分析") # VLM 推理阶段 (40% ~ 65%) 由 progress_callback 处理 # Pipeline 模型加载和推理阶段 (65% ~ 97%) elif "Target directory already exists" in msg: cur = _task_progress.get(task_id, {}).get("progress", 0) if cur < 75: _update_task_progress(task_id, max(cur, 66), "加载Pipeline模型") elif "MFD Predict:" in msg: _update_task_progress(task_id, 78, "数学公式检测") elif "MFR Predict:" in msg: _update_task_progress(task_id, 85, "数学公式识别") elif "OCR-det:" in msg: _update_task_progress(task_id, 90, "文字区域检测") elif "OCR-rec Predict:" in msg: _update_task_progress(task_id, 95, "文字识别") elif "local output dir is" in msg: _update_task_progress(task_id, 97, "保存输出结果") class _StderrProgressCapture: """后台线程捕获 stderr 中的 tqdm 进度条输出,解析并更新任务进度""" # tqdm 进度条模式:名称: 百分比|...| 当前/总数 _PATTERNS = [ (re.compile(r'Layout Preparation:\s*(\d+)%.*?(\d+)/(\d+)'), 'layout_prepare'), (re.compile(r'Layout Output Parsing:\s*(\d+)%.*?(\d+)/(\d+)'), 'layout_parse'), (re.compile(r'Extract Preparation:\s*(\d+)%.*?(\d+)/(\d+)'), 'extract_prepare'), (re.compile(r'Post Processing:\s*(\d+)%.*?(\d+)/(\d+)'), 'post_process'), (re.compile(r'Two Step Extraction:\s*(\d+)%.*?(\d+)/(\d+)'), 'vlm_predict'), (re.compile(r'MFD Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'mfd'), (re.compile(r'MFR Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'mfr'), (re.compile(r'OCR-det:\s*(\d+)%.*?(\d+)/(\d+)'), 'ocr_det'), (re.compile(r'OCR-rec Predict:\s*(\d+)%.*?(\d+)/(\d+)'), 'ocr_rec'), (re.compile(r'Loading safetensors.*?:\s*(\d+)%'), 'load_model'), (re.compile(r'Capturing CUDA graphs.*?:\s*(\d+)%'), 'cuda_graph'), ] _GENERIC_PREDICT_PATTERN = re.compile(r'^Predict:\s*(\d+)%.*?(\d+)/(\d+)') # 各阶段的进度映射范围 [start%, end%] _RANGES = { 'load_model': (12, 18), 'cuda_graph': (33, 37), 'layout_prepare': (42, 45), 'layout_predict': (45, 68), 'layout_parse': (68, 70), 'extract_prepare': (70, 72), 'extract_predict': (72, 88), 'vlm_predict': (45, 88), 'post_process': (88, 90), 'mfd': (90, 92), 'mfr': (92, 94), 'ocr_det': (94, 96), 'ocr_rec': (96, 97), } _STAGE_LABELS = { 'load_model': '加载模型权重', 'cuda_graph': '捕获CUDA计算图', 'layout_prepare': '准备版面分析', 'layout_predict': '版面分析', 'layout_parse': '解析版面结果', 'extract_prepare': '准备内容抽取', 'extract_predict': '内容抽取', 'vlm_predict': 'VLM文档分析', 'post_process': '后处理', 'mfd': '数学公式检测', 'mfr': '数学公式识别', 'ocr_det': '文字区域检测', 'ocr_rec': '文字识别', } def __init__(self, task_id: str): self.task_id = task_id self._active = False self._orig_stderr = None self._buf = "" self._last_anchor = "" def start(self): self._active = True self._orig_stderr = sys.stderr sys.stderr = self def stop(self): self._active = False if self._buf.strip(): self._parse_line(self._buf.strip()) self._buf = "" if self._orig_stderr is not None and sys.stderr is self: sys.stderr = self._orig_stderr def write(self, text): if self._orig_stderr is not None: self._orig_stderr.write(text) if not self._active: return len(text) for ch in text: self._buf += ch # tqdm 用 \r 更新同一行,\n 表示新行 if ch == '\r' or ch == '\n': if self._buf.strip(): self._parse_line(self._buf.strip()) self._buf = "" return len(text) def flush(self): if self._orig_stderr is not None: self._orig_stderr.flush() def isatty(self): return bool(self._orig_stderr and self._orig_stderr.isatty()) def fileno(self): if self._orig_stderr is not None: return self._orig_stderr.fileno() raise OSError("stderr is not available") def __getattr__(self, name): if self._orig_stderr is not None: return getattr(self._orig_stderr, name) raise AttributeError(name) def _parse_line(self, line: str): if "Layout Preparation:" in line: self._last_anchor = "layout" elif "Extract Preparation:" in line: self._last_anchor = "extract" generic_predict = self._GENERIC_PREDICT_PATTERN.search(line) if generic_predict: stage = "extract_predict" if self._last_anchor == "extract" else "layout_predict" self._update_from_match(generic_predict, stage) return for pattern, stage in self._PATTERNS: m = pattern.search(line) if m: self._update_from_match(m, stage) break def _update_from_match(self, match, stage: str): pct = int(match.group(1)) lo, hi = self._RANGES.get(stage, (0, 100)) mapped = lo + int((hi - lo) * pct / 100) label = self._STAGE_LABELS.get(stage, stage) if len(match.groups()) >= 3: cur, total = match.group(2), match.group(3) unit = "页" if stage in ("layout_predict", "vlm_predict") else "" label = f"{label} ({cur}/{total}{unit})" _update_task_progress(self.task_id, mapped, label) async def limit_concurrency(): if _request_semaphore is not None: if _request_semaphore.locked(): raise HTTPException( status_code=503, detail=f"Server is at maximum capacity: {os.getenv('MINERU_API_MAX_CONCURRENT_REQUESTS', '0')}." ) async with _request_semaphore: yield else: yield def sanitize_filename(filename: str) -> str: sanitized = re.sub(r'[/\\\.]{2,}|[/\\]', '', filename) sanitized = re.sub(r'[^\w.-]', '_', sanitized, flags=re.UNICODE) if sanitized.startswith('.'): sanitized = '_' + sanitized[1:] return sanitized or 'unnamed' def cleanup_file(file_path: str) -> None: try: if os.path.exists(file_path): os.remove(file_path) except Exception as e: logger.warning(f"fail clean file {file_path}: {e}") def encode_image(image_path: str) -> str: with open(image_path, "rb") as f: return b64encode(f.read()).decode() def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]: result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}") if os.path.exists(result_file_path): with open(result_file_path, "r", encoding="utf-8") as fp: return fp.read() return None # --- 创建 API 路由组 --- api_router = APIRouter(prefix="/api") class MindmapOrganizeRequest(BaseModel): markdown: str mode: str = "smart" prompt: Optional[str] = None def _extract_json_object(text: str) -> str: content = text.strip() if content.startswith("```"): content = re.sub(r"^```(?:json|markdown|md)?\s*", "", content, flags=re.IGNORECASE) content = re.sub(r"\s*```$", "", content) return content.strip() DEFAULT_MINDMAP_ORGANIZE_PROMPT = """你是文档结构整理助手。请基于用户提供的 Markdown 生成适合思维导图展示的 Markdown。 要求: 1. 保留原文标题结构,不要重写或打乱主要标题层级。 2. 将标题下的段落内容总结为要点,合并相近段落,避免逐段照抄。 3. 保留原文语言:英文内容输出英文,中文内容输出中文,多语言内容按原文语言分别保留。 4. 不要编造原文没有的信息。 5. 保留关键数字、公式、专有名词、步骤和结论。 6. 最大层级不超过 4 层。 7. 每个父节点下最多 8 个子节点。 8. 节点标题尽量简短,正文说明使用短句列表。 9. 只输出 Markdown,不要输出解释、代码块围栏或额外说明。""" MINDMAP_MERGE_PROMPT = """你是 Markdown 思维导图结构校对助手。下面是多个分块已经整理总结后的局部 Markdown 大纲。 任务: 1. 合并这些局部大纲为一份完整 Markdown。 2. 只检查和调整标题层级结构、顺序、重复标题和父子关系。 3. 不要重新总结正文,不要扩写内容,不要新增原文没有的信息。 4. 保留各局部大纲中的源语言:英文保持英文,中文保持中文,多语言分别保留。 5. 最大层级不超过 4 层。 6. 每个父节点下最多 8 个子节点,必要时只合并相近标题。 7. 只输出 Markdown,不要输出解释、代码块围栏或额外说明。""" def _estimate_tokens(text: str) -> int: ascii_chars = 0 non_ascii_chars = 0 for ch in text: if ch.isspace(): continue if ord(ch) < 128: ascii_chars += 1 else: non_ascii_chars += 1 return max(1, int(ascii_chars / 4) + int(non_ascii_chars * 1.5)) def _get_mindmap_context_budget(prompt: str, reserve_output_tokens: int = 4096) -> tuple[int, int]: max_context_tokens = int(os.getenv("MINDMAP_LLM_MAX_CONTEXT_TOKENS", "32768")) prompt_tokens = _estimate_tokens(prompt) safety_tokens = int(os.getenv("MINDMAP_LLM_SAFETY_TOKENS", "1024")) input_budget_tokens = max(2048, max_context_tokens - prompt_tokens - reserve_output_tokens - safety_tokens) logger.info( "Mindmap context budget max_context_tokens={} prompt_tokens={} reserve_output_tokens={} safety_tokens={} input_budget_tokens={}", max_context_tokens, prompt_tokens, reserve_output_tokens, safety_tokens, input_budget_tokens ) return max_context_tokens, input_budget_tokens def _split_markdown_blocks(markdown: str) -> list[str]: lines = markdown.splitlines() blocks: list[str] = [] current: list[str] = [] heading_pattern = re.compile(r"^#{1,6}\s+") for line in lines: if heading_pattern.match(line) and current: blocks.append("\n".join(current).strip()) current = [line] else: current.append(line) if current: blocks.append("\n".join(current).strip()) return [block for block in blocks if block] def _split_large_block(block: str, max_tokens: int) -> list[str]: paragraphs = re.split(r"\n{2,}", block) chunks: list[str] = [] current: list[str] = [] current_tokens = 0 for paragraph in paragraphs: paragraph = paragraph.strip() if not paragraph: continue paragraph_tokens = _estimate_tokens(paragraph) if current and current_tokens + paragraph_tokens > max_tokens: chunks.append("\n\n".join(current)) current = [paragraph] current_tokens = paragraph_tokens else: current.append(paragraph) current_tokens += paragraph_tokens if current: chunks.append("\n\n".join(current)) return chunks or [block] def _chunk_markdown_by_headings(markdown: str, max_tokens: int) -> list[str]: blocks = _split_markdown_blocks(markdown) chunks: list[str] = [] current: list[str] = [] current_tokens = 0 for block in blocks: block_tokens = _estimate_tokens(block) if block_tokens > max_tokens: if current: chunks.append("\n\n".join(current)) current = [] current_tokens = 0 chunks.extend(_split_large_block(block, max_tokens)) continue if current and current_tokens + block_tokens > max_tokens: chunks.append("\n\n".join(current)) current = [block] current_tokens = block_tokens else: current.append(block) current_tokens += block_tokens if current: chunks.append("\n\n".join(current)) return chunks or [markdown] def _call_mindmap_llm(markdown: str, mode: str = "smart", custom_prompt: Optional[str] = None, task_id: Optional[str] = None, request_role: str = "organize") -> str: base_url = os.getenv("MINDMAP_LLM_BASE_URL", "").rstrip("/") model = os.getenv("MINDMAP_LLM_MODEL", "gemma-4-26B") api_key = os.getenv("MINDMAP_LLM_API_KEY", "") timeout = int(os.getenv("MINDMAP_LLM_TIMEOUT", "180")) if not base_url: raise RuntimeError("未配置智能整理模型服务,请设置 MINDMAP_LLM_BASE_URL") compact_markdown = markdown.strip() prompt_template = (custom_prompt or "").strip() or DEFAULT_MINDMAP_ORGANIZE_PROMPT prompt = f"""{prompt_template} 原始 Markdown: {compact_markdown} """ logger.info( "Mindmap LLM request start task_id={} role={} model={} base_url={} mode={} input_chars={} input_tokens_est={} prompt_chars={}", task_id or "-", request_role, model, base_url, mode, len(compact_markdown), _estimate_tokens(compact_markdown), len(prompt_template) ) payload = { "model": model, "messages": [ {"role": "system", "content": "你擅长把长文档整理成结构清晰、层次合理的思维导图 Markdown,并严格保留原文语言。"}, {"role": "user", "content": prompt}, ], "temperature": float(os.getenv("MINDMAP_LLM_TEMPERATURE", "0.2")), } data = json.dumps(payload, ensure_ascii=False).encode("utf-8") headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" url = f"{base_url}/chat/completions" req = urllib.request.Request(url, data=data, headers=headers, method="POST") try: with urllib.request.urlopen(req, timeout=timeout) as resp: result = json.loads(resp.read().decode("utf-8")) except urllib.error.HTTPError as exc: detail = exc.read().decode("utf-8", errors="ignore") raise RuntimeError(f"智能整理模型请求失败: HTTP {exc.code} {detail}") from exc except Exception as exc: raise RuntimeError(f"智能整理模型请求失败: {exc}") from exc message = result.get("choices", [{}])[0].get("message", {}) content = message.get("content", "") organized = _extract_json_object(content) if not organized: raise RuntimeError("智能整理模型未返回有效内容") logger.info( "Mindmap LLM request completed task_id={} role={} output_chars={} output_tokens_est={}", task_id or "-", request_role, len(organized), _estimate_tokens(organized) ) return organized def _organize_mindmap_markdown(markdown: str, mode: str, custom_prompt: Optional[str], task_id: str) -> str: prompt_template = (custom_prompt or "").strip() or DEFAULT_MINDMAP_ORGANIZE_PROMPT _, input_budget_tokens = _get_mindmap_context_budget(prompt_template) source_tokens = _estimate_tokens(markdown) logger.info( "Mindmap organize strategy task_id={} source_chars={} source_tokens_est={} input_budget_tokens={}", task_id, len(markdown), source_tokens, input_budget_tokens ) if source_tokens <= input_budget_tokens: _update_task_progress(task_id, 35, "调用智能整理模型") return _call_mindmap_llm(markdown, mode, prompt_template, task_id, "single") chunks = _chunk_markdown_by_headings(markdown, input_budget_tokens) logger.info("Mindmap large input split task_id={} chunks={}", task_id, len(chunks)) partial_results: list[str] = [] for index, chunk in enumerate(chunks, start=1): progress = 20 + int(index / max(len(chunks), 1) * 55) _update_task_progress(task_id, progress, f"智能整理分块 {index}/{len(chunks)}") logger.info( "Mindmap chunk organize task_id={} chunk={}/{} chars={} tokens_est={}", task_id, index, len(chunks), len(chunk), _estimate_tokens(chunk) ) partial = _call_mindmap_llm(chunk, mode, prompt_template, task_id, f"chunk-{index}") partial_results.append(partial) merged_input = "\n\n".join( f"\n{partial}" for index, partial in enumerate(partial_results, start=1) ) _, merge_budget_tokens = _get_mindmap_context_budget(MINDMAP_MERGE_PROMPT) merge_tokens = _estimate_tokens(merged_input) if merge_tokens > merge_budget_tokens: logger.warning( "Mindmap merged outline still exceeds context task_id={} tokens_est={} budget={} chunks={}", task_id, merge_tokens, merge_budget_tokens, len(partial_results) ) merge_chunks = _chunk_markdown_by_headings(merged_input, merge_budget_tokens) merged_round: list[str] = [] for index, chunk in enumerate(merge_chunks, start=1): _update_task_progress(task_id, 78 + int(index / max(len(merge_chunks), 1) * 10), f"合并局部大纲 {index}/{len(merge_chunks)}") merged_round.append(_call_mindmap_llm(chunk, mode, MINDMAP_MERGE_PROMPT, task_id, f"merge-round-{index}")) merged_input = "\n\n".join(merged_round) _update_task_progress(task_id, 90, "全局整理标题结构") return _call_mindmap_llm(merged_input, mode, MINDMAP_MERGE_PROMPT, task_id, "merge") async def _run_mindmap_organize_task(task_id: str, markdown: str, mode: str, prompt: Optional[str]): try: _store_task_progress(task_id, { "progress": 10, "stage": "准备智能整理", "status": "processing", "error": None, "file_names": "", "result_md": None, }) organized = await asyncio.to_thread(_organize_mindmap_markdown, markdown, mode, prompt, task_id) state = _get_task_progress(task_id) or {} state.update({ "progress": 100, "stage": "智能整理完成", "status": "completed", "error": None, "result_md": organized, }) _store_task_progress(task_id, state) except Exception as exc: logger.exception(f"Mindmap organize task failed task_id={task_id}: {exc}") state = _get_task_progress(task_id) or {} state.update({ "progress": 100, "stage": "智能整理失败", "status": "failed", "error": str(exc), "result_md": None, }) _store_task_progress(task_id, state) @api_router.post("/parse_tasks/{task_id}", status_code=201) async def create_parse_task(task_id: str): """Register a task before the multipart upload starts.""" state = { "progress": 0, "stage": "等待上传", "status": "pending", "error": None, "file_names": "", } _store_task_progress(task_id, state) logger.info(f"Registered parse task pid={os.getpid()} task_id={task_id}") return state @api_router.get("/parse_progress/{task_id}") async def get_parse_progress(task_id: str): """查询解析任务的实时进度""" state = _get_task_progress(task_id) if state is None: logger.warning(f"Parse task not found pid={os.getpid()} task_id={task_id}") raise HTTPException(status_code=404, detail="Task not found") return state @api_router.post("/mindmap_tasks/{task_id}", status_code=201) async def create_mindmap_task(task_id: str, request: MindmapOrganizeRequest): """Create an async task that organizes Markdown into summarized mindmap Markdown.""" markdown = request.markdown.strip() if not markdown: raise HTTPException(status_code=400, detail="Markdown content is required") state = { "progress": 0, "stage": "等待智能整理", "status": "pending", "error": None, "file_names": "", "result_md": None, } _store_task_progress(task_id, state) asyncio.create_task(_run_mindmap_organize_task(task_id, markdown, request.mode, request.prompt)) logger.info( "Registered mindmap organize task pid={} task_id={} mode={} input_chars={} custom_prompt={}", os.getpid(), task_id, request.mode, len(markdown), bool((request.prompt or "").strip()) ) return state @api_router.get("/mindmap_progress/{task_id}") async def get_mindmap_progress(task_id: str): """Query async mindmap organization progress and result.""" state = _get_task_progress(task_id) if state is None: logger.warning(f"Mindmap task not found pid={os.getpid()} task_id={task_id}") raise HTTPException(status_code=404, detail="Task not found") return state @api_router.post(path="/file_parse", dependencies=[Depends(limit_concurrency)]) async def parse_pdf( files: List[UploadFile] = File(..., description="Upload pdf, image, or Word files for parsing"), output_dir: str = Form("./output", description="Output local directory"), lang_list: List[str] = Form(["ch"]), backend: str = Form("hybrid-auto-engine"), parse_method: str = Form("auto"), formula_enable: bool = Form(True), table_enable: bool = Form(True), server_url: Optional[str] = Form(None), return_md: bool = Form(True), return_middle_json: bool = Form(False), return_model_output: bool = Form(False), return_content_list: bool = Form(False), return_images: bool = Form(False), response_format_zip: bool = Form(False), start_page_id: int = Form(0), end_page_id: int = Form(99999), form_task_id: Optional[str] = Form(None, alias="task_id"), x_task_id: Optional[str] = Header(None), ): # 从 app 实例状态中获取配置 (FastAPI 实例会在下方创建) from fastapi import Request config = getattr(app.state, "config", {}) # 初始化进度跟踪 task_id = x_task_id or form_task_id or str(uuid.uuid4()) file_names_str = ", ".join(f.filename or "unknown" for f in files) _store_task_progress(task_id, { "progress": 0, "stage": "准备中", "status": "processing", "error": None, "file_names": file_names_str, }) logger.info( f"Started parse task pid={os.getpid()} task_id={task_id} " f"header_task_id={x_task_id} form_task_id={form_task_id}" ) try: unique_dir = os.path.join(output_dir, str(uuid.uuid4())) os.makedirs(unique_dir, exist_ok=True) pdf_file_names = [] pdf_bytes_list = [] for file in files: content = await file.read() file_path = Path(file.filename) temp_path = Path(unique_dir) / file_path.name with open(temp_path, "wb") as f: f.write(content) file_suffix = guess_suffix_by_path(temp_path) if file_suffix in pdf_suffixes + image_suffixes + word_suffixes: try: pdf_bytes = read_fn(temp_path) pdf_bytes_list.append(pdf_bytes) pdf_file_names.append(file_path.stem) os.remove(temp_path) except Exception as e: return JSONResponse(status_code=400, content={"error": f"Failed to load file: {str(e)}"}) else: return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {file_suffix}"}) actual_lang_list = lang_list if len(actual_lang_list) != len(pdf_file_names): actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names) # 进度回调:将 common.py 内部进度映射为细粒度阶段 def progress_callback(pct, msg): msg_str = str(msg) if "处理文件" in msg_str: _update_task_progress(task_id, 40, f"VLM文档分析: {msg_str}") elif "完成文件" in msg_str: _update_task_progress(task_id, 65, f"VLM分析完成: {msg_str}") else: _update_task_progress(task_id, int(40 + pct * 0.25), msg_str) _update_task_progress(task_id, 5, "开始解析文档") # 设置日志捕获(将当前任务ID绑定到日志 sink) global _current_task_id _current_task_id = task_id sink_id = logger.add(_log_progress_sink, level="DEBUG") # 启动 stderr 捕获(解析 tqdm 进度条输出) stderr_capture = _StderrProgressCapture(task_id) stderr_capture.start() await aio_do_parse( output_dir=unique_dir, pdf_file_names=pdf_file_names, pdf_bytes_list=pdf_bytes_list, p_lang_list=actual_lang_list, backend=backend, parse_method=parse_method, formula_enable=formula_enable, table_enable=table_enable, server_url=server_url, f_draw_layout_bbox=False, f_draw_span_bbox=False, f_dump_md=return_md, f_dump_middle_json=return_middle_json, f_dump_model_output=return_model_output, f_dump_orig_pdf=False, f_dump_content_list=return_content_list, start_page_id=start_page_id, end_page_id=end_page_id, progress_callback=progress_callback, **config ) _update_task_progress(task_id, 97, "生成结果文件") _update_task_progress(task_id, 100, "转换完成") completed_state = _get_task_progress(task_id) or {} completed_state["status"] = "completed" _store_task_progress(task_id, completed_state) _cleanup_runtime_memory() # 清理日志捕获和 stderr 捕获 stderr_capture.stop() logger.remove(sink_id) _current_task_id = None if response_format_zip: zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_") os.close(zip_fd) with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: for pdf_name in pdf_file_names: safe_pdf_name = sanitize_filename(pdf_name) # 路径匹配逻辑 if backend.startswith("pipeline"): p_dir = os.path.join(unique_dir, pdf_name, parse_method) elif backend.startswith("vlm"): p_dir = os.path.join(unique_dir, pdf_name, "vlm") else: p_dir = os.path.join(unique_dir, pdf_name, f"hybrid_{parse_method}") if not os.path.exists(p_dir): continue if return_md: path = os.path.join(p_dir, f"{pdf_name}.md") if os.path.exists(path): zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}.md")) if return_images: images_dir = os.path.join(p_dir, "images") for img in glob.glob(os.path.join(glob.escape(images_dir), "*.jpg")): zf.write(img, arcname=os.path.join(safe_pdf_name, "images", os.path.basename(img))) return FileResponse(path=zip_path, media_type="application/zip", filename="results.zip", background=BackgroundTask(cleanup_file, zip_path)) else: result_dict = {} for pdf_name in pdf_file_names: result_dict[pdf_name] = {} data = result_dict[pdf_name] if backend.startswith("pipeline"): p_dir = os.path.join(unique_dir, pdf_name, parse_method) elif backend.startswith("vlm"): p_dir = os.path.join(unique_dir, pdf_name, "vlm") else: p_dir = os.path.join(unique_dir, pdf_name, f"hybrid_{parse_method}") if os.path.exists(p_dir): if return_md: data["md_content"] = get_infer_result(".md", pdf_name, p_dir) if return_images: img_dir = os.path.join(p_dir, "images") data["images"] = {os.path.basename(p): f"data:image/jpeg;base64,{encode_image(p)}" for p in glob.glob(os.path.join(glob.escape(img_dir), "*.jpg"))} return JSONResponse(status_code=200, content={"backend": backend, "version": __version__, "results": result_dict}) except Exception as e: logger.exception(e) # 清理日志捕获和 stderr 捕获 try: stderr_capture.stop() logger.remove(sink_id) _current_task_id = None except Exception: pass failed_state = _get_task_progress(task_id) or {} failed_state["status"] = "failed" failed_state["error"] = _format_parse_error(e) _store_task_progress(task_id, failed_state) _cleanup_runtime_memory() return JSONResponse(status_code=500, content={"error": _format_parse_error(e)}) # --- FastAPI 核心应用 --- def create_app(): enable_docs = str(os.getenv("MINERU_API_ENABLE_FASTAPI_DOCS", "1")).lower() in ("1", "true", "yes") app = FastAPI( openapi_url="/openapi.json" if enable_docs else None, docs_url="/docs" if enable_docs else None, redoc_url="/redoc" if enable_docs else None, ) global _request_semaphore try: mcr = int(os.getenv("MINERU_API_MAX_CONCURRENT_REQUESTS", "0")) if mcr > 0: _request_semaphore = asyncio.Semaphore(mcr) logger.info(f"Concurrency limited to {mcr}") except: pass app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) app.add_middleware(GZipMiddleware, minimum_size=1000) # 1. 首先挂载 API 路由组 (处理所有 /api/* 请求) app.include_router(api_router) # 2. 根路径重定向或特定处理 (可选) @app.get("/health") async def health(): return {"status": "ok"} # 3. 最后挂载静态文件服务 (处理剩下的所有请求,如 /, /index.html, /assets/*) static_dir = Path(__file__).parent / "static" / "web" if static_dir.exists(): logger.info(f"Mounting static files from {static_dir}") app.mount("/", StaticFiles(directory=static_dir, html=True), name="static") else: logger.warning("Static directory not found, web UI will not be available.") return app app = create_app() @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)) @click.pass_context @click.option('--host', default='127.0.0.1') @click.option('--port', default=8000, type=int) @click.option('--reload', is_flag=True) def main(ctx, host, port, reload, **kwargs): kwargs.update(arg_parse(ctx)) app.state.config = kwargs mcr = str(kwargs.get("mineru_api_max_concurrent_requests", "0") or "0") os.environ["MINERU_API_MAX_CONCURRENT_REQUESTS"] = mcr uvicorn.run("mineru.cli.fast_api:app", host=host, port=port, reload=reload) if __name__ == "__main__": main()