Comments¶
Transformations in
torchvision.transforms
work on images, tensors (representing images) and possibly on numpy arrays (representing images). However, a transformation (e.g.,ToTensor
) might work differently on different input types. So you'd be clear about what exactly a transformation function does. A good practice is to always convert your non-tensor input data to tensors using the transformationToTensor
and then apply other transformation functions (which then consumes tensors and produces tensors).It is always a good idea to normalize your input tensors to be within a small range (e.g., [0, 1]).
import torch
import torchvision
import numpy as np
from PIL import Image
img = Image.open("../../home/media/poker/4h.png")
img
arr = np.array(img)
arr
arr.shape
torchvision.transforms.ToTensor¶
Converts a PIL Image or numpy.ndarray (H x W x C)
in the range [0, 255]
to a torch.FloatTensor
of shape (C x H x W)
in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
.
This is the transformation that you alway need when preparing dataset for a computer vision task.
trans = torchvision.transforms.ToTensor()
t1 = trans(img)
t1
t1.shape
t2 = trans(arr)
t2
t2.shape