Concatenating joins a sequence of tensors along an existing axis. The PyTorch function for concatenation is cat()
. Stacking joins a sequence of tensors along a new axis. The PyTorch function for stacking is stack()
.
This tutorial will go through the two PyTorch functions with code examples.
Table of contents
- PyTorch Cat
- Syntax
- Example
- PyTorch Stack
- Syntax
- Example
- PyTorch Cat Vs Stack
- Summary
PyTorch Cat
We can use the PyTorch cat()
function to concatenate a sequence of tensors along the same dimension. The tensors must have the same shape (except in the concatenating dimension) or be empty.
Syntax
torch.cat(tensors, dim=0, *, out=None)
Parameters
- tensors (sequence of Tensors): Required. Any Python sequence of tensors of the same type. Non-empty tensors must have the same shape except in the concatenating dimension.
- dim (int): Optional. The dimension to concatenate the tensors over.
Keyword Arguments
- out (Tensor): Optional. Output tensor
Example
Let’s look at an example where we concatenate three tensors into one tensor using cat()
. First, we have to import the PyTorch library and then use the tensor()
function to create the tensors:
import torchx = torch.tensor([2, 3, 4, 5])y = torch.tensor([4, 10, 30])z = torch.tensor([7, 22, 4, 8, 3, 6])
Next, we can concatenate the tensors along the 0th dimension, the only available axis.
xyz = torch.cat((x, y, z), dim=0)print(xyz)print(xyz.shape)
Let’s run the code to see the result:
tensor([ 2, 3, 4, 5, 4, 10, 30, 7, 22, 4, 8, 3, 6])
PyTorch Stack
We can use the PyTorch stack()
function to concatenate a sequence of tensors along a new dimension. The tensors must have the same shape.
Syntax
torch.stack(tensors, dim=0, *, out=None)
Parameters
- tensors (sequence of Tensors): Required. Python sequence of tensors of the same size.
- dim (int): Optional. The new dimension to insert. The dimension must be between 0 and the number of dimensions of concatenated tensors.
Keyword Arguments
- out (Tensor): Optional. Output tensor
Example
Let’s look at an example where we stack three tensors into one tensor using stack()
. First, we have to import the PyTorch library and then use the tensor()
function to create the tensors:
import torchx = torch.tensor([2, 3, 4, 5])y = torch.tensor([4, 10, 30, 40])z = torch.tensor([8, 7, 16, 14])
In the above code, the tensors x
, y
, and z
are one-dimensional, each having four elements. Next, we will stack the tensors along dim=0
and dim=1
.
# Stacking Tensors using dimension 0stacked_0 = torch.stack((x, y, z), dim=0)# Stacking Tensors using dimension 1stacked_1 = torch.stack((x,y, z), dim=1)# Resultant combined tensor with new axes along dimension 0print(stacked_0)# Shape of combined tensorprint(stacked_0.shape)# Resultant combined tensor with new axes along dimension 1print(stacked_1)# Shape of combined tensorprint(stacked_1.shape)
Let’s run the code to get the result:
tensor([[ 2, 3, 4, 5], [ 4, 10, 30, 40], [ 8, 7, 16, 14]])torch.Size([3, 4])tensor([[ 2, 4, 8], [ 3, 10, 7], [ 4, 30, 16], [ 5, 40, 14]])torch.Size([4, 3])
The resultant concatenated tensor is two-dimensional. As the individual tensors are one-dimensional, we can stack them with dimensions 0 and 1.
With dim=0
the tensors are stacked row-wise, giving us a 3×4 matrix. With dim=1
we transpose the tensors and stack them column-wise, giving us a 4×3 matrix.
PyTorch Cat Vs Stack
The two PyTorch functions offer similar functionality but differ in how they concatenate tensors. The cat()
function concatenates tensors along the existing dimension. The stack()
function concatenates tensors along a new dimension not present in the individual tensors.
We can derive the same results of the stack()
function using the cat()
function. We can apply the unsqueeze operation to each tensor before passing them to the cat() function to get the same result. Let’s look at the result with the tensors from the previous example
import torchx = torch.tensor([2, 3, 4, 5])y = torch.tensor([4, 10, 30, 40])z = torch.tensor([8, 7, 16, 14])xyz = torch.cat((x.unsqueeze(0), y.unsqueeze(0), z.unsqueeze(0)), dim=0)print(xyz)print(xyz.shape)
The unsqueeze operation adds a new dimension of length one to the tensors, and then we concatenate along the first axis. Let’s run the code to get the result:
tensor([[ 2, 3, 4, 5], [ 4, 10, 30, 40], [ 8, 7, 16, 14]])torch.Size([3, 4])
Therefore torch.stack((A, B), dim=0)
is equivalent to torch.cat((A.unsqueeze(0), B.unsqueeze(0)), dim=0
Summary
Congratulations on reading to the end of this tutorial. We have gone through how to concatenate tensors using both cat()
and stack()
and explained the differences between the two functions.
For further reading on PyTorch, go to the article: How to Convert NumPy Array to PyTorch Tensor.
To learn more about Python for data science and machine learning, go to theonline courses page on Pythonfor the most comprehensive courses available.
Have fun and happy researching!