373 lines
15 KiB
Python
373 lines
15 KiB
Python
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, get_minbox_if_overlap_by_ratio
|
||
from mineru.utils.enum_class import CategoryId, ContentType
|
||
from mineru.utils.magic_model_utils import tie_up_category_by_distance_v3, reduct_overlap
|
||
|
||
|
||
class MagicModel:
|
||
"""每个函数没有得到元素的时候返回空list."""
|
||
def __init__(self, page_model_info: dict, scale: float):
|
||
self.__page_model_info = page_model_info
|
||
self.__scale = scale
|
||
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
|
||
self.__fix_axis()
|
||
"""删除置信度特别低的模型数据(<0.05),提高质量"""
|
||
self.__fix_by_remove_low_confidence()
|
||
"""删除高iou(>0.9)数据中置信度较低的那个"""
|
||
self.__fix_by_remove_high_iou_and_low_confidence()
|
||
"""将部分tbale_footnote修正为image_footnote"""
|
||
self.__fix_footnote()
|
||
"""处理重叠的image_body和table_body"""
|
||
self.__fix_by_remove_overlap_image_table_body()
|
||
|
||
def __fix_by_remove_overlap_image_table_body(self):
|
||
need_remove_list = []
|
||
layout_dets = self.__page_model_info['layout_dets']
|
||
image_blocks = list(filter(
|
||
lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
|
||
))
|
||
table_blocks = list(filter(
|
||
lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
|
||
))
|
||
|
||
def add_need_remove_block(blocks):
|
||
for i in range(len(blocks)):
|
||
for j in range(i + 1, len(blocks)):
|
||
block1 = blocks[i]
|
||
block2 = blocks[j]
|
||
overlap_box = get_minbox_if_overlap_by_ratio(
|
||
block1['bbox'], block2['bbox'], 0.8
|
||
)
|
||
if overlap_box is not None:
|
||
# 判断哪个区块的面积更小,移除较小的区块
|
||
area1 = (block1['bbox'][2] - block1['bbox'][0]) * (block1['bbox'][3] - block1['bbox'][1])
|
||
area2 = (block2['bbox'][2] - block2['bbox'][0]) * (block2['bbox'][3] - block2['bbox'][1])
|
||
|
||
if area1 <= area2:
|
||
block_to_remove = block1
|
||
large_block = block2
|
||
else:
|
||
block_to_remove = block2
|
||
large_block = block1
|
||
|
||
if block_to_remove not in need_remove_list:
|
||
# 扩展大区块的边界框
|
||
x1, y1, x2, y2 = large_block['bbox']
|
||
sx1, sy1, sx2, sy2 = block_to_remove['bbox']
|
||
x1 = min(x1, sx1)
|
||
y1 = min(y1, sy1)
|
||
x2 = max(x2, sx2)
|
||
y2 = max(y2, sy2)
|
||
large_block['bbox'] = [x1, y1, x2, y2]
|
||
need_remove_list.append(block_to_remove)
|
||
|
||
# 处理图像-图像重叠
|
||
add_need_remove_block(image_blocks)
|
||
# 处理表格-表格重叠
|
||
add_need_remove_block(table_blocks)
|
||
|
||
# 从布局中移除标记的区块
|
||
for need_remove in need_remove_list:
|
||
if need_remove in layout_dets:
|
||
layout_dets.remove(need_remove)
|
||
|
||
|
||
def __fix_axis(self):
|
||
need_remove_list = []
|
||
layout_dets = self.__page_model_info['layout_dets']
|
||
for layout_det in layout_dets:
|
||
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
|
||
bbox = [
|
||
int(x0 / self.__scale),
|
||
int(y0 / self.__scale),
|
||
int(x1 / self.__scale),
|
||
int(y1 / self.__scale),
|
||
]
|
||
layout_det['bbox'] = bbox
|
||
# 删除高度或者宽度小于等于0的spans
|
||
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
|
||
need_remove_list.append(layout_det)
|
||
for need_remove in need_remove_list:
|
||
layout_dets.remove(need_remove)
|
||
|
||
def __fix_by_remove_low_confidence(self):
|
||
need_remove_list = []
|
||
layout_dets = self.__page_model_info['layout_dets']
|
||
for layout_det in layout_dets:
|
||
if layout_det['score'] <= 0.05:
|
||
need_remove_list.append(layout_det)
|
||
else:
|
||
continue
|
||
for need_remove in need_remove_list:
|
||
layout_dets.remove(need_remove)
|
||
|
||
def __fix_by_remove_high_iou_and_low_confidence(self):
|
||
need_remove_list = []
|
||
layout_dets = list(filter(
|
||
lambda x: x['category_id'] in [
|
||
CategoryId.Title,
|
||
CategoryId.Text,
|
||
CategoryId.ImageBody,
|
||
CategoryId.ImageCaption,
|
||
CategoryId.TableBody,
|
||
CategoryId.TableCaption,
|
||
CategoryId.TableFootnote,
|
||
CategoryId.InterlineEquation_Layout,
|
||
CategoryId.InterlineEquationNumber_Layout,
|
||
], self.__page_model_info['layout_dets']
|
||
)
|
||
)
|
||
for i in range(len(layout_dets)):
|
||
for j in range(i + 1, len(layout_dets)):
|
||
layout_det1 = layout_dets[i]
|
||
layout_det2 = layout_dets[j]
|
||
|
||
if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
|
||
|
||
layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
|
||
|
||
if layout_det_need_remove not in need_remove_list:
|
||
need_remove_list.append(layout_det_need_remove)
|
||
|
||
for need_remove in need_remove_list:
|
||
self.__page_model_info['layout_dets'].remove(need_remove)
|
||
|
||
def __fix_footnote(self):
|
||
footnotes = []
|
||
figures = []
|
||
tables = []
|
||
|
||
for obj in self.__page_model_info['layout_dets']:
|
||
if obj['category_id'] == CategoryId.TableFootnote:
|
||
footnotes.append(obj)
|
||
elif obj['category_id'] == CategoryId.ImageBody:
|
||
figures.append(obj)
|
||
elif obj['category_id'] == CategoryId.TableBody:
|
||
tables.append(obj)
|
||
if len(footnotes) * len(figures) == 0:
|
||
continue
|
||
dis_figure_footnote = {}
|
||
dis_table_footnote = {}
|
||
|
||
for i in range(len(footnotes)):
|
||
for j in range(len(figures)):
|
||
pos_flag_count = sum(
|
||
list(
|
||
map(
|
||
lambda x: 1 if x else 0,
|
||
bbox_relative_pos(
|
||
footnotes[i]['bbox'], figures[j]['bbox']
|
||
),
|
||
)
|
||
)
|
||
)
|
||
if pos_flag_count > 1:
|
||
continue
|
||
dis_figure_footnote[i] = min(
|
||
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
|
||
dis_figure_footnote.get(i, float('inf')),
|
||
)
|
||
for i in range(len(footnotes)):
|
||
for j in range(len(tables)):
|
||
pos_flag_count = sum(
|
||
list(
|
||
map(
|
||
lambda x: 1 if x else 0,
|
||
bbox_relative_pos(
|
||
footnotes[i]['bbox'], tables[j]['bbox']
|
||
),
|
||
)
|
||
)
|
||
)
|
||
if pos_flag_count > 1:
|
||
continue
|
||
|
||
dis_table_footnote[i] = min(
|
||
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
|
||
dis_table_footnote.get(i, float('inf')),
|
||
)
|
||
for i in range(len(footnotes)):
|
||
if i not in dis_figure_footnote:
|
||
continue
|
||
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
|
||
footnotes[i]['category_id'] = CategoryId.ImageFootnote
|
||
|
||
def _bbox_distance(self, bbox1, bbox2):
|
||
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
|
||
flags = [left, right, bottom, top]
|
||
count = sum([1 if v else 0 for v in flags])
|
||
if count > 1:
|
||
return float('inf')
|
||
if left or right:
|
||
l1 = bbox1[3] - bbox1[1]
|
||
l2 = bbox2[3] - bbox2[1]
|
||
else:
|
||
l1 = bbox1[2] - bbox1[0]
|
||
l2 = bbox2[2] - bbox2[0]
|
||
|
||
if l2 > l1 and (l2 - l1) / l1 > 0.3:
|
||
return float('inf')
|
||
|
||
return bbox_distance(bbox1, bbox2)
|
||
|
||
def __tie_up_category_by_distance_v3(self, subject_category_id, object_category_id):
|
||
# 定义获取主体和客体对象的函数
|
||
def get_subjects():
|
||
return reduct_overlap(
|
||
list(
|
||
map(
|
||
lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
||
filter(
|
||
lambda x: x['category_id'] == subject_category_id,
|
||
self.__page_model_info['layout_dets'],
|
||
),
|
||
)
|
||
)
|
||
)
|
||
|
||
def get_objects():
|
||
return reduct_overlap(
|
||
list(
|
||
map(
|
||
lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
||
filter(
|
||
lambda x: x['category_id'] == object_category_id,
|
||
self.__page_model_info['layout_dets'],
|
||
),
|
||
)
|
||
)
|
||
)
|
||
|
||
# 调用通用方法
|
||
return tie_up_category_by_distance_v3(
|
||
get_subjects,
|
||
get_objects
|
||
)
|
||
|
||
def get_imgs(self):
|
||
with_captions = self.__tie_up_category_by_distance_v3(
|
||
CategoryId.ImageBody, CategoryId.ImageCaption
|
||
)
|
||
with_footnotes = self.__tie_up_category_by_distance_v3(
|
||
CategoryId.ImageBody, CategoryId.ImageFootnote
|
||
)
|
||
ret = []
|
||
for v in with_captions:
|
||
record = {
|
||
'image_body': v['sub_bbox'],
|
||
'image_caption_list': v['obj_bboxes'],
|
||
}
|
||
filter_idx = v['sub_idx']
|
||
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
|
||
record['image_footnote_list'] = d['obj_bboxes']
|
||
ret.append(record)
|
||
return ret
|
||
|
||
def get_tables(self) -> list:
|
||
with_captions = self.__tie_up_category_by_distance_v3(
|
||
CategoryId.TableBody, CategoryId.TableCaption
|
||
)
|
||
with_footnotes = self.__tie_up_category_by_distance_v3(
|
||
CategoryId.TableBody, CategoryId.TableFootnote
|
||
)
|
||
ret = []
|
||
for v in with_captions:
|
||
record = {
|
||
'table_body': v['sub_bbox'],
|
||
'table_caption_list': v['obj_bboxes'],
|
||
}
|
||
filter_idx = v['sub_idx']
|
||
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
|
||
record['table_footnote_list'] = d['obj_bboxes']
|
||
ret.append(record)
|
||
return ret
|
||
|
||
def get_equations(self) -> tuple[list, list, list]: # 有坐标,也有字
|
||
inline_equations = self.__get_blocks_by_type(
|
||
CategoryId.InlineEquation, ['latex']
|
||
)
|
||
interline_equations = self.__get_blocks_by_type(
|
||
CategoryId.InterlineEquation_YOLO, ['latex']
|
||
)
|
||
interline_equations_blocks = self.__get_blocks_by_type(
|
||
CategoryId.InterlineEquation_Layout
|
||
)
|
||
return inline_equations, interline_equations, interline_equations_blocks
|
||
|
||
def get_discarded(self) -> list: # 自研模型,只有坐标
|
||
blocks = self.__get_blocks_by_type(CategoryId.Abandon)
|
||
return blocks
|
||
|
||
def get_text_blocks(self) -> list: # 自研模型搞的,只有坐标,没有字
|
||
blocks = self.__get_blocks_by_type(CategoryId.Text)
|
||
return blocks
|
||
|
||
def get_title_blocks(self) -> list: # 自研模型,只有坐标,没字
|
||
blocks = self.__get_blocks_by_type(CategoryId.Title)
|
||
return blocks
|
||
|
||
def get_all_spans(self) -> list:
|
||
|
||
def remove_duplicate_spans(spans):
|
||
new_spans = []
|
||
for span in spans:
|
||
if not any(span == existing_span for existing_span in new_spans):
|
||
new_spans.append(span)
|
||
return new_spans
|
||
|
||
all_spans = []
|
||
layout_dets = self.__page_model_info['layout_dets']
|
||
allow_category_id_list = [
|
||
CategoryId.ImageBody,
|
||
CategoryId.TableBody,
|
||
CategoryId.InlineEquation,
|
||
CategoryId.InterlineEquation_YOLO,
|
||
CategoryId.OcrText,
|
||
]
|
||
"""当成span拼接的"""
|
||
for layout_det in layout_dets:
|
||
category_id = layout_det['category_id']
|
||
if category_id in allow_category_id_list:
|
||
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
|
||
if category_id == CategoryId.ImageBody:
|
||
span['type'] = ContentType.IMAGE
|
||
elif category_id == CategoryId.TableBody:
|
||
# 获取table模型结果
|
||
latex = layout_det.get('latex', None)
|
||
html = layout_det.get('html', None)
|
||
if latex:
|
||
span['latex'] = latex
|
||
elif html:
|
||
span['html'] = html
|
||
span['type'] = ContentType.TABLE
|
||
elif category_id == CategoryId.InlineEquation:
|
||
span['content'] = layout_det['latex']
|
||
span['type'] = ContentType.INLINE_EQUATION
|
||
elif category_id == CategoryId.InterlineEquation_YOLO:
|
||
span['content'] = layout_det['latex']
|
||
span['type'] = ContentType.INTERLINE_EQUATION
|
||
elif category_id == CategoryId.OcrText:
|
||
span['content'] = layout_det['text']
|
||
span['type'] = ContentType.TEXT
|
||
all_spans.append(span)
|
||
return remove_duplicate_spans(all_spans)
|
||
|
||
def __get_blocks_by_type(
|
||
self, category_type: int, extra_col=None
|
||
) -> list:
|
||
if extra_col is None:
|
||
extra_col = []
|
||
blocks = []
|
||
layout_dets = self.__page_model_info.get('layout_dets', [])
|
||
for item in layout_dets:
|
||
category_id = item.get('category_id', -1)
|
||
bbox = item.get('bbox', None)
|
||
|
||
if category_id == category_type:
|
||
block = {
|
||
'bbox': bbox,
|
||
'score': item.get('score'),
|
||
}
|
||
for col in extra_col:
|
||
block[col] = item.get(col, None)
|
||
blocks.append(block)
|
||
return blocks |