149 lines
5.6 KiB
Python
149 lines
5.6 KiB
Python
import os
|
||
|
||
from PIL import Image
|
||
import cv2
|
||
import numpy as np
|
||
import onnxruntime
|
||
from loguru import logger
|
||
from tqdm import tqdm
|
||
|
||
from mineru.backend.pipeline.model_list import AtomicModel
|
||
from mineru.utils.enum_class import ModelPath
|
||
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
||
|
||
|
||
class PaddleTableClsModel:
|
||
def __init__(self):
|
||
self.sess = onnxruntime.InferenceSession(
|
||
os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
|
||
)
|
||
self.less_length = 256
|
||
self.cw, self.ch = 224, 224
|
||
self.std = [0.229, 0.224, 0.225]
|
||
self.scale = 0.00392156862745098
|
||
self.mean = [0.485, 0.456, 0.406]
|
||
self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
|
||
|
||
def preprocess(self, input_img):
|
||
# 放大图片,使其最短边长为256
|
||
h, w = input_img.shape[:2]
|
||
scale = 256 / min(h, w)
|
||
h_resize = round(h * scale)
|
||
w_resize = round(w * scale)
|
||
img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
|
||
# 调整为224*224的正方形
|
||
h, w = img.shape[:2]
|
||
cw, ch = 224, 224
|
||
x1 = max(0, (w - cw) // 2)
|
||
y1 = max(0, (h - ch) // 2)
|
||
x2 = min(w, x1 + cw)
|
||
y2 = min(h, y1 + ch)
|
||
if w < cw or h < ch:
|
||
raise ValueError(
|
||
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
|
||
)
|
||
img = img[y1:y2, x1:x2, ...]
|
||
# 正则化
|
||
split_im = list(cv2.split(img))
|
||
std = [0.229, 0.224, 0.225]
|
||
scale = 0.00392156862745098
|
||
mean = [0.485, 0.456, 0.406]
|
||
alpha = [scale / std[i] for i in range(len(std))]
|
||
beta = [-mean[i] / std[i] for i in range(len(std))]
|
||
for c in range(img.shape[2]):
|
||
split_im[c] = split_im[c].astype(np.float32)
|
||
split_im[c] *= alpha[c]
|
||
split_im[c] += beta[c]
|
||
img = cv2.merge(split_im)
|
||
# 5. 转换为 CHW 格式
|
||
img = img.transpose((2, 0, 1))
|
||
imgs = [img]
|
||
x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
|
||
return x
|
||
|
||
def predict(self, input_img):
|
||
if isinstance(input_img, Image.Image):
|
||
np_img = np.asarray(input_img)
|
||
elif isinstance(input_img, np.ndarray):
|
||
np_img = input_img
|
||
else:
|
||
raise ValueError("Input must be a pillow object or a numpy array.")
|
||
x = self.preprocess(np_img)
|
||
result = self.sess.run(None, {"x": x})
|
||
idx = np.argmax(result)
|
||
conf = float(np.max(result))
|
||
return self.labels[idx], conf
|
||
|
||
def list_2_batch(self, img_list, batch_size=16):
|
||
"""
|
||
将任意长度的列表按照指定的batch size分成多个batch
|
||
|
||
Args:
|
||
img_list: 输入的列表
|
||
batch_size: 每个batch的大小,默认为16
|
||
|
||
Returns:
|
||
一个包含多个batch的列表,每个batch都是原列表的一个子列表
|
||
"""
|
||
batches = []
|
||
for i in range(0, len(img_list), batch_size):
|
||
batch = img_list[i : min(i + batch_size, len(img_list))]
|
||
batches.append(batch)
|
||
return batches
|
||
|
||
def batch_preprocess(self, imgs):
|
||
res_imgs = []
|
||
for img in imgs:
|
||
img = np.asarray(img)
|
||
# 放大图片,使其最短边长为256
|
||
h, w = img.shape[:2]
|
||
scale = 256 / min(h, w)
|
||
h_resize = round(h * scale)
|
||
w_resize = round(w * scale)
|
||
img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
|
||
# 调整为224*224的正方形
|
||
h, w = img.shape[:2]
|
||
cw, ch = 224, 224
|
||
x1 = max(0, (w - cw) // 2)
|
||
y1 = max(0, (h - ch) // 2)
|
||
x2 = min(w, x1 + cw)
|
||
y2 = min(h, y1 + ch)
|
||
if w < cw or h < ch:
|
||
raise ValueError(
|
||
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
|
||
)
|
||
img = img[y1:y2, x1:x2, ...]
|
||
# 正则化
|
||
split_im = list(cv2.split(img))
|
||
std = [0.229, 0.224, 0.225]
|
||
scale = 0.00392156862745098
|
||
mean = [0.485, 0.456, 0.406]
|
||
alpha = [scale / std[i] for i in range(len(std))]
|
||
beta = [-mean[i] / std[i] for i in range(len(std))]
|
||
for c in range(img.shape[2]):
|
||
split_im[c] = split_im[c].astype(np.float32)
|
||
split_im[c] *= alpha[c]
|
||
split_im[c] += beta[c]
|
||
img = cv2.merge(split_im)
|
||
# 5. 转换为 CHW 格式
|
||
img = img.transpose((2, 0, 1))
|
||
res_imgs.append(img)
|
||
x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
|
||
return x
|
||
def batch_predict(self, img_info_list, batch_size=16):
|
||
imgs = [item["wired_table_img"] for item in img_info_list]
|
||
imgs = self.list_2_batch(imgs, batch_size=batch_size)
|
||
label_res = []
|
||
with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict", disable=True) as pbar:
|
||
for img_batch in imgs:
|
||
x = self.batch_preprocess(img_batch)
|
||
result = self.sess.run(None, {"x": x})
|
||
for img_res in result[0]:
|
||
idx = np.argmax(img_res)
|
||
conf = float(np.max(img_res))
|
||
label_res.append((self.labels[idx],conf))
|
||
pbar.update(len(img_batch))
|
||
for img_info, (label, conf) in zip(img_info_list, label_res):
|
||
img_info['table_res']["cls_label"] = label
|
||
img_info['table_res']["cls_score"] = round(conf, 3)
|