NumPy 函数

NumPy squeeze() 函数从 a 中删除长度为 1 的轴。

语法

numpy.squeeze(a, axis=None) 

参数

a必填。 指定输入数组。
axis可选。 它可以是 None 或 int 或 int 元组。选择形状中长度为 1 的条目的子集。如果选择的轴的形状条目大于 1,则会引发错误。

返回值

返回输入数组,但删除了长度为 1 的所有维度或部分维度。

示例:

在下面的示例中,squeeze() 函数用于删除给定数组中长度为 1 的轴。

import numpy as np

Arr = np.array([[[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]]])

#挤压数组
Arr1 = np.squeeze(Arr)

#显示结果
print("shape of Arr:", Arr.shape)
print("Arr is:")
print(Arr)

print("\nshape of Arr1:", Arr1.shape)
print("Arr1 is (squeeze with axis=None):")
print(Arr1) 

上述代码的输出将是:

shape of Arr: (1, 3, 3)
Arr is:
[[[1 2 3]
  [4 5 6]
  [7 8 9]]]

shape of Arr1: (3, 3)
Arr1 is (squeeze with axis=None):
[[1 2 3]
 [4 5 6]
 [7 8 9]] 

示例:

数组不能被挤压在形状大于 1 的轴上。考虑下面的示例:

import numpy as np

#创建形状为 (1, 3, 1) 的数组
Arr = np.array([[[1],[2],[3]]])

#数组不能在轴上挤压
#其中形状大于一。例如
#at axis=1,Arr 的形状 = 3,因此
#在 axis=1 处挤压它会引发异常

#squeezeing 到达轴=0
Arr1 = np.squeeze(Arr, axis=0)

#squeezeing 到达轴=2
Arr2 = np.squeeze(Arr, axis=2)

#显示结果
print("shape of Arr:", Arr.shape)
print("Arr is:")
print(Arr)

print("\nshape of Arr1:", Arr1.shape)
print("Arr1 is (squeeze with axis=0):")
print(Arr1)

print("\nshape of Arr2:", Arr2.shape)
print("Arr2 is (squeeze with axis=2):")
print(Arr1) 

上述代码的输出将是:

shape of Arr: (1, 3, 1)
Arr is:
[[[1]
  [2]
  [3]]]

shape of Arr1: (3, 1)
Arr1 is (squeeze with axis=0):
[[1]
 [2]
 [3]]

shape of Arr2: (1, 3)
Arr2 is (squeeze with axis=2):
[[1]
 [2]
 [3]]