Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracking Sea Megafauna from Aerial Footage – Handling High Camera Motion and Low Object Velocity #1725

Closed
1 task done
LilBabines opened this issue Nov 4, 2024 · 9 comments
Labels
question Further information is requested

Comments

@LilBabines
Copy link

LilBabines commented Nov 4, 2024

Search before asking

  • I have searched the Yolo Tracking issues and found no similar bug report.

Question

Description

I'm working on tracking sea megafauna from aerial footage. The camera is fixed under a moving plane, which results in significant background motion while the objects (sea animals) have relatively low velocity, appearing almost stationary.

I've tested various parameter configurations and generated example tracks, available in this Google Drive link with this code :

import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
from boxmot import DeepOcSort
import os

def yolo_to_bbox(yolo_bbox, img_width, img_height):
    try :

        class_id, x_center, y_center, width, height , conf = map(float, yolo_bbox.split())
    except ValueError :
        class_id, x_center, y_center, width, height = map(float, yolo_bbox.split())
        conf = 1000 # :)
    x_center, y_center = x_center * img_width, y_center * img_height
    width, height = width * img_width, height * img_height
    x1, y1 = int(x_center - width / 2), int(y_center - height / 2)
    x2, y2 = int(x_center + width / 2), int(y_center + height / 2)
    return x1, y1, x2, y2 ,  conf ,int(class_id)


tracker = DeepOcSort(
    reid_weights=Path('osnet_x0_25_msmt17.pt'),
    device=0,
    half=False,
    asso_func='iou',
    embedding_off = True,
    delta_t = 10,
    inertia = 0.8,
    min_hits = 10,
    max_age= 7,
    det_thresh = 0.45,
)

output_dir = f'./test_print'#{tracker.asso_func}_d{tracker.delta_t}_i{tracker.inertia}_mh{tracker.min_hits}_ma{tracker.max_age}_dt{tracker.det_thresh}_emmOFF{tracker.embedding_off}_Qxy{tracker.Q_xy_scaling}_Qs{tracker.Q_s_scaling}'
os.makedirs(f'./{output_dir}', exist_ok=True)
os.makedirs(f'./{output_dir}/output', exist_ok=True)
os.makedirs(f'./{output_dir}/output_black', exist_ok=True)

images = Path('data/test_print')
labels = Path('runs/predict/video_207/labels')

