import functools
import tempfile
from io import BytesIO
from pathlib import Path
from typing import Optional, List
from zipfile import is_zipfile, ZipFile
import numpy as np
import yaml
import diplomat.processing.type_casters as tc
from diplomat.frontends import ModelInfo, ModelLike
from ._verify_func import _load_dlc_like_zip_file
from .dlc_importer import ort, tf, tf2onnx, onnx
from diplomat.processing import TrackingData
from diplomat.utils.cli_tools import Flag
def _get_model_folder(cfg: dict, project_root: Path, shuffle: int = 1, train_fraction: float = None, model_prefix: str = ""):
task = cfg["Task"]
date = cfg["date"]
iterate = f"iteration-{str(cfg['iteration'])}"
train_fraction = train_fraction if(train_fraction is not None) else cfg["TrainingFraction"][0]
model_prefix = "" if(model_prefix in ["..", "."] or "/" in model_prefix or "\\" in model_prefix) else model_prefix
return Path(project_root) / Path(
model_prefix,
"dlc-models",
iterate,
f"{task}{date}-trainset{str(int(train_fraction * 100))}shuffle{str(shuffle)}"
)
def _build_provider_ordering(device_index: Optional[int], use_cpu: bool):
supported_devices = ort.get_available_providers()
device_config = []
def _add(val, extra=None):
if(extra is None):
extra = {}
if(device_index is not None):
extra["device_id"] = device_index
return (val, extra)
if(not use_cpu):
if("CUDAExecutionProvider" in supported_devices):
device_config.append(_add("CUDAExecutionProvider"))
if("ROCMExecutionProvider" in supported_devices):
device_config.append(_add("ROCMExecutionProvider"))
if("CoreMLExecutionProvider" in supported_devices):
device_config.append("CoreMLExecutionProvider")
# Fallback...
device_config.append("CPUExecutionProvider")
return device_config
def _prune_tf_model(graph_def, outputs: List[str]):
name_to_idx = {n.name: i for i, n in enumerate(graph_def.node)}
visited = [False] * len(name_to_idx)
if not all(o in name_to_idx for o in outputs):
raise ValueError("Not all output nodes exist in the model!")
stack = []
stack.extend(outputs)
while len(stack) > 0:
node_name = stack.pop()
idx = name_to_idx[node_name]
visited[idx] = True
for input_node_name in graph_def.node[idx].input:
if input_node_name in name_to_idx and not visited[name_to_idx[input_node_name]]:
stack.append(input_node_name)
temp_stack = []
for i in range(len(visited) - 1, -1, -1):
node = graph_def.node.pop()
if visited[i]:
temp_stack.append(node)
graph_def.node.extend(temp_stack[::-1])
print(f"Total nodes: {len(visited)}")
print(f"Removed nodes: {len(visited) - sum(visited)}")
return graph_def
def _load_meta_graph_def(meta_file):
meta_graph_def = tf.compat.v1.MetaGraphDef()
with open(meta_file, 'rb') as f:
meta_graph_def.MergeFromString(f.read())
return meta_graph_def
def from_checkpoint(model_path, input_names, output_names):
"""Load tensorflow graph from checkpoint."""
import tensorflow as tf
import tf2onnx
tf_v1 = tf.compat.v1
# make sure we start with clean default graph
tf_v1.reset_default_graph()
# model_path = checkpoint/checkpoint.meta
with tf.device("/cpu:0"):
with tf_v1.Session() as sess:
saver = tf_v1.train.import_meta_graph(model_path, clear_devices=True)
# restore from model_path minus the ".meta"
sess.run(tf_v1.global_variables_initializer())
saver.restore(sess, model_path[:-5])
input_names = tf2onnx.tf_loader.inputs_without_resource(sess, input_names)
frozen_graph = tf2onnx.tf_loader.freeze_session(sess, input_names=input_names, output_names=output_names)
input_names = tf2onnx.tf_loader.remove_redundant_inputs(frozen_graph, input_names)
tf_v1.reset_default_graph()
with tf_v1.Session() as sess:
frozen_graph = tf2onnx.tf_loader.tf_optimize(input_names, output_names, frozen_graph)
tf_v1.reset_default_graph()
return frozen_graph, input_names, output_names
def _find_direct_consumers(graph_def, node):
consumers = []
for n in graph_def.node:
for i, ins in enumerate(n.input):
if(ins == node):
consumers.append(f"{n.name}:{i}")
return consumers
def _get_dlc_inputs_and_outputs(meta_path):
meta_graph_def = _load_meta_graph_def(meta_path)
desired_outputs = [
("pose/part_pred/block4/BiasAdd:0", True),
("pose/locref_pred/block4/BiasAdd:0", False)
]
output_names = []
op_names = {n.name for n in meta_graph_def.graph_def.node}
for op_name, is_required in desired_outputs:
op_only = op_name.split(":")[0]
if(op_only in op_names):
output_names.append(op_name)
elif(is_required):
raise ValueError(f"Unable to find weights for layer: {op_name} in DLC model, which is required.")
input_names = ["fifo_queue_Dequeue:0"]
for input_name in input_names:
if(input_name.split(":")[0] not in op_names):
raise ValueError("Can't find input node!")
return input_names, output_names
def _load_and_convert_model(model_dir: Path, device_index: Optional[int], use_cpu: bool):
import tensorflow as tf
from tensorflow.python.training import py_checkpoint_reader
import tensorflow.compat.v1 as tf_v1
import tf2onnx
tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_v2_behavior()
tf.compat.v1.reset_default_graph()
meta_files = [file for file in model_dir.iterdir() if file.stem.startswith("snapshot-") and file.suffix == ".meta"]
if(len(meta_files) == 0):
raise ValueError("No checkpoint files, make sure you've trained a DLC model first!")
latest_meta_file = max(meta_files, key=lambda k: int(k.stem.split("-")[-1]))
inputs, outputs = _get_dlc_inputs_and_outputs(str(latest_meta_file))
print(inputs, outputs)
graph_def, inputs, outputs = from_checkpoint(
str(latest_meta_file), inputs, outputs
)
model, __ = tf2onnx.convert.from_graph_def(
graph_def,
name=str(latest_meta_file.name),
input_names=inputs,
output_names=outputs,
shape_override={
inputs[0]: [None, None, None, 3]
},
opset=17
)
b = BytesIO()
onnx.save(model, b)
return ort.InferenceSession(
b.getvalue(),
providers=_build_provider_ordering(device_index, use_cpu)
)
class FakeTempDir:
def __init__(self, name):
self.name = name
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class FrameExtractor:
def __init__(self, onnx_model: ort.InferenceSession, model_config: dict):
self._model = onnx_model
self._image_input_name = self._model.get_inputs()[0].name
self._config = model_config
def __call__(self, frames: np.ndarray) -> TrackingData:
outputs = self._model.run(None, {self._image_input_name: frames.astype(np.float32)})
locref = outputs[1] if(len(outputs) > 1) else None
if(locref is not None):
locref = locref.reshape((*locref.shape[:-1], -1, 2))
locref *= self._config["locref_stdev"]
return TrackingData(
1 / (1 + np.exp(-outputs[0])),
locref,
float(np.ceil(max(frames.shape[1] / outputs[0].shape[1], frames.shape[2] / outputs[0].shape[2])))
)
[docs]
@tc.typecaster_function
def load_model(
config: tc.PathLike,
num_outputs: tc.Optional[int] = None,
batch_size: tc.Optional[int] = None,
gpu_index: tc.Optional[int] = None,
model_prefix: str = "",
shuffle: int = 1,
training_set_index: int = 0,
use_cpu: Flag = False
) -> tc.Tuple[ModelInfo, ModelLike]:
"""
Run DIPLOMAT tracking on videos using a DEEPLABCUT project and trained network.
:param config: The path to the DLC config for the DEEPLABCUT project.
:param shuffle: int, optional. Integer specifying which TrainingsetFraction to use. By default, the first
(note that TrainingFraction is a list in config.yaml).
:param training_set_index: int, optional. Integer specifying which TrainingsetFraction to use. By default the first
(note that TrainingFraction is a list in config.yaml).
:param gpu_index: Integer index of the GPU to use for inference (in tensorflow) defaults to 0, or selecting the first detected GPU if available.
:param batch_size: The batch size to use while processing. Defaults to None, which uses the default batch size for the project.
:param model_prefix: The string prefix of the DEEPLABCUT model to use defaults to no prefix (the default model).
:param num_outputs: The number of outputs, or bodies to track in the video. Defaults to the value specified in the DLC config, or None if one
is not specified.
:param use_cpu: If True, run on cpu even if a gpu is available. Defaults to False.
:return: A model info dictionary, and a deeplabcut model wrapper that can be used to estimate poses from video frames.
"""
if(isinstance(config, (tuple, list))):
if(len(config) != 1):
raise ValueError("Can't pass multiple config files!")
config = config[0]
if(is_zipfile(config)):
tmp_dir = tempfile.TemporaryDirectory()
is_zip = True
else:
tmp_dir = FakeTempDir(str(config))
is_zip = False
with tmp_dir as tmp_dir:
if is_zip:
with ZipFile(config, "r") as z:
sub_path, config = _load_dlc_like_zip_file(z)
z.extractall(tmp_dir.name)
project_dir = (Path(tmp_dir.name) / sub_path).resolve().parent
else:
project_dir = Path(config).resolve().parent
with open(config, "rb") as f:
config = yaml.load(f, yaml.SafeLoader)
iteration = config["iteration"]
train_frac = config["TrainingFraction"][training_set_index]
model_directory = _get_model_folder(config, project_dir, shuffle, train_frac, model_prefix)
model_directory = model_directory.resolve()
try:
with (model_directory / "test" / "pose_cfg.yaml").open("rb") as f:
model_config = yaml.load(f, yaml.SafeLoader)
except FileNotFoundError as e:
print(e)
raise FileNotFoundError(f"Invalid model selection: (Iteration {iteration}, Training Fraction {train_frac}, Shuffle: {shuffle})")
# Set the number of outputs...
num_outputs = config.get("num_outputs", model_config.get("num_outputs", None)) if(num_outputs is None) else num_outputs
if(num_outputs is not None):
num_outputs = int(num_outputs)
batch_size = batch_size if(batch_size is not None) else config["batch_size"]
body_parts = list(model_config["all_joints_names"])
if("partaffinityfield_graph" in model_config):
skeleton = sorted(set(tuple(sorted([body_parts[a], body_parts[b]])) for a, b in model_config["partaffinityfield_graph"]))
else:
skeleton = []
return (
ModelInfo(
num_outputs=num_outputs,
batch_size=batch_size,
dotsize=int(config.get("dotsize", 4)),
colormap=config.get("colormap", None),
shape_list=None,
alphavalue=config.get("alphavalue", 0.7),
pcutoff=config.get("pcutoff", 0.1),
line_thickness=1,
bp_names=body_parts,
skeleton=skeleton,
frontend="deeplabcut"
),
FrameExtractor(
_load_and_convert_model(model_directory / "train", gpu_index, bool(use_cpu)),
model_config
)
)