PyTorch:初学者全面指南torch函数使用

2024-01-02 13:18:57

目录

引言

多维张量数据结构

数学运算

序列化功能

CUDA 支持

其他实用工具

PyTorch的关键特性

torch相关函数演示和简介?

is_tensor

函数作用

使用技巧

使用方法和示例

展示结果

is_storage

存储对象(Storage)简介

函数作用

使用技巧

使用方法和示例

?is_complex

函数作用

使用技巧

使用方法和示例

?is_conj

函数作用

使用技巧

使用方法和示例

is_floating_point

函数作用

使用技巧

使用方法和示例

is_nonzero

函数作用

使用注意点

使用方法和示例

set_default_dtype

函数作用

使用注意点

使用方法和示例

get_default_dtype

函数作用

使用注意点

使用方法和示例

set_default_device

函数作用

使用注意点

使用方法和示例

set_default_tensor_type

函数作用

使用注意点

使用方法和示例

numel

函数作用

使用方法和示例

set_printoptions

函数作用

参数详解

使用方法和示例

set_flush_denormal

函数作用

使用注意点

使用方法和示例


引言

? ?torch 是一个功能丰富的库,是 PyTorch 框架的核心,用于深度学习和张量计算。它提供了一系列工具和功能,使得科学计算和机器学习更加高效和易于操作。

多维张量数据结构

  • torch 提供了多维张量(Tensor)的数据结构,这是其最基本的组成部分。张量在 PyTorch 中类似于 NumPy 的数组,但与此同时,它们还能在 GPU 上进行计算以加速运行速度。
  • 张量用于存储和操作数据,是深度学习模型中各种计算的基础。

数学运算

  • torch 定义了一系列的数学运算来操作这些张量,例如加法、乘法、转置等,以及更高级的操作,如矩阵乘法、卷积等。
  • 这些运算是构建和训练深度学习模型的基石,用于实现前向传播和反向传播算法。

序列化功能

  • torch 提供了有效的工具来序列化和反序列化张量数据,使得数据的保存和加载变得容易。
  • 序列化对于模型的持久化存储、分享和部署至关重要。

CUDA 支持

  • 通过其 CUDA 对应部分,torch 允许张量计算在 NVIDIA GPU 上运行,条件是 GPU 的计算能力需要大于等于 3.0。
  • 使用 GPU 可以显著加快训练过程,特别是在处理大型神经网络和大规模数据集时。

其他实用工具

  • torch 还包括其他多种实用工具,如自动微分(重要于神经网络训练中的反向传播)、数据集加载和处理工具等。
  • 这些工具使得从数据处理到模型训练的整个流程更加高效和标准化。

????????总之,torch 包作为 PyTorch 框架的核心,为深度学习和广泛的科学计算提供了强大的工具和功能。它的灵活性和易用性使得开发、训练和部署深度学习模型变得更加简单和高效。

PyTorch的关键特性

  1. 张量(Tensors):PyTorch的核心是张量。张量在PyTorch中是多维数组,类似于NumPy数组,但它们具有在GPU上运行的附加优势。它们是PyTorch的基础构件,用于各种操作。

  2. CUDA支持:PyTorch提供对CUDA的支持,允许在NVIDIA GPU上进行张量计算,前提是GPU的计算能力需要大于等于3.0。这意味着你可以将你的张量计算无缝迁移到GPU上,从而实现显著的性能提升。

torch相关函数演示和简介?

is_tensor

?????????torch.is_tensor 是 PyTorch 中的一个函数,用来判断给定的对象是否是一个 PyTorch 张量(Tensor)。这个函数非常有用,特别是在你处理多种数据类型时,需要确认某个对象是否为PyTorch张量的情况下。

函数作用

  • 目的: torch.is_tensor 的主要目的是进行类型检查,即确认一个对象是否为 PyTorch 的 Tensor 类型。
  • 使用场景: 在混合不同类型数据(如列表、数组、张量等)的处理或计算中,确保进行张量相关操作的对象确实是张量。

使用技巧

  • 类型检查: 尽管 torch.is_tensor 可以用来检查一个对象是否为张量,但在进行类型检查时,使用 isinstance(obj, Tensor) 通常更推荐。这是因为 isinstance 对于类型检查来说更加明确,且更适合于类型提示工具(如mypy)。
  • 调试和验证: 当编写涉及多种数据结构的复杂代码时,使用 torch.is_tensor 可以帮助调试和验证数据类型,确保传入的参数是正确的类型。

使用方法和示例

以下是一个使用 torch.is_tensor 的简单例子:

import torch

# 创建一个PyTorch张量
x = torch.tensor([1, 2, 3])

# 使用torch.is_tensor来检查x是否为张量如果是则返回true,不是则返回false
if torch.is_tensor(x):
    print("x is a tensor")
else:
    print("x is not a tensor")

展示结果

在上面的代码中,输出将会是:

x is a tensor

?????????这是因为变量 x 被创建为一个 PyTorch 张量,所以 torch.is_tensor(x) 返回 True。

is_storage

????????torch.is_storage 是 PyTorch 库中的一个函数,用于判断给定的对象是否是一个 PyTorch 存储对象(Storage Object)。在深入了解其用途之前,我们首先需要明白什么是 PyTorch 的存储对象。

存储对象(Storage)简介

PyTorch 中的存储对象是一种底层数据结构,用于持有张量(Tensor)的数据。每个张量都与一个存储对象相关联,该对象实际上包含了张量的数据。虽然多个张量可以共享相同的存储,但它们可以有不同的视图(比如不同的大小和步长)。

函数作用

  • 目的: torch.is_storage 的主要目的是确认一个对象是否为 PyTorch 的存储对象。
  • 使用场景: 这个函数在底层数据操作和高级张量处理中非常有用,尤其是在你需要处理直接存储数据或者进行优化和内存管理时。

使用技巧

  • 内存和性能优化: 在进行内存优化或者性能调优时,了解张量背后的存储机制是非常重要的。使用 torch.is_storage 可以帮助你在调试和开发过程中确认对象是否为存储对象。
  • 深入理解张量: 理解存储对象可以帮助你更深入地理解张量是如何在 PyTorch 中表示和处理的。

使用方法和示例

以下是使用 torch.is_storage 的一个示例:

import torch

# 创建一个PyTorch张量
x = torch.tensor([1, 2, 3])

# 获取与张量x关联的存储对象
storage = x.storage()

# 检查storage是否为PyTorch存储对象
if torch.is_storage(storage):
    print("storage is a PyTorch storage object")
else:
    print("storage is not a PyTorch storage object")

# 在上述代码中,输出将会是:
## storage is a PyTorch storage object
## 这表明变量 storage 是与张量 x 相关联的 PyTorch 存储对象。

?is_complex

? ? torch.is_complex 是 PyTorch 中的一个函数,用于判断给定的张量是否是复数数据类型。这个函数特别有用在处理涉及复数运算的场景中,比如信号处理或者某些特定的数学计算。

函数作用

  • 目的: 确认输入张量的数据类型是否为复数类型。在 PyTorch 中,复数类型主要有 torch.complex64torch.complex128
  • 使用场景: 当你的代码涉及到复数运算时,使用 torch.is_complex 可以确保你的输入是合适的数据类型,从而避免类型不匹配导致的错误。

使用技巧

  • 数据类型验证: 在执行涉及复数的操作之前,使用 torch.is_complex 来验证数据类型可以确保代码的正确性和稳定性。
  • 调试和开发: 在开发和调试阶段,特别是在进行数学和算法的实现时,确认输入是否为复数类型非常重要。

使用方法和示例

以下是 torch.is_complex 的使用示例:

import torch

# 创建一个复数类型的张量
z = torch.tensor([1 + 1j, 2 + 2j], dtype=torch.complex64)

# 使用torch.is_complex来检查z是否为复数类型
is_complex = torch.is_complex(z)
print(is_complex)  # 输出结果: True

# 创建一个非复数类型的张量
x = torch.tensor([1, 2, 3])

# 再次使用torch.is_complex来检查x
is_complex_x = torch.is_complex(x)
print(is_complex_x)  # 输出结果: False

?????????在上述代码中,torch.is_complex(z) 的结果是 True,因为 z 是一个复数类型的张量。而 torch.is_complex(x) 的结果是 False,因为 x 是一个实数类型的张量。

?is_conj

? ? torch.is_conj 是 PyTorch 库中的一个函数,用于判断一个张量是否为共轭张量。这个函数在处理涉及复数张量的场景,特别是在涉及共轭运算的数学和工程计算中非常有用。

函数作用

  • 目的: torch.is_conj 的主要作用是检查输入的张量是否已经被设置为共轭状态,即其共轭位是否为True。
  • 使用场景: 在执行涉及复数张量的共轭操作时,这个函数可以确保你正在处理的是正确的张量类型。这在某些算法中非常重要,比如在信号处理和复数域中的数学计算。

使用技巧

  • 共轭状态确认: 在进行共轭操作或处理共轭数据时,使用 torch.is_conj 确认张量的共轭状态可以避免可能的错误。
  • 算法实现与调试: 在实现和调试涉及复数运算的算法时,torch.is_conj 是一个有用的工具,可以帮助开发者理解和追踪张量的状态。

使用方法和示例

以下是 torch.is_conj 的使用示例:

import torch

# 创建一个复数张量
z = torch.tensor([1 + 1j, 2 + 2j], dtype=torch.complex64)

# 假设我们对z进行了共轭操作
# 注意:PyTorch在某些版本中可能没有提供直接的共轭操作函数
conj_z = z.conj()

# 使用torch.is_conj来检查z是否为共轭张量
is_conjugated = torch.is_conj(conj_z)
print(is_conjugated)  # 输出结果: True 或 False,取决于是否进行了共轭操作

?????????在上述代码中,由于 PyTorch 在某些版本中可能还没有提供直接的共轭操作或共轭位的设置,所以 torch.is_conj 函数的具体使用可能受到限制。但是,其基本用途是用于检查一个张量是否被设置为共轭状态。

is_floating_point

????????torch.is_floating_point 是 PyTorch 中的一个函数,用于判断一个张量的数据类型是否为浮点类型。这个函数在处理涉及浮点运算的场景中非常有用,尤其是在精度和性能优化方面。

函数作用

  • 目的: 确认输入张量的数据类型是否为浮点类型。PyTorch 中的浮点类型包括 torch.float64 (双精度浮点), torch.float32 (单精度浮点), torch.float16 (半精度浮点), 和 torch.bfloat16
  • 使用场景: 在执行数值计算时,尤其是涉及到精度要求或性能优化的计算时,使用 torch.is_floating_point 可以确保输入数据具有合适的数据类型。

使用技巧

  • 精度和性能优化: 在进行深度学习模型的训练和推断时,不同的浮点精度对性能和内存占用有显著影响。使用 torch.is_floating_point 可以帮助你确认当前的数据类型,从而做出相应的优化决策。
  • 数据类型验证: 在复杂的计算中,特别是涉及不同数据类型的操作时,使用 torch.is_floating_point 确保你的输入数据是浮点类型,可以防止类型不匹配导致的错误。

使用方法和示例

以下是 torch.is_floating_point 的使用示例:

import torch

# 创建不同数据类型的张量
tensor_float64 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)
tensor_int = torch.tensor([1, 2, 3])

# 检查tensor_float64是否为浮点类型
print(torch.is_floating_point(tensor_float64))  # 输出: True

# 检查tensor_int是否为浮点类型
print(torch.is_floating_point(tensor_int))  # 输出: False

?????????在上述代码中,torch.is_floating_point(tensor_float64) 的结果是 True,因为 tensor_float64 是一个双精度浮点类型的张量。而 torch.is_floating_point(tensor_int) 的结果是 False,因为 tensor_int 是一个整数类型的张量。

is_nonzero

? ? torch.is_nonzero 是 PyTorch 中的一个函数,用于判断一个单元素张量(tensor)是否非零。这个函数在进行条件检查和断言时特别有用,尤其是在处理只包含一个元素的张量时。

函数作用

  • 目的: 确认输入张量是一个单元素张量,并且这个元素的值不等于零。
  • 使用场景: 通常用于条件判断和断言,特别是在需要验证某个计算结果是否为非零值时。这在数学和算法编程中非常常见。

使用注意点

  • 单元素张量: torch.is_nonzero 只对单元素张量有效。如果输入张量包含多于一个元素,或者是空的,函数将抛出 RuntimeError
  • 类型转换: 函数会根据需要进行类型转换。例如,布尔值 False 被视为 0,浮点数或整数中的 0 也被视为 0

使用方法和示例

以下是一些使用 torch.is_nonzero 的示例,包括它如何对不同的输入做出反应:

import torch

# 对于单元素张量且元素为0,返回False
print(torch.is_nonzero(torch.tensor([0.])))  # 输出: False

# 对于单元素张量且元素不为0,返回True
print(torch.is_nonzero(torch.tensor([1.5])))  # 输出: True

# 对于单元素布尔型张量False,返回False
print(torch.is_nonzero(torch.tensor([False])))  # 输出: False

# 对于单元素整数型张量且元素不为0,返回True
print(torch.is_nonzero(torch.tensor([3])))  # 输出: True

# 对于多元素张量,抛出RuntimeError
try:
    torch.is_nonzero(torch.tensor([1, 3, 5]))
except RuntimeError as e:
    print(e)  # 输出: bool value of Tensor with more than one value is ambiguous

# 对于空张量,抛出RuntimeError
try:
    torch.is_nonzero(torch.tensor([]))
except RuntimeError as e:
    print(e)  # 输出: bool value of Tensor with no values is ambiguous

????????在这些示例中,我们可以看到 torch.is_nonzero 如何对不同类型和大小的张量做出反应。特别需要注意的是,它只适用于单元素张量,且当张量为空或包含多个元素时,会引发错误。

set_default_dtype

? ? torch.set_default_dtype 是 PyTorch 库中的一个函数,用于设置 PyTorch 中默认的浮点数据类型(dtype)。这个设置对于控制浮点数张量的默认精度非常重要,特别是在创建新的张量时没有明确指定数据类型的情况下。

函数作用

  • 目的: 设置 PyTorch 中的默认浮点数据类型。PyTorch 初始化时,默认的浮点类型是 torch.float32torch.set_default_dtype 可以将其改为 torch.float32torch.float64
  • 使用场景: 用于控制新创建张量的默认数据类型,尤其是在从 Python 的浮点数或复数创建张量时。此外,它还会影响复数张量的默认数据类型和张量类型提升的结果。

使用注意点

  • 支持类型: 仅支持 torch.float32torch.float64 作为输入。虽然其他数据类型可能被接受而不报错,但它们并不受支持,且可能无法按预期工作。
  • 类型推断: 类似于 NumPy,当设置默认类型为 torch.float64 时,会促进类似 NumPy 的类型推断。
  • 复数数据类型影响: 默认的浮点类型也隐式决定了默认的复数数据类型。例如,当默认浮点类型为 float32 时,默认复数类型为 complex64;当为 float64 时,默认为 complex128

使用方法和示例

以下是 torch.set_default_dtype 的使用示例:

import torch

# 初始默认的浮点类型为torch.float32
print(torch.tensor([1.2, 3]).dtype)  # 输出: torch.float32

# 初始默认的复数类型为torch.complex64
print(torch.tensor([1.2, 3j]).dtype)  # 输出: torch.complex64

# 设置默认的浮点类型为torch.float64
torch.set_default_dtype(torch.float64)

# 现在Python浮点数被解释为float64
print(torch.tensor([1.2, 3]).dtype)  # 输出: torch.float64

# 现在复数Python数字被解释为complex128
print(torch.tensor([1.2, 3j]).dtype)  # 输出: torch.complex128

