使用Keras如果要使用大規模數據集對網絡進行訓練,就沒辦法先加載進內存再從內存直接傳到顯存了,除了使用Sequence類以外,還可以使用迭代器去生成數據,但迭代器無法在fit_generation里開啟多進程,會影響數據的讀取和預處理效率,在本文中就不在敘述了,有需要的可以另外去百度。
下面是我所使用的代碼
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
class SequenceData(Sequence): def __init__( self , path, batch_size = 32 ): self .path = path self .batch_size = batch_size f = open (path) self .datas = f.readlines() self .L = len ( self .datas) self .index = random.sample( range ( self .L), self .L) #返回長度,通過len(<你的實例>)調用 def __len__( self ): return self .L - self .batch_size #即通過索引獲取a[0],a[1]這種 def __getitem__( self , idx): batch_indexs = self .index[idx:(idx + self .batch_size)] batch_datas = [ self .datas[k] for k in batch_indexs] img1s,img2s,audios,labels = self .data_generation(batch_datas) return ({ 'face1_input_1' : img1s, 'face2_input_2' : img2s, 'input_3' :audios},{ 'activation_7' :labels}) def data_generation( self , batch_datas): #預處理操作 return img1s,img2s,audios,labels |
然后在代碼里通過fit_generation函數調用并訓練
這里要注意,use_multiprocessing參數是是否開啟多進程,由于python的多線程不是真的多線程,所以多進程還是會獲得比較客觀的加速,但不支持windows,windows下python無法使用多進程。
1
2
3
4
|
D = SequenceData( 'train.csv' ) model_train.fit_generator(generator = D,steps_per_epoch = int ( len (D)), epochs = 2 , workers = 20 , #callbacks=[checkpoint], use_multiprocessing = True , validation_data = SequenceData( 'vali.csv' ),validation_steps = int ( 20000 / 32 )) |
同樣的,也可以在測試的時候使用
model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)
補充知識:keras數據自動生成器,繼承keras.utils.Sequence,結合fit_generator實現節約內存訓練
我就廢話不多說了,大家還是直接看代碼吧~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
|
#coding=utf-8 ''' Created on 2018-7-10 ''' import keras import math import os import cv2 import numpy as np from keras.models import Sequential from keras.layers import Dense class DataGenerator(keras.utils.Sequence): def __init__( self , datas, batch_size = 1 , shuffle = True ): self .batch_size = batch_size self .datas = datas self .indexes = np.arange( len ( self .datas)) self .shuffle = shuffle def __len__( self ): #計算每一個epoch的迭代次數 return math.ceil( len ( self .datas) / float ( self .batch_size)) def __getitem__( self , index): #生成每個batch數據,這里就根據自己對數據的讀取方式進行發揮了 # 生成batch_size個索引 batch_indexs = self .indexes[index * self .batch_size:(index + 1 ) * self .batch_size] # 根據索引獲取datas集合中的數據 batch_datas = [ self .datas[k] for k in batch_indexs] # 生成數據 X, y = self .data_generation(batch_datas) return X, y def on_epoch_end( self ): #在每一次epoch結束是否需要進行一次隨機,重新隨機一下index if self .shuffle = = True : np.random.shuffle( self .indexes) def data_generation( self , batch_datas): images = [] labels = [] # 生成數據 for i, data in enumerate (batch_datas): #x_train數據 image = cv2.imread(data) image = list (image) images.append(image) #y_train數據 right = data.rfind( "\\" , 0 ) left = data.rfind( "\\" , 0 ,right) + 1 class_name = data[left:right] if class_name = = "dog" : labels.append([ 0 , 1 ]) else : labels.append([ 1 , 0 ]) #如果為多輸出模型,Y的格式要變一下,外層list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3] return np.array(images), np.array(labels) # 讀取樣本名稱,然后根據樣本名稱去讀取數據 class_num = 0 train_datas = [] for file in os.listdir( "D:/xxx" ): file_path = os.path.join( "D:/xxx" , file ) if os.path.isdir(file_path): class_num = class_num + 1 for sub_file in os.listdir(file_path): train_datas.append(os.path.join(file_path, sub_file)) # 數據生成器 training_generator = DataGenerator(train_datas) #構建網絡 model = Sequential() model.add(Dense(units = 64 , activation = 'relu' , input_dim = 784 )) model.add(Dense(units = 2 , activation = 'softmax' )) model. compile (loss = 'categorical_crossentropy' , optimizer = 'sgd' , metrics = [ 'accuracy' ]) model. compile (optimizer = 'sgd' , loss = 'categorical_crossentropy' , metrics = [ 'accuracy' ]) model.fit_generator(training_generator, epochs = 50 ,max_queue_size = 10 ,workers = 1 ) |
以上這篇keras使用Sequence類調用大規模數據集進行訓練的實現就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。
原文鏈接:https://blog.csdn.net/qq_22033759/article/details/88798423