UnisMindMap/mineru/model/table/cls/paddle_table_cls.py

149 lines
5.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
from PIL import Image
import cv2
import numpy as np
import onnxruntime
from loguru import logger
from tqdm import tqdm
from mineru.backend.pipeline.model_list import AtomicModel
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
class PaddleTableClsModel:
def __init__(self):
self.sess = onnxruntime.InferenceSession(
os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
)
self.less_length = 256
self.cw, self.ch = 224, 224
self.std = [0.229, 0.224, 0.225]
self.scale = 0.00392156862745098
self.mean = [0.485, 0.456, 0.406]
self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
def preprocess(self, input_img):
# 放大图片使其最短边长为256
h, w = input_img.shape[:2]
scale = 256 / min(h, w)
h_resize = round(h * scale)
w_resize = round(w * scale)
img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
# 调整为224*224的正方形
h, w = img.shape[:2]
cw, ch = 224, 224
x1 = max(0, (w - cw) // 2)
y1 = max(0, (h - ch) // 2)
x2 = min(w, x1 + cw)
y2 = min(h, y1 + ch)
if w < cw or h < ch:
raise ValueError(
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
)
img = img[y1:y2, x1:x2, ...]
# 正则化
split_im = list(cv2.split(img))
std = [0.229, 0.224, 0.225]
scale = 0.00392156862745098
mean = [0.485, 0.456, 0.406]
alpha = [scale / std[i] for i in range(len(std))]
beta = [-mean[i] / std[i] for i in range(len(std))]
for c in range(img.shape[2]):
split_im[c] = split_im[c].astype(np.float32)
split_im[c] *= alpha[c]
split_im[c] += beta[c]
img = cv2.merge(split_im)
# 5. 转换为 CHW 格式
img = img.transpose((2, 0, 1))
imgs = [img]
x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
return x
def predict(self, input_img):
if isinstance(input_img, Image.Image):
np_img = np.asarray(input_img)
elif isinstance(input_img, np.ndarray):
np_img = input_img
else:
raise ValueError("Input must be a pillow object or a numpy array.")
x = self.preprocess(np_img)
result = self.sess.run(None, {"x": x})
idx = np.argmax(result)
conf = float(np.max(result))
return self.labels[idx], conf
def list_2_batch(self, img_list, batch_size=16):
"""
将任意长度的列表按照指定的batch size分成多个batch
Args:
img_list: 输入的列表
batch_size: 每个batch的大小默认为16
Returns:
一个包含多个batch的列表每个batch都是原列表的一个子列表
"""
batches = []
for i in range(0, len(img_list), batch_size):
batch = img_list[i : min(i + batch_size, len(img_list))]
batches.append(batch)
return batches
def batch_preprocess(self, imgs):
res_imgs = []
for img in imgs:
img = np.asarray(img)
# 放大图片使其最短边长为256
h, w = img.shape[:2]
scale = 256 / min(h, w)
h_resize = round(h * scale)
w_resize = round(w * scale)
img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
# 调整为224*224的正方形
h, w = img.shape[:2]
cw, ch = 224, 224
x1 = max(0, (w - cw) // 2)
y1 = max(0, (h - ch) // 2)
x2 = min(w, x1 + cw)
y2 = min(h, y1 + ch)
if w < cw or h < ch:
raise ValueError(
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
)
img = img[y1:y2, x1:x2, ...]
# 正则化
split_im = list(cv2.split(img))
std = [0.229, 0.224, 0.225]
scale = 0.00392156862745098
mean = [0.485, 0.456, 0.406]
alpha = [scale / std[i] for i in range(len(std))]
beta = [-mean[i] / std[i] for i in range(len(std))]
for c in range(img.shape[2]):
split_im[c] = split_im[c].astype(np.float32)
split_im[c] *= alpha[c]
split_im[c] += beta[c]
img = cv2.merge(split_im)
# 5. 转换为 CHW 格式
img = img.transpose((2, 0, 1))
res_imgs.append(img)
x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
return x
def batch_predict(self, img_info_list, batch_size=16):
imgs = [item["wired_table_img"] for item in img_info_list]
imgs = self.list_2_batch(imgs, batch_size=batch_size)
label_res = []
with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict", disable=True) as pbar:
for img_batch in imgs:
x = self.batch_preprocess(img_batch)
result = self.sess.run(None, {"x": x})
for img_res in result[0]:
idx = np.argmax(img_res)
conf = float(np.max(img_res))
label_res.append((self.labels[idx],conf))
pbar.update(len(img_batch))
for img_info, (label, conf) in zip(img_info_list, label_res):
img_info['table_res']["cls_label"] = label
img_info['table_res']["cls_score"] = round(conf, 3)