make_training.py
Go to the documentation of this file.
1 ############################################
2 # First load all the modules in the begining
3 ############################################
4 from __future__ import print_function
5 import os
6 
7 # To save dictionaries
8 import pickle
9 
10 # Keras utilities
11 import keras
12 from keras.models import model_from_json, load_model
13 from keras.optimizers import SGD
14 from keras.utils import np_utils
15 
16 import h5py
17 import numpy as np
18 
19 # Local models
20 import models
21 import generator
22 from multi_gpu import make_parallel
23 
24 dataset = "train.h5"
25 hf = h5py.File(dataset, 'r')
26 n1 = hf.get('data')
27 total_count = n1.shape[0]
28 hf.close()
29 
30 def train_model(model, dataset, validation_ratio=0.2, batch_size=64):
31  with h5py.File(dataset, "r") as data:
32 
33  total_ids = range(0, total_count)
34  total_ids = np.random.permutation(total_ids)
35  train_total_ids = total_ids[0:int((1-validation_ratio)*total_count)]
36  test_total_ids = total_ids[int((1-validation_ratio)*total_count):]
37 
38  training_sequence_generator = generator.produce_seq(batch_size=batch_size,
39  data=data, sample_ids=train_total_ids)
40  validation_sequence_generator = generator.produce_seq(batch_size=batch_size,
41  data=data, sample_ids=test_total_ids)
42 
43  history = model.fit_generator(generator=training_sequence_generator,
44  validation_data=validation_sequence_generator,
45  samples_per_epoch=len(train_total_ids),
46  nb_val_samples=len(test_total_ids),
47  nb_epoch=1,
48  max_q_size=1,
49  verbose=1,
50  class_weight=None,
51  nb_worker=1)
52 
53  directory = 'logs/'
54  if not os.path.exists(directory):
55  os.makedirs(directory)
56 
57  # Save the history/dictonary to plot it later
58  with open(directory + 'history.pickle', 'wb') as handle:
59  pickle.dump(history.history, handle, protocol=2)
60  print("The training/testing logs saved")
61 
62  # serialize model to JSON
63  model_json = model.to_json()
64  with open(directory + "model.json", "w") as json_file:
65  json_file.write(model_json)
66  print("The arch saved")
67 
68  #serialize weights to HDF5
69  model.save_weights(directory + "model_weights.h5")
70  model.save(directory + 'model_4recover.h5')
71  print("The weights and model saved")
72 
73 
74 model = models.CVN(5)
75 #model = make_parallel(model, 2)
76 learning_rate = 0.02
77 decay_rate = 0.1
78 momentum = 0.9
79 opt = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)
80 model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['acc','top_k_categorical_accuracy'])
81 
82 train_model(model, dataset)
def train_model(model, dataset, validation_ratio=0.2, batch_size=64)
def produce_seq(batch_size, data, sample_ids)
Definition: generator.py:7
bool print
procfile open("FD_BRL_v0.txt")
def CVN(num_classes)
Definition: models.py:37