"""
Provides a generic API and some core data structures for frame store formats, or files which store model outputs on disk.
"""
from typing import Any, Optional, Iterator, List, BinaryIO, MutableMapping
from abc import ABC, abstractmethod
import numpy as np
from diplomat.processing import TrackingData
# REQUIRED DATA TYPES: (With little endian encoding...)
luint8 = np.dtype(np.uint8).newbyteorder("<")
luint16 = np.dtype(np.uint16).newbyteorder("<")
luint32 = np.dtype(np.uint32).newbyteorder("<")
luint64 = np.dtype(np.uint64).newbyteorder("<")
ldouble = np.dtype(np.float64).newbyteorder("<")
lfloat = np.dtype(np.float32).newbyteorder("<")
[docs]
def to_bytes(obj: Any, dtype: np.dtype) -> bytes:
"""
Converts an object to bytes.
:param obj: The object to convert to bytes.
:param dtype: The numpy data type to interpret the object as when converting to bytes.
:return: A bytes object, representing the object obj as type dtype.
"""
return dtype.type(obj).tobytes()
[docs]
def from_bytes(data: bytes, dtype: np.dtype) -> Any:
"""
Converts bytes to a single object depending on the passed data type.
:param data: The bytes to convert to an object
:param dtype: The numpy data type to convert the bytes to.
:return: An object of the specified data type passed to this method.
"""
return np.frombuffer(data, dtype=dtype)[0]
[docs]
def string_list(lister: list):
"""
Casts object to a list of strings, enforcing type...
:param lister: The list
:return: A list of strings
:raises: ValueError if the list doesn't contain strings...
"""
lister = list(lister)
for item in lister:
if not isinstance(item, str):
raise ValueError("Must be a list of strings!")
return lister
[docs]
def edge_list(lister: list):
"""
Normalizes a list of graph edges. It does so by ordering edges so the node with the lower index goes first,
removing duplicates, and then sorting the final list of edges.
:param lister: A list of tuples, each tuple container 2 integers. Each integer is the index of a node.
:return: A list of tuples, each tuple with 2 integers, the normalized edge list.
"""
lister_new = set()
for item in lister:
if (
not isinstance(item, (list, tuple))
or len(item) != 2
or all(isinstance(v, int) for v in item)
):
raise ValueError("Must be a list of 2 integer tuples!")
a, b = item
lister_new.add(tuple(sorted([int(a), int(b)])))
return list(lister_new).sort()
[docs]
def non_max_int32(val: luint32) -> Optional[int]:
"""
Casts an object to a non-max integer, being None if it is the maximum value.
:param val: The value to cast...
:return: An integer, or None if the value equals the max possible integer.
"""
if val is None:
return None
val = int(val)
if (val == np.iinfo(luint32).max) or (val < 0):
return None
else:
return val
[docs]
class FrameReader(ABC):
"""
The frame reader API. Allows for reading frames from a diplomat frame store format to
:py:class:`~diplomat.processing.track_data.TrackingData` object.
"""
[docs]
@abstractmethod
def __init__(self, file: BinaryIO):
"""
Construct a frame read frame reader.
:param file: The file to read frames from.
"""
pass
[docs]
@abstractmethod
def has_next(self, num_frames: int = 1) -> bool:
"""
Checks to see if there are more frames available for reading.
:param num_frames: An integer, the number of frames to check for. Defaults to 1 frame.
:returns: A boolean, True if there are at least num_frames frames available for reading from the file.
Otherwise, this method returns False.
"""
pass
[docs]
@abstractmethod
def tell_frame(self) -> int:
"""
Get the current frame this frame reader is on.
:returns: An integer, being the index of the frame that the frame reader will be reading next.
"""
pass
[docs]
def seek_frame(self, frame_idx: int):
"""
Seek to the specified frame in the frame store object. Implementors of the FrameReader class are not required
to support this method, and the default implementation of this method throws a NotImplementedError.
:param frame_idx: The frame index that the frame reader will move to, an integer.
"""
raise NotImplementedError(
"Seeking functionality is not supported for this implementation of FrameReader!"
)
[docs]
@abstractmethod
def read_frames(self, num_frames: int = 1) -> TrackingData:
"""
Read frames from the frame store.
:param num_frames: The number of frames to read from the frame store, and integer. Defaults to 1.
:returns: A DIPLOMAT TrackingData object, which will contain all of the probability frames for num_frames
frames.
:throws: ValueError if the frame reader reaches the end of the file and the number of frames requested is
greater than the number of frames available in the frame store.
"""
pass
# Adds with statement support so user does not have to call close manually...
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
[docs]
@abstractmethod
def close(self):
"""
Close this frame reader. This does not close the file handler that this frame reader is utilizing, simply the
frame reader itself.
"""
pass
[docs]
class FrameWriter(ABC):
"""
The frame writer API. Allows for writing frames in the form of
:py:class:`~diplomat.processing.track_data.TrackingData` objects to a diplomat frame store format.
"""
[docs]
@abstractmethod
def __init__(
self,
file: BinaryIO,
header: DLFSHeader,
threshold: Optional[float] = 1e-6,
):
"""
Create a new frame writer.
:param file: A binary frame object, the file to write the frames to.
:param header: The DLFSHeader for this frame store, contains important metadata.
:param threshold: The minimum threshold for keeping probabilities. If set to None, this indicates to the frame
writer that the probability frames should be stored in a non-sparse way. Defaults to
1e-6.
"""
pass
[docs]
@abstractmethod
def tell_frame(self) -> int:
"""
Get the current frame this frame writer is on.
:returns: An integer, being the index of the frame that the frame writer will be writing next.
"""
pass
[docs]
@abstractmethod
def write_data(self, data: TrackingData):
"""
Write data to the file using this frame writer.
:param data: A TrackingData object, being the frames to write to the file.
:throws: ValueError if there is an attempt to write more frames than the total number of frames specified
in the DLFSHeader passed when this frame writer was created.
"""
pass
[docs]
@abstractmethod
def close(self):
"""
Close this frame writer. This does not close the file handler that this frame writer is utilizing, simply the
frame writer itself.
"""
pass
# Adds with statement support so user does not have to call close manually...
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()