Improve type inference in loiter flight plans.

This commit is contained in:
Dan Albert 2023-08-13 11:35:12 -07:00
parent 59756ce14c
commit aaf66107ad
5 changed files with 42 additions and 32 deletions

View File

@ -314,7 +314,9 @@ class FlightPlan(ABC, Generic[LayoutT]):
raise NotImplementedError
@self_type_guard
def is_loiter(self, flight_plan: FlightPlan[Any]) -> TypeGuard[LoiterFlightPlan]:
def is_loiter(
self, flight_plan: FlightPlan[Any]
) -> TypeGuard[LoiterFlightPlan[Any]]:
return False
@self_type_guard
@ -326,5 +328,5 @@ class FlightPlan(ABC, Generic[LayoutT]):
@self_type_guard
def is_formation(
self, flight_plan: FlightPlan[Any]
) -> TypeGuard[FormationFlightPlan]:
) -> TypeGuard[FormationFlightPlan[Any]]:
return False

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import cached_property
from typing import Any, TYPE_CHECKING, TypeGuard
from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar
from game.typeguard import self_type_guard
from game.utils import Speed
@ -25,7 +25,10 @@ class FormationLayout(LoiterLayout, ABC):
nav_from: list[FlightWaypoint]
class FormationFlightPlan(LoiterFlightPlan, ABC):
LayoutT = TypeVar("LayoutT", bound=FormationLayout)
class FormationFlightPlan(LoiterFlightPlan[LayoutT], ABC):
@property
@abstractmethod
def package_speed_waypoints(self) -> set[FlightWaypoint]:
@ -107,5 +110,5 @@ class FormationFlightPlan(LoiterFlightPlan, ABC):
@self_type_guard
def is_formation(
self, flight_plan: FlightPlan[Any]
) -> TypeGuard[FormationFlightPlan]:
) -> TypeGuard[FormationFlightPlan[Any]]:
return True

View File

@ -23,7 +23,29 @@ if TYPE_CHECKING:
from ..flight import Flight
class FormationAttackFlightPlan(FormationFlightPlan, ABC):
@dataclass(frozen=True)
class FormationAttackLayout(FormationLayout):
ingress: FlightWaypoint
targets: list[FlightWaypoint]
def iter_waypoints(self) -> Iterator[FlightWaypoint]:
yield self.departure
yield self.hold
yield from self.nav_to
yield self.join
yield self.ingress
yield from self.targets
yield self.split
if self.refuel is not None:
yield self.refuel
yield from self.nav_from
yield self.arrival
if self.divert is not None:
yield self.divert
yield self.bullseye
class FormationAttackFlightPlan(FormationFlightPlan[FormationAttackLayout], ABC):
@property
def package_speed_waypoints(self) -> set[FlightWaypoint]:
return {
@ -95,28 +117,6 @@ class FormationAttackFlightPlan(FormationFlightPlan, ABC):
return super().tot_for_waypoint(waypoint)
@dataclass(frozen=True)
class FormationAttackLayout(FormationLayout):
ingress: FlightWaypoint
targets: list[FlightWaypoint]
def iter_waypoints(self) -> Iterator[FlightWaypoint]:
yield self.departure
yield self.hold
yield from self.nav_to
yield self.join
yield self.ingress
yield from self.targets
yield self.split
if self.refuel is not None:
yield self.refuel
yield from self.nav_from
yield self.arrival
if self.divert is not None:
yield self.divert
yield self.bullseye
FlightPlanT = TypeVar("FlightPlanT", bound=FlightPlan[FormationAttackLayout])
LayoutT = TypeVar("LayoutT", bound=FormationAttackLayout)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, TYPE_CHECKING, TypeGuard
from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar
from game.typeguard import self_type_guard
from .flightplan import FlightPlan
@ -18,7 +18,10 @@ class LoiterLayout(StandardLayout, ABC):
hold: FlightWaypoint
class LoiterFlightPlan(StandardFlightPlan[Any], ABC):
LayoutT = TypeVar("LayoutT", bound=LoiterLayout)
class LoiterFlightPlan(StandardFlightPlan[LayoutT], ABC):
@property
def hold_duration(self) -> timedelta:
return timedelta(minutes=5)
@ -42,5 +45,7 @@ class LoiterFlightPlan(StandardFlightPlan[Any], ABC):
return travel_time + self.hold_duration
@self_type_guard
def is_loiter(self, flight_plan: FlightPlan[Any]) -> TypeGuard[LoiterFlightPlan]:
def is_loiter(
self, flight_plan: FlightPlan[Any]
) -> TypeGuard[LoiterFlightPlan[Any]]:
return True

View File

@ -37,7 +37,7 @@ class SweepLayout(LoiterLayout):
yield self.bullseye
class SweepFlightPlan(LoiterFlightPlan):
class SweepFlightPlan(LoiterFlightPlan[SweepLayout]):
@staticmethod
def builder_type() -> Type[Builder]:
return Builder