UnisMindMap/mineru/model/ori_cls/paddle_ori_cls.py

283 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) Opendatalab. All rights reserved.
import os
from PIL import Image
from collections import defaultdict
from typing import List, Dict
from tqdm import tqdm
import cv2
import numpy as np
import onnxruntime
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
class PaddleOrientationClsModel:
def __init__(self, ocr_engine):
self.sess = onnxruntime.InferenceSession(
os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
)
self.ocr_engine = ocr_engine
self.less_length = 256
self.cw, self.ch = 224, 224
self.std = [0.229, 0.224, 0.225]
self.scale = 0.00392156862745098
self.mean = [0.485, 0.456, 0.406]
self.labels = ["0", "90", "180", "270"]
def preprocess(self, input_img):
# 放大图片使其最短边长为256
h, w = input_img.shape[:2]
scale = 256 / min(h, w)
h_resize = round(h * scale)
w_resize = round(w * scale)
img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
# 调整为224*224的正方形
h, w = img.shape[:2]
cw, ch = 224, 224
x1 = max(0, (w - cw) // 2)
y1 = max(0, (h - ch) // 2)
x2 = min(w, x1 + cw)
y2 = min(h, y1 + ch)
if w < cw or h < ch:
raise ValueError(
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
)
img = img[y1:y2, x1:x2, ...]
# 正则化
split_im = list(cv2.split(img))
std = [0.229, 0.224, 0.225]
scale = 0.00392156862745098
mean = [0.485, 0.456, 0.406]
alpha = [scale / std[i] for i in range(len(std))]
beta = [-mean[i] / std[i] for i in range(len(std))]
for c in range(img.shape[2]):
split_im[c] = split_im[c].astype(np.float32)
split_im[c] *= alpha[c]
split_im[c] += beta[c]
img = cv2.merge(split_im)
# 5. 转换为 CHW 格式
img = img.transpose((2, 0, 1))
imgs = [img]
x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
return x
def predict(self, input_img):
rotate_label = "0" # Default to 0 if no rotation detected or not portrait
if isinstance(input_img, Image.Image):
np_img = np.asarray(input_img)
elif isinstance(input_img, np.ndarray):
np_img = input_img
else:
raise ValueError("Input must be a pillow object or a numpy array.")
bgr_image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
# First check the overall image aspect ratio (height/width)
img_height, img_width = bgr_image.shape[:2]
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
img_is_portrait = img_aspect_ratio > 1.2
if img_is_portrait:
det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
# Check if table is rotated by analyzing text box aspect ratios
if det_res:
vertical_count = 0
is_rotated = False
for box_ocr_res in det_res:
p1, p2, p3, p4 = box_ocr_res
# Calculate width and height
width = p3[0] - p1[0]
height = p3[1] - p1[1]
aspect_ratio = width / height if height > 0 else 1.0
# Count vertical vs horizontal text boxes
if aspect_ratio < 0.8: # Taller than wide - vertical text
vertical_count += 1
# elif aspect_ratio > 1.2: # Wider than tall - horizontal text
# horizontal_count += 1
if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
is_rotated = True
# logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
# If we have more vertical text boxes than horizontal ones,
# and vertical ones are significant, table might be rotated
if is_rotated:
x = self.preprocess(np_img)
(result,) = self.sess.run(None, {"x": x})
rotate_label = self.labels[np.argmax(result)]
# logger.debug(f"Orientation classification result: {label}")
return rotate_label
def list_2_batch(self, img_list, batch_size=16):
"""
将任意长度的列表按照指定的batch size分成多个batch
Args:
img_list: 输入的列表
batch_size: 每个batch的大小默认为16
Returns:
一个包含多个batch的列表每个batch都是原列表的一个子列表
"""
batches = []
for i in range(0, len(img_list), batch_size):
batch = img_list[i : min(i + batch_size, len(img_list))]
batches.append(batch)
return batches
def batch_preprocess(self, imgs):
res_imgs = []
for img_info in imgs:
img = np.asarray(img_info["table_img"])
# 放大图片使其最短边长为256
h, w = img.shape[:2]
scale = 256 / min(h, w)
h_resize = round(h * scale)
w_resize = round(w * scale)
img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
# 调整为224*224的正方形
h, w = img.shape[:2]
cw, ch = 224, 224
x1 = max(0, (w - cw) // 2)
y1 = max(0, (h - ch) // 2)
x2 = min(w, x1 + cw)
y2 = min(h, y1 + ch)
if w < cw or h < ch:
raise ValueError(
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
)
img = img[y1:y2, x1:x2, ...]
# 正则化
split_im = list(cv2.split(img))
std = [0.229, 0.224, 0.225]
scale = 0.00392156862745098
mean = [0.485, 0.456, 0.406]
alpha = [scale / std[i] for i in range(len(std))]
beta = [-mean[i] / std[i] for i in range(len(std))]
for c in range(img.shape[2]):
split_im[c] = split_im[c].astype(np.float32)
split_im[c] *= alpha[c]
split_im[c] += beta[c]
img = cv2.merge(split_im)
# 5. 转换为 CHW 格式
img = img.transpose((2, 0, 1))
res_imgs.append(img)
x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
return x
def batch_predict(
self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16
) -> None:
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.8.0"):
return None
"""
批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
"""
RESOLUTION_GROUP_STRIDE = 128
# 跳过长宽比小于1.2的图片
resolution_groups = defaultdict(list)
for img in imgs:
# RGB图像转换BGR
bgr_img: np.ndarray = cv2.cvtColor(np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR)
img["table_img_bgr"] = bgr_img
img_height, img_width = bgr_img.shape[:2]
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
if img_aspect_ratio > 1.2:
# 归一化尺寸到RESOLUTION_GROUP_STRIDE的倍数
normalized_h = ((img_height + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE # 向上取整到RESOLUTION_GROUP_STRIDE的倍数
normalized_w = ((img_width + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(img)
# 对每个分辨率组进行批处理
rotated_imgs = []
for group_key, group_imgs in tqdm(resolution_groups.items(), desc="Table-ori cls stage1 predict", disable=True):
# 计算目标尺寸组内最大尺寸向上取整到RESOLUTION_GROUP_STRIDE的倍数
max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
# 对所有图像进行padding到统一尺寸
batch_images = []
for img in group_imgs:
bgr_img = img["table_img_bgr"]
h, w = bgr_img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = bgr_img
batch_images.append(padded_img)
# 批处理检测
batch_results = self.ocr_engine.text_detector.batch_predict(
batch_images, min(len(batch_images), det_batch_size)
)
# 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
for index, (img_info, (dt_boxes, elapse)) in enumerate(
zip(group_imgs, batch_results)
):
vertical_count = 0
for box_ocr_res in dt_boxes:
p1, p2, p3, p4 = box_ocr_res
# Calculate width and height
width = p3[0] - p1[0]
height = p3[1] - p1[1]
aspect_ratio = width / height if height > 0 else 1.0
# Count vertical text boxes
if aspect_ratio < 0.8: # Taller than wide - vertical text
vertical_count += 1
if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
rotated_imgs.append(img_info)
# 对旋转的图片进行旋转角度预测
if len(rotated_imgs) > 0:
imgs = self.list_2_batch(rotated_imgs, batch_size=batch_size)
with tqdm(total=len(rotated_imgs), desc="Table-ori cls stage2 predict", disable=True) as pbar:
for img_batch in imgs:
x = self.batch_preprocess(img_batch)
results = self.sess.run(None, {"x": x})
for img_info, res in zip(rotated_imgs, results[0]):
label = self.labels[np.argmax(res)]
self.img_rotate(img_info, label)
pbar.update(1)
def img_rotate(self, img_info, label):
if label == "270":
img_info["table_img"] = cv2.rotate(
np.asarray(img_info["table_img"]),
cv2.ROTATE_90_CLOCKWISE,
)
img_info["wired_table_img"] = cv2.rotate(
np.asarray(img_info["wired_table_img"]),
cv2.ROTATE_90_CLOCKWISE,
)
elif label == "90":
img_info["table_img"] = cv2.rotate(
np.asarray(img_info["table_img"]),
cv2.ROTATE_90_COUNTERCLOCKWISE,
)
img_info["wired_table_img"] = cv2.rotate(
np.asarray(img_info["wired_table_img"]),
cv2.ROTATE_90_COUNTERCLOCKWISE,
)
else:
# 180度和0度不做处理
pass