Module delta.subcommands.train

Train a neural network.

Expand source code
# Copyright © 2020, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All rights reserved.
#
# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
# licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Train a neural network.
"""

import sys
import time

# import logging
# logging.getLogger("tensorflow").setLevel(logging.DEBUG)

import tensorflow as tf
from tensorflow.keras import mixed_precision

from delta.config import config
from delta.imagery import imagery_dataset
from delta.ml.train import train
from delta.ml.config_parser import config_model
from delta.ml.io import save_model, load_model


# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)

def mixed_policy_device_compatible():

    # gpu check logic taken from https://github.com/keras-team/keras/blob/70d7d07bd186b929d81f7a8ceafff5d78d8bd701/keras/mixed_precision/device_compatibility_check.py # pylint: disable=line-too-long
    gpus = tf.config.list_physical_devices('GPU')
    gpu_details_list = [tf.config.experimental.get_device_details(g) for g in gpus]

    supported_device_strs = []
    unsupported_device_strs = []
    for details in gpu_details_list:
        name = details.get('device_name', 'Unknown GPU')
        cc = details.get('compute_capability')
        if cc:
            device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
            if cc >= (7, 0):
                supported_device_strs.append(device_str)
            else:
                unsupported_device_strs.append(device_str)
        else:
            unsupported_device_strs.append(
                name + ', no compute capability (probably not an Nvidia GPU)')

    if unsupported_device_strs or not supported_device_strs:
        return False
    # else mixed policy is compatible
    return True


def main(options):
    if mixed_policy_device_compatible() and not config.train.disable_mixed_precision():
        mixed_precision.set_global_policy('mixed_float16')
        print('Tensorflow Mixed Precision is enabled. This improves training performance on compatible GPUs. '
              'However certain precautions should be taken and several additional changes can be made to improve '
              'performance further. Details: https://www.tensorflow.org/guide/mixed_precision#summary')

    images = config.dataset.images()
    if not images:
        print('No images specified.', file=sys.stderr)
        return 1

    img = images.load(0)
    model = config_model(img.num_bands())
    if options.resume is not None and not options.resume.endswith('.h5'):
        temp_model = load_model(options.resume)
    else:
        # this one is not built with proper scope, just used to get input and output shapes
        temp_model = model()

    start_time = time.time()
    tile_size = config.io.tile_size()
    tile_overlap = None
    stride = config.train.spec().stride

    # compute input and output sizes
    if temp_model.input_shape[1] is None:
        in_shape = None
        out_shape = temp_model.compute_output_shape((0, tile_size[0], tile_size[1], temp_model.input_shape[3]))
        out_shape = out_shape[1:3]
        tile_overlap = (tile_size[0] - out_shape[0], tile_size[1] - out_shape[1])
    else:
        in_shape = temp_model.input_shape[1:3]
        out_shape = temp_model.output_shape[1:3]

    if options.autoencoder:
        ids = imagery_dataset.AutoencoderDataset(images, in_shape, tile_shape=tile_size,
                                                 tile_overlap=tile_overlap, stride=stride,
                                                 max_rand_offset=config.train.spec().max_tile_offset)
    else:
        labels = config.dataset.labels()
        if not labels:
            print('No labels specified.', file=sys.stderr)
            return 1
        ids = imagery_dataset.ImageryDataset(images, labels, out_shape, in_shape,
                                             tile_shape=tile_size, tile_overlap=tile_overlap,
                                             stride=stride, max_rand_offset=config.train.spec().max_tile_offset)

    assert temp_model.input_shape[1] == temp_model.input_shape[2], 'Must have square chunks in model.'
    assert temp_model.input_shape[3] == ids.num_bands(), 'Model takes wrong number of bands.'
    tf.keras.backend.clear_session()

    # Try to have the internal model format we use match the output model format
    internal_model_extension = '.savedmodel'
    if options.model and ('.h5' in options.model):
        internal_model_extension = '.h5'
    try:
        model, _ = train(model, ids, config.train.spec(), options.resume, internal_model_extension)

        if options.model is not None:
            save_model(model, options.model)
    except KeyboardInterrupt:
        print('Training cancelled.')

    stop_time = time.time()
    print('Elapsed time = ', stop_time-start_time)
    return 0

Functions

def main(options)
Expand source code
def main(options):
    if mixed_policy_device_compatible() and not config.train.disable_mixed_precision():
        mixed_precision.set_global_policy('mixed_float16')
        print('Tensorflow Mixed Precision is enabled. This improves training performance on compatible GPUs. '
              'However certain precautions should be taken and several additional changes can be made to improve '
              'performance further. Details: https://www.tensorflow.org/guide/mixed_precision#summary')

    images = config.dataset.images()
    if not images:
        print('No images specified.', file=sys.stderr)
        return 1

    img = images.load(0)
    model = config_model(img.num_bands())
    if options.resume is not None and not options.resume.endswith('.h5'):
        temp_model = load_model(options.resume)
    else:
        # this one is not built with proper scope, just used to get input and output shapes
        temp_model = model()

    start_time = time.time()
    tile_size = config.io.tile_size()
    tile_overlap = None
    stride = config.train.spec().stride

    # compute input and output sizes
    if temp_model.input_shape[1] is None:
        in_shape = None
        out_shape = temp_model.compute_output_shape((0, tile_size[0], tile_size[1], temp_model.input_shape[3]))
        out_shape = out_shape[1:3]
        tile_overlap = (tile_size[0] - out_shape[0], tile_size[1] - out_shape[1])
    else:
        in_shape = temp_model.input_shape[1:3]
        out_shape = temp_model.output_shape[1:3]

    if options.autoencoder:
        ids = imagery_dataset.AutoencoderDataset(images, in_shape, tile_shape=tile_size,
                                                 tile_overlap=tile_overlap, stride=stride,
                                                 max_rand_offset=config.train.spec().max_tile_offset)
    else:
        labels = config.dataset.labels()
        if not labels:
            print('No labels specified.', file=sys.stderr)
            return 1
        ids = imagery_dataset.ImageryDataset(images, labels, out_shape, in_shape,
                                             tile_shape=tile_size, tile_overlap=tile_overlap,
                                             stride=stride, max_rand_offset=config.train.spec().max_tile_offset)

    assert temp_model.input_shape[1] == temp_model.input_shape[2], 'Must have square chunks in model.'
    assert temp_model.input_shape[3] == ids.num_bands(), 'Model takes wrong number of bands.'
    tf.keras.backend.clear_session()

    # Try to have the internal model format we use match the output model format
    internal_model_extension = '.savedmodel'
    if options.model and ('.h5' in options.model):
        internal_model_extension = '.h5'
    try:
        model, _ = train(model, ids, config.train.spec(), options.resume, internal_model_extension)

        if options.model is not None:
            save_model(model, options.model)
    except KeyboardInterrupt:
        print('Training cancelled.')

    stop_time = time.time()
    print('Elapsed time = ', stop_time-start_time)
    return 0
def mixed_policy_device_compatible()
Expand source code
def mixed_policy_device_compatible():

    # gpu check logic taken from https://github.com/keras-team/keras/blob/70d7d07bd186b929d81f7a8ceafff5d78d8bd701/keras/mixed_precision/device_compatibility_check.py # pylint: disable=line-too-long
    gpus = tf.config.list_physical_devices('GPU')
    gpu_details_list = [tf.config.experimental.get_device_details(g) for g in gpus]

    supported_device_strs = []
    unsupported_device_strs = []
    for details in gpu_details_list:
        name = details.get('device_name', 'Unknown GPU')
        cc = details.get('compute_capability')
        if cc:
            device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
            if cc >= (7, 0):
                supported_device_strs.append(device_str)
            else:
                unsupported_device_strs.append(device_str)
        else:
            unsupported_device_strs.append(
                name + ', no compute capability (probably not an Nvidia GPU)')

    if unsupported_device_strs or not supported_device_strs:
        return False
    # else mixed policy is compatible
    return True