Pytorch:nn.Linear() 基本定义和用法

2023-12-21 16:40:18

nn.Linear的基本定义

nn.Linear定义一个神经网络的线性层,方法签名如下:

torch.nn.Linear(     in_features, # 输入的神经元个数
	             out_features, # 输出神经元个数
	             bias=True # 是否包含偏置
	             )

Linear其实就是对输入 X n × i X^{n×i} Xn×i
H n × o = X n × i W ( i × o ) + b ( o ) H^{n×o} = X^{n×i}W^{(i×o)} + b^{(o)} Hn×o=Xn×iW(i×o)+b(o)
其中:

  • n n n为输入向量的行数
  • i i i为输入神经元的个数(例如你的样本特征数为5,则 i = 5 i=5 i=5)
  • o o o为输出神经元的个数

举个例子:

from torch import nn
import torch

model = nn.Linear(2, 1) # 输入特征数为2,输出特征数为1
input = torch.Tensor([1, 2]) # 给一个样本,该样本有2个特征(这两个特征的值分别为1和2)
output = model(input)
# output :tensor([-1.4166], grad_fn=<AddBackward0>)
# 查看模型参数
for param in model.parameters():
    print(param)

# 可以看到,模型有3个参数,分别为两个权重和一个偏执:
# Parameter containing:
# tensor([[ 0.1098, -0.5404]], requires_grad=True)
# Parameter containing:
# tensor([-0.4456], requires_grad=True)

文章来源:https://blog.csdn.net/weixin_42046845/article/details/135129376
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。