Source code for diplomat.frontends.deeplabcut.load_model

import shutil
import tempfile
from io import BytesIO
from pathlib import Path, PosixPath, PurePosixPath
from typing import Optional, List
from zipfile import is_zipfile, ZipFile

import numpy as np
import yaml

import diplomat.processing.type_casters as tc
from diplomat.frontends import ModelInfo, ModelLike
from ._verify_func import _load_dlc_like_zip_file
from .dlc_importer import ort, tf, onnx
from diplomat.processing import TrackingData
from diplomat.utils.cli_tools import Flag


def _get_model_folder(
    cfg: dict,
    project_root: Path,
    shuffle: int = 1,
    train_fraction: float = None,
    model_prefix: str = "",
    is_pytorch: bool = False,
) -> Path:
    task = cfg["Task"]
    date = cfg["date"]
    iterate = f"iteration-{str(cfg['iteration'])}"
    train_fraction = (
        train_fraction if (train_fraction is not None) else cfg["TrainingFraction"][0]
    )
    model_prefix = (
        ""
        if (model_prefix in ["..", "."] or "/" in model_prefix or "\\" in model_prefix)
        else model_prefix
    )
    return Path(project_root) / Path(
        model_prefix,
        "dlc-models-pytorch" if is_pytorch else "dlc-models",
        iterate,
        f"{task}{date}-trainset{str(int(train_fraction * 100))}shuffle{str(shuffle)}",
    )


def _build_provider_ordering(device_index: Optional[int], use_cpu: bool):
    supported_devices = ort.get_available_providers()
    device_config = []

    def _add(val, extra=None):
        if extra is None:
            extra = {}
        if device_index is not None:
            extra["device_id"] = device_index
        return (val, extra)

    if not use_cpu:
        if "CUDAExecutionProvider" in supported_devices:
            device_config.append(_add("CUDAExecutionProvider"))
        if "ROCMExecutionProvider" in supported_devices:
            device_config.append(_add("ROCMExecutionProvider"))
        if "CoreMLExecutionProvider" in supported_devices:
            device_config.append("CoreMLExecutionProvider")

    # Fallback...
    device_config.append("CPUExecutionProvider")
    return device_config


def _prune_tf_model(graph_def, outputs: List[str]):
    name_to_idx = {n.name: i for i, n in enumerate(graph_def.node)}
    visited = [False] * len(name_to_idx)

    if not all(o in name_to_idx for o in outputs):
        raise ValueError("Not all output nodes exist in the model!")

    stack = []
    stack.extend(outputs)

    while len(stack) > 0:
        node_name = stack.pop()
        idx = name_to_idx[node_name]
        visited[idx] = True
        for input_node_name in graph_def.node[idx].input:
            if (
                input_node_name in name_to_idx
                and not visited[name_to_idx[input_node_name]]
            ):
                stack.append(input_node_name)

    temp_stack = []
    for i in range(len(visited) - 1, -1, -1):
        node = graph_def.node.pop()
        if visited[i]:
            temp_stack.append(node)
    graph_def.node.extend(temp_stack[::-1])

    print(f"Total nodes: {len(visited)}")
    print(f"Removed nodes: {len(visited) - sum(visited)}")

    return graph_def


def _load_meta_graph_def(meta_file):
    meta_graph_def = tf.compat.v1.MetaGraphDef()
    with open(meta_file, "rb") as f:
        meta_graph_def.MergeFromString(f.read())
    return meta_graph_def


def from_checkpoint(model_path, input_names, output_names):
    """Load tensorflow graph from checkpoint."""
    import tensorflow as tf
    import tf2onnx

    tf_v1 = tf.compat.v1
    # make sure we start with clean default graph
    tf_v1.reset_default_graph()
    # model_path = checkpoint/checkpoint.meta
    with tf.device("/cpu:0"):
        with tf_v1.Session() as sess:
            saver = tf_v1.train.import_meta_graph(model_path, clear_devices=True)
            # restore from model_path minus the ".meta"
            sess.run(tf_v1.global_variables_initializer())
            saver.restore(sess, model_path[:-5])
            input_names = tf2onnx.tf_loader.inputs_without_resource(sess, input_names)
            frozen_graph = tf2onnx.tf_loader.freeze_session(
                sess, input_names=input_names, output_names=output_names
            )
            input_names = tf2onnx.tf_loader.remove_redundant_inputs(
                frozen_graph, input_names
            )

        tf_v1.reset_default_graph()
        with tf_v1.Session() as sess:
            frozen_graph = tf2onnx.tf_loader.tf_optimize(
                input_names, output_names, frozen_graph
            )
    tf_v1.reset_default_graph()
    return frozen_graph, input_names, output_names


