# Copyright 2021 Sean Robertson
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Submodule for corpus iterators"""
from abc import abstractmethod
from builtins import str as text
from collections.abc import Iterable
from collections.abc import Sized
from itertools import cycle
from typing import Callable, Iterator, Optional, Sequence, Union
from warnings import warn
import numpy as np
from pydrobert.kaldi.io import open as io_open
from pydrobert.kaldi.io.enums import RxfilenameType
from pydrobert.kaldi.io.util import parse_kaldi_input_path
__all__ = [
"batch_data",
"Data",
"ShuffledData",
"SequentialData",
]
def _handle_sub_batch(sub_batch, axis, pad_mode, pad_kwargs):
assert len(sub_batch)
try:
first_dtype = sub_batch[0].dtype
first_shape = sub_batch[0].shape
except AttributeError:
return sub_batch
mismatched_shapes = False
max_shape = first_shape
for sample in sub_batch:
if not isinstance(sample, (np.ndarray, np.generic)) or (
not np.issubdtype(sample.dtype, first_dtype)
):
return sub_batch
if sample.shape != first_shape:
if pad_mode:
max_shape = (max(x, y) for x, y in zip(sample.shape, max_shape))
mismatched_shapes = True
else:
return sub_batch
if mismatched_shapes:
max_shape = tuple(max_shape)
for samp_idx in range(len(sub_batch)):
sample = sub_batch[samp_idx]
if sample.shape != max_shape:
pad_widths = tuple((0, y - x) for x, y in zip(sample.shape, max_shape))
sample = np.pad(sample, pad_widths, mode=pad_mode, **pad_kwargs)
sub_batch[samp_idx] = sample
ret = np.stack(sub_batch, axis=axis)
return ret
[docs]def batch_data(
input_iter: Iterator,
subsamples: bool = True,
batch_size: Optional[int] = None,
axis: int = 0,
cast_to_array: Optional[Union[np.dtype, Sequence]] = None,
pad_mode: Optional[Union[str, Callable]] = None,
**pad_kwargs
):
"""Generate batched data from an input generator
Takes some fixed number of samples from `input_iter`, encapsulates them, and yields
them.
If `subsamples` is :obj:`True`, data from `input_iter` are expected to be
encapsulated in fixed-length sequences (e.g. ``(feat, label, len)``). Each sample
will be batched separately into a sub-batch and returned in a tuple (e.g.
``(feat_batch, label_batch, len_batch)``).
The format of a (sub-)batch depends on the properties of its samples:
1. If `cast_to_array` applies to this sub-batch, cast it to a numpy array of the
target type.
2. If all samples in the (sub-)batch are numpy arrays of the same type and shape,
samples are stacked in a bigger numpy array along the axis specified by `axis`
(see Parameters).
3. If all samples are numpy arrays of the same type but variable length and
`pad_mode` is specified, pad all sample arrays to the right such that they all
have the same (supremum) shape, then perform 2.
4. Otherwise, simply return a list of samples as-is (ignoring axis).
Parameters
----------
input_iter
An iterator over samples
subsamples
`input_iter` yields tuples to be divided into different sub-batches if
:obj:`True`
batch_size
The size of batches, except perhaps the last one. If not set or ``0``, will
yield samples (casting and encapsulating in tuples when necessary)
axis
Where to insert the batch index/indices into the shape/shapes of the inputs. If
a sequence, `subsamples` must be :obj:`True` and `input_iter` should yield
samples of the same length as axis. If an :class:`int` and subsamples is
:obj:`True`, the same axis will be used for all sub-samples.
cast_to_array
Dictates whether data should be cast to numpy arrays and of what type. If a
sequence, `subsamples` must be :obj:`True` and `input_iter` should yield samples
of the same length as `cast_to_array`. If a single value and `subsamples` is
:obj:`True`, the same value will be used for all sub-samples. Value(s) of
:obj:`None` indicate no casting should be done for this (sub-)sample. Other
values will be used to cast (sub-)samples to numpy arrays
pad_mode
If set, inputs within a batch will be padded on the end to match the largest
shapes in the batch. How the inputs are padded matches the argument to
:func:`numpy.pad`. If not set, will raise a :class:`ValueError` if they don't
all have the same shape
pad_kwargs
Additional keyword arguments are passed along to :func:`numpy.pad`
if padding.
See Also
--------
numpy.pad
For different pad modes and options
"""
num_sub = None
if subsamples:
try:
axis = tuple(axis)
num_sub = len(axis)
except TypeError: # one value
axis = (axis,)
try:
cast_to_array = tuple(cast_to_array)
if num_sub is None:
num_sub = len(cast_to_array)
elif len(cast_to_array) != num_sub:
raise ValueError(
"axis and cast_to_array should be of the same "
"length if both sequences (got {} and {} resp)".format(
num_sub, len(cast_to_array)
)
)
except TypeError:
cast_to_array = (cast_to_array,)
if not batch_size:
# ideally we factor this out into some helper, but py2.7 doesn't
# have yield-from syntax
for sample in input_iter:
if subsamples:
sample = tuple(sample)
if num_sub is None:
num_sub = len(sample)
elif num_sub != len(sample):
raise ValueError(
"Expected {} sub-samples per sample, got {}".format(
num_sub, len(sample)
)
)
if cast_to_array[0] is not None:
yield tuple(
np.array(sub_sample, dtype=cast_to_array[0], copy=False)
for sub_sample in sample
)
else:
yield sample
elif cast_to_array is not None:
yield np.array(sample, dtype=cast_to_array[0], copy=False)
else:
yield sample
return
cur_batch = []
cur_batch_size = 0
for sample in input_iter:
if subsamples:
for sub_batch_idx, (sub_sample, sub_cast) in enumerate(
zip(sample, cycle(cast_to_array))
):
if sub_cast is not None:
sub_sample = np.array(sub_sample, dtype=sub_cast, copy=False)
if len(cur_batch) == sub_batch_idx:
cur_batch.append([sub_sample])
else:
cur_batch[sub_batch_idx].append(sub_sample)
if num_sub is None:
num_sub = len(cur_batch)
elif num_sub != len(cur_batch):
raise ValueError(
"Expected {} sub-samples per sample, got {}".format(
num_sub, len(cur_batch)
)
)
else:
if cast_to_array is not None:
sample = np.array(sample, dtype=cast_to_array, copy=False)
cur_batch.append(sample)
cur_batch_size += 1
if cur_batch_size == batch_size:
if subsamples:
yield tuple(
_handle_sub_batch(sub_batch, sub_axis, pad_mode, pad_kwargs)
for sub_batch, sub_axis in zip(cur_batch, cycle(axis))
)
else:
yield _handle_sub_batch(cur_batch, axis, pad_mode, pad_kwargs)
cur_batch_size = 0
cur_batch = []
if cur_batch_size:
if subsamples:
yield tuple(
_handle_sub_batch(sub_batch, sub_axis, pad_mode, pad_kwargs)
for sub_batch, sub_axis in zip(cur_batch, cycle(axis))
)
else:
yield _handle_sub_batch(cur_batch, axis, pad_mode, pad_kwargs)
[docs]class Data(Iterable, Sized):
"""Metaclass for data iterables
A template for providing iterators over kaldi tables. They can be used like this
>>> data = DataSubclass(
... 'scp:feats.scp', 'scp:labels.scp', batch_size=10)
>>> for feat_batch, label_batch in data:
>>> pass # do something
>>> for feat_batch, label_batch in data:
>>> pass # do something again
Where `DataSubclass` is some subclass of this virtual class. Calling :func:`iter` on
an instance (which occurs implicitly in for-loops) will generate a new iterator over
the entire data set.
The class takes an arbitrary positive number of positional arguments on
initialization, each a table to open. Each argument is one of:
1. An rspecifier (ideally for a script file). Assumed to be of type
:class:`KaldiDataType.BaseMatrix`
2. A sequence of length 2: the first element is the rspecifier, the
second the rspecifier's :class:`KaldiDataType`
3. A sequence of length 3: the first element is the rspecifier, the second the
rspecifier's :class:`KaldiDataType`, and the third is a dictionary to be passed
as keyword arguments to the :func:`pydrobert.kaldi.io.open` function
All tables are assumed to index data using the same keys.
If `batch_size` is set, data are stacked in batches along a new axis. The keyword
arguments `batch_axis`, `batch_pad_mode`, and any remaining keywords are sent to
this module's :func:`batch_data` function. If `batch_size` is :obj:`None` or ``0``,
samples are returned one-by-one. Data are always cast to numpy arrays before being
returned. Consult that function for more information on batching.
If only one table is specified and neither `axis_lengths` or `add_key` is specified,
iterators will be of a batch of the table's data directly. Otherwise, iterators
yield "batches" of tuples containing "sub-batches" from each respective data source.
Sub-batches belonging to the same batch share the same subset of ordered keys.
If `add_key` is :obj:`True`, a sub-batch of referrent keys is added as the first
element of a batch tuple.
For batched sequence-to-sequence tasks, it is often important to know the original
length of data before padding. Setting `axis_lengths` adds one or more sub-batches
to the end of a batch tuple with this information. These sub-batches are filled with
signed 32-bit integers. `axis_lengths` can be one of:
1. An integer specifying an axis from the first table to get the lengths of.
2. A pair of integers. The first element is the table index, the second is the axis
index in that table.
3. A sequence of pairs of integers. Sub-batches will be appended to the batch tuple
in that order
Note that axes in `axis_lengths` index the axes in individual samples, not the
batch. For instance, if ``batch_axis == 0`` and ``axis_lengths == 0``, then the last
sub-batch will refer to the pre-padded value of sub-batch 0's axis 1
(``batch[0].shape[1]``).
The length of this object is the number of batches it serves per
epoch.
"""
_DATA_PARAMS_DOC = """
Parameters
----------
table
The first table specifier
additional_tables
Table specifiers past the first. If not empty, will iterate over tuples of
sub-batches
add_key
If :obj:`True`, will insert sub-samples into the 0th index of each sample
sequence that specify the key that this sample was indexed by. Defaults to
:obj:`False`
axis_lengths
If set, sub-batches of axis lengths will be appended to the end of a batch tuple
batch_axis
The axis or axes (in the case of multiple tables) along which samples are
stacked in (sub-)batches. batch_axis should take into account axis length and
key sub-batches when applicable. Defaults to ``0``
batch_cast_to_array
A numpy type or sequence of types to cast each (sub-)batch to. :obj:`None`
values indicate no casting should occur. `batch_cast_to_array` should take into
acount axis length and key sub-batches when applicable
batch_kwargs
Additional keyword arguments to pass to ``batch_data``
batch_pad_mode
If set, pads samples in (sub-)batches according to this :func:`numpy.pad`
strategy when samples do not have the same length
batch_size
The number of samples per (sub-)batch. Defaults to :obj:`None`, which means
samples are served without batching
ignore_missing
If :obj:`True` and some provided table does not have some key, that key will
simply be ignored. Otherwise, a missing key raises a ValueError. Default to
:obj:`False`
"""
_DATA_ATTRIBUTES_DOC = """
Attributes
----------
table_specifiers
A tuple of triples indicating ``(rspecifier, kaldi_dtype, open_kwargs)`` for
each table
add_key
Whether a sub-batch of table keys has been prepended to existing sub-batches
axis_lengths
A tuple of pairs for each axis-length sub-batch requested. Each pair is
``(sub_batch_idx, axis)``.
batch_axis
A tuple of length num_sub indicating which axis (sub-)samples will be arrayed
along in a given (sub-)batch when all (sub-)samples are (or are cast to) fixed
length numpy arrays of the same type
batch_cast_to_array
A tuple of length `num_sub` indicating what numpy types, if any (sub-)samples
should be cast to. Values of :obj:`None` indicate no casting should be done on
that (sub-)sample
batch_kwargs
Additional keyword arguments to pass to ``batch_data``
batch_pad_mode
If set, pads samples in (sub-)batches according to this :func:`numpy.pad`
strategy when samples do not have the same length
batch_size
The number of samples per (sub-)batch
ignore_missing
If :obj:`True` and some provided table does not have some key, that key will
simply be ignored. Otherwise, a missing key raises a ValueError
num_sub
The number of sub-batches per batch. If > 1, batches are yielded as tuples of
sub-batches. This number accounts for key, table, and axis-length sub-batches
"""
__doc__ += _DATA_PARAMS_DOC + "\n" + _DATA_ATTRIBUTES_DOC
def __init__(self, table, *additional_tables, **kwargs):
table_specifiers = [table]
table_specifiers += additional_tables
for table_idx, table_spec in enumerate(table_specifiers):
if isinstance(table_spec, str) or isinstance(table_spec, text):
table_spec = (table_spec, "bm", dict())
elif len(table_spec) == 2:
table_spec += (dict(),)
elif len(table_spec) != 3:
raise ValueError("Invalid table spec {}".format(table_spec))
table_specifiers[table_idx] = table_spec
self.table_specifiers = tuple(table_specifiers)
self.add_key = bool(kwargs.pop("add_key", False))
axis_lengths = kwargs.pop("axis_lengths", None)
batch_axis = kwargs.pop("batch_axis", 0)
batch_cast_to_array = kwargs.pop("batch_cast_to_array", None)
self.batch_pad_mode = kwargs.pop("batch_pad_mode", None)
self.batch_size = kwargs.pop("batch_size", None)
self.ignore_missing = bool(kwargs.pop("ignore_missing", False))
self.batch_kwargs = kwargs
invalid_kwargs = {"axis", "cast_to_array", "pad_mode", "subsamples"}
invalid_kwargs &= set(kwargs.keys())
if invalid_kwargs:
raise TypeError("Invalid argument {}".format(invalid_kwargs.pop()))
if axis_lengths is None:
self.axis_lengths = tuple()
elif isinstance(axis_lengths, int):
self.axis_lengths = ((0, axis_lengths),)
else:
axis_lengths = tuple(axis_lengths) # in case generator
if (
len(axis_lengths) == 2
and isinstance(axis_lengths[0], int)
and isinstance(axis_lengths[1], int)
):
self.axis_lengths = (axis_lengths,)
else:
self.axis_lengths = tuple(tuple(pair) for pair in axis_lengths)
self.num_sub = len(table_specifiers) + int(self.add_key)
self.num_sub += len(self.axis_lengths)
for attribute_name, variable in (
("batch_axis", batch_axis),
("batch_cast_to_array", batch_cast_to_array),
):
try:
variable = tuple(variable)
if len(variable) != self.num_sub:
error_msg = "Expected {} to be a scalar or ".format(attribute_name)
error_msg += "container of length {}, got {}".format(
self.num_sub, len(variable)
)
if len(variable) >= len(table_specifiers):
error_msg += " (did you forget to account for "
error_msg += "axis_lengths or add_key?)"
raise ValueError(error_msg)
setattr(self, attribute_name, variable)
except TypeError:
setattr(self, attribute_name, (variable,) * self.num_sub)
@property
@abstractmethod
def num_samples(self) -> int:
"""int : the number of samples yielded per epoch
This number takes into account the number of terms missing if
``self.ignore_missing == True``
"""
pass
@property
def num_batches(self) -> int:
"""int : the number of batches yielded per epoch
This number takes into account the number of terms missing if
``self.ignore_missing == True``
"""
if self.batch_size:
return int(np.ceil(self.num_samples / self.batch_size))
else:
return self.num_samples
def __len__(self):
return self.num_batches
[docs] @abstractmethod
def sample_generator_for_epoch(self):
"""A generator which yields individual samples from data for an epoch
An epoch means one pass through the data from start to finish. Equivalent to
``sample_generator(False)``.
Yields
------
sample : np.array or tuple
A sample if ``self.num_sub == 1``, otherwise a tuple of sub-samples
"""
pass
[docs] def sample_generator(self, repeat: bool = False):
"""A generator which yields individual samples from data
Parameters
----------
repeat
Whether to stop generating after one epoch (False) or keep
restart and continue generating indefinitely
Yields
------
sample : np.array or tuple
A sample if ``self.num_sub == 1``, otherwise a tuple of
sub-samples
"""
while True:
for sample in self.sample_generator_for_epoch():
yield sample
if not repeat:
break
[docs] def batch_generator(self, repeat: bool = False):
"""A generator which yields batches of data
Parameters
----------
repeat
Whether to stop generating after one epoch (False) or keep
restart and continue generating indefinitely
Yields
------
batch : np.array or tuple
A batch if ``self.num_sub == 1``, otherwise a tuple of sub-batches. If
self.batch_size does not divide an epoch's worth of data evenly, the last
batch of every epoch will be smaller
"""
subsamples = self.num_sub != 1
while True:
for batch in batch_data(
self.sample_generator_for_epoch(),
subsamples=subsamples,
batch_size=self.batch_size,
axis=self.batch_axis if subsamples else self.batch_axis[0],
pad_mode=self.batch_pad_mode,
cast_to_array=(
self.batch_cast_to_array
if subsamples
else self.batch_cast_to_array[0]
),
**self.batch_kwargs
):
yield batch
if not repeat:
break
def __iter__(self):
yield from self.batch_generator()
[docs]class ShuffledData(Data):
"""Provides iterators over shuffled data
A master list of keys is either provided by keyword argument or inferred from the
first table. Every new iterator requested shuffles that list of keys and returns
batches in that order. Appropriate for training data.
Notes
-----
For efficiency, it is highly recommended to use scripts to access tables rather than
archives.
"""
__doc__ += (
Data._DATA_PARAMS_DOC
+ """
key_list
A master list of keys. No other keys will be queried. If not specified, the key
list will be inferred by passing through the first table once
rng
Either a :class:`numpy.random.RandomState` object or a seed to create one. It
will be used to shuffle the list of keys
"""
)
__doc__ += (
"\n"
+ Data._DATA_ATTRIBUTES_DOC
+ """
key_list
The master list of keys
rng
Used to shuffle the list of keys every epoch
table_holders
A tuple of table readers opened in random access mode
"""
)
def __init__(self, table, *additional_tables, **kwargs):
key_list = kwargs.pop("key_list", None)
rng = kwargs.pop("rng", None)
super(ShuffledData, self).__init__(table, *additional_tables, **kwargs)
try:
key_list = tuple(key_list)
except TypeError:
pass
if key_list is None:
_, rx_fn, rx_type, _ = parse_kaldi_input_path(self.table_specifiers[0][0])
if rx_type == RxfilenameType.InvalidInput:
raise IOError("Invalid rspecifier {}".format(rx_fn))
elif rx_type == RxfilenameType.StandardInput:
raise IOError("Cannot infer key list from stdin (cannot reopen)")
with io_open(*self.table_specifiers[0][:2]) as reader:
self.key_list = tuple(reader.keys())
else:
self.key_list = tuple(key_list)
if self.ignore_missing:
self._num_samples = None
else:
self._num_samples = len(self.key_list)
if isinstance(rng, np.random.RandomState):
self.rng = rng
else:
self.rng = np.random.RandomState(rng)
self.table_handles = tuple(
io_open(rspecifier, kdtype, mode="r+", **o_kwargs)
for rspecifier, kdtype, o_kwargs in self.table_specifiers
)
@property
def num_samples(self) -> int:
if self._num_samples is None:
self._num_samples = 0
for key in self.key_list:
missing = False
for handle in self.table_handles:
if key not in handle:
missing = True
break
if not missing:
self._num_samples += 1
return self._num_samples
[docs] def sample_generator_for_epoch(self):
shuffled_keys = np.array(self.key_list)
self.rng.shuffle(shuffled_keys)
num_samples = 0
for key in shuffled_keys:
samp_tup = []
missing = False
for spec, handle in zip(self.table_specifiers, self.table_handles):
if key not in handle:
if self.ignore_missing:
missing = True
break
else:
raise IOError("Table {} missing key {}".format(spec[0], key))
samp_tup.append(handle[key])
if missing:
continue
num_samples += 1
for sub_batch_idx, axis_idx in self.axis_lengths:
samp_tup.append(
np.array(samp_tup[sub_batch_idx], copy=False).shape[axis_idx]
)
if self.add_key:
samp_tup.insert(0, key)
if self.num_sub != 1:
yield tuple(samp_tup)
else:
yield samp_tup[0]
if self._num_samples is None:
self._num_samples = num_samples
elif self._num_samples != num_samples:
raise IOError("Different number of samples from last time!")
sample_generator_for_epoch.__doc__ = Data.num_samples.__doc__
[docs]class SequentialData(Data):
"""Provides iterators to read data sequentially
Tables are always assumed to be sorted so reading can proceed in lock-step.
Warning
-------
Each time an iterator is requested, new sequential readers are opened. Be careful
with stdin!
"""
__doc__ += Data._DATA_PARAMS_DOC + "\n" + Data._DATA_ATTRIBUTES_DOC
def __init__(self, table, *additional_tables, **kwargs):
super(SequentialData, self).__init__(table, *additional_tables, **kwargs)
self._num_samples = None
sorteds = tuple(
parse_kaldi_input_path(spec[0])[3]["sorted"]
for spec in self.table_specifiers
)
if not all(sorteds):
uns_rspec = self.table_specifiers[sorteds.index(False)][0]
uns_rspec_split = uns_rspec.split(":")
uns_rspec_split[0] += ",s"
sor_rspec = ":".join(uns_rspec_split)
warn(
'SequentialData assumes data are sorted, and "{}" does '
"not promise to be sorted. To supress this warning, "
"check that this table is sorted, then add the sorted "
'flag to this rspecifier ("{}")'.format(uns_rspec, sor_rspec)
)
if self.ignore_missing and len(self.table_specifiers) > 1:
self._sample_generator_for_epoch = self._ignore_epoch
else:
self._sample_generator_for_epoch = self._no_ignore_epoch
def _ignore_epoch(self):
"""Epoch of samples w/ ignore_missing"""
iters = tuple(
io_open(spec[0], spec[1], **spec[2]).items()
for spec in self.table_specifiers
)
num_samples = 0
num_tabs = len(iters)
try:
while True:
samp_tup = [None] * num_tabs
high_key = None
tab_idx = 0
while tab_idx < num_tabs:
if samp_tup[tab_idx] is None:
key, value = next(iters[tab_idx])
if high_key is None:
high_key = key
elif high_key < key:
# key is further along than keys in
# samp_tup. Discard those and keep this
samp_tup = [None] * num_tabs
samp_tup[tab_idx] = value
high_key = key
tab_idx = 0
continue
elif high_key > key:
# key is behind high_key. keep pushing this
# iterator forward
continue
samp_tup[tab_idx] = value
tab_idx += 1
num_samples += 1
for sub_batch_idx, axis_idx in self.axis_lengths:
samp_tup.append(
np.array(samp_tup[sub_batch_idx], copy=False).shape[axis_idx]
)
if self.add_key:
samp_tup.insert(0, key)
if self.num_sub != 1:
yield tuple(samp_tup)
else:
yield samp_tup[0]
except StopIteration:
pass
# don't care if one iterator ends first - rest will be missing
# that iterator's value
if self._num_samples is None:
self._num_samples = num_samples
elif self._num_samples != num_samples:
raise IOError(
"Different number of samples from last time! (is a "
"table from stdin?)"
)
def _no_ignore_epoch(self):
"""Epoch of samples w/o ignore_missing"""
iters = tuple(
io_open(spec[0], spec[1], **spec[2]).items()
for spec in self.table_specifiers
)
num_samples = 0
for kv_pairs in zip(*iters):
samp_tup = []
past_key = None
for tab_idx, (key, sample) in enumerate(kv_pairs):
if past_key is None:
past_key = key
elif past_key != key:
# assume sorted, base on which is first
if past_key < key:
miss_rspec = self.table_specifiers[tab_idx][0]
miss_key = past_key
else:
miss_rspec = self.table_specifiers[tab_idx - 1][0]
miss_key = key
raise IOError(
"Table {} missing key {} (or tables are sorted "
"differently)".format(miss_rspec, miss_key)
)
samp_tup.append(sample)
num_samples += 1
for sub_batch_idx, axis_idx in self.axis_lengths:
samp_tup.append(
np.array(samp_tup[sub_batch_idx], copy=False).shape[axis_idx]
)
if self.add_key:
samp_tup.insert(0, key)
if self.num_sub != 1:
yield tuple(samp_tup)
else:
yield samp_tup[0]
# make sure all iterators ended at the same time
for tab_idx, it in enumerate(iters):
try:
miss_key, _ = next(it)
if tab_idx:
miss_rspec = self.table_specifiers[0][0]
else:
miss_rspec = self.table_specifiers[1][0]
raise IOError("Table {} missing key {}".format(miss_rspec, miss_key))
except StopIteration:
pass
if self._num_samples is None:
self._num_samples = num_samples
elif self._num_samples != num_samples:
raise IOError(
"Different number of samples from last time! (is a "
"table from stdin?)"
)
@property
def num_samples(self) -> int:
if self._num_samples is None:
# gets set after you run through an epoch
assert (
sum(1 for _ in self.sample_generator_for_epoch()) == self._num_samples
)
return self._num_samples
[docs] def sample_generator_for_epoch(self):
return self._sample_generator_for_epoch()
sample_generator_for_epoch.__doc__ = Data.sample_generator_for_epoch.__doc__
SequentialData.num_samples.__doc__ = Data.num_samples.__doc__
ShuffledData.num_samples.__doc__ = Data.num_samples.__doc__