for file in images.iterdir():

    name = file.stem
    img_path = f"{images}/{name}.jpeg"

    im = cv2.imread(img_path)
    img_height, img_width,_ = im.shape[:3]
    
    # Read detections from file
    if  (labels / f'{name}.txt').exists():
        with open(labels / f'{name}.txt', 'r') as f:
            dets = f.readlines()
        dets = np.array([yolo_to_bbox(det, img_width, img_height) for det in dets], dtype=np.float32)
    else :
        dets = np.empty((0, 6))

    tracker.update(dets, im) 
    tracker.plot_results(im, show_trajectories=False)    
    
    black_img = np.zeros((im.shape[0], im.shape[1], 3), dtype=np.uint8)
    tracker.plot_results(black_img, show_trajectories=False)

    if dets.shape[0] != 0:
        for det in dets:
            if det[4] >= tracker.det_thresh:
                cv2.rectangle(im, det[:2].astype(int), det[ 2:4].astype(int), (150,50,100), 4)
                cv2.rectangle(black_img, det[:2].astype(int), det[ 2:4].astype(int), (150,50,100), 4)

    im = cv2.resize(im, (int(im.shape[1]*(1/2)), int(im.shape[0]*(1/2))))
    black_img = cv2.resize(black_img, (int(black_img.shape[1]*(1/2)), int(black_img.shape[0]*(1/2))))
    
    cv2.imwrite(f'./{output_dir}/output/{name}.jpeg', im)
    cv2.imwrite(f'./{output_dir}/output_black/{name}.jpeg', black_img)
    # im = cv2.resize(im, (int(im.shape[1]*(4/5)), int(im.shape[0]*(4/5))))
    cv2.putText(im, name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv2.imshow('BoxMOT detection',im)
    
    key = cv2.waitKey(1) & 0xFF
    if key == ord(' ') or key == ord('q'):
        break

cv2.destroyAllWindows()

After some investigation, I found that apply_affine_correction(transform) from SOF - deepocsort (similar with ECC and strongsort) may be overestimating camera motion, which causes the Kalman Filter prediction to drift too far from detections.

Question :

  • Are there any existing trackers or parameter configurations that would be effective for this use case?
  • I haven't found many implementations specifically designed for aerial multi-object tracking. Would customizing a tracker be necessary to meet the constraints of high background motion with low object velocity?

Thank you for your work and any guidance you can provide!

@LilBabines LilBabines added the question Further information is requested label Nov 4, 2024
@amabilee
Copy link

amabilee commented Nov 4, 2024

Take a look at DeepSORT with Custom Deep Association Networks. This tracker extends the original SORT algorithm by integrating appearance information based on a deep appearance descriptor.

@mikel-brostrom
Copy link
Owner

mikel-brostrom commented Nov 5, 2024

Are there any existing trackers or parameter configurations that would be effective for this use case?

I would start by trying out ecc as the camera motion compensation algorithm. If this is not solving the issue fully, also modify the process noise covariance values in DeepOCSORT to increase the uncertainty/noise level:

Q_xy_scaling:
type: loguniform
default: 0.01 # from the default parameters
range: [0.01, 1]
Q_s_scaling:
type: loguniform
default: 0.0001 # from the default parameters
range: [0.0001, 1]

You can increase the default values by factors of ten. I usually see bad results for sof on different datasets as you have already noticed. I believe these two adjustments may solve your tracking issues. Let me know if this helps! 😄

@mikel-brostrom
Copy link
Owner

mikel-brostrom commented Nov 5, 2024

Take a look at DeepSORT with Custom Deep Association Networks. This tracker extends the original SORT algorithm by integrating appearance information based on a deep appearance descriptor.

I believe that the deep appearance descriptor is not the most relevant aspect for this issue and its linear motion KF is not the best configuration for the task at hand.

@mikel-brostrom
Copy link
Owner

Ohh. @LilBabines, I recommend you to start with OCSORT for the sake of simplicity. I don't se the ReID model being of any use here due to the small objects you are detecting.

@LilBabines
Copy link
Author

Thank you very much for your help !

Ohh. @LilBabines, I recommend you to start with OCSORT for the sake of simplicity. I don't se the ReID model being of any use here due to the small objects you are detecting.

I agree, I've already disabled appearance-based ReID.

I would start by trying out ecc as the camera motion compensation algorithm.

The ECC method produces excellent predictions with wrap_mode = cv2.MOTION_AFFINE. Calling both trk.predict() and trk.apply_affine_correction(transform) overestimate object velocity. So I'll try to use "CMC predictions" directly for the matching process instead of relying on KF predictions.

I’ll keep you updated on my progress, including which parameters work best for me and any adjustments I make to adapt OCSORT.

Thanks again for your guidance!

@mikel-brostrom
Copy link
Owner

mikel-brostrom commented Nov 5, 2024

If the KF is overshooting you can try to decrease the before-mentioned values by a factor of ten, instead of increasing them

@mikel-brostrom
Copy link
Owner

mikel-brostrom commented Nov 5, 2024

The KF is tuned for MOT16/17. So it may not serve your case out of the box, at all

@LilBabines
Copy link
Author

LilBabines commented Nov 12, 2024

AerialOcSort.zip

Here is a custom OC-SORT tracker tailored for aerial footage. I've completely removed the Kalman filter, relying solely on camera motion constraints for predictions. This setup works effectively for my specific use case ; where there is significant background motion and the objects (sea animals) have relatively low velocities, making them appear nearly stationary. However, it may not perform well in other scenarios.

Using Kalman prediction as a fallback when the ECC warp matrix cannot be computed would likely be a valuable improvement.

# Adapted from
# https://github.com/mikel-brostrom/boxmot/blob/master/boxmot/trackers/ocsort/ocsort.py
# Adaptor : [email protected]

from boxmot.motion.cmc import get_cmc_method
import cv2
import numpy as np
from collections import deque

from boxmot.utils.association import associate, linear_assignment
from boxmot.trackers.basetracker import BaseTracker

def k_previous_obs(observations, cur_age, k):
    if len(observations) == 0:
        return [-1, -1, -1, -1, -1]
    for i in range(k):
        dt = k - i
        if cur_age - dt in observations:
            return observations[cur_age - dt]
    max_age = max(observations.keys())
    return observations[max_age]


def convert_x_to_bbox(x, score=None):
    """
    Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
      [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
    """
    w = np.sqrt(x[2] * x[3])
    h = x[2] / w
    if score is None:
        return np.array(
            [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]
        ).reshape((1, 4))
    else:
        return np.array(
            [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]
        ).reshape((1, 5))


def speed_direction(bbox1, bbox2):
    cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
    cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
    speed = np.array([cy2 - cy1, cx2 - cx1])
    norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6
    return speed / norm


class CMCBoxTracker(object):
    """
    This class represents the internal state of individual tracked objects observed as bbox.
    """

    count = 0

    def __init__(self, bbox, cls, det_ind, delta_t=3, max_obs=50):
        """
        Initialises a tracker using initial bounding box.

        """
        self.det_ind = det_ind
        self.time_since_update = 0
        self.id = CMCBoxTracker.count
        CMCBoxTracker.count += 1
        self.max_obs = max_obs
        self.history = deque([bbox], maxlen=self.max_obs)
        self.hits = 0
        self.hit_streak = 0
        self.age = 0
        self.conf = bbox[-1]
        self.cls = cls
        self.last_observation = np.array([-1, -1, -1, -1, -1])  # placeholder
        self.observations = dict([(0, bbox)])
        self.velocity = None
        self.delta_t = delta_t

    def update(self, bbox, cls, det_ind):
        """
        Updates the state vector with observed bbox.
        """
        self.det_ind = det_ind
        if bbox is not None: # The tracker match with a detection
            self.time_since_update = 0
            self.hits += 1
            self.hit_streak += 1
            self.history[-1] = bbox  # The CMC prediction is replaced with the matched detection.
            self.conf = bbox[-1]  # must be 0 for cmc prediction
            self.cls = cls
        else :
            bbox = self.history[-1]  # cmc prediction
            self.conf=0

        # Find previous observation \Delta t steps away
        previous_box = None
        for i in range(self.delta_t):
            dt = self.delta_t - i
            if self.age - dt in self.observations:
                previous_box = self.observations[self.age - dt]
                break
        if previous_box is None:
            previous_box = self.last_observation
        # Estimate the track speed direction with observations \Delta t steps away
        self.velocity = speed_direction(previous_box, bbox)

        
        self.last_observation = bbox # can be cmc preds
        self.observations[self.age] = bbox


    def predict(self,affine):
        """
        Advances the state vector and returns the predicted bounding box estimate.
        """
        bbox = self.history[-1][:4] # get last observation

        # get affine matrix
        m = affine[:, :2]
        t = affine[:, 2].reshape(2, 1)

        # apply affine matrix to bbox
        ps = bbox.reshape(2, 2).T
        ps = m @ ps + t
        pred = ps.T.reshape(-1)

        # update state
        self.age += 1
        if self.time_since_update > 0:
            self.hit_streak = 0
        self.time_since_update += 1
        pred = np.append(pred,0) # conf zero for cmc prediction
        self.history.append(pred)
        return self.history[-1]

    def get_state(self):
        """
        Returns the current bounding box estimate.
        """
        return self.history[-1]
    

class AerialSort(BaseTracker):
    """
    AerialSort Tracker: An adaptation of OcSort for aerial tracking that relies solely on camera motion compensation (CMC).

    Args:
        per_class (bool, optional): Whether to perform per-class tracking. If True, tracks are maintained separately for each object class.
        wrap_mode (int, optional): The warp mode to use for motion compensation (computed with `cv2.findTransformECC()`). Options include cv2.MOTION_TRANSLATION, cv2.MOTION_EUCLIDEAN, cv2.MOTION_AFFINE and cv2.MOTION_HOMOGRAPHY.
        det_thresh (float, optional): Detection confidence threshold. Detections below this threshold are ignored in the first association step.
        max_age (int, optional): Maximum number of frames to keep a track alive without any detections.
        min_hits (int, optional): Minimum number of hits required to confirm a track.
        asso_threshold (float, optional): Threshold for the association step in data association. Controls the maximum distance allowed between tracklets and detections for a match.
        delta_t (int, optional): Time delta for velocity estimation.
        asso_func (str, optional): Association function to use for data association. Options include "iou" for IoU-based association.
        inertia (float, optional): Weight for inertia in motion modeling. Higher values make tracks less responsive to changes.
        use_byte (bool, optional): Whether to use BYTE association in the second association step.
    """
    def __init__(
        self,
        per_class: bool = False,
        wrap_mode: int = cv2.MOTION_AFFINE,
        det_thresh: float = 0.2,
        max_age: int = 10,
        min_hits: int = 3,
        asso_threshold: float = 0.2,
        delta_t: int = 3,
        asso_func: str = "iou",
        inertia: float = 0.2,
        use_byte: bool = False,
    ):
        super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func)
        """
        Sets key parameters for SORT
        """
        self.per_class = per_class
        self.max_age = max_age
        self.min_hits = min_hits
        self.asso_threshold = asso_threshold
        self.frame_count = 0
        self.det_thresh = det_thresh
        self.delta_t = delta_t
        self.inertia = inertia
        self.use_byte = use_byte
        CMCBoxTracker.count = 0
        self.cmc = get_cmc_method('ecc')(warp_mode=wrap_mode) 

    @BaseTracker.on_first_frame_setup
    @BaseTracker.per_class_decorator
    def update(self, dets: np.ndarray, img: np.ndarray) -> np.ndarray:
        """
        Params:
          dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
        Requires: this method must be called once for each frame even with empty detections
        (use np.empty((0, 5)) for frames without detections).
        Returns the a similar array, where the last column is the object ID.
        NOTE: The number of objects returned may differ from the number of detections provided.
        """

        self.check_inputs(dets, img)

        self.frame_count += 1
        h, w = img.shape[0:2]

        dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)])
        confs = dets[:, 4]

        inds_low = confs > 0.1
        inds_high = confs < self.det_thresh
        inds_second = np.logical_and(
            inds_low, inds_high
        )
        dets_second = dets[inds_second]  # detections confs > 0.1 for second matching when use_byte=True
        remain_inds = confs > self.det_thresh
        dets = dets[remain_inds]

        #compute cmc matrix 
        transform = self.cmc.apply(img)

        # get predicted locations from existing trackers.
        trks = np.zeros((len(self.active_tracks), 5))
        to_del = []
        ret = []
        for t, trk in enumerate(trks):
            pos = self.active_tracks[t].predict(transform) # return cmc prediction, "last_obs @ transform_matrix"
            trk[:] = pos
            if np.any(np.isnan(pos)):
                to_del.append(t)
        trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
        for t in reversed(to_del):
            self.active_tracks.pop(t)

        velocities = np.array(
            [
                trk.velocity if trk.velocity is not None else np.array((0, 0))
                for trk in self.active_tracks
            ]
        )
        last_boxes = np.array([trk.last_observation for trk in self.active_tracks])
        k_observations = np.array(
            [
                k_previous_obs(trk.observations, trk.age, self.delta_t)
                for trk in self.active_tracks
            ]
        )

        """
            First round of association
        """
        matched, unmatched_dets, unmatched_trks = associate(
            dets[:, 0:5], trks, self.asso_func, self.asso_threshold, velocities, k_observations, self.inertia, w, h
        )
        for m in matched:
            self.active_tracks[m[1]].update(dets[m[0], :5], dets[m[0], 5], dets[m[0], 6])

        """
            Second round of associaton by OCR
        """
        # BYTE association
        if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0:
            u_trks = trks[unmatched_trks]
            iou_left = self.asso_func(
                dets_second, u_trks
            )  # iou between low score detections and unmatched tracks
            iou_left = np.array(iou_left)
            if iou_left.max() > self.asso_threshold:
                """
                NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may
                get a higher performance especially on MOT17/MOT20 datasets. But we keep it
                uniform here for simplicity
                """
                matched_indices = linear_assignment(-iou_left)
                to_remove_trk_indices = []
                for m in matched_indices:
                    det_ind, trk_ind = m[0], unmatched_trks[m[1]]
                    if iou_left[m[0], m[1]] < self.asso_threshold:
                        continue
                    self.active_tracks[trk_ind].update(
                        dets_second[det_ind, :5], dets_second[det_ind, 5], dets_second[det_ind, 6]
                    )
                    to_remove_trk_indices.append(trk_ind)
                unmatched_trks = np.setdiff1d(
                    unmatched_trks, np.array(to_remove_trk_indices)
                )

        if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
            left_dets = dets[unmatched_dets]
            left_trks = last_boxes[unmatched_trks]
            iou_left = self.asso_func(left_dets, left_trks)
            iou_left = np.array(iou_left)
            if iou_left.max() > self.asso_threshold:
                """
                NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may
                get a higher performance especially on MOT17/MOT20 datasets. But we keep it
                uniform here for simplicity
                """
                rematched_indices = linear_assignment(-iou_left)
                to_remove_det_indices = []
                to_remove_trk_indices = []
                for m in rematched_indices:
                    det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
                    if iou_left[m[0], m[1]] < self.asso_threshold:
                        continue
                    self.active_tracks[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5], dets[det_ind, 6])
                    to_remove_det_indices.append(det_ind)
                    to_remove_trk_indices.append(trk_ind)
                unmatched_dets = np.setdiff1d(
                    unmatched_dets, np.array(to_remove_det_indices)
                )
                unmatched_trks = np.setdiff1d(
                    unmatched_trks, np.array(to_remove_trk_indices)
                )

        for m in unmatched_trks:
            self.active_tracks[m].update(None, None, None)

        # create and initialise new trackers for unmatched detections
        for i in unmatched_dets:
            trk = CMCBoxTracker(dets[i, :5], dets[i, 5], dets[i, 6], delta_t=self.delta_t, max_obs=self.max_obs)
            self.active_tracks.append(trk)
        i = len(self.active_tracks)


        for trk in reversed(self.active_tracks):
            if trk.last_observation.sum() < 0:
                d = trk.get_state()#[0]
                
            else:
                """
                this is optional to use the recent observation or the kalman filter prediction,
                we didn't notice significant difference here
                """
                d = trk.last_observation[:4]
                
            if (trk.time_since_update < 1) and (
                trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits
            ):
                # +1 as MOT benchmark requires positive
                ret.append(
                    np.concatenate((d, [trk.id + 1], [trk.conf], [trk.cls], [trk.det_ind])).reshape(
                        1, -1
                    )
                )
            i -= 1
            # remove dead tracklet
            if trk.time_since_update > self.max_age:
                self.active_tracks.pop(i)
        if len(ret) > 0:
            return np.concatenate(ret)
        return np.array([])

    def plot_box_on_img(self, img: np.ndarray, box: tuple, conf: float, cls: int, id: int, thickness: int = 2, fontscale: float = 0.5) -> np.ndarray:
        """
        Draws a bounding box with ID, confidence, and class information on an image.

        Parameters:
        - img (np.ndarray): The image array to draw on.
        - box (tuple): The bounding box coordinates as (x1, y1, x2, y2).
        - conf (float): Confidence score of the detection.
        - cls (int): Class ID of the detection.
        - id (int): Unique identifier for the detection.
        - thickness (int): The thickness of the bounding box.
        - fontscale (float): The font scale for the text.

        Returns:
        - np.ndarray: The image array with the bounding box drawn on it.
        """

        img = cv2.rectangle(
            img,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            self.id_to_color(id),
            thickness
        )
        img = cv2.putText(
            img,
            f'id: {int(id)}, conf: {conf:.2f}, c: {int(cls)}',
            (int(box[0]), int(box[1]) - 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            fontscale,
            self.id_to_color(id),
            thickness
        )
        return img


    def plot_trackers_trajectories(self, img: np.ndarray, observations: list, id: int) -> np.ndarray:
        """
        Draws the trajectories of tracked objects based on historical observations. Each point
        in the trajectory is represented by a circle, with the thickness increasing for more
        recent observations to visualize the path of movement.

        Parameters:
        - img (np.ndarray): The image array on which to draw the trajectories.
        - observations (list): A list of bounding box coordinates representing the historical
        observations of a tracked object. Each observation is in the format (x1, y1, x2, y2).
        - id (int): The unique identifier of the tracked object for color consistency in visualization.

        Returns:
        - np.ndarray: The image array with the trajectories drawn on it.
        """
        for i, box in enumerate(observations):
            trajectory_thickness = int(np.sqrt(float (i + 1)) * 1.2)
            if box[4]>0 :  # the observation is a detection

                img = cv2.circle(
                    img,
                    (int((box[0] + box[2]) / 2),
                    int((box[1] + box[3]) / 2)), 
                    2,
                    color=self.id_to_color(int(id)),
                    thickness=trajectory_thickness
                )
            else : # the observation is a cmc prediction
                img = cv2.drawMarker(
                    img,
                    (int((box[0] + box[2]) / 2),
                    int((box[1] + box[3]) / 2)),
                    color=self.id_to_color(int(id)),
                    markerType=cv2.MARKER_CROSS,
                    markerSize=15,
                    thickness=trajectory_thickness
                )
        return img


    def plot_results(self, img: np.ndarray, show_trajectories: bool, thickness: int = 2, fontscale: float = 0.5) -> np.ndarray:
        """
        Visualizes the trajectories of all active tracks on the image. For each track,
        it draws the latest bounding box and the path of movement if the history of
        observations is longer than two. This helps in understanding the movement patterns
        of each tracked object.

        Parameters:
        - img (np.ndarray): The image array on which to draw the trajectories and bounding boxes.
        - show_trajectories (bool): Whether to show the trajectories.
        - thickness (int): The thickness of the bounding box.
        - fontscale (float): The font scale for the text.

        Returns:
        - np.ndarray: The image array with trajectories and bounding boxes of all active tracks.
        """

        # if values in dict
        if self.per_class_active_tracks is not None:
            for k in self.per_class_active_tracks.keys():
                active_tracks = self.per_class_active_tracks[k]
                for a in active_tracks:
                    if a.history:
                        if a.hits >= self.min_hits : 
                            box = a.history[-1]
                            img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id, thickness, fontscale)
                            if show_trajectories:
                                img = self.plot_trackers_trajectories(img, a.history, a.id)
        else:
            for a in self.active_tracks:
                if a.history:
                    if  a.hits >= self.min_hits : #len(a.history_observations) > 2 
                        box = a.history[-1]
                        img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id, thickness, fontscale)
                        if show_trajectories:
                            img = self.plot_trackers_trajectories(img, a.history, a.id)
                
        return img

@mikel-brostrom
Copy link
Owner

mikel-brostrom commented Nov 12, 2024

Glad you found a solution 🚀 . And thanks for posting it here 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants