NumPy expand_dims() 函数扩展数组的形状。它插入一个新轴,该轴将出现在扩展数组形状中的轴位置。
语法
numpy.expand_dims(a, axis)
参数
a | 必填。 指定输入数组。 |
axis | 必填。 指定扩展轴中放置新轴(或多个轴)的位置。它可以是整数或整数元组。 |
返回值
返回a的视图,其中维数增加。
示例:
在下面的示例中,数组在给定轴上扩展。
import numpy as np
x = np.array([1, 2, 3])
#在轴上扩展x的维度=0
x1 = np.expand_dims(x, axis=0)
#扩展 x 在轴上的维度=1
x2 = np.expand_dims(x, axis=1)
#扩展x在轴上的维度=(0,1)
x3 = np.expand_dims(x, axis=(0,1))
#显示结果
print("shape of x:", x.shape)
print("x contains:")
print(x)
print("\nshape of x1:", x1.shape)
print("x1 contains:")
print(x1)
print("\nshape of x2:", x2.shape)
print("x2 contains:")
print(x2)
print("\nshape of x3:", x3.shape)
print("x3 contains:")
print(x3)
上述代码的输出将为:
shape of x: (3,)
x contains:
[1 2 3]
shape of x1: (1, 3)
x1 contains:
[[1 2 3]]
shape of x2: (3, 1)
x2 contains:
[[1]
[2]
[3]]
shape of x3: (1, 1, 3)
x3 contains:
[[[1 2 3]]]