UnisMindMap/mineru/backend/pipeline/pipeline_magic_model.py

373 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, get_minbox_if_overlap_by_ratio
from mineru.utils.enum_class import CategoryId, ContentType
from mineru.utils.magic_model_utils import tie_up_category_by_distance_v3, reduct_overlap
class MagicModel:
"""每个函数没有得到元素的时候返回空list."""
def __init__(self, page_model_info: dict, scale: float):
self.__page_model_info = page_model_info
self.__scale = scale
"""为所有模型数据添加bbox信息(缩放poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
"""将部分tbale_footnote修正为image_footnote"""
self.__fix_footnote()
"""处理重叠的image_body和table_body"""
self.__fix_by_remove_overlap_image_table_body()
def __fix_by_remove_overlap_image_table_body(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
image_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
))
table_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
))
def add_need_remove_block(blocks):
for i in range(len(blocks)):
for j in range(i + 1, len(blocks)):
block1 = blocks[i]
block2 = blocks[j]
overlap_box = get_minbox_if_overlap_by_ratio(
block1['bbox'], block2['bbox'], 0.8
)
if overlap_box is not None:
# 判断哪个区块的面积更小,移除较小的区块
area1 = (block1['bbox'][2] - block1['bbox'][0]) * (block1['bbox'][3] - block1['bbox'][1])
area2 = (block2['bbox'][2] - block2['bbox'][0]) * (block2['bbox'][3] - block2['bbox'][1])
if area1 <= area2:
block_to_remove = block1
large_block = block2
else:
block_to_remove = block2
large_block = block1
if block_to_remove not in need_remove_list:
# 扩展大区块的边界框
x1, y1, x2, y2 = large_block['bbox']
sx1, sy1, sx2, sy2 = block_to_remove['bbox']
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block['bbox'] = [x1, y1, x2, y2]
need_remove_list.append(block_to_remove)
# 处理图像-图像重叠
add_need_remove_block(image_blocks)
# 处理表格-表格重叠
add_need_remove_block(table_blocks)
# 从布局中移除标记的区块
for need_remove in need_remove_list:
if need_remove in layout_dets:
layout_dets.remove(need_remove)
def __fix_axis(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det in layout_dets:
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [
int(x0 / self.__scale),
int(y0 / self.__scale),
int(x1 / self.__scale),
int(y1 / self.__scale),
]
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_low_confidence(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det in layout_dets:
if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_high_iou_and_low_confidence(self):
need_remove_list = []
layout_dets = list(filter(
lambda x: x['category_id'] in [
CategoryId.Title,
CategoryId.Text,
CategoryId.ImageBody,
CategoryId.ImageCaption,
CategoryId.TableBody,
CategoryId.TableCaption,
CategoryId.TableFootnote,
CategoryId.InterlineEquation_Layout,
CategoryId.InterlineEquationNumber_Layout,
], self.__page_model_info['layout_dets']
)
)
for i in range(len(layout_dets)):
for j in range(i + 1, len(layout_dets)):
layout_det1 = layout_dets[i]
layout_det2 = layout_dets[j]
if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
for need_remove in need_remove_list:
self.__page_model_info['layout_dets'].remove(need_remove)
def __fix_footnote(self):
footnotes = []
figures = []
tables = []
for obj in self.__page_model_info['layout_dets']:
if obj['category_id'] == CategoryId.TableFootnote:
footnotes.append(obj)
elif obj['category_id'] == CategoryId.ImageBody:
figures.append(obj)
elif obj['category_id'] == CategoryId.TableBody:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
count = sum([1 if v else 0 for v in flags])
if count > 1:
return float('inf')
if left or right:
l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1]
else:
l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0]
if l2 > l1 and (l2 - l1) / l1 > 0.3:
return float('inf')
return bbox_distance(bbox1, bbox2)
def __tie_up_category_by_distance_v3(self, subject_category_id, object_category_id):
# 定义获取主体和客体对象的函数
def get_subjects():
return reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__page_model_info['layout_dets'],
),
)
)
)
def get_objects():
return reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__page_model_info['layout_dets'],
),
)
)
)
# 调用通用方法
return tie_up_category_by_distance_v3(
get_subjects,
get_objects
)
def get_imgs(self):
with_captions = self.__tie_up_category_by_distance_v3(
CategoryId.ImageBody, CategoryId.ImageCaption
)
with_footnotes = self.__tie_up_category_by_distance_v3(
CategoryId.ImageBody, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
record = {
'image_body': v['sub_bbox'],
'image_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3(
CategoryId.TableBody, CategoryId.TableCaption
)
with_footnotes = self.__tie_up_category_by_distance_v3(
CategoryId.TableBody, CategoryId.TableFootnote
)
ret = []
for v in with_captions:
record = {
'table_body': v['sub_bbox'],
'table_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_equations(self) -> tuple[list, list, list]: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(
CategoryId.InlineEquation, ['latex']
)
interline_equations = self.__get_blocks_by_type(
CategoryId.InterlineEquation_YOLO, ['latex']
)
interline_equations_blocks = self.__get_blocks_by_type(
CategoryId.InterlineEquation_Layout
)
return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(CategoryId.Abandon)
return blocks
def get_text_blocks(self) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(CategoryId.Text)
return blocks
def get_title_blocks(self) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(CategoryId.Title)
return blocks
def get_all_spans(self) -> list:
def remove_duplicate_spans(spans):
new_spans = []
for span in spans:
if not any(span == existing_span for existing_span in new_spans):
new_spans.append(span)
return new_spans
all_spans = []
layout_dets = self.__page_model_info['layout_dets']
allow_category_id_list = [
CategoryId.ImageBody,
CategoryId.TableBody,
CategoryId.InlineEquation,
CategoryId.InterlineEquation_YOLO,
CategoryId.OcrText,
]
"""当成span拼接的"""
for layout_det in layout_dets:
category_id = layout_det['category_id']
if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == CategoryId.ImageBody:
span['type'] = ContentType.IMAGE
elif category_id == CategoryId.TableBody:
# 获取table模型结果
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
if latex:
span['latex'] = latex
elif html:
span['html'] = html
span['type'] = ContentType.TABLE
elif category_id == CategoryId.InlineEquation:
span['content'] = layout_det['latex']
span['type'] = ContentType.INLINE_EQUATION
elif category_id == CategoryId.InterlineEquation_YOLO:
span['content'] = layout_det['latex']
span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == CategoryId.OcrText:
span['content'] = layout_det['text']
span['type'] = ContentType.TEXT
all_spans.append(span)
return remove_duplicate_spans(all_spans)
def __get_blocks_by_type(
self, category_type: int, extra_col=None
) -> list:
if extra_col is None:
extra_col = []
blocks = []
layout_dets = self.__page_model_info.get('layout_dets', [])
for item in layout_dets:
category_id = item.get('category_id', -1)
bbox = item.get('bbox', None)
if category_id == category_type:
block = {
'bbox': bbox,
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks