from datetime import datetime
from enum import Enum
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
# Defines the set of assignment/distance functions that can be used
# to associate tracks with measurements
[docs]
class TrackerAssignmentFunction(str, Enum):
EUCLIDEAN = "euclidean"
HAVERSINE = "haversine"
INTERSECTION_OVER_UNION = "intersection_over_union"
MAHALANOBIS = "mahalanobis"
# Defines the set of label behavior options that determine
# how labels are assigned to tracks
[docs]
class TrackerLabelBehavior(str, Enum):
OMIT = "omit"
LAST = "last"
MAJORITY = "majority"
# Defines the available tracker types/kinds
[docs]
class TrackerKind(str, Enum):
POINT_UNITLESS = "point_unitless"
POINT_LAT_LONG = "point_latitude_longitude"
POINT_UTM = "point_utm"
BOX_UNITLESS = "box_unitless"
BOX_LAT_LONG = "box_latitude_longitude"
[docs]
class NewCreateTrackerRequest(BaseModel):
project_id: str
name: str
kind: TrackerKind
max_missing_updates: int
min_lifetime_before_active: int
assignment_function: TrackerAssignmentFunction
assignment_threshold: float
label_behavior: TrackerLabelBehavior
state_transition: list[list[float]] | None = None
process_noise_covariance: list[list[float]] | None = None
[docs]
@model_validator(mode="after")
def check_state_transition(self):
if self.state_transition is None:
return self
match self.kind:
case TrackerKind.BOX_UNITLESS:
if len(self.state_transition) != 7 or any(
len(row) != 7 for row in self.state_transition
):
raise ValueError(
"state_transition must be a 7x7 matrix when kind is 'box_unitless'"
)
case _:
if len(self.state_transition) != 4 or any(
len(row) != 4 for row in self.state_transition
):
raise ValueError(
"state_transition must be a 4x4 matrix when kind is 'point_unitless', 'point_latitude_longitude', 'point_utm', or 'box_latitude_longitude'"
)
return self
[docs]
@model_validator(mode="after")
def check_process_noise_covariance(self):
if self.process_noise_covariance is None:
return self
match self.kind:
case TrackerKind.BOX_UNITLESS:
if len(self.process_noise_covariance) != 7 or any(
len(row) != 7 for row in self.process_noise_covariance
):
raise ValueError(
"process_noise_covariance must be a 7x7 matrix when kind is 'box_unitless'"
)
case _:
if len(self.process_noise_covariance) != 4 or any(
len(row) != 4 for row in self.process_noise_covariance
):
raise ValueError(
"process_noise_covariance must be a 4x4 matrix when kind is 'point_unitless', 'point_latitude_longitude', 'point_utm', or 'box_latitude_longitude'"
)
return self
model_config = ConfigDict(use_enum_values=True)
[docs]
class Measurement(BaseModel):
label: str
score: float
covariance: list[list[float]] = Field(default_factory=list)
id: str
[docs]
class UnitlessPoint(Measurement):
x: float
y: float
[docs]
@field_validator("covariance", check_fields=False)
@classmethod
def validate_covariance(cls, v):
if len(v) == 0:
return
if len(v) != 2 or any(len(row) != 2 for row in v):
raise ValueError("covariance must be a 2x2 matrix")
return v
[docs]
class LatLongPoint(Measurement):
latitude: float
longitude: float
[docs]
@field_validator("covariance", check_fields=False)
@classmethod
def validate_covariance(cls, v):
if len(v) == 0:
return
if len(v) != 2 or any(len(row) != 2 for row in v):
raise ValueError("covariance must be a 2x2 matrix")
return v
[docs]
class UTMPoint(Measurement):
northing: float
easting: float
zone: str
[docs]
@field_validator("covariance", check_fields=False)
@classmethod
def validate_covariance(cls, v):
if len(v) == 0:
return
if len(v) != 2 or any(len(row) != 2 for row in v):
raise ValueError("covariance must be a 2x2 matrix")
return v
[docs]
class UnitlessBox(Measurement):
x: float
y: float
w: float
h: float
[docs]
@field_validator("covariance", check_fields=False)
@classmethod
def validate_covariance(cls, v):
if len(v) == 0:
return
if len(v) != 4 or any(len(row) != 4 for row in v):
raise ValueError("covariance must be a 4x4 matrix")
return v
[docs]
class NewUpdateTrackerRequest(BaseModel):
tracker_id: str
external_input: ExternalTrackerInput
points_unitless: list[UnitlessPoint] = Field(default_factory=list)
points_lat_long: list[LatLongPoint] = Field(default_factory=list)
points_utm: list[UTMPoint] = Field(default_factory=list)
boxes_unitless: list[UnitlessBox] = Field(default_factory=list)
wait_for_output: bool = False
wait_time: float = 100.0
[docs]
class Tracker(BaseModel):
tracker_id: str
name: str
kind: TrackerKind
max_missing_updates: int
min_lifetime_before_active: int
assignment_function: TrackerAssignmentFunction
assignment_threshold: float
label_behavior: TrackerLabelBehavior
state_transition: list[list[float]] | None = None
process_noise_covariance: list[list[float]] | None = None
[docs]
class GeoPoint(BaseModel):
"""Defines a point on the globe.
:param latitude: A latitude value in decimal format between -90 and 90.
:type latitude: float
:param longitude: A longitude value in decimal format between -180 and 180.
:type longitude: float
"""
latitude: float
longitude: float
[docs]
class Track(BaseModel):
"""Defines an object track.
:param track_id: The reference ID for the track.
:type track_id: Optional[str]
:param created_at: A timestamp of when the inference was created.
:type created_at: Optional[str]
:param updated_at: A timestamp of when the inference was last updated.
:type updated_at: Optional[str]
:param measurement_id: The measurement/detection level ID within the inference response.
:type measurement_id: Optional[str]
:param lifetime: The count of frames the track has been alive for.
:type lifetime: int | None
:param total_updates: The number of total updates in which a measurement has been associated to the track.
:type total_updates: int | None
:param total_missed_updates: The number of total missed updates in which a measurement has not been associated to the track.
:type total_missed_updates: int | None
:param consecutive_updates: The total count of consecutive updates via measurement.
:type consecutive_updates: int | None
:param consecutive_missed_updates: The count of consecutive missed updates. If this exceeds the tracker's max_missed_updates, the track status should be "lost".
:type consecutive_missed_updates: int | None
:param status: The track status: [new, active, lost].
:type status: Optional[str]
:param label: The track class label.
:type label: Optional[str]
:param measurement: The measurement used to update the tracker.
:type measurement: Optional[dict]
:param measurement_uncertainty: The measurement uncertainty used to update the tracker.
:type measurement_uncertainty: Optional[list[list[float]]]
:param predicted_estimated_state: The track's predicted estimate position/state before a potential measurement assignment
:type predicted_estimated_state: Optional[dict]
:param estimated_state: The track's updated estimate position/state after a potential measurement assignment
:type estimated_state: Optional[dict]
:param estimated_state_uncertainty: The track's state estimate uncertainty
:type estimated_state_uncertainty: Optional[list[list[float]]]
:param cost: The cost value from the assignment function.
:type cost: Optional[float]
:param coordinate: The coordinate associated with the track.
:type coordinate: GeoPoint | None
"""
track_id: str | None
created_at: datetime | None
updated_at: datetime | None
measurement_id: str = None
lifetime: int | None
total_updates: int | None
total_missed_updates: int | None
consecutive_updates: int | None
consecutive_missed_updates: int | None
status: str | None
label: str | None
measurement: dict | None
measurement_uncertainty: list[list[float]] | None
predicted_estimated_state: dict | None
estimated_state: dict | None
estimated_state_uncertainty: list[list[float]] | None
cost: float | None
coordinate: GeoPoint | None = None
[docs]
class NewStoreTracksRequest(BaseModel):
"""Defines a store tracks request.
:param tracker_id: The tracker providing track outputs.
:type tracker_id: str
:param tracker_kind: The kind of tracker.
:type tracker_kind: str
:param project_id: The project containing the tracker.
:type project_id: str
:param model_id: The model providing the inference.
:type model_id: str
:param inference_id: The inference being used to update the tracker.
:type inference_id: str
:param ts: The timestamp indicating when the frame/packet/entity arrived at the sensor.
:type ts: str
:param sequence_number: An optional monotonically increasing integer such as a frame number.
:type sequence_number: int
:param tracks: The collection of tracks.
:type tracks: list[Track]
"""
tracker_id: str
tracker_kind: str
project_id: str
model_id: str
inference_id: str
ts: str
sequence_number: int | None = None
tracks: list[Track] = []