Source code for diplomat.wx_gui.probability_displayer

"""
Provides a plotting widget, which displays a filled line graph. Used for displaying metrics at the bottom of the UI.
"""

from enum import IntEnum
from typing import Iterable, NamedTuple, Optional
import wx
import numpy as np


[docs] class DrawMode(IntEnum): NORMAL = 0 POORLY_LABELED = 1 USER_MODIFIED = 2 USER_MODIFIED_AND_POORLY_LABELED = 3
[docs] class DrawCommand(NamedTuple): draw_mode: DrawMode points: np.ndarray point_before: np.ndarray point_after: np.ndarray
[docs] class DrawingInfo(NamedTuple): x_center: float y_center: float center_draw_mode: DrawMode segment_xs: Iterable[int] segment_fix_xs: Iterable[int] draw_commands: Iterable[DrawCommand]
[docs] class ProbabilityDisplayer(wx.Control): """ A custom wx.Control which displays a list of probabilities in the form of a line segment plot. Uses native colors so as to match other native widgets in the UI. """ # Minimum pixels between probabilities.... MIN_PROB_STEP = 10 # The number of probabilities to default to displaying on the screen... VISIBLE_PROBS = 100 # Default height, pointer triangle size, and padding values in pixels... DEF_HEIGHT = 50 TRIANGLE_SIZE = 7 TOP_PADDING = 3
[docs] def __init__( self, parent, data: np.ndarray = None, bad_locations: np.ndarray = None, text: str = None, height: int = DEF_HEIGHT, visible_probs: int = VISIBLE_PROBS, style=wx.BORDER_DEFAULT, name="ProbabilityDisplayer", **kwargs ): """ Construct a new ProbabilityDisplayer.... :param parent: The parent widget. :param data: The probability data, a 1D numpy array of number type values. :param text: The text to display in the top left corner of this probability display, or None to display no text. :param w_id: wx ID of the window, and integer. Defaults to wx.ID_ANY. :param height: The minimum height of the probability display. Defaults to 50 pixels. :param visible_probs: The max number of probabilities to show on screen at once. Defaults to 100. :param pos: WX Position of control. Defaults to wx.DefaultPosition. :param size: WX Size of the control. Defaults to wx.DefaultSize. :param style: WX Control Style. See wx.Control docs for possible options. (Defaults to wx.BORDER_DEFAULT). :param validator: WX Validator, defaults to :param name: WX internal name of widget. """ super().__init__(parent, style=style | wx.FULL_REPAINT_ON_RESIZE, name=name, **kwargs) # This tell WX that we are going to handle background painting ourselves, disabling system background clearing # and avoiding glitchy rendering and flickering... self.SetBackgroundStyle(wx.BG_STYLE_PAINT) if((len(data.shape) != 1)): raise ValueError("Invalid data! Must be a numpy array of 1 dimension...") self._data = np.copy(data) self._bad_locations = bad_locations.astype(np.uint64) self._user_modified_from_last_pass = np.array([], dtype=np.uint64) self._max_data_point = np.nanmax(self._data) self._refresh_bad_locations() self._ticks_visible = visible_probs self._segment_starts = None self._segment_fix_frames = None self._best_size = wx.Size(self.MIN_PROB_STEP * 5, max(height, (self.TRIANGLE_SIZE * 4) + self.TOP_PADDING)) self.SetMinSize(self._best_size) self.SetInitialSize(self._best_size) self._current_index = 0 self._text = text # Rig up paint event, and also disable erase background event... self.Bind(wx.EVT_PAINT, self.on_paint) self.Bind(wx.EVT_ERASE_BACKGROUND, lambda evt: None)
def DoGetBestSize(self): return self._best_size
[docs] def on_paint(self, event: wx.PaintEvent): """ PRIVATE: Triggered on a wx paint event, redraws the probability display... """ # If the platform already uses double buffering, use a plain old PaintDC, otherwise use a BufferedPaintDC to # avoid flickering on unbuffered platforms.... painter = wx.PaintDC(self) if(self.IsDoubleBuffered()) else wx.BufferedPaintDC(self) # Using a GCDC allows for much prettier aliased painting, making plot look nicer. painter = wx.GCDC(painter) self.on_draw(painter)
@staticmethod def _is_touched(idx, bad_labels, old_user_mods): if(len(old_user_mods) == 0): return False idx2 = np.searchsorted(old_user_mods, idx) low_goal = old_user_mods[max(0, idx2 - 1)] high_goal = old_user_mods[min(idx2, len(old_user_mods) - 1)] low_idx = np.searchsorted(bad_labels, low_goal) high_idx = np.searchsorted(bad_labels, high_goal) mid_idx = np.searchsorted(bad_labels, idx) low_value = int(bad_labels[min(low_idx, len(bad_labels) - 1)]) high_value = int(bad_labels[min(high_idx, len(bad_labels) - 1)]) mid_value = int(bad_labels[min(mid_idx, len(bad_labels) - 1)]) low_gap_match = low_value == low_goal and mid_idx - low_idx == mid_value - low_value high_gap_match = high_value == high_goal and high_idx - mid_idx == high_value - mid_value return (mid_value == idx) and (low_gap_match or high_gap_match) @classmethod def _get_draw_commands( cls, x_arr, y_arr, mode_arr, low_val, bad_labels, old_user_mods ) -> Iterable[DrawCommand]: # We compute islands by finding locations where the pairwise # difference of modes array is non-zero (Indicating mode change) change_locs = np.flatnonzero(mode_arr[1:] - mode_arr[:-1]) + 1 change_locs = np.concatenate(([0], change_locs, [len(mode_arr)])) for start, end in zip(change_locs[:-1], change_locs[1:]): before_idx = max(start - 1, 0) after_idx = min(end, len(x_arr) - 1) if(np.isnan(x_arr[before_idx]) or np.isnan(y_arr[before_idx])): before_idx = start if(np.isnan(x_arr[after_idx]) or np.isnan(y_arr[after_idx])): after_idx -= 1 mode = DrawMode(mode_arr[start]) if(mode_arr[start] == DrawMode.POORLY_LABELED): if(cls._is_touched(low_val + start, bad_labels, old_user_mods)): mode = DrawMode.USER_MODIFIED_AND_POORLY_LABELED yield DrawCommand( mode, np.stack([x_arr[start:end], y_arr[start:end]], -1), np.array([x_arr[before_idx], y_arr[before_idx]]), np.array([x_arr[after_idx], y_arr[after_idx]]) ) def _compute_points(self, height: int, width: int) -> DrawingInfo: """ PRIVATE: Computes the points to be rendered to the screen given the probability data. :param height: The height of the control. :param width: The width of the control. :returns: An iterable of tuples of (str, int, int, numpy array), which represent: - - The center of the widget horizontally. - The current highlighted index or selected index within the point list returned. - A numpy array of shape (N, 2). Representing the X, Y locations of points. These can be directly drawn to the widget. """ data = self._data # Compute the amount of probabilities to display per side based on configured parameters... tick_step = max(self.MIN_PROB_STEP, int(width / self._ticks_visible)) center = (width // 2) values_per_side = (center - 1) // tick_step # Compute the lowest and highest indexes for probabilities we can show... low_val = max(self._current_index - values_per_side, 0) high_val = min(self._current_index + values_per_side + 1, len(data)) # Points are distributed evenly by tick_step on x-axis.... On y axis we set there value by interpolating # between the available space for the probabilities within the widget. offset = center - ((self._current_index - low_val) * tick_step) # Identify "bad" locations as they'll be drawn differently... low_bad = np.searchsorted(self._bad_locations, low_val) high_bad = np.searchsorted(self._bad_locations, high_val) bad_locations = self._bad_locations[low_bad:high_bad] - low_val # If there are segments, identify what segments we can see... if(self._segment_starts is not None): seg_low = np.searchsorted(self._segment_starts[1:], low_val) seg_high = np.searchsorted(self._segment_starts[1:], high_val) seg_offsets = (self._segment_starts[1:][seg_low:seg_high] - low_val) * tick_step + offset - (tick_step / 2) else: seg_offsets = np.array([]) # If there are segments, identify what segments we can see... if(self._segment_fix_frames is not None): seg_low = np.searchsorted(self._segment_fix_frames, low_val) seg_high = np.searchsorted(self._segment_fix_frames, high_val) seg_fix_offsets = (self._segment_fix_frames[seg_low:seg_high] - low_val) * tick_step + offset else: seg_fix_offsets = np.array([]) x = np.arange(0, high_val - low_val) * tick_step + offset y = data[low_val:high_val] y = (1 - (y / self._max_data_point)) * (height - ((self.TRIANGLE_SIZE * 2) + self.TOP_PADDING)) + self.TOP_PADDING # Build a mode array. mode = np.zeros(len(y), dtype=np.int8) mode[bad_locations] = DrawMode.POORLY_LABELED mode[np.isnan(y)] = DrawMode.USER_MODIFIED return DrawingInfo( int(center), y[self._current_index - low_val], DrawMode(mode[self._current_index - low_val]), seg_offsets.astype(int), seg_fix_offsets.astype(int), self._get_draw_commands( x, y, mode, low_val, self._bad_locations, self._user_modified_from_last_pass ) )
[docs] def on_draw(self, dc: wx.DC): """ For internal use! Executed on drawing update, redraws the probability display. Expects a wx.DC for drawing to. """ width, height = self.GetClientSize() if((not width) or (not height)): return # Clear the background with the default color... dc.SetBackground( wx.Brush(self.GetBackgroundColour(), wx.BRUSHSTYLE_SOLID) ) dc.Clear() # Colors used in pens and brushes below... highlight_color = wx.SystemSettings.GetColour(wx.SYS_COLOUR_HIGHLIGHT) highlight_color2 = wx.Colour( *highlight_color[:3], int(highlight_color.Alpha() * 0.3) ) # WX widgets doesn't provide an error highlight color. Since the # highlight color doesn't typically match the foreground or # background, we take the complement of it as a second selection color # (This color happens to usually be a Blue, so this typically produces # a Red/Orange) error_color = wx.Colour( 255 - highlight_color.Red(), 255 - highlight_color.Green(), 255 - highlight_color.Blue(), highlight_color.Alpha() ) error_color2 = wx.Colour(*error_color[:3], int(error_color.Alpha() * 0.3)) foreground_color = self.GetForegroundColour() fixed_error_color = wx.Colour(*((( np.asarray(self.GetBackgroundColour(), int) + np.asarray(foreground_color, int) ) / 2).astype(int))) fixed_error_color2 = wx.Colour( *fixed_error_color[:3], int(fixed_error_color.Alpha() * 0.3) ) # All the pens and brushes we will need... transparent_pen = wx.Pen(highlight_color, 2, wx.PENSTYLE_TRANSPARENT) # Primary highlight color (normal plot locations)... highlight_pen = wx.Pen(highlight_color, 2, wx.PENSTYLE_SOLID) highlight_pen2 = wx.Pen(highlight_color, 5, wx.PENSTYLE_SOLID) highlight_brush = wx.Brush(highlight_color2, wx.BRUSHSTYLE_SOLID) error_pen = wx.Pen(error_color, 2, wx.PENSTYLE_SOLID) error_pen2 = wx.Pen(error_color, 5, wx.PENSTYLE_SOLID) error_brush = wx.Brush(error_color2, wx.BRUSHSTYLE_SOLID) fixed_error_pen = wx.Pen(fixed_error_color, 2, wx.PENSTYLE_SOLID) fixed_error_pen2 = wx.Pen(fixed_error_color, 5, wx.PENSTYLE_SOLID) fixed_error_brush = wx.Brush(fixed_error_color2, wx.BRUSHSTYLE_SOLID) indicator_brush = wx.Brush(foreground_color, wx.BRUSHSTYLE_SOLID) indicator_pen = wx.Pen(foreground_color, 5, wx.PENSTYLE_SOLID) indicator_pen2 = wx.Pen(foreground_color, 1, wx.PENSTYLE_SOLID) # This patches the point drawing for the latest versions of wxWidgets, which don't respect the pen's width correctly... def draw_points(points, pen): top = np.round((np.asarray(points) - pen.GetWidth() / 2)).astype(int) args = np.concatenate([top, np.full(top.shape, pen.GetWidth(), dtype=int)], axis=-1) dc.DrawEllipseList(args, transparent_pen, wx.Brush(pen.GetColour(), wx.BRUSHSTYLE_SOLID)) # Compute the center and points to place on the line... draw_info = self._compute_points(height, width) for seg_x in draw_info.segment_xs: seg_x = int(seg_x) dc.DrawLineList([[seg_x, 0, seg_x, height]], fixed_error_pen) dc.DrawPolygonList([[ [seg_x - int(self.TRIANGLE_SIZE / 2), 0], [seg_x + int(self.TRIANGLE_SIZE / 2), 0], [seg_x, int(self.TRIANGLE_SIZE)] ]], fixed_error_pen, fixed_error_brush) for seg_x in draw_info.segment_fix_xs: seg_x = int(seg_x) dc.DrawPolygonList([[ [seg_x - int(self.TRIANGLE_SIZE / 2), 0], [seg_x + int(self.TRIANGLE_SIZE / 2), 0], [seg_x, int(self.TRIANGLE_SIZE)] ]], indicator_pen2, highlight_brush) # Plot all of the points the filled-in polygon underneath, and the line connecting the points... for draw_command in draw_info.draw_commands: if(draw_command.draw_mode == DrawMode.USER_MODIFIED): continue if(draw_command.draw_mode == DrawMode.NORMAL): pen = highlight_pen pen2 = highlight_pen2 brush = highlight_brush elif(draw_command.draw_mode == DrawMode.POORLY_LABELED): pen = error_pen pen2 = error_pen2 brush = error_brush else: pen = fixed_error_pen pen2 = fixed_error_pen2 brush = fixed_error_brush poly_begin_point = (draw_command.points[0] + draw_command.point_before) / 2 poly_end_point = (draw_command.points[-1] + draw_command.point_after) / 2 wrap_polygon_points = np.array([ poly_end_point, [poly_end_point[0], height], [poly_begin_point[0], height], poly_begin_point ]) dc.DrawPolygonList( [np.concatenate((draw_command.points, wrap_polygon_points))], transparent_pen, brush ) all_points = np.concatenate( ([poly_begin_point], draw_command.points, [poly_end_point]) ) dc.DrawLineList(np.concatenate((all_points[1:], all_points[:-1]), 1).astype(int), pen) draw_points(draw_command.points.astype(int), pen2) # Draw the current location indicating line, point and arrow, indicates which data point we are currently on. dc.DrawLineList([[int(draw_info.x_center), 0, int(draw_info.x_center), height]], indicator_pen2) dc.DrawPolygonList([[ [int(draw_info.x_center - self.TRIANGLE_SIZE), height], [int(draw_info.x_center + self.TRIANGLE_SIZE), height], [int(draw_info.x_center), height - int(self.TRIANGLE_SIZE * 1.5)] ]], indicator_pen2, indicator_brush) if(draw_info.center_draw_mode != DrawMode.USER_MODIFIED): draw_points([[int(draw_info.x_center), int(draw_info.y_center)]], indicator_pen) # If the user set the name of this probability display plot, write it to the top-left corner... if(self._text is not None): back_pen = wx.Pen(self.GetBackgroundColour(), 3, wx.PENSTYLE_SOLID) back_brush = wx.Brush(self.GetBackgroundColour(), wx.BRUSHSTYLE_SOLID) dc.SetTextBackground(self.GetBackgroundColour()) dc.SetTextForeground(self.GetForegroundColour()) dc.SetFont(self.GetFont()) size: wx.Size = dc.GetTextExtent(self._text) width, height = size.GetWidth(), size.GetHeight() dc.DrawRectangleList([(0, 0, width, height)], back_pen, back_brush) dc.DrawText(self._text, 0, 0)
[docs] def set_location(self, location: int): """ Set the current location of the probability display, or the index to which the arrow is pointing. :param location: A integer, being the frame or index to make this probability display center and point to. """ if(not (0 <= location < self._data.shape[0])): raise ValueError(f"Location {location} is not within the range: 0 through {self._data.shape[0]}.") self._current_index = location self.Refresh()
[docs] def get_location(self) -> int: """ Get the current location of the probability display, or the index to which the arrow is pointing. :returns: A integer, being the frame or index of the currently pointed to location. """ return self._current_index
[docs] def set_data(self, data: np.ndarray): """ Set all of the data. :param data: Numpy array of numbers, will be copied over into the internal data store and displayed on next redraw. """ self._data[:] = data self._max_data_point = np.nanmax(self._data) self._refresh_bad_locations() self.Refresh()
[docs] def get_data(self) -> np.ndarray: """ Get all the data. :returns: Numpy array of numbers, the data of this probability display. The returned array is a read only view... """ view = self._data.view() view.flags.writeable = False return view
[docs] def set_data_at(self, frame: int, value: float): """ Set the data at the given frame to the given value. :param frame: The frame or index to set the data value at. :param value: A float or number, the value to assign at the data point. """ self._data[frame] = value self._max_data_point = np.nanmax([self._max_data_point, value]) self._refresh_bad_locations() self.Refresh()
[docs] def get_data_at(self, frame: int) -> float: """ Get the data at a given index or frame within the probability display. :param frame: The index or frame to get the probability data of. :returns: A float, being the data at the given frame. """ return self._data[frame]
def _refresh_bad_locations(self): # Remove nan values... self._bad_locations = self._bad_locations[ np.isfinite(self._data[self._bad_locations]) ]
[docs] def get_user_modified_locations(self) -> np.ndarray: """ Get the current user modified locations... """ return np.flatnonzero(np.isnan(self._data))
def set_prior_modified_user_locations(self, value: np.ndarray): self._user_modified_from_last_pass = np.asarray(value, dtype=np.uint64) def get_prior_modified_user_locations(self) -> np.ndarray: return self._user_modified_from_last_pass
[docs] def set_bad_locations(self, locations: np.ndarray): """ Set the list of indexes specifying poorly annotated locations during the video. :param locations: List of integers, the indexes of poorly annotated locations within the video. """ self._bad_locations = locations.astype(np.uint64) self._refresh_bad_locations()
[docs] def get_bad_locations(self) -> np.ndarray: """ Get the list of indexes specifying poorly annotated locations during the video. :returns: A read only view of the numpy array storing indexes of poorly annotated frames. """ view = self._bad_locations.view() view.flags.writeable = False return view
[docs] def get_prev_bad_location(self, location: int = None, orig_location = None, moves_done = 0) -> int: """ Get the previous bad location based on the current location in the probability display. :returns: An integer, the index of the nearest previous bad location. """ if(location is None): location = self.get_location() if(orig_location is None): orig_location = location if(len(self._bad_locations) == 0): return location idx = np.searchsorted(self._bad_locations, location, side="left") is_bad_spot = self._bad_locations[idx % len(self._bad_locations)] == location idx -= 1 if(is_bad_spot): while(location - int(self._bad_locations[idx]) == 1): location = int(self._bad_locations[idx]) idx -= 1 if(self._is_touched( self._bad_locations[idx], self._bad_locations, self._user_modified_from_last_pass )): val = self._bad_locations[idx] if(val >= location): val = -len(self._data) + val moves_done += location - val if(moves_done >= len(self._data)): return orig_location return self.get_prev_bad_location( self._bad_locations[idx], orig_location, moves_done ) return int(self._bad_locations[idx])
[docs] def get_next_bad_location(self, location: int = None, orig_location = None, moves_done = 0) -> int: """ Get the next bad location based on the current location in the probability display. :returns: An integer, the index of the nearest next bad location. """ if(location is None): location = self.get_location() if(orig_location is None): orig_location = location if(len(self._bad_locations) == 0): return location idx = np.searchsorted(self._bad_locations, location, side="right") is_bad_spot = self._bad_locations[idx - 1] == location idx = idx % len(self._bad_locations) if(is_bad_spot): while(int(self._bad_locations[idx]) - location == 1): location = int(self._bad_locations[idx]) idx = (idx + 1) % len(self._bad_locations) if(self._is_touched( self._bad_locations[idx], self._bad_locations, self._user_modified_from_last_pass )): val = self._bad_locations[idx] if(val <= location): val = len(self._data) + val moves_done += val - location if(moves_done >= len(self._data)): return orig_location return self.get_prev_bad_location( self._bad_locations[idx], orig_location, moves_done ) return int(self._bad_locations[idx])
[docs] def get_text(self) -> str: """ Get the display text for this probability display. :returns: A string, being the display text. """ return self._text
[docs] def set_text(self, value: str): """ Set the display text of this probability display. :param value: The string value to set the display text to... """ self._text = value
def set_segment_starts(self, value: Optional[np.ndarray]): self._segment_starts = value def set_segment_fix_frames(self, value: Optional[np.ndarray]): self._segment_fix_frames = value
[docs] def test_demo_displayer(): app = wx.App() print_all_sys_colors() frame = wx.Frame(None, wx.ID_ANY, "Test Window") layout = wx.BoxSizer(wx.VERTICAL) data = np.random.rand(100) data[np.random.randint(0, 100, 5)] = np.nan prob_display = ProbabilityDisplayer(frame, data, np.flatnonzero(data < 0.1), text="Test1") prob_display.set_segment_starts(np.unique(np.random.randint(0, 100, 5))) prob_display.set_segment_fix_frames(np.unique(np.random.randint(0, 100, 5))) layout.Add(prob_display, 1, wx.EXPAND) prob_display2 = ProbabilityDisplayer(frame, data, np.flatnonzero(data < 0.1), text="Test2") layout.Add(prob_display2, 1, wx.EXPAND) slider = wx.Slider(frame, minValue=0, maxValue=len(data) - 1) layout.Add(slider, 0, wx.EXPAND) def do(evt): prob_display.set_location(slider.GetValue()) prob_display2.set_location(slider.GetValue()) slider.Bind(wx.EVT_SLIDER, do) frame.SetSizerAndFit(layout) frame.SetSize(500, 100) frame.Show() app.MainLoop()
if(__name__ == "__main__"): test_demo_displayer()