Intro

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images 1.

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

torchvision.datasets.CIFAR10(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

Example

  • Load dataset
from torchvision import datasets, transforms
 
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_test  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
 
print("Training set has {} instances".format(len(cifar10_train)))
print("Test set has {} instances".format(len(cifar10_test)))
 
train_images, train_labels = cifar10_train.data, np.array(cifar10_train.targets)
test_images, test_labels = cifar10_test.data, np.array(cifar10_test.targets)
 
print("Train images shape:", train_images.shape)
print("Train labels shape:", train_labels.shape)
 
 
# For labels
class_names = cifar10_train.classes
print(f'class names: {class_names}')
Training set has 50000 instances
Test set has 10000 instances
Train images shape: (50000, 32, 32, 3)
Train labels shape: (50000,)
class names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  • Plot first image
# Plotting first image
plt.figure()
plt.imshow(train_images[0])  # first image in the dataset
plt.colorbar()
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.show()

Footnotes

  1. https://www.cs.toronto.edu/~kriz/cifar.html