def _find_direct_consumers(graph_def, node):
    consumers = []
    for n in graph_def.node:
        for i, ins in enumerate(n.input):
            if ins == node:
                consumers.append(f"{n.name}:{i}")

    return consumers


def _get_dlc_inputs_and_outputs(meta_path):
    meta_graph_def = _load_meta_graph_def(meta_path)

    desired_outputs = [
        ("pose/part_pred/block4/BiasAdd:0", True),
        ("pose/locref_pred/block4/BiasAdd:0", False),
    ]
    output_names = []
    op_names = {n.name for n in meta_graph_def.graph_def.node}

    for op_name, is_required in desired_outputs:
        op_only = op_name.split(":")[0]
        if op_only in op_names:
            output_names.append(op_name)
        elif is_required:
            raise ValueError(
                f"Unable to find weights for layer: {op_name} in DLC model, which is required."
            )

    input_names = ["fifo_queue_Dequeue:0"]
    for input_name in input_names:
        if input_name.split(":")[0] not in op_names:
            raise ValueError("Can't find input node!")
    return input_names, output_names


def _load_and_convert_model(
    model_dir: Path, device_index: Optional[int], use_cpu: bool
):
    import tensorflow as tf
    import tensorflow.compat.v1 as tf_v1
    import tf2onnx

    tf.compat.v1.disable_eager_execution()
    tf.compat.v1.disable_v2_behavior()
    tf.compat.v1.reset_default_graph()

    meta_files = [
        file
        for file in model_dir.iterdir()
        if file.stem.startswith("snapshot-") and file.suffix == ".meta"
    ]
    if len(meta_files) == 0:
        raise ValueError(
            "No checkpoint files, make sure you've trained a DLC model first!"
        )
    latest_meta_file = max(meta_files, key=lambda k: int(k.stem.split("-")[-1]))

    inputs, outputs = _get_dlc_inputs_and_outputs(str(latest_meta_file))

    graph_def, inputs, outputs = from_checkpoint(str(latest_meta_file), inputs, outputs)

    model, __ = tf2onnx.convert.from_graph_def(
        graph_def,
        name=str(latest_meta_file.name),
        input_names=inputs,
        output_names=outputs,
        shape_override={inputs[0]: [None, None, None, 3]},
        opset=17,
    )

    b = BytesIO()
    onnx.save(model, b)

    return ort.InferenceSession(
        b.getvalue(), providers=_build_provider_ordering(device_index, use_cpu)
    )


class FakeTempDir:
    def __init__(self, name):
        self.name = name

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass


class FrameExtractor:
    def __init__(self, onnx_model: ort.InferenceSession, model_config: dict):
        self._model = onnx_model
        self._image_input_name = self._model.get_inputs()[0].name
        self._config = model_config

    def __call__(self, frames: np.ndarray) -> TrackingData:
        outputs = self._model.run(
            None, {self._image_input_name: frames.astype(np.float32)}
        )

        locref = outputs[1] if (len(outputs) > 1) else None

        if locref is not None:
            locref = locref.reshape((*locref.shape[:-1], -1, 2))
            locref *= self._config["locref_stdev"]

        return TrackingData(
            1 / (1 + np.exp(-outputs[0])),
            locref,
            float(
                np.ceil(
                    max(
                        frames.shape[1] / outputs[0].shape[1],
                        frames.shape[2] / outputs[0].shape[2],
                    )
                )
            ),
        )


