"Transform" Numpy Arrray: Move Dimension

1598 views python
3

I'm creating array a:

import numpy as np
a = np.zeros((3, 10, 10), np.uint8)
a[1,5,5] = 255

with a red dot in the center, where the RGB is the first dimension. Then I plot it using matplotlib:

import matplotlib.pyplot as plt
plt.imshow(a)

But of course this doesn't work because imshow expects an array with dimensions (10, 10, 3) and I am feeding it an array with dimensions (3, 10, 10). How could I 'flip' the array so that the RGB is the third dimension, instead of the first?

answered question

1 Answer

11

What you need is swapaxes.

import numpy as np
a = np.zeros((3, 10, 10), np.uint8)

print(a.shape) //(3,10,10)

print(np.swapaxes(a,0,2).shape) //(10,10,3)

See documentation.

posted this

Have an answer?

JD

Please login first before posting an answer.