UnisMindMap/projects/mineru_tianshu/task_db.py

437 lines
16 KiB
Python
Raw Normal View History

"""
MinerU Tianshu - SQLite Task Database Manager
天枢任务数据库管理器
负责任务的持久化存储状态管理和原子性操作
"""
import sqlite3
import json
import uuid
from contextlib import contextmanager
from typing import Optional, List, Dict
from pathlib import Path
class TaskDB:
"""任务数据库管理类"""
def __init__(self, db_path='mineru_tianshu.db'):
self.db_path = db_path
self._init_db()
def _get_conn(self):
"""获取数据库连接(每次创建新连接,避免 pickle 问题)
并发安全说明
- 使用 check_same_thread=False 是安全的因为
1. 每次调用都创建新连接不跨线程共享
2. 连接使用完立即关闭 get_cursor 上下文管理器中
3. 不使用连接池避免线程间共享同一连接
- timeout=30.0 防止死锁如果锁等待超过30秒会抛出异常
"""
conn = sqlite3.connect(
self.db_path,
check_same_thread=False,
timeout=30.0
)
conn.row_factory = sqlite3.Row
return conn
@contextmanager
def get_cursor(self):
"""上下文管理器,自动提交和错误处理"""
conn = self._get_conn()
cursor = conn.cursor()
try:
yield cursor
conn.commit()
except Exception as e:
conn.rollback()
raise e
finally:
conn.close() # 关闭连接
def _init_db(self):
"""初始化数据库表"""
with self.get_cursor() as cursor:
cursor.execute('''
CREATE TABLE IF NOT EXISTS tasks (
task_id TEXT PRIMARY KEY,
file_name TEXT NOT NULL,
file_path TEXT,
status TEXT DEFAULT 'pending',
priority INTEGER DEFAULT 0,
backend TEXT DEFAULT 'pipeline',
options TEXT,
result_path TEXT,
error_message TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
worker_id TEXT,
retry_count INTEGER DEFAULT 0
)
''')
# 创建索引加速查询
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_priority ON tasks(priority DESC)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created_at ON tasks(created_at)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_worker_id ON tasks(worker_id)')
def create_task(self, file_name: str, file_path: str,
backend: str = 'pipeline', options: dict = None,
priority: int = 0) -> str:
"""
创建新任务
Args:
file_name: 文件名
file_path: 文件路径
backend: 处理后端 (pipeline/vlm-transformers/vlm-vllm-engine)
options: 处理选项 (dict)
priority: 优先级数字越大越优先
Returns:
task_id: 任务ID
"""
task_id = str(uuid.uuid4())
with self.get_cursor() as cursor:
cursor.execute('''
INSERT INTO tasks (task_id, file_name, file_path, backend, options, priority)
VALUES (?, ?, ?, ?, ?, ?)
''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
return task_id
def get_next_task(self, worker_id: str, max_retries: int = 3) -> Optional[Dict]:
"""
获取下一个待处理任务原子操作防止并发冲突
Args:
worker_id: Worker ID
max_retries: 当任务被其他 worker 抢走时的最大重试次数默认3次
Returns:
task: 任务字典如果没有任务返回 None
并发安全说明
1. 使用 BEGIN IMMEDIATE 立即获取写锁
2. UPDATE 时检查 status = 'pending' 防止重复拉取
3. 检查 rowcount 确保更新成功
4. 如果任务被抢走立即重试而不是返回 None避免不必要的等待
"""
for attempt in range(max_retries):
with self.get_cursor() as cursor:
# 使用事务确保原子性
cursor.execute('BEGIN IMMEDIATE')
# 按优先级和创建时间获取任务
cursor.execute('''
SELECT * FROM tasks
WHERE status = 'pending'
ORDER BY priority DESC, created_at ASC
LIMIT 1
''')
task = cursor.fetchone()
if task:
# 立即标记为 processing并确保状态仍是 pending
cursor.execute('''
UPDATE tasks
SET status = 'processing',
started_at = CURRENT_TIMESTAMP,
worker_id = ?
WHERE task_id = ? AND status = 'pending'
''', (worker_id, task['task_id']))
# 检查是否更新成功(防止被其他 worker 抢走)
if cursor.rowcount == 0:
# 任务被其他进程抢走了,立即重试
# 因为队列中可能还有其他待处理任务
continue
return dict(task)
else:
# 队列中没有待处理任务,返回 None
return None
# 重试次数用尽,仍未获取到任务(高并发场景)
return None
def _build_update_clauses(self, status: str, result_path: str = None,
error_message: str = None, worker_id: str = None,
task_id: str = None):
"""
构建 UPDATE WHERE 子句的辅助方法
Args:
status: 新状态
result_path: 结果路径可选
error_message: 错误信息可选
worker_id: Worker ID可选
task_id: 任务ID可选
Returns:
tuple: (update_clauses, update_params, where_clauses, where_params)
"""
update_clauses = ['status = ?']
update_params = [status]
where_clauses = []
where_params = []
# 添加 task_id 条件(如果提供)
if task_id:
where_clauses.append('task_id = ?')
where_params.append(task_id)
# 处理 completed 状态
if status == 'completed':
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
if result_path:
update_clauses.append('result_path = ?')
update_params.append(result_path)
# 只更新正在处理的任务
where_clauses.append("status = 'processing'")
if worker_id:
where_clauses.append('worker_id = ?')
where_params.append(worker_id)
# 处理 failed 状态
elif status == 'failed':
update_clauses.append('completed_at = CURRENT_TIMESTAMP')
if error_message:
update_clauses.append('error_message = ?')
update_params.append(error_message)
# 只更新正在处理的任务
where_clauses.append("status = 'processing'")
if worker_id:
where_clauses.append('worker_id = ?')
where_params.append(worker_id)
return update_clauses, update_params, where_clauses, where_params
def update_task_status(self, task_id: str, status: str,
result_path: str = None, error_message: str = None,
worker_id: str = None):
"""
更新任务状态
Args:
task_id: 任务ID
status: 新状态 (pending/processing/completed/failed/cancelled)
result_path: 结果路径可选
error_message: 错误信息可选
worker_id: Worker ID可选用于并发检查
Returns:
bool: 更新是否成功
并发安全说明
1. 更新为 completed/failed 时会检查状态是 processing
2. 如果提供 worker_id会检查任务是否属于该 worker
3. 返回 False 表示任务被其他进程修改了
"""
with self.get_cursor() as cursor:
# 使用辅助方法构建 UPDATE 和 WHERE 子句
update_clauses, update_params, where_clauses, where_params = \
self._build_update_clauses(status, result_path, error_message, worker_id, task_id)
# 合并参数:先 UPDATE 部分,再 WHERE 部分
all_params = update_params + where_params
sql = f'''
UPDATE tasks
SET {', '.join(update_clauses)}
WHERE {' AND '.join(where_clauses)}
'''
cursor.execute(sql, all_params)
# 检查更新是否成功
success = cursor.rowcount > 0
# 调试日志(仅在失败时)
if not success and status in ['completed', 'failed']:
from loguru import logger
logger.debug(
f"Status update failed: task_id={task_id}, status={status}, "
f"worker_id={worker_id}, SQL: {sql}, params: {all_params}"
)
return success
def get_task(self, task_id: str) -> Optional[Dict]:
"""
查询任务详情
Args:
task_id: 任务ID
Returns:
task: 任务字典如果不存在返回 None
"""
with self.get_cursor() as cursor:
cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
task = cursor.fetchone()
return dict(task) if task else None
def get_queue_stats(self) -> Dict[str, int]:
"""
获取队列统计信息
Returns:
stats: 各状态的任务数量
"""
with self.get_cursor() as cursor:
cursor.execute('''
SELECT status, COUNT(*) as count
FROM tasks
GROUP BY status
''')
stats = {row['status']: row['count'] for row in cursor.fetchall()}
return stats
def get_tasks_by_status(self, status: str, limit: int = 100) -> List[Dict]:
"""
根据状态获取任务列表
Args:
status: 任务状态
limit: 返回数量限制
Returns:
tasks: 任务列表
"""
with self.get_cursor() as cursor:
cursor.execute('''
SELECT * FROM tasks
WHERE status = ?
ORDER BY created_at DESC
LIMIT ?
''', (status, limit))
return [dict(row) for row in cursor.fetchall()]
def cleanup_old_task_files(self, days: int = 7):
"""
清理旧任务的结果文件保留数据库记录
Args:
days: 清理多少天前的任务文件
Returns:
int: 删除的文件目录数
注意
- 只删除结果文件保留数据库记录
- 数据库中的 result_path 字段会被清空
- 用户仍可查询任务状态和历史记录
"""
from pathlib import Path
import shutil
with self.get_cursor() as cursor:
# 查询要清理文件的任务
cursor.execute('''
SELECT task_id, result_path FROM tasks
WHERE completed_at < datetime('now', '-' || ? || ' days')
AND status IN ('completed', 'failed')
AND result_path IS NOT NULL
''', (days,))
old_tasks = cursor.fetchall()
file_count = 0
# 删除结果文件
for task in old_tasks:
if task['result_path']:
result_path = Path(task['result_path'])
if result_path.exists() and result_path.is_dir():
try:
shutil.rmtree(result_path)
file_count += 1
# 清空数据库中的 result_path表示文件已被清理
cursor.execute('''
UPDATE tasks
SET result_path = NULL
WHERE task_id = ?
''', (task['task_id'],))
except Exception as e:
from loguru import logger
logger.warning(f"Failed to delete result files for task {task['task_id']}: {e}")
return file_count
def cleanup_old_task_records(self, days: int = 30):
"""
清理极旧的任务记录可选功能
Args:
days: 删除多少天前的任务记录
Returns:
int: 删除的记录数
注意
- 这个方法会永久删除数据库记录
- 建议设置较长的保留期如30-90
- 一般情况下不需要调用此方法
"""
with self.get_cursor() as cursor:
cursor.execute('''
DELETE FROM tasks
WHERE completed_at < datetime('now', '-' || ? || ' days')
AND status IN ('completed', 'failed')
''', (days,))
deleted_count = cursor.rowcount
return deleted_count
def reset_stale_tasks(self, timeout_minutes: int = 60):
"""
重置超时的 processing 任务为 pending
Args:
timeout_minutes: 超时时间分钟
"""
with self.get_cursor() as cursor:
cursor.execute('''
UPDATE tasks
SET status = 'pending',
worker_id = NULL,
retry_count = retry_count + 1
WHERE status = 'processing'
AND started_at < datetime('now', '-' || ? || ' minutes')
''', (timeout_minutes,))
reset_count = cursor.rowcount
return reset_count
if __name__ == '__main__':
# 测试代码
db = TaskDB('test_tianshu.db')
# 创建测试任务
task_id = db.create_task(
file_name='test.pdf',
file_path='/tmp/test.pdf',
backend='pipeline',
options={'lang': 'ch', 'formula_enable': True},
priority=1
)
print(f"Created task: {task_id}")
# 查询任务
task = db.get_task(task_id)
print(f"Task details: {task}")
# 获取统计
stats = db.get_queue_stats()
print(f"Queue stats: {stats}")
# 清理测试数据库
Path('test_tianshu.db').unlink(missing_ok=True)
print("Test completed!")