?????????在这个示例中,我们可以看到在调用 torch.set_default_dtype(torch.float64) 后,新创建的浮点数和复数张量的数据类型分别变为 torch.float64torch.complex128

get_default_dtype

? ? torch.get_default_dtype 是 PyTorch 库中的一个函数,它用于获取当前设置的默认浮点数据类型(dtype)。这个函数在需要确认或记录当前环境的默认浮点数据类型时非常有用,尤其是在进行多种数据处理任务时,了解当前的默认数据类型有助于保持代码的一致性和准确性。

函数作用

  • 目的: 获取当前 PyTorch 环境中的默认浮点数据类型。
  • 使用场景: 在需要根据当前的默认数据类型调整算法行为或者在进行数据类型敏感的操作之前,使用此函数可以确保你了解当前的默认设置。

使用注意点

  • set_default_dtype配合使用: torch.get_default_dtypetorch.set_default_dtype 通常一起使用,前者用于获取当前默认的浮点类型,后者用于设置默认的浮点类型。
  • 受其他设置影响: 设置默认张量类型(如通过 torch.set_default_tensor_type)也会影响默认的浮点数据类型,因为不同的张量类型可能有不同的默认数据类型。

使用方法和示例

以下是使用 torch.get_default_dtype 的示例:

import torch

# 获取初始默认的浮点数据类型
print(torch.get_default_dtype())  # 输出: torch.float32,初始默认的浮点类型

# 设置默认的浮点数据类型为torch.float64
torch.set_default_dtype(torch.float64)

# 获取更改后的默认数据类型
print(torch.get_default_dtype())  # 输出: torch.float64

# 设置默认的张量类型为torch.FloatTensor
torch.set_default_tensor_type(torch.FloatTensor)

# 再次获取默认的浮点数据类型
print(torch.get_default_dtype())  # 输出: torch.float32,因为torch.FloatTensor的默认类型是torch.float32

?????????在这个示例中,我们首先获取了初始的默认浮点数据类型(torch.float32),然后将其更改为 torch.float64,并验证了这一更改。随后,我们改变了默认的张量类型为 torch.FloatTensor,这也导致默认的浮点数据类型变回 torch.float32

set_default_device

? ? torch.set_default_device 是 PyTorch 库中的一个函数,用于设置默认的设备,以便在未明确指定设备时,新创建的 torch.Tensor 将被分配到该设备上。这个函数在需要统一管理张量分配位置的场景中非常有用,尤其是在进行多设备计算时。

函数作用

  • 目的: 设置 PyTorch 张量默认的分配设备。这对于优化代码,使其在特定设备(如 GPU)上运行非常有用。
  • 使用场景: 在需要将大部分张量操作默认在特定设备上执行的情况下,如在 GPU 上进行深度学习模型训练时。

使用注意点

  • 初始默认设备: 初始情况下,PyTorch 的默认设备是 CPU。使用 torch.set_default_device 可以将其更改为 GPU 或其他设备。
  • 不影响显式指定的设备调用: 如果使用工厂函数(如 torch.tensor)并显式指定了设备,这个设置不会影响那些调用。
  • 性能成本: 使用此函数会在每次调用 PyTorch API 时带来轻微的性能成本。
  • 临时更改设备: 如果只想临时更改默认设备,建议使用 with torch.device(device): 上下文管理器。

使用方法和示例

以下是 torch.set_default_device 的使用示例:

import torch

# 检查一个新创建的张量默认分配到哪个设备
print(torch.tensor([1.2, 3]).device)  # 输出: device(type='cpu')

# 将默认设备设置为 CUDA 设备 0
torch.set_default_device('cuda')

# 再次检查新创建的张量默认分配到哪个设备
print(torch.tensor([1.2, 3]).device)  # 输出: device(type='cuda', index=0)

# 将默认设备设置为 CUDA 设备 1  多个显卡这么搞
torch.set_default_device('cuda:1')

# 再次检查新创建的张量默认分配到哪个设备
print(torch.tensor([1.2, 3]).device)  # 输出: device(type='cuda', index=1)

