Module delta.subcommands.validate

Check if the input data is valid.

Expand source code
# Copyright © 2020, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All rights reserved.
#
# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
# licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Check if the input data is valid.
"""

import sys
import os

import numpy as np
from osgeo import gdal

from delta.config import config

def get_image_stats(path):
    '''Return a list of image band statistics like [[min, max, mean, stddev], ...]'''
    tif_handle = gdal.Open(path)
    num_bands = tif_handle.RasterCount

    output = []
    for b in range(0,num_bands):
        band = tif_handle.GetRasterBand(b+1)
        stats = band.GetStatistics(False, True)
        output.append(stats)

    return output


def get_class_dict():
    '''Populate dictionary with class names by index number'''
    d = {}
    for c in config.dataset.classes:
        d[c.end_value] = c.name
    if config.dataset.labels().nodata_value():
        d[len(config.dataset.classes)] = 'nodata'
    return d

def classes_string(classes, values, image_name):
    '''Generate a formatted string out of strings or numbers.
       "classes" must come from get_class_dict()'''
    s = '%-20s   ' % (image_name)
    is_integer = np.issubdtype(type(values[0]), np.integer)
    is_float   = isinstance(values[0], float)
    if is_integer:
        total = sum(values.values())
        nodata_class = None
        if config.dataset.labels().nodata_value():
            nodata_class = len(config.dataset.classes)
            total -= values[nodata_class]
    for (j, name) in classes.items():
        if name == 'nodata':
            continue
        v = values[j] if j in values else 0
        if is_integer:
            s += '%12.2f%% ' % (v / total * 100, )
        else:
            if is_float:
                s += '%12.2f ' % (v)
            else:
                s += '%12s  ' % (v, )
    return s



def check_image(images, measures, total_counts, i):
    '''Accumulate total_counts and print out image statistics'''

    # Find min, max, mean, std
    stats = get_image_stats(images[i])

    # Accumulate statistics
    if not total_counts:
        for band in stats: #pylint: disable=W0612
            total_counts.append({'min'   : 0.0,
                                 'max'   : 0.0,
                                 'mean'  : 0.0,
                                 'stddev': 0.0})

    for (b, bandstats) in enumerate(stats):
        total_counts[b]['min'   ] += bandstats[0]
        total_counts[b]['max'   ] += bandstats[1]
        total_counts[b]['mean'  ] += bandstats[2]
        total_counts[b]['stddev'] += bandstats[3]
        name = ''
        if b == 0:
            name = os.path.basename(images[i])
        print(classes_string(measures, dict(enumerate(bandstats)), name))

    return ''

def print_image_totals(images, measures, total_counts):
    '''Convert from source image stat totals to averages and print'''
    num_images = len(images)
    num_bands  = len(total_counts)
    for b in range(0,num_bands):
        values = []
        for m in range(0,len(measures)): #pylint: disable=C0200
            values.append(total_counts[b][measures[m]]/num_images)
        name = ''
        if b == 0:
            name = 'Total'
        print(classes_string(measures, dict(enumerate(values)), name))

def check_label(images, labels, classes, total_counts, i):
    '''Accumulate total_counts and print out image statistics'''
    img   = images.load(i)
    label = labels.load(i)
    if label.size() != img.size():
        return 'Error: size mismatch for %s and %s.\n' % (images[i], labels[i])
    # Count number of times each label appears in image
    v, counts = np.unique(label.read(), return_counts=True)

    # Load the label counts into dictionary and accumulate total_counts
    values = { k:0 for (k, _) in classes.items() }
    for (j, value) in enumerate(v):
        values[value] = counts[j]
        if value not in total_counts:
            total_counts[value] = 0
        total_counts[value] += counts[j]
    # Print out display line with percentages
    print(classes_string(classes, values, labels[i].split('/')[-1]))
    return ''

def evaluate_images(images, labels):
    '''Print class statistics for a set of images with matching labels'''
    errors = ''
    classes = get_class_dict()

    # Evaluate labels first
    counts = {}
    if config.dataset.labels().nodata_value():
        counts[len(config.dataset.classes)] = 0
    header = classes_string(classes, classes, 'Label')
    print(header)
    print('-' * len(header))
    for i in range(len(labels)):
        errors += check_label(images, labels, classes, counts, i)
    print('-' * len(header))
    print(classes_string(classes, counts, 'Total'))
    print()

    if config.dataset.labels().nodata_value():
        nodata_c = counts[len(config.dataset.classes)]
        total = sum(counts.values())
        print('Nodata is %6.2f%% of the data. Total Pixels: %.2f million.' % \
              (nodata_c / total * 100, (total - nodata_c) / 1000000))

    # Now evaluate source images
    counts = []
    print()
    measures = {0:'min', 1:'max', 2:'mean', 3:'stddev'}
    header = classes_string(classes, measures, 'Image')
    print(header)
    print('-' * len(header))
    for i in range(len(images)):
        errors += check_image(images, measures, counts, i)
    print('-' * len(header))
    print_image_totals(images, measures, counts)
    print()


    return errors

def main(_):
    images = config.dataset.images() # Get all image paths based on config values
    labels = config.dataset.labels() # Get all label paths based on config pathn
    if not images:
        print('No images specified.', file=sys.stderr)
        return 1
    if not labels:
        print('No labels specified.', file=sys.stderr)
    else:
        assert len(images) == len(labels)
    print('Validating %d images.' % (len(images)))
    errors = evaluate_images(images, labels)
    tc = config.train.spec()
    if tc.validation.images:
        print('Validating %d validation images.' % (len(tc.validation.images)))
        errors += evaluate_images(tc.validation.images, tc.validation.labels)
    if errors:
        print(errors, file=sys.stderr)
        return -1

    print('Validation successful.')
    return 0

Functions

def check_image(images, measures, total_counts, i)

Accumulate total_counts and print out image statistics

Expand source code
def check_image(images, measures, total_counts, i):
    '''Accumulate total_counts and print out image statistics'''

    # Find min, max, mean, std
    stats = get_image_stats(images[i])

    # Accumulate statistics
    if not total_counts:
        for band in stats: #pylint: disable=W0612
            total_counts.append({'min'   : 0.0,
                                 'max'   : 0.0,
                                 'mean'  : 0.0,
                                 'stddev': 0.0})

    for (b, bandstats) in enumerate(stats):
        total_counts[b]['min'   ] += bandstats[0]
        total_counts[b]['max'   ] += bandstats[1]
        total_counts[b]['mean'  ] += bandstats[2]
        total_counts[b]['stddev'] += bandstats[3]
        name = ''
        if b == 0:
            name = os.path.basename(images[i])
        print(classes_string(measures, dict(enumerate(bandstats)), name))

    return ''
def check_label(images, labels, classes, total_counts, i)

Accumulate total_counts and print out image statistics

Expand source code
def check_label(images, labels, classes, total_counts, i):
    '''Accumulate total_counts and print out image statistics'''
    img   = images.load(i)
    label = labels.load(i)
    if label.size() != img.size():
        return 'Error: size mismatch for %s and %s.\n' % (images[i], labels[i])
    # Count number of times each label appears in image
    v, counts = np.unique(label.read(), return_counts=True)

    # Load the label counts into dictionary and accumulate total_counts
    values = { k:0 for (k, _) in classes.items() }
    for (j, value) in enumerate(v):
        values[value] = counts[j]
        if value not in total_counts:
            total_counts[value] = 0
        total_counts[value] += counts[j]
    # Print out display line with percentages
    print(classes_string(classes, values, labels[i].split('/')[-1]))
    return ''
def classes_string(classes, values, image_name)

Generate a formatted string out of strings or numbers. "classes" must come from get_class_dict()

Expand source code
def classes_string(classes, values, image_name):
    '''Generate a formatted string out of strings or numbers.
       "classes" must come from get_class_dict()'''
    s = '%-20s   ' % (image_name)
    is_integer = np.issubdtype(type(values[0]), np.integer)
    is_float   = isinstance(values[0], float)
    if is_integer:
        total = sum(values.values())
        nodata_class = None
        if config.dataset.labels().nodata_value():
            nodata_class = len(config.dataset.classes)
            total -= values[nodata_class]
    for (j, name) in classes.items():
        if name == 'nodata':
            continue
        v = values[j] if j in values else 0
        if is_integer:
            s += '%12.2f%% ' % (v / total * 100, )
        else:
            if is_float:
                s += '%12.2f ' % (v)
            else:
                s += '%12s  ' % (v, )
    return s
def evaluate_images(images, labels)

Print class statistics for a set of images with matching labels

Expand source code
def evaluate_images(images, labels):
    '''Print class statistics for a set of images with matching labels'''
    errors = ''
    classes = get_class_dict()

    # Evaluate labels first
    counts = {}
    if config.dataset.labels().nodata_value():
        counts[len(config.dataset.classes)] = 0
    header = classes_string(classes, classes, 'Label')
    print(header)
    print('-' * len(header))
    for i in range(len(labels)):
        errors += check_label(images, labels, classes, counts, i)
    print('-' * len(header))
    print(classes_string(classes, counts, 'Total'))
    print()

    if config.dataset.labels().nodata_value():
        nodata_c = counts[len(config.dataset.classes)]
        total = sum(counts.values())
        print('Nodata is %6.2f%% of the data. Total Pixels: %.2f million.' % \
              (nodata_c / total * 100, (total - nodata_c) / 1000000))

    # Now evaluate source images
    counts = []
    print()
    measures = {0:'min', 1:'max', 2:'mean', 3:'stddev'}
    header = classes_string(classes, measures, 'Image')
    print(header)
    print('-' * len(header))
    for i in range(len(images)):
        errors += check_image(images, measures, counts, i)
    print('-' * len(header))
    print_image_totals(images, measures, counts)
    print()


    return errors
def get_class_dict()

Populate dictionary with class names by index number

Expand source code
def get_class_dict():
    '''Populate dictionary with class names by index number'''
    d = {}
    for c in config.dataset.classes:
        d[c.end_value] = c.name
    if config.dataset.labels().nodata_value():
        d[len(config.dataset.classes)] = 'nodata'
    return d
def get_image_stats(path)

Return a list of image band statistics like [[min, max, mean, stddev], …]

Expand source code
def get_image_stats(path):
    '''Return a list of image band statistics like [[min, max, mean, stddev], ...]'''
    tif_handle = gdal.Open(path)
    num_bands = tif_handle.RasterCount

    output = []
    for b in range(0,num_bands):
        band = tif_handle.GetRasterBand(b+1)
        stats = band.GetStatistics(False, True)
        output.append(stats)

    return output
def main(_)
Expand source code
def main(_):
    images = config.dataset.images() # Get all image paths based on config values
    labels = config.dataset.labels() # Get all label paths based on config pathn
    if not images:
        print('No images specified.', file=sys.stderr)
        return 1
    if not labels:
        print('No labels specified.', file=sys.stderr)
    else:
        assert len(images) == len(labels)
    print('Validating %d images.' % (len(images)))
    errors = evaluate_images(images, labels)
    tc = config.train.spec()
    if tc.validation.images:
        print('Validating %d validation images.' % (len(tc.validation.images)))
        errors += evaluate_images(tc.validation.images, tc.validation.labels)
    if errors:
        print(errors, file=sys.stderr)
        return -1

    print('Validation successful.')
    return 0
def print_image_totals(images, measures, total_counts)

Convert from source image stat totals to averages and print

Expand source code
def print_image_totals(images, measures, total_counts):
    '''Convert from source image stat totals to averages and print'''
    num_images = len(images)
    num_bands  = len(total_counts)
    for b in range(0,num_bands):
        values = []
        for m in range(0,len(measures)): #pylint: disable=C0200
            values.append(total_counts[b][measures[m]]/num_images)
        name = ''
        if b == 0:
            name = 'Total'
        print(classes_string(measures, dict(enumerate(values)), name))