diff --git a/game/ato/flightstate/inflight.py b/game/ato/flightstate/inflight.py index c3724e10..40987246 100644 --- a/game/ato/flightstate/inflight.py +++ b/game/ato/flightstate/inflight.py @@ -12,7 +12,6 @@ from game.ato.flightwaypoint import FlightWaypoint from game.ato.flightwaypointtype import FlightWaypointType from game.ato.starttype import StartType from game.utils import Distance, LBS_TO_KG, Speed, pairwise -from gen.flights.flightplan import LoiterFlightPlan if TYPE_CHECKING: from game.ato.flight import Flight @@ -44,7 +43,7 @@ class InFlight(FlightState, ABC): # at a loiter point but still a regular InFlight (Loiter overrides this # method) that means we're traveling from the loiter point but no longer # loitering. - assert isinstance(self.flight.flight_plan, LoiterFlightPlan) + assert self.flight.flight_plan.is_loiter(self.flight.flight_plan) travel_time -= self.flight.flight_plan.hold_duration return travel_time diff --git a/game/ato/flightstate/loiter.py b/game/ato/flightstate/loiter.py index 63ecea53..c23129c2 100644 --- a/game/ato/flightstate/loiter.py +++ b/game/ato/flightstate/loiter.py @@ -8,7 +8,6 @@ from dcs import Point from game.ato.flightstate import FlightState, InFlight from game.ato.flightstate.navigating import Navigating from game.utils import Distance, Speed -from gen.flights.flightplan import LoiterFlightPlan if TYPE_CHECKING: from game.ato.flight import Flight @@ -17,7 +16,7 @@ if TYPE_CHECKING: class Loiter(InFlight): def __init__(self, flight: Flight, settings: Settings, waypoint_index: int) -> None: - assert isinstance(flight.flight_plan, LoiterFlightPlan) + assert flight.flight_plan.is_loiter(flight.flight_plan) self.hold_duration = flight.flight_plan.hold_duration super().__init__(flight, settings, waypoint_index) diff --git a/game/ato/flightstate/racetrack.py b/game/ato/flightstate/racetrack.py index 327152fb..14bcd86c 100644 --- a/game/ato/flightstate/racetrack.py +++ b/game/ato/flightstate/racetrack.py @@ -10,7 +10,6 @@ from game.ato import FlightType from game.ato.flightstate import InFlight from game.threatzones import ThreatPoly from game.utils import Distance, Speed, dcs_to_shapely_point -from gen.flights.flightplan import PatrollingFlightPlan if TYPE_CHECKING: from game.ato.flight import Flight @@ -19,7 +18,7 @@ if TYPE_CHECKING: class RaceTrack(InFlight): def __init__(self, flight: Flight, settings: Settings, waypoint_index: int) -> None: - assert isinstance(flight.flight_plan, PatrollingFlightPlan) + assert flight.flight_plan.is_patrol(flight.flight_plan) self.patrol_duration = flight.flight_plan.patrol_duration super().__init__(flight, settings, waypoint_index) self.commit_region = LineString( diff --git a/game/typeguard.py b/game/typeguard.py new file mode 100644 index 00000000..eac463ac --- /dev/null +++ b/game/typeguard.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import Callable, TypeGuard, TypeVar + +SelfT = TypeVar("SelfT") +BaseT = TypeVar("BaseT") +GuardT = TypeVar("GuardT") + + +def self_type_guard( + f: Callable[[SelfT, BaseT], TypeGuard[GuardT]] +) -> Callable[[SelfT, BaseT], TypeGuard[GuardT]]: + def decorator(s: SelfT, arg: BaseT) -> TypeGuard[GuardT]: + if id(s) != id(arg): + raise ValueError( + "self type guards must be called with self as the argument" + ) + return f(s, arg) + + return decorator diff --git a/gen/flights/flightplan.py b/gen/flights/flightplan.py index 9d8d2caf..224440c7 100644 --- a/gen/flights/flightplan.py +++ b/gen/flights/flightplan.py @@ -13,7 +13,15 @@ import random from dataclasses import dataclass, field from datetime import timedelta from functools import cached_property -from typing import Iterator, List, Optional, Set, TYPE_CHECKING, Tuple +from typing import ( + Iterator, + List, + Optional, + Set, + TYPE_CHECKING, + Tuple, + TypeGuard, +) from dcs.mapping import Point from dcs.unit import Unit @@ -42,6 +50,7 @@ from game.theater.theatergroundobject import ( EwrGroundObject, NavalGroundObject, ) +from game.typeguard import self_type_guard from game.utils import Distance, Heading, Speed, feet, knots, meters, nautical_miles from .closestairfields import ObjectiveDistanceCache from .traveltime import GroundSpeed, TravelTime @@ -320,6 +329,18 @@ class FlightPlan: """The time that the mission is complete and the flight RTBs.""" raise NotImplementedError + @self_type_guard + def is_loiter(self, flight_plan: FlightPlan) -> TypeGuard[LoiterFlightPlan]: + return False + + @self_type_guard + def is_patrol(self, flight_plan: FlightPlan) -> TypeGuard[PatrollingFlightPlan]: + return False + + @self_type_guard + def is_formation(self, flight_plan: FlightPlan) -> TypeGuard[FormationFlightPlan]: + return False + @dataclass(frozen=True) class LoiterFlightPlan(FlightPlan): @@ -357,6 +378,10 @@ class LoiterFlightPlan(FlightPlan): def mission_departure_time(self) -> timedelta: raise NotImplementedError + @self_type_guard + def is_loiter(self, flight_plan: FlightPlan) -> TypeGuard[LoiterFlightPlan]: + return True + @dataclass(frozen=True) class FormationFlightPlan(LoiterFlightPlan): @@ -442,6 +467,10 @@ class FormationFlightPlan(LoiterFlightPlan): def mission_departure_time(self) -> timedelta: return self.split_time + @self_type_guard + def is_formation(self, flight_plan: FlightPlan) -> TypeGuard[FormationFlightPlan]: + return True + @dataclass(frozen=True) class PatrollingFlightPlan(FlightPlan): @@ -498,6 +527,10 @@ class PatrollingFlightPlan(FlightPlan): def mission_departure_time(self) -> timedelta: return self.patrol_end_time + @self_type_guard + def is_patrol(self, flight_plan: FlightPlan) -> TypeGuard[PatrollingFlightPlan]: + return True + @dataclass(frozen=True) class BarCapFlightPlan(PatrollingFlightPlan):