?????????在这个示例中,我们首先检查了新创建的张量默认分配到 CPU。然后,我们将默认设备更改为 CUDA 设备 0 和 CUDA 设备 1,并观察到新创建的张量随之分配到相应的设备。

set_default_tensor_type

? ? torch.set_default_tensor_type 是 PyTorch 库中的一个函数,用于设置默认的 torch.Tensor 类型。特别地,这个设置影响浮点张量的默认类型,也用于 torch.tensor() 函数中的类型推断。

函数作用

  • 目的: 设置默认的 torch.Tensor 类型,特别是浮点张量的类型。这对于控制新创建的张量的数据类型非常重要。
  • 使用场景: 当你希望改变 PyTorch 中新创建的浮点张量的默认类型时,这个函数非常有用。例如,如果你想要默认所有新的浮点张量都是双精度(torch.float64),就可以使用这个函数来实现。

使用注意点

  • 默认类型: 初始情况下,PyTorch 的默认浮点张量类型是 torch.FloatTensor,即 torch.float32
  • 类型设置: 该函数接受一个类型或其名称作为参数,用于设置默认的张量类型。
  • 类型推断影响: 设置的默认类型也会用于 torch.tensor() 中的类型推断。例如,如果设置了 torch.DoubleTensor(即 torch.float64),那么在使用 torch.tensor() 创建张量时,没有显式指定类型的浮点数将默认为 torch.float64

使用方法和示例

以下是 torch.set_default_tensor_type 的使用示例:

import torch

# 检查一个新创建的浮点张量的默认类型
print(torch.tensor([1.2, 3]).dtype)  # 输出: torch.float32

# 将默认的浮点张量类型设置为双精度浮点类型(torch.float64)
torch.set_default_tensor_type(torch.DoubleTensor)

# 再次检查新创建的浮点张量的类型
print(torch.tensor([1.2, 3]).dtype)  # 输出: torch.float64

????????在这个示例中,我们首先检查了新创建的浮点张量的默认类型是 torch.float32。然后,我们将默认的浮点张量类型设置为 torch.DoubleTensor(即 torch.float64),并观察到新创建的浮点张量类型随之变为 torch.float64

numel

? ? torch.numel 是 PyTorch 库中的一个函数,用于返回输入张量中的元素总数。这个函数在处理张量时非常有用,特别是当你需要了解张量的大小或者在进行形状变换操作之前。

函数作用

  • 目的: 计算并返回一个张量中的元素总数。
  • 使用场景: 当需要知道张量中包含多少个元素时,比如在分配内存、调整张量形状或在进行某些特定的数学运算之前,了解张量的总元素数量是非常重要的。

使用方法和示例

以下是 torch.numel 的使用示例:

import torch

# 创建一个随机张量,形状为1x2x3x4x5
a = torch.randn(1, 2, 3, 4, 5)

# 使用torch.numel获取张量a的元素总数
print(torch.numel(a))  # 输出: 120

# 创建一个4x4的零张量
b = torch.zeros(4, 4)

# 使用torch.numel获取张量b的元素总数
print(torch.numel(b))  # 输出: 16

????????在这个示例中,我们首先创建了一个形状为 1x2x3x4x5 的张量 a,它有 120 个元素。随后,我们创建了一个 4x4 的零张量 b,它有 16 个元素。使用 torch.numel 函数,我们能够快速得到这些张量中的元素数量。?

set_printoptions

? ? torch.set_printoptions 是 PyTorch 库中的一个函数,用于设置张量打印时的格式选项。这个函数借鉴了 NumPy 的打印选项,使得在控制台或日志文件中输出的张量更易于阅读和理解。

函数作用

  • 目的: 自定义张量在打印时的显示格式,包括浮点数精度、元素的显示数量、每行的宽度等。
  • 使用场景: 当你需要在控制台输出或记录较大张量时,适当设置打印选项可以使输出更加清晰和有助于理解。这在调试和展示数据时特别有用。

