Replace some isinstance calls with TypeGuard.

These aren't as ergonomic as I'd hoped because of
https://www.python.org/dev/peps/pep-0647/#narrowing-of-implicit-self-and-cls-parameters.

I added a decorator `@self_type_guard` so we can avoid needing to import
the descendent types in the typeguard implementation (which wouldn't fix
any circular imports, just move them).
This commit is contained in:
Dan Albert 2022-02-12 14:57:33 -08:00
parent 85e7b1762d
commit 2e901f3586
5 changed files with 57 additions and 7 deletions

View File

@ -12,7 +12,6 @@ from game.ato.flightwaypoint import FlightWaypoint
from game.ato.flightwaypointtype import FlightWaypointType from game.ato.flightwaypointtype import FlightWaypointType
from game.ato.starttype import StartType from game.ato.starttype import StartType
from game.utils import Distance, LBS_TO_KG, Speed, pairwise from game.utils import Distance, LBS_TO_KG, Speed, pairwise
from gen.flights.flightplan import LoiterFlightPlan
if TYPE_CHECKING: if TYPE_CHECKING:
from game.ato.flight import Flight from game.ato.flight import Flight
@ -44,7 +43,7 @@ class InFlight(FlightState, ABC):
# at a loiter point but still a regular InFlight (Loiter overrides this # at a loiter point but still a regular InFlight (Loiter overrides this
# method) that means we're traveling from the loiter point but no longer # method) that means we're traveling from the loiter point but no longer
# loitering. # loitering.
assert isinstance(self.flight.flight_plan, LoiterFlightPlan) assert self.flight.flight_plan.is_loiter(self.flight.flight_plan)
travel_time -= self.flight.flight_plan.hold_duration travel_time -= self.flight.flight_plan.hold_duration
return travel_time return travel_time

View File

@ -8,7 +8,6 @@ from dcs import Point
from game.ato.flightstate import FlightState, InFlight from game.ato.flightstate import FlightState, InFlight
from game.ato.flightstate.navigating import Navigating from game.ato.flightstate.navigating import Navigating
from game.utils import Distance, Speed from game.utils import Distance, Speed
from gen.flights.flightplan import LoiterFlightPlan
if TYPE_CHECKING: if TYPE_CHECKING:
from game.ato.flight import Flight from game.ato.flight import Flight
@ -17,7 +16,7 @@ if TYPE_CHECKING:
class Loiter(InFlight): class Loiter(InFlight):
def __init__(self, flight: Flight, settings: Settings, waypoint_index: int) -> None: def __init__(self, flight: Flight, settings: Settings, waypoint_index: int) -> None:
assert isinstance(flight.flight_plan, LoiterFlightPlan) assert flight.flight_plan.is_loiter(flight.flight_plan)
self.hold_duration = flight.flight_plan.hold_duration self.hold_duration = flight.flight_plan.hold_duration
super().__init__(flight, settings, waypoint_index) super().__init__(flight, settings, waypoint_index)

View File

