Pytorch学习torch.clamp ()用法浅析

2024-01-10 04:01:15

首先给出官方对此函数的定义网页:torch.clamp — PyTorch 2.1 documentation

一、官方定义

torch.clamp(input, min=None, max=None, *, out=None) → Tensor

其中:

  • input: 输入张量,即需要进行元素限制的张量。
  • min: 张量中的元素的最小值。如果元素小于这个值,将被替换为这个最小值。
  • max: 张量中的元素的最大值。如果元素大于这个值,将被替换为这个最大值。
  • out (可选): 输出张量,用于保存结果。如果没有提供,函数会创建一个新的张量来保存结果。

二、作用详解

  1. 将元素限制在指定范围内: 对于输入张量 input 中的每个元素,torch.clamp 将其限制在指定的范围 [min, max] 内。如果元素小于 min,就被替换为 min;如果元素大于 max,就被替换为 max

  2. 示例:

    import torch
    
    x = torch.tensor([1, 5, 10, -3, 8])
    result = torch.clamp(x, min=2, max=8)
    
    print(result)
    

    输出:

    tensor([2, 5, 8, 2, 8])
    

    在这个例子中,torch.clamp 将张量 x 中的元素限制在范围 [2, 8] 内,小于2的元素被替换为2,大于8的元素被替换为8。

三、注意事项

  1. 如果参数中未指定min,则不限制张量的下边界;如果参数中未指定max,则不限制张量的上边界;如果minmax均未提供,则不进行任何限制,函数返回的张量将和原始张量保持一致。

    示例

    import torch
    
    x = torch.tensor([1, 5, 10, -3, 8])
    result = torch.clamp(x, max=8)#未指定min值,则不限制下边界
    
    print(result)
    

    输出

    tensor([ 1,  5,  8, -3,  8])
    
  2. min和max的指定并不要求为整数,可以为浮点数,如下示例中,张量的元素被限制在[-2.5,8.7]

    示例

    import torch
    
    x = torch.tensor([1, 5, 10, -3, 8])
    result = torch.clamp(x, min=2.5, max=8.7)
    
    print(result)
    

    输出

    tensor([2.5000, 5.0000, 8.7000, 2.5000, 8.0000])
    

文章来源:https://blog.csdn.net/qq_46396470/article/details/135418165
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。