Module delta.ml.train

Train neural networks.

Expand source code
# Copyright © 2020, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All rights reserved.
#
# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
# licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Train neural networks.
"""

import datetime
import os
import tempfile
import shutil

import mlflow
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K #pylint: disable=no-name-in-module
from tensorflow.keras.layers import Layer #pylint: disable=no-name-in-module

from delta.config import config
from delta.imagery.imagery_dataset import ImageryDataset
from delta.imagery.imagery_dataset import AutoencoderDataset
from .io import save_model, load_model, print_network
from .config_parser import config_callbacks, loss_from_dict, metric_from_dict, \
                           optimizer_from_dict, config_augmentation

class DeltaLayer(Layer):
    """
    Network layer class with extra features specific to DELTA.

    Extentds `tensorflow.keras.layers.Layer`.
    """
    def callback(self): # pylint:disable=no-self-use
        """
        Override this method to make a layer automatically register
        a training callback.

        Returns
        -------
        tensorflow.keras.callbacks.Callback:
            The callback to register (or None).
        """
        return None

def _devices(num_gpus):
    '''
    Takes a number of GPUs and returns a list of TensorFlow LogicalDevices.

    Arguments

    num_gpus -- Number of GPUs to use.  If negative, will use all CPUs available.
    '''
    devs = None
    if num_gpus == 0:
        devs = [x.name for x in tf.config.list_logical_devices('CPU')]
    else:
        devs = [x.name for x in tf.config.list_logical_devices('GPU')]
        assert len(devs) >= num_gpus,\
               "Requested %d GPUs with only %d available." % (num_gpus, len(devs))
        if num_gpus > 0:
            devs = devs[:num_gpus]
    return devs

def _strategy(devices):
    '''Given a list of TensorFlow Logical Devices, returns a distribution strategy.'''
    strategy = None
    if len(devices) == 1:
        strategy = tf.distribute.OneDeviceStrategy(device=devices[0])
    else:
        strategy = tf.distribute.MirroredStrategy(devices=devices)
    return strategy

def _prep_datasets(ids, tc):
    if tc.max_tile_offset:
        # with filtering nodata, number of tiles changes
        assert tc.steps, 'max_tile_offset only supported with steps set.'
    ds = ids.dataset(config.dataset.classes.weights(), config_augmentation())

    validation=None
    if tc.validation:
        if tc.validation.from_training:
            validation = ds.take(tc.validation.steps)
            ds = ds.skip(tc.validation.steps)
        else:
            vimg   = tc.validation.images
            vlabel = tc.validation.labels
            if not vimg:
                validation = None
            else:
                if vlabel:
                    vimagery = ImageryDataset(vimg, vlabel, ids.output_shape(), ids.chunk_shape(),
                                              tile_shape=ids.tile_shape(), stride=ids.stride(),
                                              tile_overlap=ids.tile_overlap())
                else:
                    vimagery = AutoencoderDataset(vimg, ids.chunk_shape(), tile_shape=ids.tile_shape(),
                                                  stride=ids.stride(), tile_overlap=ids.tile_overlap())
                validation = vimagery.dataset(config.dataset.classes.weights())
        if validation:
            validation = validation.batch(tc.batch_size, drop_remainder=True)
    else:
        validation = None

    ds = ds.batch(tc.batch_size, drop_remainder=True)
    return (ds, validation)

def _log_mlflow_params(model, dataset, training_spec):
    images = dataset.image_set()
    #labels = dataset.label_set()
    mlflow.log_param('Images - Type',   images.type())
    mlflow.log_param('Images - Count',   len(images))
    mlflow.log_param('Images - Stride', training_spec.stride)
    mlflow.log_param('Images - Tile Size', len(model.layers))
    mlflow.log_param('Train - Steps', training_spec.steps)
    mlflow.log_param('Train - Loss Function', training_spec.loss)
    mlflow.log_param('Train - Epochs', training_spec.epochs)
    mlflow.log_param('Train - Batch Size', training_spec.batch_size)
    mlflow.log_param('Train - Optimizer', training_spec.optimizer)
    mlflow.log_param('Model - Layers', len(model.layers))
    mlflow.log_param('Model - Parameters - Non-Trainable',
                     np.sum([K.count_params(w) for w in model.non_trainable_weights]))
    mlflow.log_param('Model - Parameters - Trainable',
                     np.sum([K.count_params(w) for w in model.trainable_weights]))
    mlflow.log_param('Model - Shape - Output',   dataset.output_shape())
    mlflow.log_param('Model - Shape - Input',   dataset.input_shape())
    #mlflow.log_param('Status', 'Running') Illegal to change the value!

class _MLFlowCallback(tf.keras.callbacks.Callback):
    """
    Callback to log everything for MLFlow.
    """
    def __init__(self, temp_dir, model_extension):
        super().__init__()
        self.epoch = 0
        self.batch = 0
        self.temp_dir = temp_dir
        self.model_extension = model_extension

    def on_epoch_end(self, epoch, logs=None):
        self.epoch = epoch
        for k in logs.keys():
            if k.startswith('val_'):
                mlflow.log_metric('Validation ' + k[4:], logs[k], epoch)
            else:
                mlflow.log_metric('Epoch ' + k, logs[k], epoch)
        if config.mlflow.checkpoints.frequency() and epoch > 0 and epoch % config.mlflow.checkpoints.frequency() == 0:
            filename = os.path.join(self.temp_dir, '%d%s' % (epoch, self.model_extension))
            save_model(self.model, filename)
            if config.mlflow.checkpoints.only_save_latest():
                old = filename
                filename = os.path.join(self.temp_dir, 'latest' + self.model_extension)
                os.rename(old, filename)
            mlflow.log_artifact(filename, 'checkpoints')
            if os.path.isdir(filename):
                shutil.rmtree(filename)
            else:
                os.remove(filename)

    def on_train_batch_end(self, batch, logs=None):
        self.batch = batch
        if batch > 0 and batch % config.mlflow.frequency() == 0:
            for k in logs.keys():
                if k in ('batch', 'size'):
                    continue
                mlflow.log_metric(k, logs[k], step=batch)

def _mlflow_train_setup(model, dataset, training_spec, model_extension):
    mlflow.set_tracking_uri(config.mlflow.uri())
    mlflow.set_experiment(config.mlflow.experiment())
    mlflow.start_run()
    _log_mlflow_params(model, dataset, training_spec)

    temp_dir = tempfile.mkdtemp()
    fname = os.path.join(temp_dir, 'config.yaml')
    with open(fname, 'w') as f:
        f.write(config.export())
    mlflow.log_artifact(fname)
    os.remove(fname)

    return _MLFlowCallback(temp_dir, model_extension)

def _build_callbacks(model, dataset, training_spec, model_extension):
    """
    Create callbacks needed based on configuration.

    Returns (list of callbacks, mlflow callback).
    """
    callbacks = [tf.keras.callbacks.TerminateOnNaN()]
    # add callbacks from DeltaLayers
    for l in model.layers:
        if isinstance(l, DeltaLayer):
            c = l.callback()
            if c:
                callbacks.append(c)

    mcb = None
    if config.mlflow.enabled():
        mcb = _mlflow_train_setup(model, dataset, training_spec, model_extension)
        callbacks.append(mcb)
        if config.general.verbose():
            print('Using mlflow folder: ' + mlflow.get_artifact_uri())

    if config.tensorboard.enabled():
        tb_dir = config.tensorboard.dir()
        if config.mlflow.enabled():
            tb_dir = os.path.join(tb_dir, str(mlflow.active_run().info.run_id))
            mlflow.log_param('TensorBoard Directory', tb_dir)
        else:
            tb_dir = os.path.join(tb_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
        tcb = tf.keras.callbacks.TensorBoard(log_dir=tb_dir,
                                             update_freq='epoch',
                                             histogram_freq=1,
                                             write_images=True,
                                             embeddings_freq=1)
        callbacks.append(tcb)

    callbacks.extend(config_callbacks())

    return (callbacks, mcb)

def _compile_helper(model, training_spec):
    model.compile(optimizer=optimizer_from_dict(training_spec.optimizer),
                  loss=loss_from_dict(training_spec.loss),
                  metrics=[metric_from_dict(m) for m in training_spec.metrics])

class ContinueTrainingException(Exception):
    """
    Callbacks can raise this exception to modify the model, recompile, and
    continue training.
    """
    def __init__(self, msg: str=None, completed_epochs: int=0,
                 recompile_model: bool=False, learning_rate: float=None):
        """
        Parameters
        ----------
        msg: str
            Optional error message.
        completed_epochs: int
            The number of epochs that have been finished. (resumes from the next epoch)
        recompile_model: bool
            If True, recompile the model. This is necessary if the model has been changed.
        learning_rate: float
            Optionally set the learning rate to the given value.
        """
        super().__init__(msg)
        self.completed_epochs = completed_epochs
        self.recompile_model = recompile_model
        self.learning_rate = learning_rate

def compile_model(model_fn, training_spec, resume_path=None):
    """
    Compile and check that the model is valid.

    Parameters
    ----------
    model_fn: Callable[[], tensorflow.keras.model.Model]
        Function to construct a keras Model.
    training_spec: delta.ml.ml_config.TrainingSpec
        Trainnig parameters.
    resume_path: str
        File name to load initial model weights from.

    Returns
    -------
    tensorflow.keras.models.Model:
        The compiled model, ready for training.
    """
    if not hasattr(training_spec, 'strategy'):
        training_spec.strategy = _strategy(_devices(config.general.gpus()))
    with training_spec.strategy.scope():
        if resume_path is not None and not resume_path.endswith('.h5'):
            model = load_model(resume_path)
        else:
            model = model_fn()
            assert isinstance(model, tf.keras.models.Model), \
                    "Model is not a Tensorflow Keras model"
            if resume_path is not None:
                model.load_weights(resume_path)

        _compile_helper(model, training_spec)

    input_shape = model.input_shape
    output_shape = model.output_shape

    assert len(input_shape) == 4, 'Input to network is wrong shape.'
    assert input_shape[0] is None, 'Input is not batched.'
    # The below may no longer be valid if we move to convolutional architectures.
    assert input_shape[1] == input_shape[2], 'Input to network is not chunked'
    assert len(output_shape) == 2 or output_shape[1] == output_shape[2], 'Output from network is not chunked'

    if config.general.verbose():
        print('Training model:')
        print_network(model, (512, 512, 8))
        print(model.summary(line_length=120))

    return model

def train(model_fn, dataset : ImageryDataset, training_spec, resume_path=None, internal_model_extension='.h5'):
    """
    Trains the specified model on a dataset according to a training
    specification.

    Parameters
    ----------
    model_fn: Callable[[], tensorflow.keras.model.Model]
        Function that constructs a model.
    dataset: delta.imagery.imagery_dataset.ImageryDataset
        Dataset to train on.
    training_spec: delta.ml.ml_config.TrainingSpec
        Training parameters.
    resume_path: str
        Optional file to load initial model weights from.

    Returns
    -------
    (tensorflow.keras.models.Model, History):
        The trained model and the training history.
    """
    model = compile_model(model_fn, training_spec, resume_path)
    assert model.input_shape[3] == dataset.num_bands(), 'Number of bands in model does not match data.'
    # last element differs for the sparse metrics
    assert model.output_shape[1:-1] == dataset.output_shape()[:-1] or (model.output_shape[1] is None), \
            'Network output shape %s does not match label shape %s.' % \
            (model.output_shape[1:], dataset.output_shape()[:-1])

    (ds, validation) = _prep_datasets(dataset, training_spec)

    (callbacks, mcb) = _build_callbacks(model, dataset, training_spec, internal_model_extension)

    try:

        if (training_spec.steps is None) or (training_spec.steps > 0):
            if training_spec.steps is not None:
                ds = ds.repeat() # repeat for ever, use steps and epochs to stop
            done = False
            epochs = training_spec.epochs
            initial_epoch = 0
            while not done:
                try:
                    history = model.fit(ds,
                                        epochs=epochs,
                                        initial_epoch=initial_epoch,
                                        callbacks=callbacks,
                                        validation_data=validation,
                                        validation_steps=None, # Steps are controlled in the dataset setup
                                        steps_per_epoch=training_spec.steps,
                                        verbose=1) # Set to 2 when logging
                    done = True
                except ContinueTrainingException as cte:
                    print('Recompiling model and resuming training.')
                    initial_epoch += cte.completed_epochs
                    if cte.recompile_model:
                        model = compile_model(model, training_spec)
                    if cte.learning_rate:
                        K.set_value(model.optimizer.lr, cte.learning_rate)
        else: # Skip training
            print('Skipping straight to validation')
            history = model.evaluate(validation, steps=training_spec.validation.steps,
                                     callbacks=callbacks, verbose=1)

        if config.mlflow.enabled():
            model_path = os.path.join(mcb.temp_dir, 'final_model' + internal_model_extension)
            print('\nFinished, saving model to %s.'
                  % (mlflow.get_artifact_uri() + '/final_model' + internal_model_extension))
            save_model(model, model_path)
            mlflow.log_artifact(model_path)
            if os.path.isdir(model_path):
                shutil.rmtree(model_path)
            else:
                os.remove(model_path)
            mlflow.log_param('Status', 'Completed')
    except:
        if config.mlflow.enabled():
            mlflow.log_param('Status', 'Aborted')
            mlflow.log_param('Epoch', mcb.epoch)
            mlflow.log_param('Batch', mcb.batch)
            mlflow.end_run('FAILED')
            model_path = os.path.join(mcb.temp_dir, 'aborted_model' + internal_model_extension)
            print('\nAborting, saving current model to %s.'
                  % (mlflow.get_artifact_uri() + '/aborted_model' + internal_model_extension))
            save_model(model, model_path)
            mlflow.log_artifact(model_path)
            if os.path.isdir(model_path):
                shutil.rmtree(model_path)
            else:
                os.remove(model_path)
        raise
    finally:
        if config.mlflow.enabled():
            if mcb and mcb.temp_dir:
                shutil.rmtree(mcb.temp_dir)
            mlflow.end_run()

    return model, history

Functions

def compile_model(model_fn, training_spec, resume_path=None)

Compile and check that the model is valid.

Parameters

model_fn : Callable[[], tensorflow.keras.model.Model]
Function to construct a keras Model.
training_spec : TrainingSpec
Trainnig parameters.
resume_path : str
File name to load initial model weights from.

Returns

tensorflow.keras.models.Model:
The compiled model, ready for training.
Expand source code
def compile_model(model_fn, training_spec, resume_path=None):
    """
    Compile and check that the model is valid.

    Parameters
    ----------
    model_fn: Callable[[], tensorflow.keras.model.Model]
        Function to construct a keras Model.
    training_spec: delta.ml.ml_config.TrainingSpec
        Trainnig parameters.
    resume_path: str
        File name to load initial model weights from.

    Returns
    -------
    tensorflow.keras.models.Model:
        The compiled model, ready for training.
    """
    if not hasattr(training_spec, 'strategy'):
        training_spec.strategy = _strategy(_devices(config.general.gpus()))
    with training_spec.strategy.scope():
        if resume_path is not None and not resume_path.endswith('.h5'):
            model = load_model(resume_path)
        else:
            model = model_fn()
            assert isinstance(model, tf.keras.models.Model), \
                    "Model is not a Tensorflow Keras model"
            if resume_path is not None:
                model.load_weights(resume_path)

        _compile_helper(model, training_spec)

    input_shape = model.input_shape
    output_shape = model.output_shape

    assert len(input_shape) == 4, 'Input to network is wrong shape.'
    assert input_shape[0] is None, 'Input is not batched.'
    # The below may no longer be valid if we move to convolutional architectures.
    assert input_shape[1] == input_shape[2], 'Input to network is not chunked'
    assert len(output_shape) == 2 or output_shape[1] == output_shape[2], 'Output from network is not chunked'

    if config.general.verbose():
        print('Training model:')
        print_network(model, (512, 512, 8))
        print(model.summary(line_length=120))

    return model
def train(model_fn, dataset: ImageryDataset, training_spec, resume_path=None, internal_model_extension='.h5')

Trains the specified model on a dataset according to a training specification.

Parameters

model_fn : Callable[[], tensorflow.keras.model.Model]
Function that constructs a model.
dataset : ImageryDataset
Dataset to train on.
training_spec : TrainingSpec
Training parameters.
resume_path : str
Optional file to load initial model weights from.

Returns

(tensorflow.keras.models.Model, History): The trained model and the training history.

Expand source code
def train(model_fn, dataset : ImageryDataset, training_spec, resume_path=None, internal_model_extension='.h5'):
    """
    Trains the specified model on a dataset according to a training
    specification.

    Parameters
    ----------
    model_fn: Callable[[], tensorflow.keras.model.Model]
        Function that constructs a model.
    dataset: delta.imagery.imagery_dataset.ImageryDataset
        Dataset to train on.
    training_spec: delta.ml.ml_config.TrainingSpec
        Training parameters.
    resume_path: str
        Optional file to load initial model weights from.

    Returns
    -------
    (tensorflow.keras.models.Model, History):
        The trained model and the training history.
    """
    model = compile_model(model_fn, training_spec, resume_path)
    assert model.input_shape[3] == dataset.num_bands(), 'Number of bands in model does not match data.'
    # last element differs for the sparse metrics
    assert model.output_shape[1:-1] == dataset.output_shape()[:-1] or (model.output_shape[1] is None), \
            'Network output shape %s does not match label shape %s.' % \
            (model.output_shape[1:], dataset.output_shape()[:-1])

    (ds, validation) = _prep_datasets(dataset, training_spec)

    (callbacks, mcb) = _build_callbacks(model, dataset, training_spec, internal_model_extension)

    try:

        if (training_spec.steps is None) or (training_spec.steps > 0):
            if training_spec.steps is not None:
                ds = ds.repeat() # repeat for ever, use steps and epochs to stop
            done = False
            epochs = training_spec.epochs
            initial_epoch = 0
            while not done:
                try:
                    history = model.fit(ds,
                                        epochs=epochs,
                                        initial_epoch=initial_epoch,
                                        callbacks=callbacks,
                                        validation_data=validation,
                                        validation_steps=None, # Steps are controlled in the dataset setup
                                        steps_per_epoch=training_spec.steps,
                                        verbose=1) # Set to 2 when logging
                    done = True
                except ContinueTrainingException as cte:
                    print('Recompiling model and resuming training.')
                    initial_epoch += cte.completed_epochs
                    if cte.recompile_model:
                        model = compile_model(model, training_spec)
                    if cte.learning_rate:
                        K.set_value(model.optimizer.lr, cte.learning_rate)
        else: # Skip training
            print('Skipping straight to validation')
            history = model.evaluate(validation, steps=training_spec.validation.steps,
                                     callbacks=callbacks, verbose=1)

        if config.mlflow.enabled():
            model_path = os.path.join(mcb.temp_dir, 'final_model' + internal_model_extension)
            print('\nFinished, saving model to %s.'
                  % (mlflow.get_artifact_uri() + '/final_model' + internal_model_extension))
            save_model(model, model_path)
            mlflow.log_artifact(model_path)
            if os.path.isdir(model_path):
                shutil.rmtree(model_path)
            else:
                os.remove(model_path)
            mlflow.log_param('Status', 'Completed')
    except:
        if config.mlflow.enabled():
            mlflow.log_param('Status', 'Aborted')
            mlflow.log_param('Epoch', mcb.epoch)
            mlflow.log_param('Batch', mcb.batch)
            mlflow.end_run('FAILED')
            model_path = os.path.join(mcb.temp_dir, 'aborted_model' + internal_model_extension)
            print('\nAborting, saving current model to %s.'
                  % (mlflow.get_artifact_uri() + '/aborted_model' + internal_model_extension))
            save_model(model, model_path)
            mlflow.log_artifact(model_path)
            if os.path.isdir(model_path):
                shutil.rmtree(model_path)
            else:
                os.remove(model_path)
        raise
    finally:
        if config.mlflow.enabled():
            if mcb and mcb.temp_dir:
                shutil.rmtree(mcb.temp_dir)
            mlflow.end_run()

    return model, history

Classes

class ContinueTrainingException (msg: str = None, completed_epochs: int = 0, recompile_model: bool = False, learning_rate: float = None)

Callbacks can raise this exception to modify the model, recompile, and continue training.

Parameters

msg : str
Optional error message.
completed_epochs : int
The number of epochs that have been finished. (resumes from the next epoch)
recompile_model : bool
If True, recompile the model. This is necessary if the model has been changed.
learning_rate : float
Optionally set the learning rate to the given value.
Expand source code
class ContinueTrainingException(Exception):
    """
    Callbacks can raise this exception to modify the model, recompile, and
    continue training.
    """
    def __init__(self, msg: str=None, completed_epochs: int=0,
                 recompile_model: bool=False, learning_rate: float=None):
        """
        Parameters
        ----------
        msg: str
            Optional error message.
        completed_epochs: int
            The number of epochs that have been finished. (resumes from the next epoch)
        recompile_model: bool
            If True, recompile the model. This is necessary if the model has been changed.
        learning_rate: float
            Optionally set the learning rate to the given value.
        """
        super().__init__(msg)
        self.completed_epochs = completed_epochs
        self.recompile_model = recompile_model
        self.learning_rate = learning_rate

Ancestors

  • builtins.Exception
  • builtins.BaseException
class DeltaLayer (trainable=True, name=None, dtype=None, dynamic=False, **kwargs)

Network layer class with extra features specific to DELTA.

Extentds tensorflow.keras.layers.Layer.

Expand source code
class DeltaLayer(Layer):
    """
    Network layer class with extra features specific to DELTA.

    Extentds `tensorflow.keras.layers.Layer`.
    """
    def callback(self): # pylint:disable=no-self-use
        """
        Override this method to make a layer automatically register
        a training callback.

        Returns
        -------
        tensorflow.keras.callbacks.Callback:
            The callback to register (or None).
        """
        return None

Ancestors

  • keras.engine.base_layer.Layer
  • tensorflow.python.module.module.Module
  • tensorflow.python.training.tracking.autotrackable.AutoTrackable
  • tensorflow.python.training.tracking.base.Trackable
  • keras.utils.version_utils.LayerVersionSelector

Subclasses

Methods

def callback(self)

Override this method to make a layer automatically register a training callback.

Returns

tensorflow.keras.callbacks.Callback:
The callback to register (or None).
Expand source code
def callback(self): # pylint:disable=no-self-use
    """
    Override this method to make a layer automatically register
    a training callback.

    Returns
    -------
    tensorflow.keras.callbacks.Callback:
        The callback to register (or None).
    """
    return None