参数详解

  • precision: 浮点数输出的数字精度(默认值为4)。
  • threshold: 触发摘要显示而非完整显示的元素总数(默认值为1000)。
  • edgeitems: 摘要显示时,在每个维度的开始和结束显示的元素数量(默认值为3)。
  • linewidth: 每行的字符数,用于插入换行符(默认值为80)。超过阈值的矩阵会忽略这个参数。
  • profile: 美化打印的默认设置。可以使用上述任何选项进行覆盖(可选值包括'default', 'short', 'full')。
  • sci_mode: 启用(True)或禁用(False)科学计数法。如果指定为None(默认),则值由内部的格式化设置决定。

使用方法和示例

以下是 torch.set_printoptions 的使用示例:

import torch

# 限制元素的精度
torch.set_printoptions(precision=2)
print(torch.tensor([1.12345]))  # 输出: tensor([1.12])

# 限制显示的元素数量
torch.set_printoptions(threshold=5)
print(torch.arange(10))  # 输出: tensor([0, 1, 2, ..., 7, 8, 9])

# 恢复默认设置
torch.set_printoptions(profile='default')
print(torch.tensor([1.12345]))  # 输出: tensor([1.1235])
print(torch.arange(10))  # 输出: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

????????在这个示例中,我们首先设置了浮点数的精度,然后设置了触发摘要显示的元素数量阈值。最后,我们使用 'default' 配置文件恢复了默认的打印设置。?

set_flush_denormal

? ? torch.set_flush_denormal 是 PyTorch 库中的一个函数,用于在 CPU 上禁用非正规化(denormal)浮点数。这个功能主要用于优化性能,因为处理非正规化浮点数在某些情况下可能会导致显著的性能下降。

函数作用

  • 目的: 禁用 CPU 上的非正规化浮点数。非正规化数是非常接近于零的小数,在浮点数运算中它们可能导致性能问题。
  • 使用场景: 当你遇到由于处理非正规化浮点数而导致的性能问题时,使用这个函数可以帮助提升性能。它特别适用于深度学习或高性能计算场景。

使用注意点

  • 系统支持: torch.set_flush_denormal 仅在支持 SSE3 指令集的 x86 架构上受支持。
  • 返回值: 如果系统支持刷新非正规化数并且成功配置了刷新非正规化模式,则返回 True

使用方法和示例

以下是 torch.set_flush_denormal 的使用示例:

import torch

# 启用刷新非正规化模式
result = torch.set_flush_denormal(True)
print(result)  # 输出: True,表示设置成功

# 创建一个非常小的浮点数张量,看其如何被处理
print(torch.tensor([1e-323], dtype=torch.float64))  # 输出: tensor([ 0.], dtype=torch.float64)

# 禁用刷新非正规化模式
result = torch.set_flush_denormal(False)
print(result)  # 输出: True,表示设置成功

# 再次创建同样的浮点数张量,看其如何被处理
print(torch.tensor([1e-323], dtype=torch.float64))  # 输出: tensor([9.88131e-324], dtype=torch.float64)

?????????在这个示例中,我们首先启用了刷新非正规化模式,然后创建了一个非常小的浮点数张量。由于刷新非正规化模式处于启用状态,这个非常小的数被处理为零。随后,我们禁用了该模式,再次创建同样的张量时,这次这个小数没有被处理为零。

总结

????????这篇博客全面而详细地介绍了 PyTorch 框架及其核心组件 torch 的多种关键功能和使用场景。从基本的多维张量数据结构和数学运算,到高级功能如序列化、CUDA 支持以及各种实用工具,每个方面都进行了深入的探讨。特别地,博客还详细演示了一系列 torch 相关函数(如 is_tensor, is_storage, is_complex, is_conj, is_floating_point, is_nonzero, set_default_dtype, get_default_dtype, set_default_device, set_default_tensor_type, numel, set_printoptions, set_flush_denormal)的用途、应用场景和示例使用方法。通过这些内容,读者可以获得对 PyTorch 功能的全面理解,明白如何在实际项目中有效地应用这些工具,以优化深度学习和科学计算的流程。?

文章来源:https://blog.csdn.net/qq_42452134/article/details/135336114
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。