shuffle = False時,不打亂數(shù)據(jù)順序
shuffle = True,隨機(jī)打亂
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
import numpy as np import h5py import torch from torch.utils.data import DataLoader, Dataset h5f = h5py. File ( 'train.h5' , 'w' ); data1 = np.array([[ 1 , 2 , 3 ], [ 2 , 5 , 6 ], [ 3 , 5 , 6 ], [ 4 , 5 , 6 ]]) data2 = np.array([[ 1 , 1 , 1 ], [ 1 , 2 , 6 ], [ 1 , 3 , 6 ], [ 1 , 4 , 6 ]]) h5f.create_dataset( str ( 'data' ), data = data1) h5f.create_dataset( str ( 'label' ), data = data2) class Dataset(Dataset): def __init__( self ): h5f = h5py. File ( 'train.h5' , 'r' ) self .data = h5f[ 'data' ] self .label = h5f[ 'label' ] def __getitem__( self , index): data = torch.from_numpy( self .data[index]) label = torch.from_numpy( self .label[index]) return data, label def __len__( self ): assert self .data.shape[ 0 ] = = self .label.shape[ 0 ], "wrong data length" return self .data.shape[ 0 ] dataset_train = Dataset() loader_train = DataLoader(dataset = dataset_train, batch_size = 2 , shuffle = True ) for i, data in enumerate (loader_train): train_data, label = data print (train_data) |
pytorch DataLoader使用細(xì)節(jié)
背景:
我一開始是對數(shù)據(jù)擴(kuò)增這一塊有疑問, 只看到了數(shù)據(jù)變換(torchvisiom.transforms),但是沒看到數(shù)據(jù)擴(kuò)增, 后來搞明白了, 數(shù)據(jù)擴(kuò)增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多個epoch共同作用下完成的,
數(shù)據(jù)變換共有以下內(nèi)容
1
2
3
4
5
|
composed = transforms.Compose([transforms.Resize(( 448 , 448 )), # resize transforms.RandomCrop( 300 ), # random crop transforms.ToTensor(), transforms.Normalize(mean = [ 0.5 , 0.5 , 0.5 ], # normalize std = [ 0.5 , 0.5 , 0.5 ])]) |
簡單的數(shù)據(jù)讀取類, 進(jìn)返回PIL格式的image:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
class MyDataset(data.Dataset): def __init__( self , labels_file, root_dir, transform = None ): with open (labels_file) as csvfile: self .labels_file = list (csv.reader(csvfile)) self .root_dir = root_dir self .transform = transform def __len__( self ): return len ( self .labels_file) def __getitem__( self , idx): im_name = os.path.join(root_dir, self .labels_file[idx][ 0 ]) im = Image. open (im_name) if self .transform: im = self .transform(im) return im |
下面是主程序
1
2
3
4
5
6
7
8
9
10
11
|
labels_file = "F:/test_temp/labels.csv" root_dir = "F:/test_temp" dataset_transform = MyDataset(labels_file, root_dir, transform = composed) dataloader = data.DataLoader(dataset_transform, batch_size = 1 , shuffle = False ) """原始數(shù)據(jù)集共3張圖片, 以batch_size=1, epoch為2 展示所有圖片(共6張) """ for eopch in range ( 2 ): plt.figure(figsize = ( 6 , 6 )) for ind, i in enumerate (dataloader): a = i[ 0 , :, :, :].numpy().transpose(( 1 , 2 , 0 )) plt.subplot( 1 , 3 , ind + 1 ) plt.imshow(a) |
從上述圖片總可以看到, 在每個eopch階段實(shí)際上是對原始圖片重新使用了transform, , 這就造就了數(shù)據(jù)的擴(kuò)增
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持服務(wù)器之家。
原文鏈接:https://blog.csdn.net/qq_35752161/article/details/110875040