From aaf66107add4434deb121d0b2c6570dd3f3cf679 Mon Sep 17 00:00:00 2001 From: Dan Albert Date: Sun, 13 Aug 2023 11:35:12 -0700 Subject: [PATCH] Improve type inference in loiter flight plans. --- game/ato/flightplans/flightplan.py | 6 ++-- game/ato/flightplans/formation.py | 9 +++-- game/ato/flightplans/formationattack.py | 46 ++++++++++++------------- game/ato/flightplans/loiter.py | 11 ++++-- game/ato/flightplans/sweep.py | 2 +- 5 files changed, 42 insertions(+), 32 deletions(-) diff --git a/game/ato/flightplans/flightplan.py b/game/ato/flightplans/flightplan.py index c14d66ff..a6bf47b0 100644 --- a/game/ato/flightplans/flightplan.py +++ b/game/ato/flightplans/flightplan.py @@ -314,7 +314,9 @@ class FlightPlan(ABC, Generic[LayoutT]): raise NotImplementedError @self_type_guard - def is_loiter(self, flight_plan: FlightPlan[Any]) -> TypeGuard[LoiterFlightPlan]: + def is_loiter( + self, flight_plan: FlightPlan[Any] + ) -> TypeGuard[LoiterFlightPlan[Any]]: return False @self_type_guard @@ -326,5 +328,5 @@ class FlightPlan(ABC, Generic[LayoutT]): @self_type_guard def is_formation( self, flight_plan: FlightPlan[Any] - ) -> TypeGuard[FormationFlightPlan]: + ) -> TypeGuard[FormationFlightPlan[Any]]: return False diff --git a/game/ato/flightplans/formation.py b/game/ato/flightplans/formation.py index de54c526..b22854be 100644 --- a/game/ato/flightplans/formation.py +++ b/game/ato/flightplans/formation.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta from functools import cached_property -from typing import Any, TYPE_CHECKING, TypeGuard +from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar from game.typeguard import self_type_guard from game.utils import Speed @@ -25,7 +25,10 @@ class FormationLayout(LoiterLayout, ABC): nav_from: list[FlightWaypoint] -class FormationFlightPlan(LoiterFlightPlan, ABC): +LayoutT = TypeVar("LayoutT", bound=FormationLayout) + + +class FormationFlightPlan(LoiterFlightPlan[LayoutT], ABC): @property @abstractmethod def package_speed_waypoints(self) -> set[FlightWaypoint]: @@ -107,5 +110,5 @@ class FormationFlightPlan(LoiterFlightPlan, ABC): @self_type_guard def is_formation( self, flight_plan: FlightPlan[Any] - ) -> TypeGuard[FormationFlightPlan]: + ) -> TypeGuard[FormationFlightPlan[Any]]: return True diff --git a/game/ato/flightplans/formationattack.py b/game/ato/flightplans/formationattack.py index ce20eca7..def0b6c4 100644 --- a/game/ato/flightplans/formationattack.py +++ b/game/ato/flightplans/formationattack.py @@ -23,7 +23,29 @@ if TYPE_CHECKING: from ..flight import Flight -class FormationAttackFlightPlan(FormationFlightPlan, ABC): +@dataclass(frozen=True) +class FormationAttackLayout(FormationLayout): + ingress: FlightWaypoint + targets: list[FlightWaypoint] + + def iter_waypoints(self) -> Iterator[FlightWaypoint]: + yield self.departure + yield self.hold + yield from self.nav_to + yield self.join + yield self.ingress + yield from self.targets + yield self.split + if self.refuel is not None: + yield self.refuel + yield from self.nav_from + yield self.arrival + if self.divert is not None: + yield self.divert + yield self.bullseye + + +class FormationAttackFlightPlan(FormationFlightPlan[FormationAttackLayout], ABC): @property def package_speed_waypoints(self) -> set[FlightWaypoint]: return { @@ -95,28 +117,6 @@ class FormationAttackFlightPlan(FormationFlightPlan, ABC): return super().tot_for_waypoint(waypoint) -@dataclass(frozen=True) -class FormationAttackLayout(FormationLayout): - ingress: FlightWaypoint - targets: list[FlightWaypoint] - - def iter_waypoints(self) -> Iterator[FlightWaypoint]: - yield self.departure - yield self.hold - yield from self.nav_to - yield self.join - yield self.ingress - yield from self.targets - yield self.split - if self.refuel is not None: - yield self.refuel - yield from self.nav_from - yield self.arrival - if self.divert is not None: - yield self.divert - yield self.bullseye - - FlightPlanT = TypeVar("FlightPlanT", bound=FlightPlan[FormationAttackLayout]) LayoutT = TypeVar("LayoutT", bound=FormationAttackLayout) diff --git a/game/ato/flightplans/loiter.py b/game/ato/flightplans/loiter.py index 8b5dc108..454d5336 100644 --- a/game/ato/flightplans/loiter.py +++ b/game/ato/flightplans/loiter.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, TYPE_CHECKING, TypeGuard +from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar from game.typeguard import self_type_guard from .flightplan import FlightPlan @@ -18,7 +18,10 @@ class LoiterLayout(StandardLayout, ABC): hold: FlightWaypoint -class LoiterFlightPlan(StandardFlightPlan[Any], ABC): +LayoutT = TypeVar("LayoutT", bound=LoiterLayout) + + +class LoiterFlightPlan(StandardFlightPlan[LayoutT], ABC): @property def hold_duration(self) -> timedelta: return timedelta(minutes=5) @@ -42,5 +45,7 @@ class LoiterFlightPlan(StandardFlightPlan[Any], ABC): return travel_time + self.hold_duration @self_type_guard - def is_loiter(self, flight_plan: FlightPlan[Any]) -> TypeGuard[LoiterFlightPlan]: + def is_loiter( + self, flight_plan: FlightPlan[Any] + ) -> TypeGuard[LoiterFlightPlan[Any]]: return True diff --git a/game/ato/flightplans/sweep.py b/game/ato/flightplans/sweep.py index 93bd2c38..412a32b3 100644 --- a/game/ato/flightplans/sweep.py +++ b/game/ato/flightplans/sweep.py @@ -37,7 +37,7 @@ class SweepLayout(LoiterLayout): yield self.bullseye -class SweepFlightPlan(LoiterFlightPlan): +class SweepFlightPlan(LoiterFlightPlan[SweepLayout]): @staticmethod def builder_type() -> Type[Builder]: return Builder