Source code for diplomat.frontends.csv.label_videos

from pathlib import Path
from typing import Tuple

from diplomat.utils.cli_tools import extra_cli_args
from diplomat.frontends.sleap.visual_settings import FULL_VISUAL_SETTINGS
import diplomat.processing.type_casters as tc
from diplomat.utils.track_formats import load_diplomat_table, to_diplomat_pose
from diplomat.utils.video_io import ContextVideoWriter, ContextVideoCapture
from diplomat.utils.shapes import CV2DotShapeDrawer, shape_iterator
from diplomat.processing import Config, TQDMProgressBar
from diplomat.utils.colormaps import iter_colormap

import cv2

from .csv_utils import _fix_paths


[docs] @extra_cli_args(FULL_VISUAL_SETTINGS, auto_cast=False) @tc.typecaster_function def label_videos( config: tc.Union[tc.List[tc.PathLike], tc.PathLike], videos: tc.Union[tc.List[tc.PathLike], tc.PathLike], body_parts_to_plot: tc.Optional[tc.List[str]] = None, video_extension: str = "mp4", **kwargs ): """ Labeled videos with arbitrary csv files in diplomat's csv format. :param config: The path (or list of paths) to the csv file(s) to label the videos with. :param videos: Paths to video file(s) corresponding to the provided csv files. :param body_parts_to_plot: A set or list of body part names to label, or None, indicating to label all parts. :param video_extension: The file extension to use on the created labeled video, excluding the dot. Defaults to 'mp4'. :param kwargs: The following additional arguments are supported: {extra_cli_args} """ config, videos = _fix_paths(config, videos) visual_settings = Config(kwargs, FULL_VISUAL_SETTINGS) for c, v in zip(config, videos): _label_videos_single(str(c), str(v), body_parts_to_plot, video_extension, visual_settings)
class EverythingSet: def __contains__(self, item): return True def _to_cv2_color(color: Tuple[float, float, float, float]) -> Tuple[int, int, int, int]: r, g, b, a = [min(255, max(0, int(val * 256))) for val in color] return (b, g, r, a) def _label_videos_single( csv: str, video: str, body_parts_to_plot: tc.Optional[tc.List[str]], video_extension: str, visual_settings: Config ): pose_data = load_diplomat_table(csv) poses, bp_names, num_outputs = to_diplomat_pose(pose_data) video_extension = video_extension if(video_extension.startswith(".")) else f".{video_extension}" video_path = Path(video) # Create the output path... output_path = video_path.parent / (video_path.stem + "_labeled" + video_extension) body_parts_to_plot = EverythingSet() if(body_parts_to_plot is None) else set(body_parts_to_plot) upscale = 1 if(visual_settings.upscale_factor is None) else visual_settings.upscale_factor with ContextVideoCapture(video) as in_video: out_w, out_h = tuple( int(dim * upscale) for dim in [ in_video.get(cv2.CAP_PROP_FRAME_WIDTH), in_video.get(cv2.CAP_PROP_FRAME_HEIGHT) ] ) with ContextVideoWriter( str(output_path), visual_settings.output_codec, in_video.get(cv2.CAP_PROP_FPS), (out_w, out_h) ) as writer: with TQDMProgressBar(total=poses.get_frame_count()) as p: for f_i in range(poses.get_frame_count()): retval, frame = in_video.read() if(not retval): continue if (visual_settings.upscale_factor is not None): frame = cv2.resize( frame, (out_w, out_h), interpolation=cv2.INTER_NEAREST ) overlay = frame.copy() colors = iter_colormap(visual_settings.colormap, poses.get_bodypart_count()) shapes = shape_iterator(visual_settings.shape_list, num_outputs) part_iter = zip( [name for name in bp_names for _ in range(num_outputs)], poses.get_x_at(f_i, slice(None)), poses.get_y_at(f_i, slice(None)), poses.get_prob_at(f_i, slice(None)), colors, shapes ) for (name, x, y, prob, color, shape) in part_iter: if (x != x or y != y): continue if (name not in body_parts_to_plot): continue shape_drawer = CV2DotShapeDrawer( overlay, _to_cv2_color(tuple(color[:3]) + (1,)), -1 if (prob > visual_settings.pcutoff) else visual_settings.line_thickness, cv2.LINE_AA if (visual_settings.antialiasing) else None )[shape] if (prob > visual_settings.pcutoff or visual_settings.draw_hidden_tracks): shape_drawer(int(x * upscale), int(y * upscale), int(visual_settings.dotsize * upscale)) writer.write(cv2.addWeighted( overlay, visual_settings.alphavalue, frame, 1 - visual_settings.alphavalue, 0 )) p.update()