Module delta.imagery.imagery_dataset

Tools for loading input images into the TensorFlow Dataset class.

Expand source code
# Copyright © 2020, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All rights reserved.
#
# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
# licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tools for loading input images into the TensorFlow Dataset class.
"""
from concurrent.futures import ThreadPoolExecutor
import functools
import random
import threading
import tensorflow as tf
import numpy as np

from delta.imagery import rectangle
from delta.config import config

class ImageryDataset: # pylint: disable=too-many-instance-attributes,too-many-arguments
    """
    A dataset for tiling very large imagery for training with tensorflow.
    """

    def __init__(self, images, labels, output_shape, chunk_shape, stride=None,
                 tile_shape=(256, 256), tile_overlap=None, max_rand_offset=None):
        """
        Parameters
        ----------
        images: ImageSet
            Images to train on
        labels: ImageSet
            Corresponding labels to train on
        output_shape: (int, int)
            Shape of the corresponding labels for a given chunk or tile size.
        chunk_shape: (int, int)
            If specified, divide tiles into individual chunks of this shape.
        stride: (int, int)
            Skip this stride between chunks. Only valid with chunk_shape.
        tile_shape: (int, int)
            Size of tiles to load from the images at a time.
        tile_overlap: (int, int)
            If specified, overlap tiles by this amount.
        max_rand_offset: int
            If specified, in each epoch, offset all tiles by a random amount in x and y
            in the range(-max_rand_offset, max_rand_offset).
        """

        self._iopool = ThreadPoolExecutor(config.io.threads())

        # Record some of the config values
        self.set_chunk_output_shapes(chunk_shape, output_shape)
        self._output_dims  = 1
        # one for imagery, one for labels
        if stride is None:
            stride = (1, 1)
        self._stride = stride
        self._data_type    = tf.float32
        self._label_type   = tf.uint8
        self._tile_shape = tile_shape
        if tile_overlap is None:
            tile_overlap = (0, 0)
        self._tile_overlap = tile_overlap
        self._max_rand_offset = max_rand_offset if max_rand_offset else 0

        if labels:
            assert len(images) == len(labels)
        self._images = images
        self._labels = labels
        self._epoch = [0, 0] # track images and labels separately for simplicity

        # Load the first image to get the number of bands for the input files.
        self._num_bands = images.load(0).num_bands()
        self._random_seed = random.randint(0, 1 << 16)

    def _list_tiles(self, i): # pragma: no cover
        """
        Parameters
        ----------
        i: int
            Image to list tiles for.

        Returns
        -------
        List[Rectangle]:
            List of tiles to read from the given image
        """
        img = self._images.load(i)

        if self._labels: # If we have labels make sure they are the same size as the input images
            label = self._labels.load(i)
            if label.size() != img.size():
                raise AssertionError('Label file ' + self._labels[i] + ' with size ' + str(label.size())
                                     + ' does not match input image ' + self._images[i] + ' size of ' + str(img.size()))
        tile_shape = self._tile_shape
        if self._chunk_shape:
            assert tile_shape[0] >= self._chunk_shape[0] and \
                   tile_shape[1] >= self._chunk_shape[1], 'Tile too small.'
            return img.tiles((tile_shape[0], tile_shape[1]), min_shape=self._chunk_shape,
                             overlap_shape=self._tile_overlap,
                             by_block=True)
        return img.tiles((tile_shape[0], tile_shape[1]), partials=False, partials_overlap=True,
                         overlap_shape=self._tile_overlap, by_block=True)

    def _tile_generator(self, is_labels): # pragma: no cover
        """
        A generator that yields image tiles over all images.

        Parameters
        ----------
        is_labels: bool
            Load the label if true, image if false

        Returns
        -------
        Iterator[numpy.ndarray]:
            Iterator over iamge tiles.
        """
        # track epoch (must be same for label and non-label)
        epoch = self._epoch[1 if is_labels else 0]
        self._epoch[1 if is_labels else 0] += 1
        images = [(self._labels if is_labels else self._images).load(i) for i in range(len(self._images))]
        # create lock and get preprocessing function for each image
        image_locks = {}
        image_preprocesses = {}
        for img in images:
            image_locks[img] = threading.Lock()
            image_preprocesses[img] = img.get_preprocess()
            img.set_preprocess(None) # parallelize preprocessing outside lock

        # use same seed for labels and not labels, differ by epoch times big prime number
        rand = random.Random(self._random_seed + epoch * 11617)

        # generator that creates tiles in a random order, but consistent between images and labels
        # returns generator of (img, tile_list) tuples
        def tile_gen():
            image_tiles = [(images[i], self._list_tiles(i)) for i in range(len(images))]
            # shuffle tiles within each image
            for (img, tiles) in image_tiles:
                rand.shuffle(tiles)
            # create iterator
            image_tiles = [(img, iter(tiles)) for (img, tiles) in image_tiles]
            while image_tiles:
                index = rand.randrange(len(image_tiles))
                (img, it) = image_tiles[index]
                try:
                    yield (img, next(it))
                except StopIteration:
                    del image_tiles[index]

        if self._max_rand_offset:
            rand_offset = (rand.randint(-self._max_rand_offset, self._max_rand_offset),
                           rand.randint(-self._max_rand_offset, self._max_rand_offset))
        else:
            rand_offset = (0, 0)
        # lock an image and read it. Necessary because gdal doesn't do multi-threading.
        def read_image(img, rect):
            lock = image_locks[img]
            preprocess = image_preprocesses[img]
            buf = np.zeros(shape=(img.num_bands(), rect.height(), rect.width()), dtype=img.dtype())

            mod_r = rectangle.Rectangle(min_x=rect.min_x, min_y=rect.min_y, max_x=rect.max_x, max_y=rect.max_y)
            mod_r.shift(rand_offset[0], rand_offset[1])
            request_r = mod_r.get_intersection(rectangle.Rectangle(min_x=0, min_y=0, width=img.width(),
                                                                   height=img.height()))

            lock.acquire()
            partial_buf = buf[:, request_r.min_y - mod_r.min_y:mod_r.height() + request_r.max_y - mod_r.max_y,
                              request_r.min_x - mod_r.min_x:mod_r.width() + request_r.max_x - mod_r.max_x]
            img.read(request_r, buf=partial_buf)
            lock.release()
            # preprocess outside of lock for concurrency
            buf = np.transpose(buf, [1, 2, 0])
            if preprocess:
                buf = preprocess(buf, rect, None)
            return buf

        # add a buffer to read to the multiprocessing queue
        def add_to_queue(buf_queue, item):
            (img, (rect, sub_tiles)) = item
            buf = self._iopool.submit(lambda: read_image(img, rect))
            buf_queue.append((rect, sub_tiles, buf))

        gen = tile_gen()
        buf_queue = []
        for _ in range(config.io.threads() * 2): # add a bit ahead
            try:
                next_item = next(gen)
            except StopIteration:
                break
            add_to_queue(buf_queue, next_item)
        # process buffers and yield sub tiles. For efficiency, we just
        # return an entire buffer's sub tiles at once, so not fully random
        cur_bufs = []
        while buf_queue or cur_bufs:
            while len(cur_bufs) < config.io.interleave_blocks() and buf_queue:
                (_, sub_tiles, buf) = buf_queue.pop(0)
                cur_bufs.append((sub_tiles, buf.result()))
                try:
                    add_to_queue(buf_queue, next(gen))
                except StopIteration:
                    pass
            while True:
                buf_index = rand.randrange(len(cur_bufs))
                (sub_tiles, buf) = cur_bufs[buf_index]
                if not sub_tiles:
                    del cur_bufs[buf_index]
                    break
                sub_index = rand.randrange(len(sub_tiles))
                s = sub_tiles[sub_index]
                del sub_tiles[sub_index]
                yield buf[s.min_y:s.max_y, s.min_x:s.max_x, :]

    def _load_images(self, is_labels, data_type):
        """
        Loads a list of images as tensors.

        Parameters
        ----------
        is_labels: bool
            Load labels if true, images if not
        data_type: numpy.dtype
            Data type that will be returned.

        Returns
        -------
        Dataset:
            Dataset of image tiles
        """
        self._epoch[1 if is_labels else 0] = 0 # count epochs for random
        return tf.data.Dataset.from_generator(functools.partial(self._tile_generator,
                                                                is_labels=is_labels),
                                              output_types=data_type,
                                              output_shapes=tf.TensorShape((None, None, None)))

    def _chunk_image(self, image): # pragma: no cover
        """Split up a tensor image into tensor chunks"""

        ksizes  = [1, self._chunk_shape[0], self._chunk_shape[1], 1] # Size of the chunks
        strides = [1, self._stride[0], self._stride[1], 1] # Spacing between chunk starts
        rates   = [1, 1, 1, 1]
        result  = tf.image.extract_patches(tf.expand_dims(image, 0), ksizes, strides, rates,
                                           padding='VALID')
        # Output is [1, M, N, chunk*chunk*bands]
        result = tf.reshape(result, [-1, self._chunk_shape[0], self._chunk_shape[1], self._num_bands])

        return result

    def _reshape_labels(self, labels): # pragma: no cover
        """Reshape the labels to account for the chunking process."""
        if self._chunk_shape:
            h = (self._chunk_shape[0] - self._output_shape[0]) // 2
            w = (self._chunk_shape[1] - self._output_shape[1]) // 2
        else:
            h = (tf.shape(labels)[0] - self._output_shape[0]) // 2
            w = (tf.shape(labels)[1] - self._output_shape[1]) // 2
        labels = tf.image.crop_to_bounding_box(labels, h, w, tf.shape(labels)[0] - 2 * h,
                                               tf.shape(labels)[1] - 2 * w)
        if not self._chunk_shape:
            return labels

        ksizes  = [1, self._output_shape[0], self._output_shape[1], 1]
        strides = [1, self._stride[0], self._stride[1], 1]
        rates   = [1, 1, 1, 1]
        labels = tf.image.extract_patches(tf.expand_dims(labels, 0), ksizes, strides, rates,
                                          padding='VALID')
        result = tf.reshape(labels, [-1, self._output_shape[0], self._output_shape[1], 1])
        return result

    def data(self):
        """
        Returns
        -------
        Dataset:
            image chunks / tiles.
        """
        ret = self._load_images(False, self._data_type)
        if self._chunk_shape:
            ret = ret.map(self._chunk_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            return ret.unbatch()
        return ret

    def labels(self):
        """
        Returns
        -------
        Dataset:
            Unbatched dataset of labels corresponding to `data()`.
        """
        label_set = self._load_images(True, self._label_type)
        if self._chunk_shape or self._output_shape:
            label_set = label_set.map(self._reshape_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE) #pylint: disable=C0301
            if self._chunk_shape:
                return label_set.unbatch()
        return label_set

    def dataset(self, class_weights=None, augment_function=None):
        """
        Returns a tensorflow dataset as configured by the class.

        Parameters
        ----------
        class_weights: list
            list of weights for the classes.
        augment_function: Callable[[Tensor, Tensor], (Tensor, Tensor)]
            Function to be applied to the image and label before use.

        Returns
        -------
        tensorflow Dataset:
            With (data, labels, optionally weights)
        """

        # Pair the data and labels in our dataset
        ds = tf.data.Dataset.zip((self.data(), self.labels()))
        # ignore chunks which are all nodata (nodata is re-indexed to be after the classes)
        # cannot do with max_rand_offset since would have different number of tiles which
        # breaks keras fit
        if self._labels.nodata_value() is not None:
            ds = ds.filter(lambda x, y: tf.math.reduce_any(tf.math.not_equal(y, self._labels.nodata_value())))
        if augment_function is not None:
            ds = ds.map(augment_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        if class_weights is not None:
            class_weights.append(0.0)
            lookup = tf.constant(class_weights)
            ds = ds.map(lambda x, y: (x, y, tf.gather(lookup, tf.cast(y, tf.int32), axis=None)),
                        num_parallel_calls=config.io.threads())
        return ds

    def num_bands(self):
        """
        Returns
        -------
        int:
            number of bands in each image
        """
        return self._num_bands

    def set_chunk_output_shapes(self, chunk_shape, output_shape):
        """
        Parameters
        ----------
        chunk_shape: (int, int)
            Size of chunks to read at a time. Set to None to
            use on a per tile basis (i.e., for FCNs).
        output_shape: (int, int)
            Shape output by the network. May differ from the input size
            (dervied from chunk_shape or tile_shape)
        """
        if chunk_shape:
            assert len(chunk_shape) == 2, 'Chunk must be two dimensional.'
            assert (chunk_shape[0] % 2) == (chunk_shape[1] % 2) == \
                   (output_shape[0] % 2) == (output_shape[1] % 2), 'Chunk and output shapes must both be even or odd.'
        if output_shape:
            assert len(output_shape) == 2 or len(output_shape) == 3, 'Output must be two or three dimensional.'
            if len(output_shape) == 3:
                output_shape = output_shape[0:2]
        self._chunk_shape = chunk_shape
        self._output_shape = output_shape

    def chunk_shape(self):
        """
        Returns
        -------
        (int, int):
            Size of chunks used for inputs.
        """
        return self._chunk_shape

    def input_shape(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            Input size for the network.
        """
        if self._chunk_shape:
            return (self._chunk_shape[0], self._chunk_shape[1], self._num_bands)
        return (None, None, self._num_bands)

    def output_shape(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            Output size, size of blocks of labels
        """
        if self._output_shape:
            return (self._output_shape[0], self._output_shape[1], self._output_dims)
        return (None, None, self._output_dims)

    def image_set(self):
        """
        Returns
        -------
        ImageSet:
            set of images
        """
        return self._images
    def label_set(self):
        """
        Returns
        -------
        ImageSet:
            set of labels
        """
        return self._labels

    def set_tile_shape(self, tile_shape):
        """
        Set the tile size.

        Parameters
        ----------
        tile_shape: (int, int)
            New tile shape"""
        self._tile_shape = tile_shape

    def tile_shape(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            tile shape to load at a time
        """
        return self._tile_shape

    def tile_overlap(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            the amount tiles overlap
        """
        return self._tile_overlap

    def stride(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            Stride between chunks (only when chunk_shape is set).
        """
        return self._stride

class AutoencoderDataset(ImageryDataset):
    """
    Slightly modified dataset class for the autoencoder.

    Instead of specifying labels, the inputs are used as labels.
    """

    def __init__(self, images, chunk_shape, stride=(1, 1), tile_shape=(256, 256), tile_overlap=None,
                 max_rand_offset=None):
        super().__init__(images, None, chunk_shape, chunk_shape, tile_shape=tile_shape,
                         stride=stride, tile_overlap=tile_overlap, max_rand_offset=max_rand_offset)
        self._labels = self._images
        self._output_dims = self.num_bands()

    def labels(self):
        return self.data()

    def dataset(self, class_weights=None, augment_function=None):
        return self.data().map(lambda x: (x, x))

Classes

class AutoencoderDataset (images, chunk_shape, stride=(1, 1), tile_shape=(256, 256), tile_overlap=None, max_rand_offset=None)

Slightly modified dataset class for the autoencoder.

Instead of specifying labels, the inputs are used as labels.

Parameters

images : ImageSet
Images to train on
labels : ImageSet
Corresponding labels to train on
output_shape : (int, int)
Shape of the corresponding labels for a given chunk or tile size.
chunk_shape : (int, int)
If specified, divide tiles into individual chunks of this shape.
stride : (int, int)
Skip this stride between chunks. Only valid with chunk_shape.
tile_shape : (int, int)
Size of tiles to load from the images at a time.
tile_overlap : (int, int)
If specified, overlap tiles by this amount.
max_rand_offset : int
If specified, in each epoch, offset all tiles by a random amount in x and y in the range(-max_rand_offset, max_rand_offset).
Expand source code
class AutoencoderDataset(ImageryDataset):
    """
    Slightly modified dataset class for the autoencoder.

    Instead of specifying labels, the inputs are used as labels.
    """

    def __init__(self, images, chunk_shape, stride=(1, 1), tile_shape=(256, 256), tile_overlap=None,
                 max_rand_offset=None):
        super().__init__(images, None, chunk_shape, chunk_shape, tile_shape=tile_shape,
                         stride=stride, tile_overlap=tile_overlap, max_rand_offset=max_rand_offset)
        self._labels = self._images
        self._output_dims = self.num_bands()

    def labels(self):
        return self.data()

    def dataset(self, class_weights=None, augment_function=None):
        return self.data().map(lambda x: (x, x))

Ancestors

Inherited members

class ImageryDataset (images, labels, output_shape, chunk_shape, stride=None, tile_shape=(256, 256), tile_overlap=None, max_rand_offset=None)

A dataset for tiling very large imagery for training with tensorflow.

Parameters

images : ImageSet
Images to train on
labels : ImageSet
Corresponding labels to train on
output_shape : (int, int)
Shape of the corresponding labels for a given chunk or tile size.
chunk_shape : (int, int)
If specified, divide tiles into individual chunks of this shape.
stride : (int, int)
Skip this stride between chunks. Only valid with chunk_shape.
tile_shape : (int, int)
Size of tiles to load from the images at a time.
tile_overlap : (int, int)
If specified, overlap tiles by this amount.
max_rand_offset : int
If specified, in each epoch, offset all tiles by a random amount in x and y in the range(-max_rand_offset, max_rand_offset).
Expand source code
class ImageryDataset: # pylint: disable=too-many-instance-attributes,too-many-arguments
    """
    A dataset for tiling very large imagery for training with tensorflow.
    """

    def __init__(self, images, labels, output_shape, chunk_shape, stride=None,
                 tile_shape=(256, 256), tile_overlap=None, max_rand_offset=None):
        """
        Parameters
        ----------
        images: ImageSet
            Images to train on
        labels: ImageSet
            Corresponding labels to train on
        output_shape: (int, int)
            Shape of the corresponding labels for a given chunk or tile size.
        chunk_shape: (int, int)
            If specified, divide tiles into individual chunks of this shape.
        stride: (int, int)
            Skip this stride between chunks. Only valid with chunk_shape.
        tile_shape: (int, int)
            Size of tiles to load from the images at a time.
        tile_overlap: (int, int)
            If specified, overlap tiles by this amount.
        max_rand_offset: int
            If specified, in each epoch, offset all tiles by a random amount in x and y
            in the range(-max_rand_offset, max_rand_offset).
        """

        self._iopool = ThreadPoolExecutor(config.io.threads())

        # Record some of the config values
        self.set_chunk_output_shapes(chunk_shape, output_shape)
        self._output_dims  = 1
        # one for imagery, one for labels
        if stride is None:
            stride = (1, 1)
        self._stride = stride
        self._data_type    = tf.float32
        self._label_type   = tf.uint8
        self._tile_shape = tile_shape
        if tile_overlap is None:
            tile_overlap = (0, 0)
        self._tile_overlap = tile_overlap
        self._max_rand_offset = max_rand_offset if max_rand_offset else 0

        if labels:
            assert len(images) == len(labels)
        self._images = images
        self._labels = labels
        self._epoch = [0, 0] # track images and labels separately for simplicity

        # Load the first image to get the number of bands for the input files.
        self._num_bands = images.load(0).num_bands()
        self._random_seed = random.randint(0, 1 << 16)

    def _list_tiles(self, i): # pragma: no cover
        """
        Parameters
        ----------
        i: int
            Image to list tiles for.

        Returns
        -------
        List[Rectangle]:
            List of tiles to read from the given image
        """
        img = self._images.load(i)

        if self._labels: # If we have labels make sure they are the same size as the input images
            label = self._labels.load(i)
            if label.size() != img.size():
                raise AssertionError('Label file ' + self._labels[i] + ' with size ' + str(label.size())
                                     + ' does not match input image ' + self._images[i] + ' size of ' + str(img.size()))
        tile_shape = self._tile_shape
        if self._chunk_shape:
            assert tile_shape[0] >= self._chunk_shape[0] and \
                   tile_shape[1] >= self._chunk_shape[1], 'Tile too small.'
            return img.tiles((tile_shape[0], tile_shape[1]), min_shape=self._chunk_shape,
                             overlap_shape=self._tile_overlap,
                             by_block=True)
        return img.tiles((tile_shape[0], tile_shape[1]), partials=False, partials_overlap=True,
                         overlap_shape=self._tile_overlap, by_block=True)

    def _tile_generator(self, is_labels): # pragma: no cover
        """
        A generator that yields image tiles over all images.

        Parameters
        ----------
        is_labels: bool
            Load the label if true, image if false

        Returns
        -------
        Iterator[numpy.ndarray]:
            Iterator over iamge tiles.
        """
        # track epoch (must be same for label and non-label)
        epoch = self._epoch[1 if is_labels else 0]
        self._epoch[1 if is_labels else 0] += 1
        images = [(self._labels if is_labels else self._images).load(i) for i in range(len(self._images))]
        # create lock and get preprocessing function for each image
        image_locks = {}
        image_preprocesses = {}
        for img in images:
            image_locks[img] = threading.Lock()
            image_preprocesses[img] = img.get_preprocess()
            img.set_preprocess(None) # parallelize preprocessing outside lock

        # use same seed for labels and not labels, differ by epoch times big prime number
        rand = random.Random(self._random_seed + epoch * 11617)

        # generator that creates tiles in a random order, but consistent between images and labels
        # returns generator of (img, tile_list) tuples
        def tile_gen():
            image_tiles = [(images[i], self._list_tiles(i)) for i in range(len(images))]
            # shuffle tiles within each image
            for (img, tiles) in image_tiles:
                rand.shuffle(tiles)
            # create iterator
            image_tiles = [(img, iter(tiles)) for (img, tiles) in image_tiles]
            while image_tiles:
                index = rand.randrange(len(image_tiles))
                (img, it) = image_tiles[index]
                try:
                    yield (img, next(it))
                except StopIteration:
                    del image_tiles[index]

        if self._max_rand_offset:
            rand_offset = (rand.randint(-self._max_rand_offset, self._max_rand_offset),
                           rand.randint(-self._max_rand_offset, self._max_rand_offset))
        else:
            rand_offset = (0, 0)
        # lock an image and read it. Necessary because gdal doesn't do multi-threading.
        def read_image(img, rect):
            lock = image_locks[img]
            preprocess = image_preprocesses[img]
            buf = np.zeros(shape=(img.num_bands(), rect.height(), rect.width()), dtype=img.dtype())

            mod_r = rectangle.Rectangle(min_x=rect.min_x, min_y=rect.min_y, max_x=rect.max_x, max_y=rect.max_y)
            mod_r.shift(rand_offset[0], rand_offset[1])
            request_r = mod_r.get_intersection(rectangle.Rectangle(min_x=0, min_y=0, width=img.width(),
                                                                   height=img.height()))

            lock.acquire()
            partial_buf = buf[:, request_r.min_y - mod_r.min_y:mod_r.height() + request_r.max_y - mod_r.max_y,
                              request_r.min_x - mod_r.min_x:mod_r.width() + request_r.max_x - mod_r.max_x]
            img.read(request_r, buf=partial_buf)
            lock.release()
            # preprocess outside of lock for concurrency
            buf = np.transpose(buf, [1, 2, 0])
            if preprocess:
                buf = preprocess(buf, rect, None)
            return buf

        # add a buffer to read to the multiprocessing queue
        def add_to_queue(buf_queue, item):
            (img, (rect, sub_tiles)) = item
            buf = self._iopool.submit(lambda: read_image(img, rect))
            buf_queue.append((rect, sub_tiles, buf))

        gen = tile_gen()
        buf_queue = []
        for _ in range(config.io.threads() * 2): # add a bit ahead
            try:
                next_item = next(gen)
            except StopIteration:
                break
            add_to_queue(buf_queue, next_item)
        # process buffers and yield sub tiles. For efficiency, we just
        # return an entire buffer's sub tiles at once, so not fully random
        cur_bufs = []
        while buf_queue or cur_bufs:
            while len(cur_bufs) < config.io.interleave_blocks() and buf_queue:
                (_, sub_tiles, buf) = buf_queue.pop(0)
                cur_bufs.append((sub_tiles, buf.result()))
                try:
                    add_to_queue(buf_queue, next(gen))
                except StopIteration:
                    pass
            while True:
                buf_index = rand.randrange(len(cur_bufs))
                (sub_tiles, buf) = cur_bufs[buf_index]
                if not sub_tiles:
                    del cur_bufs[buf_index]
                    break
                sub_index = rand.randrange(len(sub_tiles))
                s = sub_tiles[sub_index]
                del sub_tiles[sub_index]
                yield buf[s.min_y:s.max_y, s.min_x:s.max_x, :]

    def _load_images(self, is_labels, data_type):
        """
        Loads a list of images as tensors.

        Parameters
        ----------
        is_labels: bool
            Load labels if true, images if not
        data_type: numpy.dtype
            Data type that will be returned.

        Returns
        -------
        Dataset:
            Dataset of image tiles
        """
        self._epoch[1 if is_labels else 0] = 0 # count epochs for random
        return tf.data.Dataset.from_generator(functools.partial(self._tile_generator,
                                                                is_labels=is_labels),
                                              output_types=data_type,
                                              output_shapes=tf.TensorShape((None, None, None)))

    def _chunk_image(self, image): # pragma: no cover
        """Split up a tensor image into tensor chunks"""

        ksizes  = [1, self._chunk_shape[0], self._chunk_shape[1], 1] # Size of the chunks
        strides = [1, self._stride[0], self._stride[1], 1] # Spacing between chunk starts
        rates   = [1, 1, 1, 1]
        result  = tf.image.extract_patches(tf.expand_dims(image, 0), ksizes, strides, rates,
                                           padding='VALID')
        # Output is [1, M, N, chunk*chunk*bands]
        result = tf.reshape(result, [-1, self._chunk_shape[0], self._chunk_shape[1], self._num_bands])

        return result

    def _reshape_labels(self, labels): # pragma: no cover
        """Reshape the labels to account for the chunking process."""
        if self._chunk_shape:
            h = (self._chunk_shape[0] - self._output_shape[0]) // 2
            w = (self._chunk_shape[1] - self._output_shape[1]) // 2
        else:
            h = (tf.shape(labels)[0] - self._output_shape[0]) // 2
            w = (tf.shape(labels)[1] - self._output_shape[1]) // 2
        labels = tf.image.crop_to_bounding_box(labels, h, w, tf.shape(labels)[0] - 2 * h,
                                               tf.shape(labels)[1] - 2 * w)
        if not self._chunk_shape:
            return labels

        ksizes  = [1, self._output_shape[0], self._output_shape[1], 1]
        strides = [1, self._stride[0], self._stride[1], 1]
        rates   = [1, 1, 1, 1]
        labels = tf.image.extract_patches(tf.expand_dims(labels, 0), ksizes, strides, rates,
                                          padding='VALID')
        result = tf.reshape(labels, [-1, self._output_shape[0], self._output_shape[1], 1])
        return result

    def data(self):
        """
        Returns
        -------
        Dataset:
            image chunks / tiles.
        """
        ret = self._load_images(False, self._data_type)
        if self._chunk_shape:
            ret = ret.map(self._chunk_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            return ret.unbatch()
        return ret

    def labels(self):
        """
        Returns
        -------
        Dataset:
            Unbatched dataset of labels corresponding to `data()`.
        """
        label_set = self._load_images(True, self._label_type)
        if self._chunk_shape or self._output_shape:
            label_set = label_set.map(self._reshape_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE) #pylint: disable=C0301
            if self._chunk_shape:
                return label_set.unbatch()
        return label_set

    def dataset(self, class_weights=None, augment_function=None):
        """
        Returns a tensorflow dataset as configured by the class.

        Parameters
        ----------
        class_weights: list
            list of weights for the classes.
        augment_function: Callable[[Tensor, Tensor], (Tensor, Tensor)]
            Function to be applied to the image and label before use.

        Returns
        -------
        tensorflow Dataset:
            With (data, labels, optionally weights)
        """

        # Pair the data and labels in our dataset
        ds = tf.data.Dataset.zip((self.data(), self.labels()))
        # ignore chunks which are all nodata (nodata is re-indexed to be after the classes)
        # cannot do with max_rand_offset since would have different number of tiles which
        # breaks keras fit
        if self._labels.nodata_value() is not None:
            ds = ds.filter(lambda x, y: tf.math.reduce_any(tf.math.not_equal(y, self._labels.nodata_value())))
        if augment_function is not None:
            ds = ds.map(augment_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        if class_weights is not None:
            class_weights.append(0.0)
            lookup = tf.constant(class_weights)
            ds = ds.map(lambda x, y: (x, y, tf.gather(lookup, tf.cast(y, tf.int32), axis=None)),
                        num_parallel_calls=config.io.threads())
        return ds

    def num_bands(self):
        """
        Returns
        -------
        int:
            number of bands in each image
        """
        return self._num_bands

    def set_chunk_output_shapes(self, chunk_shape, output_shape):
        """
        Parameters
        ----------
        chunk_shape: (int, int)
            Size of chunks to read at a time. Set to None to
            use on a per tile basis (i.e., for FCNs).
        output_shape: (int, int)
            Shape output by the network. May differ from the input size
            (dervied from chunk_shape or tile_shape)
        """
        if chunk_shape:
            assert len(chunk_shape) == 2, 'Chunk must be two dimensional.'
            assert (chunk_shape[0] % 2) == (chunk_shape[1] % 2) == \
                   (output_shape[0] % 2) == (output_shape[1] % 2), 'Chunk and output shapes must both be even or odd.'
        if output_shape:
            assert len(output_shape) == 2 or len(output_shape) == 3, 'Output must be two or three dimensional.'
            if len(output_shape) == 3:
                output_shape = output_shape[0:2]
        self._chunk_shape = chunk_shape
        self._output_shape = output_shape

    def chunk_shape(self):
        """
        Returns
        -------
        (int, int):
            Size of chunks used for inputs.
        """
        return self._chunk_shape

    def input_shape(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            Input size for the network.
        """
        if self._chunk_shape:
            return (self._chunk_shape[0], self._chunk_shape[1], self._num_bands)
        return (None, None, self._num_bands)

    def output_shape(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            Output size, size of blocks of labels
        """
        if self._output_shape:
            return (self._output_shape[0], self._output_shape[1], self._output_dims)
        return (None, None, self._output_dims)

    def image_set(self):
        """
        Returns
        -------
        ImageSet:
            set of images
        """
        return self._images
    def label_set(self):
        """
        Returns
        -------
        ImageSet:
            set of labels
        """
        return self._labels

    def set_tile_shape(self, tile_shape):
        """
        Set the tile size.

        Parameters
        ----------
        tile_shape: (int, int)
            New tile shape"""
        self._tile_shape = tile_shape

    def tile_shape(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            tile shape to load at a time
        """
        return self._tile_shape

    def tile_overlap(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            the amount tiles overlap
        """
        return self._tile_overlap

    def stride(self):
        """
        Returns
        -------
        Tuple[int, ...]:
            Stride between chunks (only when chunk_shape is set).
        """
        return self._stride

Subclasses

Methods

def chunk_shape(self)

Returns

(int, int): Size of chunks used for inputs.

Expand source code
def chunk_shape(self):
    """
    Returns
    -------
    (int, int):
        Size of chunks used for inputs.
    """
    return self._chunk_shape
def data(self)

Returns

Dataset

image chunks / tiles.

Expand source code
def data(self):
    """
    Returns
    -------
    Dataset:
        image chunks / tiles.
    """
    ret = self._load_images(False, self._data_type)
    if self._chunk_shape:
        ret = ret.map(self._chunk_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        return ret.unbatch()
    return ret
def dataset(self, class_weights=None, augment_function=None)

Returns a tensorflow dataset as configured by the class.

Parameters

class_weights : list
list of weights for the classes.
augment_function : Callable[[Tensor, Tensor], (Tensor, Tensor)]
Function to be applied to the image and label before use.

Returns

tensorflow Dataset:
With (data, labels, optionally weights)
Expand source code
def dataset(self, class_weights=None, augment_function=None):
    """
    Returns a tensorflow dataset as configured by the class.

    Parameters
    ----------
    class_weights: list
        list of weights for the classes.
    augment_function: Callable[[Tensor, Tensor], (Tensor, Tensor)]
        Function to be applied to the image and label before use.

    Returns
    -------
    tensorflow Dataset:
        With (data, labels, optionally weights)
    """

    # Pair the data and labels in our dataset
    ds = tf.data.Dataset.zip((self.data(), self.labels()))
    # ignore chunks which are all nodata (nodata is re-indexed to be after the classes)
    # cannot do with max_rand_offset since would have different number of tiles which
    # breaks keras fit
    if self._labels.nodata_value() is not None:
        ds = ds.filter(lambda x, y: tf.math.reduce_any(tf.math.not_equal(y, self._labels.nodata_value())))
    if augment_function is not None:
        ds = ds.map(augment_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if class_weights is not None:
        class_weights.append(0.0)
        lookup = tf.constant(class_weights)
        ds = ds.map(lambda x, y: (x, y, tf.gather(lookup, tf.cast(y, tf.int32), axis=None)),
                    num_parallel_calls=config.io.threads())
    return ds
def image_set(self)

Returns

Imageset

set of images

Expand source code
def image_set(self):
    """
    Returns
    -------
    ImageSet:
        set of images
    """
    return self._images
def input_shape(self)

Returns

Tuple[int, ...]:
Input size for the network.
Expand source code
def input_shape(self):
    """
    Returns
    -------
    Tuple[int, ...]:
        Input size for the network.
    """
    if self._chunk_shape:
        return (self._chunk_shape[0], self._chunk_shape[1], self._num_bands)
    return (None, None, self._num_bands)
def label_set(self)

Returns

Imageset

set of labels

Expand source code
def label_set(self):
    """
    Returns
    -------
    ImageSet:
        set of labels
    """
    return self._labels
def labels(self)

Returns

Dataset

Unbatched dataset of labels corresponding to data().

Expand source code
def labels(self):
    """
    Returns
    -------
    Dataset:
        Unbatched dataset of labels corresponding to `data()`.
    """
    label_set = self._load_images(True, self._label_type)
    if self._chunk_shape or self._output_shape:
        label_set = label_set.map(self._reshape_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE) #pylint: disable=C0301
        if self._chunk_shape:
            return label_set.unbatch()
    return label_set
def num_bands(self)

Returns

int:
number of bands in each image
Expand source code
def num_bands(self):
    """
    Returns
    -------
    int:
        number of bands in each image
    """
    return self._num_bands
def output_shape(self)

Returns

Tuple[int, ...]:
Output size, size of blocks of labels
Expand source code
def output_shape(self):
    """
    Returns
    -------
    Tuple[int, ...]:
        Output size, size of blocks of labels
    """
    if self._output_shape:
        return (self._output_shape[0], self._output_shape[1], self._output_dims)
    return (None, None, self._output_dims)
def set_chunk_output_shapes(self, chunk_shape, output_shape)

Parameters

chunk_shape : (int, int)
Size of chunks to read at a time. Set to None to use on a per tile basis (i.e., for FCNs).
output_shape : (int, int)
Shape output by the network. May differ from the input size (dervied from chunk_shape or tile_shape)
Expand source code
def set_chunk_output_shapes(self, chunk_shape, output_shape):
    """
    Parameters
    ----------
    chunk_shape: (int, int)
        Size of chunks to read at a time. Set to None to
        use on a per tile basis (i.e., for FCNs).
    output_shape: (int, int)
        Shape output by the network. May differ from the input size
        (dervied from chunk_shape or tile_shape)
    """
    if chunk_shape:
        assert len(chunk_shape) == 2, 'Chunk must be two dimensional.'
        assert (chunk_shape[0] % 2) == (chunk_shape[1] % 2) == \
               (output_shape[0] % 2) == (output_shape[1] % 2), 'Chunk and output shapes must both be even or odd.'
    if output_shape:
        assert len(output_shape) == 2 or len(output_shape) == 3, 'Output must be two or three dimensional.'
        if len(output_shape) == 3:
            output_shape = output_shape[0:2]
    self._chunk_shape = chunk_shape
    self._output_shape = output_shape
def set_tile_shape(self, tile_shape)

Set the tile size.

Parameters

tile_shape : (int, int)
New tile shape
Expand source code
def set_tile_shape(self, tile_shape):
    """
    Set the tile size.

    Parameters
    ----------
    tile_shape: (int, int)
        New tile shape"""
    self._tile_shape = tile_shape
def stride(self)

Returns

Tuple[int, ...]:
Stride between chunks (only when chunk_shape is set).
Expand source code
def stride(self):
    """
    Returns
    -------
    Tuple[int, ...]:
        Stride between chunks (only when chunk_shape is set).
    """
    return self._stride
def tile_overlap(self)

Returns

Tuple[int, ...]:
the amount tiles overlap
Expand source code
def tile_overlap(self):
    """
    Returns
    -------
    Tuple[int, ...]:
        the amount tiles overlap
    """
    return self._tile_overlap
def tile_shape(self)

Returns

Tuple[int, ...]:
tile shape to load at a time
Expand source code
def tile_shape(self):
    """
    Returns
    -------
    Tuple[int, ...]:
        tile shape to load at a time
    """
    return self._tile_shape