283 lines
12 KiB
Python
283 lines
12 KiB
Python
# Copyright (c) Opendatalab. All rights reserved.
|
||
import os
|
||
|
||
from PIL import Image
|
||
from collections import defaultdict
|
||
from typing import List, Dict
|
||
from tqdm import tqdm
|
||
import cv2
|
||
import numpy as np
|
||
import onnxruntime
|
||
|
||
from mineru.utils.enum_class import ModelPath
|
||
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
||
|
||
|
||
class PaddleOrientationClsModel:
|
||
def __init__(self, ocr_engine):
|
||
self.sess = onnxruntime.InferenceSession(
|
||
os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
|
||
)
|
||
self.ocr_engine = ocr_engine
|
||
self.less_length = 256
|
||
self.cw, self.ch = 224, 224
|
||
self.std = [0.229, 0.224, 0.225]
|
||
self.scale = 0.00392156862745098
|
||
self.mean = [0.485, 0.456, 0.406]
|
||
self.labels = ["0", "90", "180", "270"]
|
||
|
||
def preprocess(self, input_img):
|
||
# 放大图片,使其最短边长为256
|
||
h, w = input_img.shape[:2]
|
||
scale = 256 / min(h, w)
|
||
h_resize = round(h * scale)
|
||
w_resize = round(w * scale)
|
||
img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
|
||
# 调整为224*224的正方形
|
||
h, w = img.shape[:2]
|
||
cw, ch = 224, 224
|
||
x1 = max(0, (w - cw) // 2)
|
||
y1 = max(0, (h - ch) // 2)
|
||
x2 = min(w, x1 + cw)
|
||
y2 = min(h, y1 + ch)
|
||
if w < cw or h < ch:
|
||
raise ValueError(
|
||
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
|
||
)
|
||
img = img[y1:y2, x1:x2, ...]
|
||
# 正则化
|
||
split_im = list(cv2.split(img))
|
||
std = [0.229, 0.224, 0.225]
|
||
scale = 0.00392156862745098
|
||
mean = [0.485, 0.456, 0.406]
|
||
alpha = [scale / std[i] for i in range(len(std))]
|
||
beta = [-mean[i] / std[i] for i in range(len(std))]
|
||
for c in range(img.shape[2]):
|
||
split_im[c] = split_im[c].astype(np.float32)
|
||
split_im[c] *= alpha[c]
|
||
split_im[c] += beta[c]
|
||
img = cv2.merge(split_im)
|
||
# 5. 转换为 CHW 格式
|
||
img = img.transpose((2, 0, 1))
|
||
imgs = [img]
|
||
x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
|
||
return x
|
||
|
||
def predict(self, input_img):
|
||
rotate_label = "0" # Default to 0 if no rotation detected or not portrait
|
||
if isinstance(input_img, Image.Image):
|
||
np_img = np.asarray(input_img)
|
||
elif isinstance(input_img, np.ndarray):
|
||
np_img = input_img
|
||
else:
|
||
raise ValueError("Input must be a pillow object or a numpy array.")
|
||
bgr_image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
|
||
# First check the overall image aspect ratio (height/width)
|
||
img_height, img_width = bgr_image.shape[:2]
|
||
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
|
||
img_is_portrait = img_aspect_ratio > 1.2
|
||
|
||
if img_is_portrait:
|
||
|
||
det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
|
||
# Check if table is rotated by analyzing text box aspect ratios
|
||
if det_res:
|
||
vertical_count = 0
|
||
is_rotated = False
|
||
|
||
for box_ocr_res in det_res:
|
||
p1, p2, p3, p4 = box_ocr_res
|
||
|
||
# Calculate width and height
|
||
width = p3[0] - p1[0]
|
||
height = p3[1] - p1[1]
|
||
|
||
aspect_ratio = width / height if height > 0 else 1.0
|
||
|
||
# Count vertical vs horizontal text boxes
|
||
if aspect_ratio < 0.8: # Taller than wide - vertical text
|
||
vertical_count += 1
|
||
# elif aspect_ratio > 1.2: # Wider than tall - horizontal text
|
||
# horizontal_count += 1
|
||
|
||
if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
|
||
is_rotated = True
|
||
# logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
|
||
|
||
# If we have more vertical text boxes than horizontal ones,
|
||
# and vertical ones are significant, table might be rotated
|
||
if is_rotated:
|
||
x = self.preprocess(np_img)
|
||
(result,) = self.sess.run(None, {"x": x})
|
||
rotate_label = self.labels[np.argmax(result)]
|
||
# logger.debug(f"Orientation classification result: {label}")
|
||
|
||
return rotate_label
|
||
|
||
def list_2_batch(self, img_list, batch_size=16):
|
||
"""
|
||
将任意长度的列表按照指定的batch size分成多个batch
|
||
|
||
Args:
|
||
img_list: 输入的列表
|
||
batch_size: 每个batch的大小,默认为16
|
||
|
||
Returns:
|
||
一个包含多个batch的列表,每个batch都是原列表的一个子列表
|
||
"""
|
||
batches = []
|
||
for i in range(0, len(img_list), batch_size):
|
||
batch = img_list[i : min(i + batch_size, len(img_list))]
|
||
batches.append(batch)
|
||
return batches
|
||
|
||
def batch_preprocess(self, imgs):
|
||
res_imgs = []
|
||
for img_info in imgs:
|
||
img = np.asarray(img_info["table_img"])
|
||
# 放大图片,使其最短边长为256
|
||
h, w = img.shape[:2]
|
||
scale = 256 / min(h, w)
|
||
h_resize = round(h * scale)
|
||
w_resize = round(w * scale)
|
||
img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
|
||
# 调整为224*224的正方形
|
||
h, w = img.shape[:2]
|
||
cw, ch = 224, 224
|
||
x1 = max(0, (w - cw) // 2)
|
||
y1 = max(0, (h - ch) // 2)
|
||
x2 = min(w, x1 + cw)
|
||
y2 = min(h, y1 + ch)
|
||
if w < cw or h < ch:
|
||
raise ValueError(
|
||
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
|
||
)
|
||
img = img[y1:y2, x1:x2, ...]
|
||
# 正则化
|
||
split_im = list(cv2.split(img))
|
||
std = [0.229, 0.224, 0.225]
|
||
scale = 0.00392156862745098
|
||
mean = [0.485, 0.456, 0.406]
|
||
alpha = [scale / std[i] for i in range(len(std))]
|
||
beta = [-mean[i] / std[i] for i in range(len(std))]
|
||
for c in range(img.shape[2]):
|
||
split_im[c] = split_im[c].astype(np.float32)
|
||
split_im[c] *= alpha[c]
|
||
split_im[c] += beta[c]
|
||
img = cv2.merge(split_im)
|
||
# 5. 转换为 CHW 格式
|
||
img = img.transpose((2, 0, 1))
|
||
res_imgs.append(img)
|
||
x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
|
||
return x
|
||
|
||
def batch_predict(
|
||
self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16
|
||
) -> None:
|
||
|
||
import torch
|
||
from packaging import version
|
||
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||
return None
|
||
|
||
"""
|
||
批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
|
||
"""
|
||
RESOLUTION_GROUP_STRIDE = 128
|
||
# 跳过长宽比小于1.2的图片
|
||
resolution_groups = defaultdict(list)
|
||
for img in imgs:
|
||
# RGB图像转换BGR
|
||
bgr_img: np.ndarray = cv2.cvtColor(np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR)
|
||
img["table_img_bgr"] = bgr_img
|
||
img_height, img_width = bgr_img.shape[:2]
|
||
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
|
||
if img_aspect_ratio > 1.2:
|
||
# 归一化尺寸到RESOLUTION_GROUP_STRIDE的倍数
|
||
normalized_h = ((img_height + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE # 向上取整到RESOLUTION_GROUP_STRIDE的倍数
|
||
normalized_w = ((img_width + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
||
group_key = (normalized_h, normalized_w)
|
||
resolution_groups[group_key].append(img)
|
||
|
||
# 对每个分辨率组进行批处理
|
||
rotated_imgs = []
|
||
for group_key, group_imgs in tqdm(resolution_groups.items(), desc="Table-ori cls stage1 predict", disable=True):
|
||
# 计算目标尺寸(组内最大尺寸,向上取整到RESOLUTION_GROUP_STRIDE的倍数)
|
||
max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
|
||
max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
|
||
target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
||
target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
||
|
||
# 对所有图像进行padding到统一尺寸
|
||
batch_images = []
|
||
for img in group_imgs:
|
||
bgr_img = img["table_img_bgr"]
|
||
h, w = bgr_img.shape[:2]
|
||
# 创建目标尺寸的白色背景
|
||
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
|
||
# 将原图像粘贴到左上角
|
||
padded_img[:h, :w] = bgr_img
|
||
batch_images.append(padded_img)
|
||
|
||
# 批处理检测
|
||
batch_results = self.ocr_engine.text_detector.batch_predict(
|
||
batch_images, min(len(batch_images), det_batch_size)
|
||
)
|
||
|
||
# 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
|
||
|
||
for index, (img_info, (dt_boxes, elapse)) in enumerate(
|
||
zip(group_imgs, batch_results)
|
||
):
|
||
vertical_count = 0
|
||
for box_ocr_res in dt_boxes:
|
||
p1, p2, p3, p4 = box_ocr_res
|
||
|
||
# Calculate width and height
|
||
width = p3[0] - p1[0]
|
||
height = p3[1] - p1[1]
|
||
|
||
aspect_ratio = width / height if height > 0 else 1.0
|
||
|
||
# Count vertical text boxes
|
||
if aspect_ratio < 0.8: # Taller than wide - vertical text
|
||
vertical_count += 1
|
||
|
||
if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
|
||
rotated_imgs.append(img_info)
|
||
|
||
# 对旋转的图片进行旋转角度预测
|
||
if len(rotated_imgs) > 0:
|
||
imgs = self.list_2_batch(rotated_imgs, batch_size=batch_size)
|
||
with tqdm(total=len(rotated_imgs), desc="Table-ori cls stage2 predict", disable=True) as pbar:
|
||
for img_batch in imgs:
|
||
x = self.batch_preprocess(img_batch)
|
||
results = self.sess.run(None, {"x": x})
|
||
for img_info, res in zip(rotated_imgs, results[0]):
|
||
label = self.labels[np.argmax(res)]
|
||
self.img_rotate(img_info, label)
|
||
pbar.update(1)
|
||
|
||
def img_rotate(self, img_info, label):
|
||
if label == "270":
|
||
img_info["table_img"] = cv2.rotate(
|
||
np.asarray(img_info["table_img"]),
|
||
cv2.ROTATE_90_CLOCKWISE,
|
||
)
|
||
img_info["wired_table_img"] = cv2.rotate(
|
||
np.asarray(img_info["wired_table_img"]),
|
||
cv2.ROTATE_90_CLOCKWISE,
|
||
)
|
||
elif label == "90":
|
||
img_info["table_img"] = cv2.rotate(
|
||
np.asarray(img_info["table_img"]),
|
||
cv2.ROTATE_90_COUNTERCLOCKWISE,
|
||
)
|
||
img_info["wired_table_img"] = cv2.rotate(
|
||
np.asarray(img_info["wired_table_img"]),
|
||
cv2.ROTATE_90_COUNTERCLOCKWISE,
|
||
)
|
||
else:
|
||
# 180度和0度不做处理
|
||
pass
|