437 lines
16 KiB
Python
437 lines
16 KiB
Python
|
|
"""
|
|||
|
|
MinerU Tianshu - SQLite Task Database Manager
|
|||
|
|
天枢任务数据库管理器
|
|||
|
|
|
|||
|
|
负责任务的持久化存储、状态管理和原子性操作
|
|||
|
|
"""
|
|||
|
|
import sqlite3
|
|||
|
|
import json
|
|||
|
|
import uuid
|
|||
|
|
from contextlib import contextmanager
|
|||
|
|
from typing import Optional, List, Dict
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TaskDB:
|
|||
|
|
"""任务数据库管理类"""
|
|||
|
|
|
|||
|
|
def __init__(self, db_path='mineru_tianshu.db'):
|
|||
|
|
self.db_path = db_path
|
|||
|
|
self._init_db()
|
|||
|
|
|
|||
|
|
def _get_conn(self):
|
|||
|
|
"""获取数据库连接(每次创建新连接,避免 pickle 问题)
|
|||
|
|
|
|||
|
|
并发安全说明:
|
|||
|
|
- 使用 check_same_thread=False 是安全的,因为:
|
|||
|
|
1. 每次调用都创建新连接,不跨线程共享
|
|||
|
|
2. 连接使用完立即关闭(在 get_cursor 上下文管理器中)
|
|||
|
|
3. 不使用连接池,避免线程间共享同一连接
|
|||
|
|
- timeout=30.0 防止死锁,如果锁等待超过30秒会抛出异常
|
|||
|
|
"""
|
|||
|
|
conn = sqlite3.connect(
|
|||
|
|
self.db_path,
|
|||
|
|
check_same_thread=False,
|
|||
|
|
timeout=30.0
|
|||
|
|
)
|
|||
|
|
conn.row_factory = sqlite3.Row
|
|||
|
|
return conn
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def get_cursor(self):
|
|||
|
|
"""上下文管理器,自动提交和错误处理"""
|
|||
|
|
conn = self._get_conn()
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
try:
|
|||
|
|
yield cursor
|
|||
|
|
conn.commit()
|
|||
|
|
except Exception as e:
|
|||
|
|
conn.rollback()
|
|||
|
|
raise e
|
|||
|
|
finally:
|
|||
|
|
conn.close() # 关闭连接
|
|||
|
|
|
|||
|
|
def _init_db(self):
|
|||
|
|
"""初始化数据库表"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute('''
|
|||
|
|
CREATE TABLE IF NOT EXISTS tasks (
|
|||
|
|
task_id TEXT PRIMARY KEY,
|
|||
|
|
file_name TEXT NOT NULL,
|
|||
|
|
file_path TEXT,
|
|||
|
|
status TEXT DEFAULT 'pending',
|
|||
|
|
priority INTEGER DEFAULT 0,
|
|||
|
|
backend TEXT DEFAULT 'pipeline',
|
|||
|
|
options TEXT,
|
|||
|
|
result_path TEXT,
|
|||
|
|
error_message TEXT,
|
|||
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|||
|
|
started_at TIMESTAMP,
|
|||
|
|
completed_at TIMESTAMP,
|
|||
|
|
worker_id TEXT,
|
|||
|
|
retry_count INTEGER DEFAULT 0
|
|||
|
|
)
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
# 创建索引加速查询
|
|||
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
|||
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_priority ON tasks(priority DESC)')
|
|||
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created_at ON tasks(created_at)')
|
|||
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_worker_id ON tasks(worker_id)')
|
|||
|
|
|
|||
|
|
def create_task(self, file_name: str, file_path: str,
|
|||
|
|
backend: str = 'pipeline', options: dict = None,
|
|||
|
|
priority: int = 0) -> str:
|
|||
|
|
"""
|
|||
|
|
创建新任务
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
file_name: 文件名
|
|||
|
|
file_path: 文件路径
|
|||
|
|
backend: 处理后端 (pipeline/vlm-transformers/vlm-vllm-engine)
|
|||
|
|
options: 处理选项 (dict)
|
|||
|
|
priority: 优先级,数字越大越优先
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
"""
|
|||
|
|
task_id = str(uuid.uuid4())
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute('''
|
|||
|
|
INSERT INTO tasks (task_id, file_name, file_path, backend, options, priority)
|
|||
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|||
|
|
''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
|
|||
|
|
return task_id
|
|||
|
|
|
|||
|
|
def get_next_task(self, worker_id: str, max_retries: int = 3) -> Optional[Dict]:
|
|||
|
|
"""
|
|||
|
|
获取下一个待处理任务(原子操作,防止并发冲突)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
worker_id: Worker ID
|
|||
|
|
max_retries: 当任务被其他 worker 抢走时的最大重试次数(默认3次)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
task: 任务字典,如果没有任务返回 None
|
|||
|
|
|
|||
|
|
并发安全说明:
|
|||
|
|
1. 使用 BEGIN IMMEDIATE 立即获取写锁
|
|||
|
|
2. UPDATE 时检查 status = 'pending' 防止重复拉取
|
|||
|
|
3. 检查 rowcount 确保更新成功
|
|||
|
|
4. 如果任务被抢走,立即重试而不是返回 None(避免不必要的等待)
|
|||
|
|
"""
|
|||
|
|
for attempt in range(max_retries):
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
# 使用事务确保原子性
|
|||
|
|
cursor.execute('BEGIN IMMEDIATE')
|
|||
|
|
|
|||
|
|
# 按优先级和创建时间获取任务
|
|||
|
|
cursor.execute('''
|
|||
|
|
SELECT * FROM tasks
|
|||
|
|
WHERE status = 'pending'
|
|||
|
|
ORDER BY priority DESC, created_at ASC
|
|||
|
|
LIMIT 1
|
|||
|
|
''')
|
|||
|
|
|
|||
|
|
task = cursor.fetchone()
|
|||
|
|
if task:
|
|||
|
|
# 立即标记为 processing,并确保状态仍是 pending
|
|||
|
|
cursor.execute('''
|
|||
|
|
UPDATE tasks
|
|||
|
|
SET status = 'processing',
|
|||
|
|
started_at = CURRENT_TIMESTAMP,
|
|||
|
|
worker_id = ?
|
|||
|
|
WHERE task_id = ? AND status = 'pending'
|
|||
|
|
''', (worker_id, task['task_id']))
|
|||
|
|
|
|||
|
|
# 检查是否更新成功(防止被其他 worker 抢走)
|
|||
|
|
if cursor.rowcount == 0:
|
|||
|
|
# 任务被其他进程抢走了,立即重试
|
|||
|
|
# 因为队列中可能还有其他待处理任务
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
return dict(task)
|
|||
|
|
else:
|
|||
|
|
# 队列中没有待处理任务,返回 None
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 重试次数用尽,仍未获取到任务(高并发场景)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def _build_update_clauses(self, status: str, result_path: str = None,
|
|||
|
|
error_message: str = None, worker_id: str = None,
|
|||
|
|
task_id: str = None):
|
|||
|
|
"""
|
|||
|
|
构建 UPDATE 和 WHERE 子句的辅助方法
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
status: 新状态
|
|||
|
|
result_path: 结果路径(可选)
|
|||
|
|
error_message: 错误信息(可选)
|
|||
|
|
worker_id: Worker ID(可选)
|
|||
|
|
task_id: 任务ID(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tuple: (update_clauses, update_params, where_clauses, where_params)
|
|||
|
|
"""
|
|||
|
|
update_clauses = ['status = ?']
|
|||
|
|
update_params = [status]
|
|||
|
|
where_clauses = []
|
|||
|
|
where_params = []
|
|||
|
|
|
|||
|
|
# 添加 task_id 条件(如果提供)
|
|||
|
|
if task_id:
|
|||
|
|
where_clauses.append('task_id = ?')
|
|||
|
|
where_params.append(task_id)
|
|||
|
|
|
|||
|
|
# 处理 completed 状态
|
|||
|
|
if status == 'completed':
|
|||
|
|
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
|
|||
|
|
if result_path:
|
|||
|
|
update_clauses.append('result_path = ?')
|
|||
|
|
update_params.append(result_path)
|
|||
|
|
# 只更新正在处理的任务
|
|||
|
|
where_clauses.append("status = 'processing'")
|
|||
|
|
if worker_id:
|
|||
|
|
where_clauses.append('worker_id = ?')
|
|||
|
|
where_params.append(worker_id)
|
|||
|
|
|
|||
|
|
# 处理 failed 状态
|
|||
|
|
elif status == 'failed':
|
|||
|
|
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
|
|||
|
|
if error_message:
|
|||
|
|
update_clauses.append('error_message = ?')
|
|||
|
|
update_params.append(error_message)
|
|||
|
|
# 只更新正在处理的任务
|
|||
|
|
where_clauses.append("status = 'processing'")
|
|||
|
|
if worker_id:
|
|||
|
|
where_clauses.append('worker_id = ?')
|
|||
|
|
where_params.append(worker_id)
|
|||
|
|
|
|||
|
|
return update_clauses, update_params, where_clauses, where_params
|
|||
|
|
|
|||
|
|
def update_task_status(self, task_id: str, status: str,
|
|||
|
|
result_path: str = None, error_message: str = None,
|
|||
|
|
worker_id: str = None):
|
|||
|
|
"""
|
|||
|
|
更新任务状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
status: 新状态 (pending/processing/completed/failed/cancelled)
|
|||
|
|
result_path: 结果路径(可选)
|
|||
|
|
error_message: 错误信息(可选)
|
|||
|
|
worker_id: Worker ID(可选,用于并发检查)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
bool: 更新是否成功
|
|||
|
|
|
|||
|
|
并发安全说明:
|
|||
|
|
1. 更新为 completed/failed 时会检查状态是 processing
|
|||
|
|
2. 如果提供 worker_id,会检查任务是否属于该 worker
|
|||
|
|
3. 返回 False 表示任务被其他进程修改了
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
# 使用辅助方法构建 UPDATE 和 WHERE 子句
|
|||
|
|
update_clauses, update_params, where_clauses, where_params = \
|
|||
|
|
self._build_update_clauses(status, result_path, error_message, worker_id, task_id)
|
|||
|
|
|
|||
|
|
# 合并参数:先 UPDATE 部分,再 WHERE 部分
|
|||
|
|
all_params = update_params + where_params
|
|||
|
|
|
|||
|
|
sql = f'''
|
|||
|
|
UPDATE tasks
|
|||
|
|
SET {', '.join(update_clauses)}
|
|||
|
|
WHERE {' AND '.join(where_clauses)}
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
cursor.execute(sql, all_params)
|
|||
|
|
|
|||
|
|
# 检查更新是否成功
|
|||
|
|
success = cursor.rowcount > 0
|
|||
|
|
|
|||
|
|
# 调试日志(仅在失败时)
|
|||
|
|
if not success and status in ['completed', 'failed']:
|
|||
|
|
from loguru import logger
|
|||
|
|
logger.debug(
|
|||
|
|
f"Status update failed: task_id={task_id}, status={status}, "
|
|||
|
|
f"worker_id={worker_id}, SQL: {sql}, params: {all_params}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return success
|
|||
|
|
|
|||
|
|
def get_task(self, task_id: str) -> Optional[Dict]:
|
|||
|
|
"""
|
|||
|
|
查询任务详情
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_id: 任务ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
task: 任务字典,如果不存在返回 None
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
|
|||
|
|
task = cursor.fetchone()
|
|||
|
|
return dict(task) if task else None
|
|||
|
|
|
|||
|
|
def get_queue_stats(self) -> Dict[str, int]:
|
|||
|
|
"""
|
|||
|
|
获取队列统计信息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
stats: 各状态的任务数量
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute('''
|
|||
|
|
SELECT status, COUNT(*) as count
|
|||
|
|
FROM tasks
|
|||
|
|
GROUP BY status
|
|||
|
|
''')
|
|||
|
|
stats = {row['status']: row['count'] for row in cursor.fetchall()}
|
|||
|
|
return stats
|
|||
|
|
|
|||
|
|
def get_tasks_by_status(self, status: str, limit: int = 100) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
根据状态获取任务列表
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
status: 任务状态
|
|||
|
|
limit: 返回数量限制
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tasks: 任务列表
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute('''
|
|||
|
|
SELECT * FROM tasks
|
|||
|
|
WHERE status = ?
|
|||
|
|
ORDER BY created_at DESC
|
|||
|
|
LIMIT ?
|
|||
|
|
''', (status, limit))
|
|||
|
|
return [dict(row) for row in cursor.fetchall()]
|
|||
|
|
|
|||
|
|
def cleanup_old_task_files(self, days: int = 7):
|
|||
|
|
"""
|
|||
|
|
清理旧任务的结果文件(保留数据库记录)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
days: 清理多少天前的任务文件
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
int: 删除的文件目录数
|
|||
|
|
|
|||
|
|
注意:
|
|||
|
|
- 只删除结果文件,保留数据库记录
|
|||
|
|
- 数据库中的 result_path 字段会被清空
|
|||
|
|
- 用户仍可查询任务状态和历史记录
|
|||
|
|
"""
|
|||
|
|
from pathlib import Path
|
|||
|
|
import shutil
|
|||
|
|
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
# 查询要清理文件的任务
|
|||
|
|
cursor.execute('''
|
|||
|
|
SELECT task_id, result_path FROM tasks
|
|||
|
|
WHERE completed_at < datetime('now', '-' || ? || ' days')
|
|||
|
|
AND status IN ('completed', 'failed')
|
|||
|
|
AND result_path IS NOT NULL
|
|||
|
|
''', (days,))
|
|||
|
|
|
|||
|
|
old_tasks = cursor.fetchall()
|
|||
|
|
file_count = 0
|
|||
|
|
|
|||
|
|
# 删除结果文件
|
|||
|
|
for task in old_tasks:
|
|||
|
|
if task['result_path']:
|
|||
|
|
result_path = Path(task['result_path'])
|
|||
|
|
if result_path.exists() and result_path.is_dir():
|
|||
|
|
try:
|
|||
|
|
shutil.rmtree(result_path)
|
|||
|
|
file_count += 1
|
|||
|
|
|
|||
|
|
# 清空数据库中的 result_path,表示文件已被清理
|
|||
|
|
cursor.execute('''
|
|||
|
|
UPDATE tasks
|
|||
|
|
SET result_path = NULL
|
|||
|
|
WHERE task_id = ?
|
|||
|
|
''', (task['task_id'],))
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
from loguru import logger
|
|||
|
|
logger.warning(f"Failed to delete result files for task {task['task_id']}: {e}")
|
|||
|
|
|
|||
|
|
return file_count
|
|||
|
|
|
|||
|
|
def cleanup_old_task_records(self, days: int = 30):
|
|||
|
|
"""
|
|||
|
|
清理极旧的任务记录(可选功能)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
days: 删除多少天前的任务记录
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
int: 删除的记录数
|
|||
|
|
|
|||
|
|
注意:
|
|||
|
|
- 这个方法会永久删除数据库记录
|
|||
|
|
- 建议设置较长的保留期(如30-90天)
|
|||
|
|
- 一般情况下不需要调用此方法
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute('''
|
|||
|
|
DELETE FROM tasks
|
|||
|
|
WHERE completed_at < datetime('now', '-' || ? || ' days')
|
|||
|
|
AND status IN ('completed', 'failed')
|
|||
|
|
''', (days,))
|
|||
|
|
|
|||
|
|
deleted_count = cursor.rowcount
|
|||
|
|
return deleted_count
|
|||
|
|
|
|||
|
|
def reset_stale_tasks(self, timeout_minutes: int = 60):
|
|||
|
|
"""
|
|||
|
|
重置超时的 processing 任务为 pending
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
timeout_minutes: 超时时间(分钟)
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.execute('''
|
|||
|
|
UPDATE tasks
|
|||
|
|
SET status = 'pending',
|
|||
|
|
worker_id = NULL,
|
|||
|
|
retry_count = retry_count + 1
|
|||
|
|
WHERE status = 'processing'
|
|||
|
|
AND started_at < datetime('now', '-' || ? || ' minutes')
|
|||
|
|
''', (timeout_minutes,))
|
|||
|
|
reset_count = cursor.rowcount
|
|||
|
|
return reset_count
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
# 测试代码
|
|||
|
|
db = TaskDB('test_tianshu.db')
|
|||
|
|
|
|||
|
|
# 创建测试任务
|
|||
|
|
task_id = db.create_task(
|
|||
|
|
file_name='test.pdf',
|
|||
|
|
file_path='/tmp/test.pdf',
|
|||
|
|
backend='pipeline',
|
|||
|
|
options={'lang': 'ch', 'formula_enable': True},
|
|||
|
|
priority=1
|
|||
|
|
)
|
|||
|
|
print(f"Created task: {task_id}")
|
|||
|
|
|
|||
|
|
# 查询任务
|
|||
|
|
task = db.get_task(task_id)
|
|||
|
|
print(f"Task details: {task}")
|
|||
|
|
|
|||
|
|
# 获取统计
|
|||
|
|
stats = db.get_queue_stats()
|
|||
|
|
print(f"Queue stats: {stats}")
|
|||
|
|
|
|||
|
|
# 清理测试数据库
|
|||
|
|
Path('test_tianshu.db').unlink(missing_ok=True)
|
|||
|
|
print("Test completed!")
|
|||
|
|
|