The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images 1.
from torchvision import datasets, transformscifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)print("Training set has {} instances".format(len(cifar10_train)))print("Test set has {} instances".format(len(cifar10_test)))train_images, train_labels = cifar10_train.data, np.array(cifar10_train.targets)test_images, test_labels = cifar10_test.data, np.array(cifar10_test.targets)print("Train images shape:", train_images.shape)print("Train labels shape:", train_labels.shape)# For labelsclass_names = cifar10_train.classesprint(f'class names: {class_names}')
Training set has 50000 instances
Test set has 10000 instances
Train images shape: (50000, 32, 32, 3)
Train labels shape: (50000,)
class names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Plot first image
# Plotting first imageplt.figure()plt.imshow(train_images[0]) # first image in the datasetplt.colorbar()plt.grid(False)plt.xticks([])plt.yticks([])plt.show()