Module delta.subcommands.visualize
Visualize the training data.
Expand source code
# Copyright © 2020, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All rights reserved.
#
# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
# licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visualize the training data.
"""
import sys
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
from delta.config import config
from delta.imagery import imagery_dataset
from delta.ml.config_parser import config_model, config_augmentation
done_plot = False
img_plots = []
lab_plots = []
figure = None
def plot_images(images, labels):
global done_plot, img_plots, lab_plots, figure #pylint: disable=global-statement
if figure is None:
figure, axarr = plt.subplots(len(images), images[0].shape[2] + 1)
label_norm = matplotlib.colors.Normalize(0, len(config.dataset.classes) + 1)
if len(axarr.shape) < 2:
axarr = np.expand_dims(axarr, axis=0)
for (j, image) in enumerate(images):
img_plots.append([])
lab_plots.append([])
label = labels[j]
for i in range(image.shape[2]):
img = axarr[j, i].imshow(image[:, :, i], norm=matplotlib.colors.Normalize(0.0, 1.0))
img_plots[j].append(img)
lab = axarr[j, i].imshow(label, cmap='inferno', alpha=0.1, norm=label_norm)
lab_plots[j].append(lab)
lab = axarr[j, -1].imshow(label, cmap='inferno', norm=label_norm)
lab_plots[j].append(lab)
plt.subplots_adjust(bottom=0.15, right=0.90)
axslide = plt.axes([0.15, 0.05, 0.70, 0.03])
axcolor = plt.axes([0.92, 0.2, 0.06, 0.6])
figure.colorbar(img_plots[0][0], cax=axcolor)
slide = matplotlib.widgets.Slider(axslide, 'Label Alpha', 0.0, 1.0, valinit=0.0, valstep=0.05)
def update_alpha(alpha):
for row in lab_plots:
for l in row[:-1]:
l.set_alpha(alpha)
slide.on_changed(update_alpha)
def on_press(event):
global done_plot #pylint: disable=global-statement
if event.key == 'q':
sys.exit(0)
done_plot = True
figure.canvas.mpl_connect('key_press_event', on_press)
else:
for i in range(len(img_plots)): #pylint: disable=consider-using-enumerate
image = images[i]
for j in range(image.shape[2]):
img_plots[i][j].set_data(image[:, :, j])
for i in range(len(lab_plots)): #pylint: disable=consider-using-enumerate
for l in lab_plots[i]:
l.set_data(labels[i])
done_plot = False
while not done_plot:
plt.waitforbuttonpress()
if not plt.get_fignums():
sys.exit(0)
def main(options):
images = config.dataset.images()
if not images:
print('No images specified.', file=sys.stderr)
return 1
img = images.load(0)
model = config_model(img.num_bands())
temp_model = model()
tile_size = config.io.tile_size()
tile_overlap = None
stride = config.train.spec().stride
# compute input and output sizes
if temp_model.input_shape[1] is None:
in_shape = None
out_shape = temp_model.compute_output_shape((0, tile_size[0], tile_size[1], temp_model.input_shape[3]))
out_shape = out_shape[1:3]
tile_overlap = (tile_size[0] - out_shape[0], tile_size[1] - out_shape[1])
else:
in_shape = temp_model.input_shape[1:3]
out_shape = temp_model.output_shape[1:3]
if options.autoencoder:
ids = imagery_dataset.AutoencoderDataset(images, in_shape, tile_shape=tile_size,
tile_overlap=tile_overlap, stride=stride)
else:
labels = config.dataset.labels()
if not labels:
print('No labels specified.', file=sys.stderr)
return 1
ids = imagery_dataset.ImageryDataset(images, labels, out_shape, in_shape,
tile_shape=tile_size, tile_overlap=tile_overlap,
stride=stride)
assert temp_model.input_shape[1] == temp_model.input_shape[2], 'Must have square chunks in model.'
assert temp_model.input_shape[3] == ids.num_bands(), 'Model takes wrong number of bands.'
tf.keras.backend.clear_session()
#colormap = np.zeros(dtype=np.uint8, shape=(len(config.dataset.classes), 3))
#for c in config.dataset.classes:
# print(len(config.dataset.classes), c.value)
# colormap[c.value][0] = (c.color >> 32) & 0xFF
# colormap[c.value][1] = (c.color >> 16) & 0xFF
# colormap[c.value][2] = c.color & 0xFF
images = []
labels = []
PLOT_AT_ONCE=7
for result in ids.dataset(config.dataset.classes.weights(), config_augmentation()):
image = result[0].numpy()
label = result[1].numpy()
pw = (image.shape[0] - label.shape[0]) // 2
ph = (image.shape[1] - label.shape[1]) // 2
label = np.pad(label, ((pw, pw), (ph, ph), (0, 0)))
images.append(image)
labels.append(label)
if len(images) == PLOT_AT_ONCE:
plot_images(images, labels)
images = []
labels = []
return 0
Functions
def main(options)
-
Expand source code
def main(options): images = config.dataset.images() if not images: print('No images specified.', file=sys.stderr) return 1 img = images.load(0) model = config_model(img.num_bands()) temp_model = model() tile_size = config.io.tile_size() tile_overlap = None stride = config.train.spec().stride # compute input and output sizes if temp_model.input_shape[1] is None: in_shape = None out_shape = temp_model.compute_output_shape((0, tile_size[0], tile_size[1], temp_model.input_shape[3])) out_shape = out_shape[1:3] tile_overlap = (tile_size[0] - out_shape[0], tile_size[1] - out_shape[1]) else: in_shape = temp_model.input_shape[1:3] out_shape = temp_model.output_shape[1:3] if options.autoencoder: ids = imagery_dataset.AutoencoderDataset(images, in_shape, tile_shape=tile_size, tile_overlap=tile_overlap, stride=stride) else: labels = config.dataset.labels() if not labels: print('No labels specified.', file=sys.stderr) return 1 ids = imagery_dataset.ImageryDataset(images, labels, out_shape, in_shape, tile_shape=tile_size, tile_overlap=tile_overlap, stride=stride) assert temp_model.input_shape[1] == temp_model.input_shape[2], 'Must have square chunks in model.' assert temp_model.input_shape[3] == ids.num_bands(), 'Model takes wrong number of bands.' tf.keras.backend.clear_session() #colormap = np.zeros(dtype=np.uint8, shape=(len(config.dataset.classes), 3)) #for c in config.dataset.classes: # print(len(config.dataset.classes), c.value) # colormap[c.value][0] = (c.color >> 32) & 0xFF # colormap[c.value][1] = (c.color >> 16) & 0xFF # colormap[c.value][2] = c.color & 0xFF images = [] labels = [] PLOT_AT_ONCE=7 for result in ids.dataset(config.dataset.classes.weights(), config_augmentation()): image = result[0].numpy() label = result[1].numpy() pw = (image.shape[0] - label.shape[0]) // 2 ph = (image.shape[1] - label.shape[1]) // 2 label = np.pad(label, ((pw, pw), (ph, ph), (0, 0))) images.append(image) labels.append(label) if len(images) == PLOT_AT_ONCE: plot_images(images, labels) images = [] labels = [] return 0
def plot_images(images, labels)
-
Expand source code
def plot_images(images, labels): global done_plot, img_plots, lab_plots, figure #pylint: disable=global-statement if figure is None: figure, axarr = plt.subplots(len(images), images[0].shape[2] + 1) label_norm = matplotlib.colors.Normalize(0, len(config.dataset.classes) + 1) if len(axarr.shape) < 2: axarr = np.expand_dims(axarr, axis=0) for (j, image) in enumerate(images): img_plots.append([]) lab_plots.append([]) label = labels[j] for i in range(image.shape[2]): img = axarr[j, i].imshow(image[:, :, i], norm=matplotlib.colors.Normalize(0.0, 1.0)) img_plots[j].append(img) lab = axarr[j, i].imshow(label, cmap='inferno', alpha=0.1, norm=label_norm) lab_plots[j].append(lab) lab = axarr[j, -1].imshow(label, cmap='inferno', norm=label_norm) lab_plots[j].append(lab) plt.subplots_adjust(bottom=0.15, right=0.90) axslide = plt.axes([0.15, 0.05, 0.70, 0.03]) axcolor = plt.axes([0.92, 0.2, 0.06, 0.6]) figure.colorbar(img_plots[0][0], cax=axcolor) slide = matplotlib.widgets.Slider(axslide, 'Label Alpha', 0.0, 1.0, valinit=0.0, valstep=0.05) def update_alpha(alpha): for row in lab_plots: for l in row[:-1]: l.set_alpha(alpha) slide.on_changed(update_alpha) def on_press(event): global done_plot #pylint: disable=global-statement if event.key == 'q': sys.exit(0) done_plot = True figure.canvas.mpl_connect('key_press_event', on_press) else: for i in range(len(img_plots)): #pylint: disable=consider-using-enumerate image = images[i] for j in range(image.shape[2]): img_plots[i][j].set_data(image[:, :, j]) for i in range(len(lab_plots)): #pylint: disable=consider-using-enumerate for l in lab_plots[i]: l.set_data(labels[i]) done_plot = False while not done_plot: plt.waitforbuttonpress() if not plt.get_fignums(): sys.exit(0)