DataLoader完整的參數(shù)表如下:
1
2
3
4
5
6
7
8
9
10
11
12
|
class torch.utils.data.DataLoader( dataset, batch_size = 1 , shuffle = False , sampler = None , batch_sampler = None , num_workers = 0 , collate_fn = <function default_collate>, pin_memory = False , drop_last = False , timeout = 0 , worker_init_fn = None ) |
DataLoader在數(shù)據(jù)集上提供單進(jìn)程或多進(jìn)程的迭代器
幾個(gè)關(guān)鍵的參數(shù)意思:
- shuffle:設(shè)置為True的時(shí)候,每個(gè)世代都會(huì)打亂數(shù)據(jù)集
- collate_fn:如何取樣本的,我們可以定義自己的函數(shù)來準(zhǔn)確地實(shí)現(xiàn)想要的功能
- drop_last:告訴如何處理數(shù)據(jù)集長(zhǎng)度除于batch_size余下的數(shù)據(jù)。True就拋棄,否則保留
一個(gè)測(cè)試的例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
import torch import torch.utils.data as Data import numpy as np test = np.array([ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 ]) inputing = torch.tensor(np.array([test[i:i + 3 ] for i in range ( 10 )])) target = torch.tensor(np.array([test[i:i + 1 ] for i in range ( 10 )])) torch_dataset = Data.TensorDataset(inputing,target) batch = 3 loader = Data.DataLoader( dataset = torch_dataset, batch_size = batch, # 批大小 # 若dataset中的樣本數(shù)不能被batch_size整除的話,最后剩余多少就使用多少 collate_fn = lambda x:( torch.cat( [x[i][j].unsqueeze( 0 ) for i in range ( len (x))], 0 ).unsqueeze( 0 ) for j in range ( len (x[ 0 ])) ) ) for (i,j) in loader: print (i) print (j) |
輸出結(jié)果:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
tensor([[[ 0 , 1 , 2 ], [ 1 , 2 , 3 ], [ 2 , 3 , 4 ]]], dtype = torch.int32) tensor([[[ 0 ], [ 1 ], [ 2 ]]], dtype = torch.int32) tensor([[[ 3 , 4 , 5 ], [ 4 , 5 , 6 ], [ 5 , 6 , 7 ]]], dtype = torch.int32) tensor([[[ 3 ], [ 4 ], [ 5 ]]], dtype = torch.int32) tensor([[[ 6 , 7 , 8 ], [ 7 , 8 , 9 ], [ 8 , 9 , 10 ]]], dtype = torch.int32) tensor([[[ 6 ], [ 7 ], [ 8 ]]], dtype = torch.int32) tensor([[[ 9 , 10 , 11 ]]], dtype = torch.int32) tensor([[[ 9 ]]], dtype = torch.int32) |
如果不要collate_fn的值,輸出變成
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
tensor([[ 0 , 1 , 2 ], [ 1 , 2 , 3 ], [ 2 , 3 , 4 ]], dtype = torch.int32) tensor([[ 0 ], [ 1 ], [ 2 ]], dtype = torch.int32) tensor([[ 3 , 4 , 5 ], [ 4 , 5 , 6 ], [ 5 , 6 , 7 ]], dtype = torch.int32) tensor([[ 3 ], [ 4 ], [ 5 ]], dtype = torch.int32) tensor([[ 6 , 7 , 8 ], [ 7 , 8 , 9 ], [ 8 , 9 , 10 ]], dtype = torch.int32) tensor([[ 6 ], [ 7 ], [ 8 ]], dtype = torch.int32) tensor([[ 9 , 10 , 11 ]], dtype = torch.int32) tensor([[ 9 ]], dtype = torch.int32) |
所以collate_fn就是使結(jié)果多一維。
看看collate_fn的值是什么意思。我們把它改為如下
1
|
collate_fn = lambda x:x |
并輸出
1
2
|
for i in loader: print (i) |
得到結(jié)果
1
2
3
4
|
[(tensor([ 0 , 1 , 2 ], dtype = torch.int32), tensor([ 0 ], dtype = torch.int32)), (tensor([ 1 , 2 , 3 ], dtype = torch.int32), tensor([ 1 ], dtype = torch.int32)), (tensor([ 2 , 3 , 4 ], dtype = torch.int32), tensor([ 2 ], dtype = torch.int32))] [(tensor([ 3 , 4 , 5 ], dtype = torch.int32), tensor([ 3 ], dtype = torch.int32)), (tensor([ 4 , 5 , 6 ], dtype = torch.int32), tensor([ 4 ], dtype = torch.int32)), (tensor([ 5 , 6 , 7 ], dtype = torch.int32), tensor([ 5 ], dtype = torch.int32))] [(tensor([ 6 , 7 , 8 ], dtype = torch.int32), tensor([ 6 ], dtype = torch.int32)), (tensor([ 7 , 8 , 9 ], dtype = torch.int32), tensor([ 7 ], dtype = torch.int32)), (tensor([ 8 , 9 , 10 ], dtype = torch.int32), tensor([ 8 ], dtype = torch.int32))] [(tensor([ 9 , 10 , 11 ], dtype = torch.int32), tensor([ 9 ], dtype = torch.int32))] |
每個(gè)i都是一個(gè)列表,每個(gè)列表包含batch_size個(gè)元組,每個(gè)元組包含TensorDataset的單獨(dú)數(shù)據(jù)。所以要將重新組合成每個(gè)batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我們的collate_fn:
1
2
3
4
5
|
collate_fn = lambda x:( torch.cat( [x[i][j].unsqueeze( 0 ) for i in range ( len (x))], 0 ).unsqueeze( 0 ) for j in range ( len (x[ 0 ])) ) |
j取的是兩個(gè)變量:input和target。i取的是batch_size。然后通過unsqueeze(0)方法在前面加一維。torch.cat(,0)將其打包起來。然后再通過unsqueeze(0)方法在前面加一維。 完成。
以上這篇Pytorch技巧:DataLoader的collate_fn參數(shù)使用詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持服務(wù)器之家。
原文鏈接:https://blog.csdn.net/weixin_42028364/article/details/81675021