Source code for chariot.tracker.models

from datetime import datetime
from enum import Enum

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator


# Defines the set of assignment/distance functions that can be used
# to associate tracks with measurements
[docs] class TrackerAssignmentFunction(str, Enum): EUCLIDEAN = "euclidean" HAVERSINE = "haversine" INTERSECTION_OVER_UNION = "intersection_over_union" MAHALANOBIS = "mahalanobis"
# Defines the set of label behavior options that determine # how labels are assigned to tracks
[docs] class TrackerLabelBehavior(str, Enum): OMIT = "omit" LAST = "last" MAJORITY = "majority"
# Defines the available tracker types/kinds
[docs] class TrackerKind(str, Enum): POINT_UNITLESS = "point_unitless" POINT_LAT_LONG = "point_latitude_longitude" POINT_UTM = "point_utm" BOX_UNITLESS = "box_unitless" BOX_LAT_LONG = "box_latitude_longitude"
[docs] class NewCreateTrackerRequest(BaseModel): project_id: str name: str kind: TrackerKind max_missing_updates: int min_lifetime_before_active: int assignment_function: TrackerAssignmentFunction assignment_threshold: float label_behavior: TrackerLabelBehavior state_transition: list[list[float]] | None = None process_noise_covariance: list[list[float]] | None = None
[docs] @model_validator(mode="after") def check_state_transition(self): if self.state_transition is None: return self match self.kind: case TrackerKind.BOX_UNITLESS: if len(self.state_transition) != 7 or any( len(row) != 7 for row in self.state_transition ): raise ValueError( "state_transition must be a 7x7 matrix when kind is 'box_unitless'" ) case _: if len(self.state_transition) != 4 or any( len(row) != 4 for row in self.state_transition ): raise ValueError( "state_transition must be a 4x4 matrix when kind is 'point_unitless', 'point_latitude_longitude', 'point_utm', or 'box_latitude_longitude'" ) return self
[docs] @model_validator(mode="after") def check_process_noise_covariance(self): if self.process_noise_covariance is None: return self match self.kind: case TrackerKind.BOX_UNITLESS: if len(self.process_noise_covariance) != 7 or any( len(row) != 7 for row in self.process_noise_covariance ): raise ValueError( "process_noise_covariance must be a 7x7 matrix when kind is 'box_unitless'" ) case _: if len(self.process_noise_covariance) != 4 or any( len(row) != 4 for row in self.process_noise_covariance ): raise ValueError( "process_noise_covariance must be a 4x4 matrix when kind is 'point_unitless', 'point_latitude_longitude', 'point_utm', or 'box_latitude_longitude'" ) return self
model_config = ConfigDict(use_enum_values=True)
[docs] class UnitlessBoxLocalizationMetadata(BaseModel): image_width: int image_height: int platform_heading_angle: float platform_pitch_angle: float platform_roll_angle: float sensor_relative_azimuth_angle: float sensor_relative_elevation_angle: float sensor_relative_roll_angle: float sensor_horizontal_field_of_view: float sensor_vertical_field_of_view: float sensor_latitude: float sensor_longitude: float sensor_true_altitude: float
[docs] class ExternalTrackerInput(BaseModel): project_id: str model_id: str inference_id: str ts: str sequence_number: int = 0 modified_state_transition_matrix: list[list[float]] | None = None modified_process_noise_covariance_matrix: list[list[float]] | None = None unitless_box_localization_metadata: UnitlessBoxLocalizationMetadata | None = None
[docs] class Measurement(BaseModel): label: str score: float covariance: list[list[float]] = Field(default_factory=list) id: str
[docs] class UnitlessPoint(Measurement): x: float y: float
[docs] @field_validator("covariance", check_fields=False) @classmethod def validate_covariance(cls, v): if len(v) == 0: return if len(v) != 2 or any(len(row) != 2 for row in v): raise ValueError("covariance must be a 2x2 matrix") return v
[docs] class LatLongPoint(Measurement): latitude: float longitude: float
[docs] @field_validator("covariance", check_fields=False) @classmethod def validate_covariance(cls, v): if len(v) == 0: return if len(v) != 2 or any(len(row) != 2 for row in v): raise ValueError("covariance must be a 2x2 matrix") return v
[docs] class UTMPoint(Measurement): northing: float easting: float zone: str
[docs] @field_validator("covariance", check_fields=False) @classmethod def validate_covariance(cls, v): if len(v) == 0: return if len(v) != 2 or any(len(row) != 2 for row in v): raise ValueError("covariance must be a 2x2 matrix") return v
[docs] class UnitlessBox(Measurement): x: float y: float w: float h: float
[docs] @field_validator("covariance", check_fields=False) @classmethod def validate_covariance(cls, v): if len(v) == 0: return if len(v) != 4 or any(len(row) != 4 for row in v): raise ValueError("covariance must be a 4x4 matrix") return v
[docs] class NewUpdateTrackerRequest(BaseModel): tracker_id: str external_input: ExternalTrackerInput points_unitless: list[UnitlessPoint] = Field(default_factory=list) points_lat_long: list[LatLongPoint] = Field(default_factory=list) points_utm: list[UTMPoint] = Field(default_factory=list) boxes_unitless: list[UnitlessBox] = Field(default_factory=list) wait_for_output: bool = False wait_time: float = 100.0
[docs] class Tracker(BaseModel): tracker_id: str name: str kind: TrackerKind max_missing_updates: int min_lifetime_before_active: int assignment_function: TrackerAssignmentFunction assignment_threshold: float label_behavior: TrackerLabelBehavior state_transition: list[list[float]] | None = None process_noise_covariance: list[list[float]] | None = None
[docs] class GeoPoint(BaseModel): """Defines a point on the globe. :param latitude: A latitude value in decimal format between -90 and 90. :type latitude: float :param longitude: A longitude value in decimal format between -180 and 180. :type longitude: float """ latitude: float longitude: float
[docs] class Track(BaseModel): """Defines an object track. :param track_id: The reference ID for the track. :type track_id: Optional[str] :param created_at: A timestamp of when the inference was created. :type created_at: Optional[str] :param updated_at: A timestamp of when the inference was last updated. :type updated_at: Optional[str] :param measurement_id: The measurement/detection level ID within the inference response. :type measurement_id: Optional[str] :param lifetime: The count of frames the track has been alive for. :type lifetime: int | None :param total_updates: The number of total updates in which a measurement has been associated to the track. :type total_updates: int | None :param total_missed_updates: The number of total missed updates in which a measurement has not been associated to the track. :type total_missed_updates: int | None :param consecutive_updates: The total count of consecutive updates via measurement. :type consecutive_updates: int | None :param consecutive_missed_updates: The count of consecutive missed updates. If this exceeds the tracker's max_missed_updates, the track status should be "lost". :type consecutive_missed_updates: int | None :param status: The track status: [new, active, lost]. :type status: Optional[str] :param label: The track class label. :type label: Optional[str] :param measurement: The measurement used to update the tracker. :type measurement: Optional[dict] :param measurement_uncertainty: The measurement uncertainty used to update the tracker. :type measurement_uncertainty: Optional[list[list[float]]] :param predicted_estimated_state: The track's predicted estimate position/state before a potential measurement assignment :type predicted_estimated_state: Optional[dict] :param estimated_state: The track's updated estimate position/state after a potential measurement assignment :type estimated_state: Optional[dict] :param estimated_state_uncertainty: The track's state estimate uncertainty :type estimated_state_uncertainty: Optional[list[list[float]]] :param cost: The cost value from the assignment function. :type cost: Optional[float] :param coordinate: The coordinate associated with the track. :type coordinate: GeoPoint | None """ track_id: str | None created_at: datetime | None updated_at: datetime | None measurement_id: str = None lifetime: int | None total_updates: int | None total_missed_updates: int | None consecutive_updates: int | None consecutive_missed_updates: int | None status: str | None label: str | None measurement: dict | None measurement_uncertainty: list[list[float]] | None predicted_estimated_state: dict | None estimated_state: dict | None estimated_state_uncertainty: list[list[float]] | None cost: float | None coordinate: GeoPoint | None = None
[docs] class NewStoreTracksRequest(BaseModel): """Defines a store tracks request. :param tracker_id: The tracker providing track outputs. :type tracker_id: str :param tracker_kind: The kind of tracker. :type tracker_kind: str :param project_id: The project containing the tracker. :type project_id: str :param model_id: The model providing the inference. :type model_id: str :param inference_id: The inference being used to update the tracker. :type inference_id: str :param ts: The timestamp indicating when the frame/packet/entity arrived at the sensor. :type ts: str :param sequence_number: An optional monotonically increasing integer such as a frame number. :type sequence_number: int :param tracks: The collection of tracks. :type tracks: list[Track] """ tracker_id: str tracker_kind: str project_id: str model_id: str inference_id: str ts: str sequence_number: int | None = None tracks: list[Track] = []