from enum import Enum
from pydantic import BaseModel, ConfigDict, Field, field_validator
# Defines the set of assignment/distance functions that can be used
# to associate tracks with measurements
[docs]
class TrackerAssignmentFunction(str, Enum):
EUCLIDEAN = "euclidean"
HAVERSINE = "haversine"
INTERSECTION_OVER_UNION = "intersection_over_union"
# Defines the available tracker types/kinds
[docs]
class TrackerKind(str, Enum):
POINT_UNITLESS = "point_unitless"
POINT_LAT_LONG = "point_latitude_longitude"
POINT_UTM = "point_utm"
BOX_UNITLESS = "box_unitless"
[docs]
class NewCreateTrackerRequest(BaseModel):
project_id: str
kind: TrackerKind
max_missing_updates: int
min_lifetime_before_active: int
assignment_function: TrackerAssignmentFunction
assignment_threshold: float
model_config = ConfigDict(use_enum_values=True)
[docs]
class Measurement(BaseModel):
label: str
score: float
covariance: list[list[float]] = Field(default_factory=list)
id: str
[docs]
class UnitlessPoint(Measurement):
x: float
y: float
[docs]
@field_validator("covariance", check_fields=False)
@classmethod
def validate_covariance(cls, v):
if len(v) == 0:
return
if len(v) != 2 or any(len(row) != 2 for row in v):
raise ValueError("covariance must be a 2x2 matrix")
return v
[docs]
class LatLongPoint(Measurement):
latitude: float
longitude: float
[docs]
@field_validator("covariance", check_fields=False)
@classmethod
def validate_covariance(cls, v):
if len(v) == 0:
return
if len(v) != 2 or any(len(row) != 2 for row in v):
raise ValueError("covariance must be a 2x2 matrix")
return v
[docs]
class UTMPoint(Measurement):
northing: float
easting: float
zone: str
[docs]
@field_validator("covariance", check_fields=False)
@classmethod
def validate_covariance(cls, v):
if len(v) == 0:
return
if len(v) != 2 or any(len(row) != 2 for row in v):
raise ValueError("covariance must be a 2x2 matrix")
return v
[docs]
class UnitlessBox(Measurement):
x: float
y: float
w: float
h: float
[docs]
@field_validator("covariance", check_fields=False)
@classmethod
def validate_covariance(cls, v):
if len(v) == 0:
return
if len(v) != 4 or any(len(row) != 4 for row in v):
raise ValueError("covariance must be a 4x4 matrix")
return v
[docs]
class NewUpdateTrackerRequest(BaseModel):
tracker_id: str
external_input: ExternalTrackerInput
points_unitless: list[UnitlessPoint] = Field(default_factory=list)
points_lat_long: list[LatLongPoint] = Field(default_factory=list)
points_utm: list[UTMPoint] = Field(default_factory=list)
boxes_unitless: list[UnitlessBox] = Field(default_factory=list)
[docs]
class Tracker(BaseModel):
tracker_id: str
kind: TrackerKind
max_missing_updates: int
min_lifetime_before_active: int
assignment_function: TrackerAssignmentFunction
assignment_threshold: float