-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tracking Sea Megafauna from Aerial Footage – Handling High Camera Motion and Low Object Velocity #1725
Comments
Take a look at DeepSORT with Custom Deep Association Networks. This tracker extends the original |
I would start by trying out boxmot/boxmot/configs/ocsort.yaml Lines 36 to 44 in a28ccef
You can increase the default values by factors of ten. I usually see bad results for |
I believe that the deep appearance descriptor is not the most relevant aspect for this issue and its linear motion KF is not the best configuration for the task at hand. |
Ohh. @LilBabines, I recommend you to start with OCSORT for the sake of simplicity. I don't se the ReID model being of any use here due to the small objects you are detecting. |
Thank you very much for your help !
I agree, I've already disabled appearance-based ReID.
The I’ll keep you updated on my progress, including which parameters work best for me and any adjustments I make to adapt OCSORT. Thanks again for your guidance! |
If the KF is overshooting you can try to decrease the before-mentioned values by a factor of ten, instead of increasing them |
The KF is tuned for MOT16/17. So it may not serve your case out of the box, at all |
Here is a custom OC-SORT tracker tailored for aerial footage. I've completely removed the Kalman filter, relying solely on camera motion constraints for predictions. This setup works effectively for my specific use case ; where there is significant background motion and the objects (sea animals) have relatively low velocities, making them appear nearly stationary. However, it may not perform well in other scenarios. Using Kalman prediction as a fallback when the ECC warp matrix cannot be computed would likely be a valuable improvement. # Adapted from
# https://github.com/mikel-brostrom/boxmot/blob/master/boxmot/trackers/ocsort/ocsort.py
# Adaptor : [email protected]
from boxmot.motion.cmc import get_cmc_method
import cv2
import numpy as np
from collections import deque
from boxmot.utils.association import associate, linear_assignment
from boxmot.trackers.basetracker import BaseTracker
def k_previous_obs(observations, cur_age, k):
if len(observations) == 0:
return [-1, -1, -1, -1, -1]
for i in range(k):
dt = k - i
if cur_age - dt in observations:
return observations[cur_age - dt]
max_age = max(observations.keys())
return observations[max_age]
def convert_x_to_bbox(x, score=None):
"""
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if score is None:
return np.array(
[x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]
).reshape((1, 4))
else:
return np.array(
[x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]
).reshape((1, 5))
def speed_direction(bbox1, bbox2):
cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
speed = np.array([cy2 - cy1, cx2 - cx1])
norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6
return speed / norm
class CMCBoxTracker(object):
"""
This class represents the internal state of individual tracked objects observed as bbox.
"""
count = 0
def __init__(self, bbox, cls, det_ind, delta_t=3, max_obs=50):
"""
Initialises a tracker using initial bounding box.
"""
self.det_ind = det_ind
self.time_since_update = 0
self.id = CMCBoxTracker.count
CMCBoxTracker.count += 1
self.max_obs = max_obs
self.history = deque([bbox], maxlen=self.max_obs)
self.hits = 0
self.hit_streak = 0
self.age = 0
self.conf = bbox[-1]
self.cls = cls
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
self.observations = dict([(0, bbox)])
self.velocity = None
self.delta_t = delta_t
def update(self, bbox, cls, det_ind):
"""
Updates the state vector with observed bbox.
"""
self.det_ind = det_ind
if bbox is not None: # The tracker match with a detection
self.time_since_update = 0
self.hits += 1
self.hit_streak += 1
self.history[-1] = bbox # The CMC prediction is replaced with the matched detection.
self.conf = bbox[-1] # must be 0 for cmc prediction
self.cls = cls
else :
bbox = self.history[-1] # cmc prediction
self.conf=0
# Find previous observation \Delta t steps away
previous_box = None
for i in range(self.delta_t):
dt = self.delta_t - i
if self.age - dt in self.observations:
previous_box = self.observations[self.age - dt]
break
if previous_box is None:
previous_box = self.last_observation
# Estimate the track speed direction with observations \Delta t steps away
self.velocity = speed_direction(previous_box, bbox)
self.last_observation = bbox # can be cmc preds
self.observations[self.age] = bbox
def predict(self,affine):
"""
Advances the state vector and returns the predicted bounding box estimate.
"""
bbox = self.history[-1][:4] # get last observation
# get affine matrix
m = affine[:, :2]
t = affine[:, 2].reshape(2, 1)
# apply affine matrix to bbox
ps = bbox.reshape(2, 2).T
ps = m @ ps + t
pred = ps.T.reshape(-1)
# update state
self.age += 1
if self.time_since_update > 0:
self.hit_streak = 0
self.time_since_update += 1
pred = np.append(pred,0) # conf zero for cmc prediction
self.history.append(pred)
return self.history[-1]
def get_state(self):
"""
Returns the current bounding box estimate.
"""
return self.history[-1]
class AerialSort(BaseTracker):
"""
AerialSort Tracker: An adaptation of OcSort for aerial tracking that relies solely on camera motion compensation (CMC).
Args:
per_class (bool, optional): Whether to perform per-class tracking. If True, tracks are maintained separately for each object class.
wrap_mode (int, optional): The warp mode to use for motion compensation (computed with `cv2.findTransformECC()`). Options include cv2.MOTION_TRANSLATION, cv2.MOTION_EUCLIDEAN, cv2.MOTION_AFFINE and cv2.MOTION_HOMOGRAPHY.
det_thresh (float, optional): Detection confidence threshold. Detections below this threshold are ignored in the first association step.
max_age (int, optional): Maximum number of frames to keep a track alive without any detections.
min_hits (int, optional): Minimum number of hits required to confirm a track.
asso_threshold (float, optional): Threshold for the association step in data association. Controls the maximum distance allowed between tracklets and detections for a match.
delta_t (int, optional): Time delta for velocity estimation.
asso_func (str, optional): Association function to use for data association. Options include "iou" for IoU-based association.
inertia (float, optional): Weight for inertia in motion modeling. Higher values make tracks less responsive to changes.
use_byte (bool, optional): Whether to use BYTE association in the second association step.
"""
def __init__(
self,
per_class: bool = False,
wrap_mode: int = cv2.MOTION_AFFINE,
det_thresh: float = 0.2,
max_age: int = 10,
min_hits: int = 3,
asso_threshold: float = 0.2,
delta_t: int = 3,
asso_func: str = "iou",
inertia: float = 0.2,
use_byte: bool = False,
):
super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func)
"""
Sets key parameters for SORT
"""
self.per_class = per_class
self.max_age = max_age
self.min_hits = min_hits
self.asso_threshold = asso_threshold
self.frame_count = 0
self.det_thresh = det_thresh
self.delta_t = delta_t
self.inertia = inertia
self.use_byte = use_byte
CMCBoxTracker.count = 0
self.cmc = get_cmc_method('ecc')(warp_mode=wrap_mode)
@BaseTracker.on_first_frame_setup
@BaseTracker.per_class_decorator
def update(self, dets: np.ndarray, img: np.ndarray) -> np.ndarray:
"""
Params:
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
Requires: this method must be called once for each frame even with empty detections
(use np.empty((0, 5)) for frames without detections).
Returns the a similar array, where the last column is the object ID.
NOTE: The number of objects returned may differ from the number of detections provided.
"""
self.check_inputs(dets, img)
self.frame_count += 1
h, w = img.shape[0:2]
dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)])
confs = dets[:, 4]
inds_low = confs > 0.1
inds_high = confs < self.det_thresh
inds_second = np.logical_and(
inds_low, inds_high
)
dets_second = dets[inds_second] # detections confs > 0.1 for second matching when use_byte=True
remain_inds = confs > self.det_thresh
dets = dets[remain_inds]
#compute cmc matrix
transform = self.cmc.apply(img)
# get predicted locations from existing trackers.
trks = np.zeros((len(self.active_tracks), 5))
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.active_tracks[t].predict(transform) # return cmc prediction, "last_obs @ transform_matrix"
trk[:] = pos
if np.any(np.isnan(pos)):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.active_tracks.pop(t)
velocities = np.array(
[
trk.velocity if trk.velocity is not None else np.array((0, 0))
for trk in self.active_tracks
]
)
last_boxes = np.array([trk.last_observation for trk in self.active_tracks])
k_observations = np.array(
[
k_previous_obs(trk.observations, trk.age, self.delta_t)
for trk in self.active_tracks
]
)
"""
First round of association
"""
matched, unmatched_dets, unmatched_trks = associate(
dets[:, 0:5], trks, self.asso_func, self.asso_threshold, velocities, k_observations, self.inertia, w, h
)
for m in matched:
self.active_tracks[m[1]].update(dets[m[0], :5], dets[m[0], 5], dets[m[0], 6])
"""
Second round of associaton by OCR
"""
# BYTE association
if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0:
u_trks = trks[unmatched_trks]
iou_left = self.asso_func(
dets_second, u_trks
) # iou between low score detections and unmatched tracks
iou_left = np.array(iou_left)
if iou_left.max() > self.asso_threshold:
"""
NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
matched_indices = linear_assignment(-iou_left)
to_remove_trk_indices = []
for m in matched_indices:
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.asso_threshold:
continue
self.active_tracks[trk_ind].update(
dets_second[det_ind, :5], dets_second[det_ind, 5], dets_second[det_ind, 6]
)
to_remove_trk_indices.append(trk_ind)
unmatched_trks = np.setdiff1d(
unmatched_trks, np.array(to_remove_trk_indices)
)
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
iou_left = self.asso_func(left_dets, left_trks)
iou_left = np.array(iou_left)
if iou_left.max() > self.asso_threshold:
"""
NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
rematched_indices = linear_assignment(-iou_left)
to_remove_det_indices = []
to_remove_trk_indices = []
for m in rematched_indices:
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.asso_threshold:
continue
self.active_tracks[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5], dets[det_ind, 6])
to_remove_det_indices.append(det_ind)
to_remove_trk_indices.append(trk_ind)
unmatched_dets = np.setdiff1d(
unmatched_dets, np.array(to_remove_det_indices)
)
unmatched_trks = np.setdiff1d(
unmatched_trks, np.array(to_remove_trk_indices)
)
for m in unmatched_trks:
self.active_tracks[m].update(None, None, None)
# create and initialise new trackers for unmatched detections
for i in unmatched_dets:
trk = CMCBoxTracker(dets[i, :5], dets[i, 5], dets[i, 6], delta_t=self.delta_t, max_obs=self.max_obs)
self.active_tracks.append(trk)
i = len(self.active_tracks)
for trk in reversed(self.active_tracks):
if trk.last_observation.sum() < 0:
d = trk.get_state()#[0]
else:
"""
this is optional to use the recent observation or the kalman filter prediction,
we didn't notice significant difference here
"""
d = trk.last_observation[:4]
if (trk.time_since_update < 1) and (
trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits
):
# +1 as MOT benchmark requires positive
ret.append(
np.concatenate((d, [trk.id + 1], [trk.conf], [trk.cls], [trk.det_ind])).reshape(
1, -1
)
)
i -= 1
# remove dead tracklet
if trk.time_since_update > self.max_age:
self.active_tracks.pop(i)
if len(ret) > 0:
return np.concatenate(ret)
return np.array([])
def plot_box_on_img(self, img: np.ndarray, box: tuple, conf: float, cls: int, id: int, thickness: int = 2, fontscale: float = 0.5) -> np.ndarray:
"""
Draws a bounding box with ID, confidence, and class information on an image.
Parameters:
- img (np.ndarray): The image array to draw on.
- box (tuple): The bounding box coordinates as (x1, y1, x2, y2).
- conf (float): Confidence score of the detection.
- cls (int): Class ID of the detection.
- id (int): Unique identifier for the detection.
- thickness (int): The thickness of the bounding box.
- fontscale (float): The font scale for the text.
Returns:
- np.ndarray: The image array with the bounding box drawn on it.
"""
img = cv2.rectangle(
img,
(int(box[0]), int(box[1])),
(int(box[2]), int(box[3])),
self.id_to_color(id),
thickness
)
img = cv2.putText(
img,
f'id: {int(id)}, conf: {conf:.2f}, c: {int(cls)}',
(int(box[0]), int(box[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
fontscale,
self.id_to_color(id),
thickness
)
return img
def plot_trackers_trajectories(self, img: np.ndarray, observations: list, id: int) -> np.ndarray:
"""
Draws the trajectories of tracked objects based on historical observations. Each point
in the trajectory is represented by a circle, with the thickness increasing for more
recent observations to visualize the path of movement.
Parameters:
- img (np.ndarray): The image array on which to draw the trajectories.
- observations (list): A list of bounding box coordinates representing the historical
observations of a tracked object. Each observation is in the format (x1, y1, x2, y2).
- id (int): The unique identifier of the tracked object for color consistency in visualization.
Returns:
- np.ndarray: The image array with the trajectories drawn on it.
"""
for i, box in enumerate(observations):
trajectory_thickness = int(np.sqrt(float (i + 1)) * 1.2)
if box[4]>0 : # the observation is a detection
img = cv2.circle(
img,
(int((box[0] + box[2]) / 2),
int((box[1] + box[3]) / 2)),
2,
color=self.id_to_color(int(id)),
thickness=trajectory_thickness
)
else : # the observation is a cmc prediction
img = cv2.drawMarker(
img,
(int((box[0] + box[2]) / 2),
int((box[1] + box[3]) / 2)),
color=self.id_to_color(int(id)),
markerType=cv2.MARKER_CROSS,
markerSize=15,
thickness=trajectory_thickness
)
return img
def plot_results(self, img: np.ndarray, show_trajectories: bool, thickness: int = 2, fontscale: float = 0.5) -> np.ndarray:
"""
Visualizes the trajectories of all active tracks on the image. For each track,
it draws the latest bounding box and the path of movement if the history of
observations is longer than two. This helps in understanding the movement patterns
of each tracked object.
Parameters:
- img (np.ndarray): The image array on which to draw the trajectories and bounding boxes.
- show_trajectories (bool): Whether to show the trajectories.
- thickness (int): The thickness of the bounding box.
- fontscale (float): The font scale for the text.
Returns:
- np.ndarray: The image array with trajectories and bounding boxes of all active tracks.
"""
# if values in dict
if self.per_class_active_tracks is not None:
for k in self.per_class_active_tracks.keys():
active_tracks = self.per_class_active_tracks[k]
for a in active_tracks:
if a.history:
if a.hits >= self.min_hits :
box = a.history[-1]
img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id, thickness, fontscale)
if show_trajectories:
img = self.plot_trackers_trajectories(img, a.history, a.id)
else:
for a in self.active_tracks:
if a.history:
if a.hits >= self.min_hits : #len(a.history_observations) > 2
box = a.history[-1]
img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id, thickness, fontscale)
if show_trajectories:
img = self.plot_trackers_trajectories(img, a.history, a.id)
return img |
Glad you found a solution 🚀 . And thanks for posting it here 😄 |
Search before asking
Question
Description
I'm working on tracking sea megafauna from aerial footage. The camera is fixed under a moving plane, which results in significant background motion while the objects (sea animals) have relatively low velocity, appearing almost stationary.
I've tested various parameter configurations and generated example tracks, available in this Google Drive link with this code :
After some investigation, I found that
apply_affine_correction(transform)
from SOF - deepocsort (similar with ECC and strongsort) may be overestimating camera motion, which causes the Kalman Filter prediction to drift too far from detections.Question :
Thank you for your work and any guidance you can provide!
The text was updated successfully, but these errors were encountered: