UnisMindMap/mineru/model/utils/pytorchocr/modeling/heads/cls_head.py

24 lines
596 B
Python

import torch
import torch.nn.functional as F
from torch import nn
class ClsHead(nn.Module):
"""
Class orientation
Args:
params(dict): super parameters for build Class network
"""
def __init__(self, in_channels, class_dim, **kwargs):
super(ClsHead, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(in_channels, class_dim, bias=True)
def forward(self, x):
x = self.pool(x)
x = torch.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.fc(x)
x = F.softmax(x, dim=1)
return x