首页 文章资讯内容详情

如何缩小 PyTorch 中的张量?

2026-06-02 2 花语

torch.narrow()方法用于对PyTorch张量执行窄操作。它返回一个新的张量,它是原始输入张量的缩小版本。

例如,[4,3]的张量可以缩小为[2,3]或[4,2]大小的张量。我们可以一次缩小一个维度上的张量。在这里,我们不能将两个维度都缩小到[2,2]的大小。我们也可以用来缩小张量的范围。Tensor.narrow()

语法

torch.narrow(input, dim, start, length) Tensor.narrow(dim, start, length)

参数

输入——它是要缩小的PyTorch张量。

dim–这是我们必须缩小原始张量输入的维度。

开始-开始维度。

长度–从起始尺寸到结束尺寸的长度。

脚步

导入火炬库。确保您已经安装了它。

import torch

创建一个PyTorch张量并打印张量及其大小。

t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) print("Tensor:\n", t) print("张量的大小:", t.size()) # size 3x3

计算并将值分配给变量。torch.narrow(input,dim,start,length)

t1 = torch.narrow(t, 0, 1, 2)

缩小后打印结果张量及其大小。

print("Tensor after Narrowing:\n", t2) print("缩小后的尺寸:", t2.size())

示例1

在以下Python代码中,输入张量大小为[3,3]。我们使用dim=0,start=1和length=2沿维度0缩小张量。它返回一个维度为[2,3]的新张量。

请注意,新张量沿维度0变窄,沿维度0的长度更改为2。

# import the library import torch # create a tensor t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # print the created tensor print("Tensor:\n", t) print("张量的大小:", t.size()) # Narrow-down the tensor in dimension 0 t1 = torch.narrow(t, 0, 1, 2) print("Tensor after Narrowing:\n", t1) print("缩小后的尺寸:", t1.size()) # Narrow down the tensor in dimension 1 t2 = torch.narrow(t, 1, 1, 2) print("Tensor after Narrowing:\n", t2) print("缩小后的尺寸:", t2.size())输出结果Tensor: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 张量的大小: torch.Size([3, 3]) Tensor after Narrowing: tensor([[4, 5, 6], [7, 8, 9]]) 缩小后的尺寸: torch.Size([2, 3]) Tensor after Narrowing: tensor([[2, 3], [5, 6], [8, 9]]) 缩小后的尺寸: torch.Size([3, 2])

示例2

以下程序显示了如何使用.Tensor.narrow()

# import required library import torch # create a tensor t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) # print the above created tensor print("Tensor:\n", t) print("张量的大小:", t.size()) # Narrow-down the tensor in dimension 0 t1 = t.narrow(0, 1, 2) print("Tensor after Narrowing:\n", t1) print("缩小后的尺寸:", t1.size()) # Narrow down the tensor in dimension 1 t2 = t.narrow(1, 0, 2) print("Tensor after Narrowing:\n", t2) print("缩小后的尺寸:", t2.size())输出结果Tensor: tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) 张量的大小: torch.Size([4, 3]) Tensor after Narrowing: tensor([[4, 5, 6], [7, 8, 9]]) 缩小后的尺寸: torch.Size([2, 3]) Tensor after Narrowing: tensor([[ 1, 2], [ 4, 5], [ 7, 8], [10, 11]]) 缩小后的尺寸: torch.Size([4, 2])