NumPy 函数

NumPy expand_dims() 函数扩展数组的形状。它插入一个新轴,该轴将出现在扩展数组形状中的轴位置。

语法

numpy.expand_dims(a, axis) 

参数

a必填。 指定输入数组。
axis必填。 指定扩展轴中放置新轴(或多个轴)的位置。它可以是整数或整数元组。

返回值

返回a的视图,其中维数增加。

示例:

在下面的示例中,数组在给定轴上扩展。

import numpy as np

x = np.array([1, 2, 3])

#在轴上扩展x的维度=0
x1 = np.expand_dims(x, axis=0)

#扩展 x 在轴上的维度=1
x2 = np.expand_dims(x, axis=1)

#扩展x在轴上的维度=(0,1)
x3 = np.expand_dims(x, axis=(0,1))

#显示结果
print("shape of x:", x.shape)
print("x contains:")
print(x)
print("\nshape of x1:", x1.shape)
print("x1 contains:")
print(x1)
print("\nshape of x2:", x2.shape)
print("x2 contains:")
print(x2)
print("\nshape of x3:", x3.shape)
print("x3 contains:")
print(x3) 

上述代码的输出将为:

shape of x: (3,)
x contains:
[1 2 3]

shape of x1: (1, 3)
x1 contains:
[[1 2 3]]

shape of x2: (3, 1)
x2 contains:
[[1]
 [2]
 [3]]

shape of x3: (1, 1, 3)
x3 contains:
[[[1 2 3]]]