UnisMindMap/projects/mineru_tianshu/litserve_worker.py

547 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
MinerU Tianshu - LitServe Worker
天枢 LitServe Worker
使用 LitServe 实现 GPU 资源的自动负载均衡
Worker 主动循环拉取任务并处理
"""
import os
import json
import sys
import time
import threading
import signal
import atexit
from pathlib import Path
import litserve as ls
from loguru import logger
# 添加父目录到路径以导入 MinerU
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from task_db import TaskDB
from mineru.cli.common import do_parse, read_fn
from mineru.utils.config_reader import get_device
from mineru.utils.model_utils import get_vram, clean_memory
# 尝试导入 markitdown
try:
from markitdown import MarkItDown
MARKITDOWN_AVAILABLE = True
except ImportError:
MARKITDOWN_AVAILABLE = False
logger.warning("⚠️ markitdown not available, Office format parsing will be disabled")
class MinerUWorkerAPI(ls.LitAPI):
"""
LitServe API Worker
Worker 主动循环拉取任务,利用 LitServe 的自动 GPU 负载均衡
支持两种解析方式:
- PDF/图片 -> MinerU 解析GPU 加速)
- 其他所有格式 -> MarkItDown 解析(快速处理)
新模式:每个 worker 启动后持续循环拉取任务,处理完一个立即拉取下一个
"""
# 支持的文件格式定义
# MinerU 专用格式PDF 和图片
PDF_IMAGE_FORMATS = {'.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
# 其他所有格式都使用 MarkItDown 解析
def __init__(self, output_dir='/tmp/mineru_tianshu_output', worker_id_prefix='tianshu',
poll_interval=0.5, enable_worker_loop=True):
super().__init__()
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.worker_id_prefix = worker_id_prefix
self.poll_interval = poll_interval # Worker 拉取任务的间隔(秒)
self.enable_worker_loop = enable_worker_loop # 是否启用 worker 循环拉取
self.db = TaskDB()
self.worker_id = None
self.markitdown = None
self.running = False # Worker 运行状态
self.worker_thread = None # Worker 线程
def setup(self, device):
"""
初始化环境(每个 worker 进程调用一次)
关键修复:使用 CUDA_VISIBLE_DEVICES 确保每个进程只使用分配的 GPU
Args:
device: LitServe 分配的设备 (cuda:0, cuda:1, etc.)
"""
# 生成唯一的 worker_id
import socket
hostname = socket.gethostname()
pid = os.getpid()
self.worker_id = f"{self.worker_id_prefix}-{hostname}-{device}-{pid}"
logger.info(f"⚙️ Worker {self.worker_id} setting up on device: {device}")
# 关键修复:设置 CUDA_VISIBLE_DEVICES 限制进程只能看到分配的 GPU
# 这样可以防止一个进程占用多张卡的显存
if device != 'auto' and device != 'cpu' and ':' in str(device):
# 从 'cuda:0' 提取设备ID '0'
device_id = str(device).split(':')[-1]
os.environ['CUDA_VISIBLE_DEVICES'] = device_id
# 设置为 cuda:0因为对进程来说只能看到一张卡逻辑ID变为0
os.environ['MINERU_DEVICE_MODE'] = 'cuda:0'
device_mode = os.environ['MINERU_DEVICE_MODE']
logger.info(f"🔒 CUDA_VISIBLE_DEVICES={device_id} (Physical GPU {device_id} → Logical GPU 0)")
else:
# 配置 MinerU 环境
if os.getenv('MINERU_DEVICE_MODE', None) is None:
os.environ['MINERU_DEVICE_MODE'] = device if device != 'auto' else get_device()
device_mode = os.environ['MINERU_DEVICE_MODE']
# 配置显存
if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
if device_mode.startswith("cuda") or device_mode.startswith("npu"):
try:
vram = get_vram(device_mode)
os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = str(vram)
except:
os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = '8' # 默认值
else:
os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = '1'
# 初始化 MarkItDown如果可用
if MARKITDOWN_AVAILABLE:
self.markitdown = MarkItDown()
logger.info(f"✅ MarkItDown initialized for Office format parsing")
logger.info(f"✅ Worker {self.worker_id} ready")
logger.info(f" Device: {device_mode}")
logger.info(f" VRAM: {os.environ['MINERU_VIRTUAL_VRAM_SIZE']}GB")
# 启动 worker 循环拉取任务(在独立线程中)
if self.enable_worker_loop:
self.running = True
self.worker_thread = threading.Thread(
target=self._worker_loop,
daemon=True,
name=f"Worker-{self.worker_id}"
)
self.worker_thread.start()
logger.info(f"🔄 Worker loop started (poll_interval={self.poll_interval}s)")
def teardown(self):
"""
优雅关闭 Worker
设置 running 标志为 False等待 worker 线程完成当前任务后退出。
这避免了守护线程可能导致的任务处理不完整或数据库操作不一致问题。
"""
if self.enable_worker_loop and self.worker_thread and self.worker_thread.is_alive():
logger.info(f"🛑 Shutting down worker {self.worker_id}...")
self.running = False
# 等待线程完成当前任务(最多等待 poll_interval * 2 秒)
timeout = self.poll_interval * 2
self.worker_thread.join(timeout=timeout)
if self.worker_thread.is_alive():
logger.warning(f"⚠️ Worker thread did not stop within {timeout}s, forcing exit")
else:
logger.info(f"✅ Worker {self.worker_id} shut down gracefully")
def _worker_loop(self):
"""
Worker 主循环:持续拉取并处理任务
这个方法在独立线程中运行,让每个 worker 主动拉取任务
而不是被动等待调度器触发
"""
logger.info(f"🔁 {self.worker_id} started task polling loop")
idle_count = 0
while self.running:
try:
# 从数据库获取任务
task = self.db.get_next_task(self.worker_id)
if task:
idle_count = 0 # 重置空闲计数
# 处理任务
task_id = task['task_id']
logger.info(f"🔄 {self.worker_id} picked up task {task_id}")
try:
self._process_task(task)
except Exception as e:
logger.error(f"{self.worker_id} failed to process task {task_id}: {e}")
success = self.db.update_task_status(
task_id, 'failed',
error_message=str(e),
worker_id=self.worker_id
)
if not success:
logger.warning(f"⚠️ Task {task_id} was modified by another process during failure update")
else:
# 没有任务时,增加空闲计数
idle_count += 1
# 只在第一次空闲时记录日志,避免刷屏
if idle_count == 1:
logger.debug(f"💤 {self.worker_id} is idle, waiting for tasks...")
# 空闲时等待一段时间再拉取
time.sleep(self.poll_interval)
except Exception as e:
logger.error(f"{self.worker_id} loop error: {e}")
time.sleep(self.poll_interval)
logger.info(f"⏹️ {self.worker_id} stopped task polling loop")
def _process_task(self, task: dict):
"""
处理单个任务
Args:
task: 任务字典
"""
task_id = task['task_id']
file_path = task['file_path']
file_name = task['file_name']
backend = task['backend']
options = json.loads(task['options'])
logger.info(f"🔄 Processing task {task_id}: {file_name}")
try:
# 准备输出目录
output_path = self.output_dir / task_id
output_path.mkdir(parents=True, exist_ok=True)
# 判断文件类型并选择解析方式
file_type = self._get_file_type(file_path)
if file_type == 'pdf_image':
# 使用 MinerU 解析 PDF 和图片
self._parse_with_mineru(
file_path=Path(file_path),
file_name=file_name,
task_id=task_id,
backend=backend,
options=options,
output_path=output_path
)
parse_method = 'MinerU'
else: # file_type == 'markitdown'
# 使用 markitdown 解析所有其他格式
self._parse_with_markitdown(
file_path=Path(file_path),
file_name=file_name,
output_path=output_path
)
parse_method = 'MarkItDown'
# 更新状态为成功
success = self.db.update_task_status(
task_id, 'completed',
result_path=str(output_path),
worker_id=self.worker_id
)
if success:
logger.info(f"✅ Task {task_id} completed by {self.worker_id}")
logger.info(f" Parser: {parse_method}")
logger.info(f" Output: {output_path}")
else:
logger.warning(
f"⚠️ Task {task_id} was modified by another process. "
f"Worker {self.worker_id} completed the work but status update was rejected."
)
finally:
# 清理临时文件
try:
if Path(file_path).exists():
Path(file_path).unlink()
except Exception as e:
logger.warning(f"Failed to clean up temp file {file_path}: {e}")
def decode_request(self, request):
"""
解码请求
现在主要用于健康检查和手动触发(兼容旧接口)
"""
return request.get('action', 'poll')
def _get_file_type(self, file_path: str) -> str:
"""
判断文件类型
Args:
file_path: 文件路径
Returns:
'pdf_image': PDF 或图片格式,使用 MinerU 解析
'markitdown': 其他所有格式,使用 markitdown 解析
"""
suffix = Path(file_path).suffix.lower()
if suffix in self.PDF_IMAGE_FORMATS:
return 'pdf_image'
else:
# 所有非 PDF/图片格式都使用 markitdown
return 'markitdown'
def _parse_with_mineru(self, file_path: Path, file_name: str, task_id: str,
backend: str, options: dict, output_path: Path):
"""
使用 MinerU 解析 PDF 和图片格式
Args:
file_path: 文件路径
file_name: 文件名
task_id: 任务ID
backend: 后端类型
options: 解析选项
output_path: 输出路径
"""
logger.info(f"📄 Using MinerU to parse: {file_name}")
try:
# 读取文件
pdf_bytes = read_fn(file_path)
# 执行解析MinerU 的 ModelSingleton 会自动复用模型)
do_parse(
output_dir=str(output_path),
pdf_file_names=[Path(file_name).stem],
pdf_bytes_list=[pdf_bytes],
p_lang_list=[options.get('lang', 'ch')],
backend=backend,
parse_method=options.get('method', 'auto'),
formula_enable=options.get('formula_enable', True),
table_enable=options.get('table_enable', True),
)
finally:
# 使用 MinerU 自带的内存清理函数
# 这个函数只清理推理产生的中间结果,不会卸载模型
try:
clean_memory()
except Exception as e:
logger.debug(f"Memory cleanup failed for task {task_id}: {e}")
def _parse_with_markitdown(self, file_path: Path, file_name: str,
output_path: Path):
"""
使用 markitdown 解析文档(支持 Office、HTML、文本等多种格式
Args:
file_path: 文件路径
file_name: 文件名
output_path: 输出路径
"""
if not MARKITDOWN_AVAILABLE or self.markitdown is None:
raise RuntimeError("markitdown is not available. Please install it: pip install markitdown")
logger.info(f"📊 Using MarkItDown to parse: {file_name}")
# 使用 markitdown 转换文档
result = self.markitdown.convert(str(file_path))
# 保存为 markdown 文件
output_file = output_path / f"{Path(file_name).stem}.md"
output_file.write_text(result.text_content, encoding='utf-8')
logger.info(f"📝 Markdown saved to: {output_file}")
def predict(self, action):
"""
HTTP 接口(主要用于健康检查和监控)
现在任务由 worker 循环自动拉取处理,这个接口主要用于:
1. 健康检查
2. 获取 worker 状态
3. 兼容旧的手动触发模式(当 enable_worker_loop=False 时)
"""
if action == 'health':
# 健康检查
stats = self.db.get_queue_stats()
return {
'status': 'healthy',
'worker_id': self.worker_id,
'worker_loop_enabled': self.enable_worker_loop,
'worker_running': self.running,
'queue_stats': stats
}
elif action == 'poll':
if not self.enable_worker_loop:
# 兼容模式:手动触发任务拉取
task = self.db.get_next_task(self.worker_id)
if not task:
return {
'status': 'idle',
'message': 'No pending tasks in queue',
'worker_id': self.worker_id
}
try:
self._process_task(task)
return {
'status': 'completed',
'task_id': task['task_id'],
'worker_id': self.worker_id
}
except Exception as e:
return {
'status': 'failed',
'task_id': task['task_id'],
'error': str(e),
'worker_id': self.worker_id
}
else:
# Worker 循环模式:返回状态信息
return {
'status': 'auto_mode',
'message': 'Worker is running in auto-loop mode, tasks are processed automatically',
'worker_id': self.worker_id,
'worker_running': self.running
}
else:
return {
'status': 'error',
'message': f'Invalid action: {action}. Use "health" or "poll".',
'worker_id': self.worker_id
}
def encode_response(self, response):
"""编码响应"""
return response
def start_litserve_workers(
output_dir='/tmp/mineru_tianshu_output',
accelerator='auto',
devices='auto',
workers_per_device=1,
port=9000,
poll_interval=0.5,
enable_worker_loop=True
):
"""
启动 LitServe Worker Pool
Args:
output_dir: 输出目录
accelerator: 加速器类型 (auto/cuda/cpu/mps)
devices: 使用的设备 (auto/[0,1,2])
workers_per_device: 每个 GPU 的 worker 数量
port: 服务端口
poll_interval: Worker 拉取任务的间隔(秒)
enable_worker_loop: 是否启用 worker 自动循环拉取任务
"""
logger.info("=" * 60)
logger.info("🚀 Starting MinerU Tianshu LitServe Worker Pool")
logger.info("=" * 60)
logger.info(f"📂 Output Directory: {output_dir}")
logger.info(f"🎮 Accelerator: {accelerator}")
logger.info(f"💾 Devices: {devices}")
logger.info(f"👷 Workers per Device: {workers_per_device}")
logger.info(f"🔌 Port: {port}")
logger.info(f"🔄 Worker Loop: {'Enabled' if enable_worker_loop else 'Disabled'}")
if enable_worker_loop:
logger.info(f"⏱️ Poll Interval: {poll_interval}s")
logger.info("=" * 60)
# 创建 LitServe 服务器
api = MinerUWorkerAPI(
output_dir=output_dir,
poll_interval=poll_interval,
enable_worker_loop=enable_worker_loop
)
server = ls.LitServer(
api,
accelerator=accelerator,
devices=devices,
workers_per_device=workers_per_device,
timeout=False, # 不设置超时
)
# 注册优雅关闭处理器
def graceful_shutdown(signum=None, frame=None):
"""处理关闭信号,优雅地停止 worker"""
logger.info("🛑 Received shutdown signal, gracefully stopping workers...")
# 注意LitServe 会为每个设备创建多个 worker 实例
# 这里的 api 只是模板,实际的 worker 实例由 LitServe 管理
# teardown 会在每个 worker 进程中被调用
if hasattr(api, 'teardown'):
api.teardown()
sys.exit(0)
# 注册信号处理器Ctrl+C 等)
signal.signal(signal.SIGINT, graceful_shutdown)
signal.signal(signal.SIGTERM, graceful_shutdown)
# 注册 atexit 处理器(正常退出时调用)
atexit.register(lambda: api.teardown() if hasattr(api, 'teardown') else None)
logger.info(f"✅ LitServe worker pool initialized")
logger.info(f"📡 Listening on: http://0.0.0.0:{port}/predict")
if enable_worker_loop:
logger.info(f"🔁 Workers will continuously poll and process tasks")
else:
logger.info(f"🔄 Workers will wait for scheduler triggers")
logger.info("=" * 60)
# 启动服务器
server.run(port=port, generate_client_file=False)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='MinerU Tianshu LitServe Worker Pool')
parser.add_argument('--output-dir', type=str, default='/tmp/mineru_tianshu_output',
help='Output directory for processed files')
parser.add_argument('--accelerator', type=str, default='auto',
choices=['auto', 'cuda', 'cpu', 'mps'],
help='Accelerator type')
parser.add_argument('--devices', type=str, default='auto',
help='Devices to use (auto or comma-separated list like 0,1,2)')
parser.add_argument('--workers-per-device', type=int, default=1,
help='Number of workers per device')
parser.add_argument('--port', type=int, default=9000,
help='Server port')
parser.add_argument('--poll-interval', type=float, default=0.5,
help='Worker poll interval in seconds (default: 0.5)')
parser.add_argument('--disable-worker-loop', action='store_true',
help='Disable worker auto-loop mode (use scheduler-driven mode)')
args = parser.parse_args()
# 处理 devices 参数
devices = args.devices
if devices != 'auto':
try:
devices = [int(d) for d in devices.split(',')]
except:
logger.warning(f"Invalid devices format: {devices}, using 'auto'")
devices = 'auto'
start_litserve_workers(
output_dir=args.output_dir,
accelerator=args.accelerator,
devices=devices,
workers_per_device=args.workers_per_device,
port=args.port,
poll_interval=args.poll_interval,
enable_worker_loop=not args.disable_worker_loop
)