[docs] @tc.typecaster_function def load_model( config: tc.PathLike, num_outputs: tc.Optional[int] = None, batch_size: tc.Optional[int] = None, gpu_index: tc.Optional[int] = None, model_prefix: str = "", shuffle: int = 1, training_set_index: int = 0, use_cpu: Flag = False, ) -> tc.Tuple[ModelInfo, ModelLike]: """ Run DIPLOMAT tracking on videos using a DEEPLABCUT project and trained network. :param config: The path to the DLC config for the DEEPLABCUT project. :param shuffle: int, optional. Integer specifying which TrainingsetFraction to use. By default, the first (note that TrainingFraction is a list in config.yaml). :param training_set_index: int, optional. Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml). :param gpu_index: Integer index of the GPU to use for inference (in tensorflow) defaults to 0, or selecting the first detected GPU if available. :param batch_size: The batch size to use while processing. Defaults to None, which uses the default batch size for the project. :param model_prefix: The string prefix of the DEEPLABCUT model to use defaults to no prefix (the default model). :param num_outputs: The number of outputs, or bodies to track in the video. Defaults to the value specified in the DLC config, or None if one is not specified. :param use_cpu: If True, run on cpu even if a gpu is available. Defaults to False. :return: A model info dictionary, and a deeplabcut model wrapper that can be used to estimate poses from video frames. """ if isinstance(config, (tuple, list)): if len(config) != 1: raise ValueError("Can't pass multiple config files!") config = config[0] if is_zipfile(config): tmp_dir = tempfile.TemporaryDirectory() is_zip = True else: tmp_dir = FakeTempDir(str(config)) is_zip = False with tmp_dir as tmp_dir: if is_zip: with ZipFile(config, "r") as z: config_path, config = _load_dlc_like_zip_file(z) zip_project_dir = PurePosixPath(config_path).parent for zip_info in z.infolist(): if zip_info.is_dir(): continue zip_path_obj = PurePosixPath(zip_info.filename) try: sub_path = zip_path_obj.relative_to(zip_project_dir) if sub_path.parts[0] not in ["dlc-models", "config.yaml"]: continue dst_path = Path(tmp_dir, sub_path).resolve() dst_path.parent.mkdir(parents=True, exist_ok=True) with z.open(zip_info, "r") as fsrc: with Path(tmp_dir, sub_path).open("wb") as fdst: shutil.copyfileobj(fsrc, fdst) except ValueError: pass project_dir = tmp_dir else: project_dir = Path(config).resolve().parent with open(config, "rb") as f: config = yaml.load(f, yaml.SafeLoader) iteration = config["iteration"] train_frac = config["TrainingFraction"][training_set_index] model_directory = _get_model_folder( config, project_dir, shuffle, train_frac, model_prefix ) model_directory = model_directory.resolve() try: with (model_directory / "test" / "pose_cfg.yaml").open("rb") as f: model_config = yaml.load(f, yaml.SafeLoader) except FileNotFoundError as e: print(e) raise FileNotFoundError( f"Invalid model selection: (Iteration {iteration}, Training Fraction {train_frac}, Shuffle: {shuffle})" ) # Set the number of outputs... num_outputs = ( config.get("num_outputs", model_config.get("num_outputs", None)) if (num_outputs is None) else num_outputs ) if num_outputs is not None: num_outputs = int(num_outputs) batch_size = batch_size if (batch_size is not None) else config["batch_size"] body_parts = list(model_config["all_joints_names"]) if "partaffinityfield_graph" in model_config: skeleton = sorted( set( tuple(sorted([body_parts[a], body_parts[b]])) for a, b in model_config["partaffinityfield_graph"] ) ) else: skeleton = [] return ( ModelInfo( num_outputs=num_outputs, batch_size=batch_size, dotsize=int(config.get("dotsize", 4)), colormap=config.get("colormap", None), shape_list=None, alphavalue=config.get("alphavalue", 0.7), pcutoff=config.get("pcutoff", 0.1), line_thickness=1, bp_names=body_parts, skeleton=skeleton, frontend="deeplabcut", ), FrameExtractor( _load_and_convert_model( model_directory / "train", gpu_index, bool(use_cpu) ), model_config, ), )