@ -10,7 +10,6 @@ from game.ato import FlightType
from game.ato.flightstate import InFlight from game.ato.flightstate import InFlight
from game.threatzones import ThreatPoly from game.threatzones import ThreatPoly
from game.utils import Distance, Speed, dcs_to_shapely_point from game.utils import Distance, Speed, dcs_to_shapely_point
from gen.flights.flightplan import PatrollingFlightPlan
if TYPE_CHECKING: if TYPE_CHECKING:
from game.ato.flight import Flight from game.ato.flight import Flight
@ -19,7 +18,7 @@ if TYPE_CHECKING:
class RaceTrack(InFlight): class RaceTrack(InFlight):
def __init__(self, flight: Flight, settings: Settings, waypoint_index: int) -> None: def __init__(self, flight: Flight, settings: Settings, waypoint_index: int) -> None:
assert isinstance(flight.flight_plan, PatrollingFlightPlan) assert flight.flight_plan.is_patrol(flight.flight_plan)
self.patrol_duration = flight.flight_plan.patrol_duration self.patrol_duration = flight.flight_plan.patrol_duration
super().__init__(flight, settings, waypoint_index) super().__init__(flight, settings, waypoint_index)
self.commit_region = LineString( self.commit_region = LineString(

20
game/typeguard.py Normal file
View File

@ -0,0 +1,20 @@
from __future__ import annotations
from typing import Callable, TypeGuard, TypeVar
SelfT = TypeVar("SelfT")
BaseT = TypeVar("BaseT")
GuardT = TypeVar("GuardT")
def self_type_guard(
f: Callable[[SelfT, BaseT], TypeGuard[GuardT]]
) -> Callable[[SelfT, BaseT], TypeGuard[GuardT]]:
def decorator(s: SelfT, arg: BaseT) -> TypeGuard[GuardT]:
if id(s) != id(arg):
raise ValueError(
"self type guards must be called with self as the argument"
)
return f(s, arg)
return decorator

View File

@ -13,7 +13,15 @@ import random
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from functools import cached_property from functools import cached_property
from typing import Iterator, List, Optional, Set, TYPE_CHECKING, Tuple from typing import (
Iterator,
List,
Optional,
Set,
TYPE_CHECKING,
Tuple,
TypeGuard,
)
from dcs.mapping import Point from dcs.mapping import Point
from dcs.unit import Unit from dcs.unit import Unit
@ -42,6 +50,7 @@ from game.theater.theatergroundobject import (
EwrGroundObject, EwrGroundObject,
NavalGroundObject, NavalGroundObject,
) )
from game.typeguard import self_type_guard
from game.utils import Distance, Heading, Speed, feet, knots, meters, nautical_miles from game.utils import Distance, Heading, Speed, feet, knots, meters, nautical_miles
from .closestairfields import ObjectiveDistanceCache from .closestairfields import ObjectiveDistanceCache
from .traveltime import GroundSpeed, TravelTime from .traveltime import GroundSpeed, TravelTime
@ -320,6 +329,18 @@ class FlightPlan:
"""The time that the mission is complete and the flight RTBs.""" """The time that the mission is complete and the flight RTBs."""
raise NotImplementedError raise NotImplementedError
@self_type_guard
def is_loiter(self, flight_plan: FlightPlan) -> TypeGuard[LoiterFlightPlan]:
return False
@self_type_guard
def is_patrol(self, flight_plan: FlightPlan) -> TypeGuard[PatrollingFlightPlan]:
return False
@self_type_guard
def is_formation(self, flight_plan: FlightPlan) -> TypeGuard[FormationFlightPlan]:
return False
@dataclass(frozen=True) @dataclass(frozen=True)
class LoiterFlightPlan(FlightPlan): class LoiterFlightPlan(FlightPlan):
@ -357,6 +378,10 @@ class LoiterFlightPlan(FlightPlan):
def mission_departure_time(self) -> timedelta: def mission_departure_time(self) -> timedelta:
raise NotImplementedError raise NotImplementedError
@self_type_guard
def is_loiter(self, flight_plan: FlightPlan) -> TypeGuard[LoiterFlightPlan]:
return True
@dataclass(frozen=True) @dataclass(frozen=True)
class FormationFlightPlan(LoiterFlightPlan): class FormationFlightPlan(LoiterFlightPlan):
@ -442,6 +467,10 @@ class FormationFlightPlan(LoiterFlightPlan):
def mission_departure_time(self) -> timedelta: def mission_departure_time(self) -> timedelta:
return self.split_time return self.split_time
@self_type_guard
def is_formation(self, flight_plan: FlightPlan) -> TypeGuard[FormationFlightPlan]:
return True
@dataclass(frozen=True) @dataclass(frozen=True)
class PatrollingFlightPlan(FlightPlan): class PatrollingFlightPlan(FlightPlan):
@ -498,6 +527,10 @@ class PatrollingFlightPlan(FlightPlan):
def mission_departure_time(self) -> timedelta: def mission_departure_time(self) -> timedelta:
return self.patrol_end_time return self.patrol_end_time
@self_type_guard
def is_patrol(self, flight_plan: FlightPlan) -> TypeGuard[PatrollingFlightPlan]:
return True
@dataclass(frozen=True) @dataclass(frozen=True)
class BarCapFlightPlan(PatrollingFlightPlan): class BarCapFlightPlan(PatrollingFlightPlan):