4种feature classification在代码的实现上是怎么样的?Linear / MLP / CNN / Attention-Based Heads
2023-12-24 11:32:43
具体的分类效果可以看:【Arxiv 2023】Diffusion Models Beat GANs on Image Classification
1、线性分类器 (Linear, A)
使用一个简单的线性层,通常与一个激活函数结合使用。
import torch.nn as nn
class LinearClassifier(nn.Module):
def __init__(self, input_size, num_classes):
super(LinearClassifier, self).__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, x):
return self.linear(x)
2、多层感知机 (Multi-Layer Perceptron, B)
包括多个线性层,每层之间可能有激活函数和dropout层。
class MLPClassifier(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(MLPClassifier, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
3、卷积神经网络 (Convolutional Neural Network, CNN, C)
使用一系列卷积层,通常包括池化层和全连接层。
class CNNClassifier(nn.Module):
def __init__(self, num_classes):
super(CNNClassifier, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.conv2 = nn.Conv2d(32, 64, 3, 1, 1)
self.fc = nn.Linear(64 * 7 * 7, num_classes) # Assuming input size is 28x28
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1) # Flatten the tensor
x = self.fc(x)
return x
4、基于注意力机制的头部 (Attention-Based Heads, D)
使用注意力机制,如Transformer的头部结构。
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class AttentionClassifier(nn.Module):
def __init__(self, input_size, num_classes, nhead, nhid, nlayers):
super(AttentionClassifier, self).__init__()
self.model_type = 'Transformer'
self.encoder_layer = TransformerEncoderLayer(d_model=input_size, nhead=nhead, dim_feedforward=nhid)
self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=nlayers)
self.decoder = nn.Linear(input_size, num_classes)
def forward(self, src):
output = self.transformer_encoder(src)
output = self.decoder(output.mean(dim=1))
return output
文章来源:https://blog.csdn.net/weixin_43135178/article/details/135178993
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!