diff --git a/game/ato/flightplans/flightplan.py b/game/ato/flightplans/flightplan.py index a6bf47b0..6cde192f 100644 --- a/game/ato/flightplans/flightplan.py +++ b/game/ato/flightplans/flightplan.py @@ -330,3 +330,6 @@ class FlightPlan(ABC, Generic[LayoutT]): self, flight_plan: FlightPlan[Any] ) -> TypeGuard[FormationFlightPlan[Any]]: return False + + def add_waypoint_actions(self) -> None: + pass diff --git a/game/ato/flightplans/ibuilder.py b/game/ato/flightplans/ibuilder.py index a91cd6f3..25ea3157 100644 --- a/game/ato/flightplans/ibuilder.py +++ b/game/ato/flightplans/ibuilder.py @@ -36,6 +36,7 @@ class IBuilder(ABC, Generic[FlightPlanT, LayoutT]): try: self._generate_package_waypoints_if_needed(dump_debug_info) self._flight_plan = self.build(dump_debug_info) + self._flight_plan.add_waypoint_actions() except NavMeshError as ex: color = "blue" if self.flight.squadron.player else "red" raise PlanningError( diff --git a/game/ato/flightplans/loiter.py b/game/ato/flightplans/loiter.py index 454d5336..4d6e9dce 100644 --- a/game/ato/flightplans/loiter.py +++ b/game/ato/flightplans/loiter.py @@ -5,7 +5,9 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar +from game.flightplan.waypointactions.hold import Hold from game.typeguard import self_type_guard +from game.utils import Speed from .flightplan import FlightPlan from .standard import StandardFlightPlan, StandardLayout @@ -49,3 +51,13 @@ class LoiterFlightPlan(StandardFlightPlan[LayoutT], ABC): self, flight_plan: FlightPlan[Any] ) -> TypeGuard[LoiterFlightPlan[Any]]: return True + + def provide_push_time(self) -> datetime: + return self.push_time + + def add_waypoint_actions(self) -> None: + hold = self.layout.hold + speed = self.flight.unit_type.patrol_speed + if speed is None: + speed = Speed.from_mach(0.6, hold.alt) + hold.add_action(Hold(self.provide_push_time, hold.alt, speed)) diff --git a/game/ato/flightplans/waypointbuilder.py b/game/ato/flightplans/waypointbuilder.py index 8cb75cdd..499cfd95 100644 --- a/game/ato/flightplans/waypointbuilder.py +++ b/game/ato/flightplans/waypointbuilder.py @@ -168,6 +168,9 @@ class WaypointBuilder: "HOLD", FlightWaypointType.LOITER, position, + # Bug: DCS only accepts MSL altitudes for the orbit task and 500 meters is + # below the ground for most if not all of NTTR (and lots of places in other + # maps). meters(500) if self.is_helo else self.doctrine.rendezvous_altitude, alt_type, description="Wait until push time", diff --git a/game/ato/flightstate/actionstate.py b/game/ato/flightstate/actionstate.py new file mode 100644 index 00000000..15884c01 --- /dev/null +++ b/game/ato/flightstate/actionstate.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from game.flightplan.waypointactions.waypointaction import WaypointAction + + +class ActionState: + def __init__(self, action: WaypointAction) -> None: + self.action = action + self._finished = False + + def describe(self) -> str: + return self.action.describe() + + def finish(self) -> None: + self._finished = True + + def is_finished(self) -> bool: + return self._finished + + def on_game_tick(self, time: datetime, duration: timedelta) -> None: + self.action.update_state(self, time, duration) diff --git a/game/ato/flightstate/inflight.py b/game/ato/flightstate/inflight.py index 53e4f852..3654255e 100644 --- a/game/ato/flightstate/inflight.py +++ b/game/ato/flightstate/inflight.py @@ -1,12 +1,15 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections import deque from datetime import datetime, timedelta from typing import TYPE_CHECKING from dcs import Point +from dcs.task import Task from game.ato.flightstate import Completed +from game.ato.flightstate.actionstate import ActionState from game.ato.flightstate.flightstate import FlightState from game.ato.flightwaypoint import FlightWaypoint from game.ato.flightwaypointtype import FlightWaypointType @@ -37,6 +40,15 @@ class InFlight(FlightState, ABC): self.total_time_to_next_waypoint = self.travel_time_between_waypoints() self.elapsed_time = timedelta() self.current_waypoint_elapsed = False + self.pending_actions: deque[ActionState] = deque( + ActionState(a) for a in self.current_waypoint.actions + ) + + @property + def current_action(self) -> ActionState | None: + if self.pending_actions: + return self.pending_actions[0] + return None @property def cancelable(self) -> bool: @@ -80,7 +92,6 @@ class InFlight(FlightState, ABC): return initial_fuel def next_waypoint_state(self) -> FlightState: - from .loiter import Loiter from .racetrack import RaceTrack from .navigating import Navigating @@ -89,8 +100,6 @@ class InFlight(FlightState, ABC): return Completed(self.flight, self.settings) if self.next_waypoint.waypoint_type is FlightWaypointType.PATROL_TRACK: return RaceTrack(self.flight, self.settings, new_index) - if self.next_waypoint.waypoint_type is FlightWaypointType.LOITER: - return Loiter(self.flight, self.settings, new_index) return Navigating(self.flight, self.settings, new_index) def advance_to_next_waypoint(self) -> FlightState: @@ -102,6 +111,12 @@ class InFlight(FlightState, ABC): def on_game_tick( self, events: GameUpdateEvents, time: datetime, duration: timedelta ) -> None: + if (action := self.current_action) is not None: + action.on_game_tick(time, duration) + if action.is_finished(): + self.pending_actions.popleft() + return + self.elapsed_time += duration if self.elapsed_time > self.total_time_to_next_waypoint: new_state = self.advance_to_next_waypoint() @@ -152,11 +167,3 @@ class InFlight(FlightState, ABC): @property def spawn_type(self) -> StartType: return StartType.IN_FLIGHT - - @property - def description(self) -> str: - if self.has_aborted: - abort = "(Aborted) " - else: - abort = "" - return f"{abort}Flying to {self.next_waypoint.name}" diff --git a/game/ato/flightstate/loiter.py b/game/ato/flightstate/loiter.py deleted file mode 100644 index c23129c2..00000000 --- a/game/ato/flightstate/loiter.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from datetime import timedelta -from typing import TYPE_CHECKING - -from dcs import Point - -from game.ato.flightstate import FlightState, InFlight -from game.ato.flightstate.navigating import Navigating -from game.utils import Distance, Speed - -if TYPE_CHECKING: - from game.ato.flight import Flight - from game.settings import Settings - - -class Loiter(InFlight): - def __init__(self, flight: Flight, settings: Settings, waypoint_index: int) -> None: - assert flight.flight_plan.is_loiter(flight.flight_plan) - self.hold_duration = flight.flight_plan.hold_duration - super().__init__(flight, settings, waypoint_index) - - def estimate_position(self) -> Point: - return self.current_waypoint.position - - def estimate_altitude(self) -> tuple[Distance, str]: - return self.current_waypoint.alt, self.current_waypoint.alt_type - - def estimate_speed(self) -> Speed: - return self.flight.unit_type.preferred_patrol_speed(self.estimate_altitude()[0]) - - def estimate_fuel(self) -> float: - # TODO: Estimate loiter consumption per minute? - return self.estimate_fuel_at_current_waypoint() - - def next_waypoint_state(self) -> FlightState: - # Do not automatically advance to the next waypoint. Just proceed from the - # current one with the normal flying state. - return Navigating(self.flight, self.settings, self.waypoint_index) - - def travel_time_between_waypoints(self) -> timedelta: - return self.hold_duration - - @property - def description(self) -> str: - return f"Loitering for {self.hold_duration - self.elapsed_time}" diff --git a/game/ato/flightstate/navigating.py b/game/ato/flightstate/navigating.py index 44603205..fc97c2be 100644 --- a/game/ato/flightstate/navigating.py +++ b/game/ato/flightstate/navigating.py @@ -80,3 +80,14 @@ class Navigating(InFlight): @property def spawn_type(self) -> StartType: return StartType.IN_FLIGHT + + @property + def description(self) -> str: + if (action := self.current_action) is not None: + return action.describe() + + if self.has_aborted: + abort = "(Aborted) " + else: + abort = "" + return f"{abort}Flying to {self.next_waypoint.name}" diff --git a/game/ato/flightwaypoint.py b/game/ato/flightwaypoint.py index ac625f57..7be69688 100644 --- a/game/ato/flightwaypoint.py +++ b/game/ato/flightwaypoint.py @@ -7,6 +7,7 @@ from typing import Literal, TYPE_CHECKING from dcs import Point from game.ato.flightwaypointtype import FlightWaypointType +from game.flightplan.waypointactions.waypointaction import WaypointAction from game.theater.theatergroup import TheaterUnit from game.utils import Distance, meters @@ -39,6 +40,8 @@ class FlightWaypoint: # The minimum amount of fuel remaining at this waypoint in pounds. min_fuel: float | None = None + actions: list[WaypointAction] = field(default_factory=list) + # These are set very late by the air conflict generator (part of mission # generation). We do it late so that we don't need to propagate changes # to waypoint times whenever the player alters the package TOT or the @@ -46,6 +49,9 @@ class FlightWaypoint: tot: datetime | None = None departure_time: datetime | None = None + def add_action(self, action: WaypointAction) -> None: + self.actions.append(action) + @property def x(self) -> float: return self.position.x diff --git a/game/flightplan/waypointactions/__init__.py b/game/flightplan/waypointactions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/game/flightplan/waypointactions/hold.py b/game/flightplan/waypointactions/hold.py new file mode 100644 index 00000000..9b2775e8 --- /dev/null +++ b/game/flightplan/waypointactions/hold.py @@ -0,0 +1,54 @@ +from collections.abc import Iterator +from datetime import datetime, timedelta + +from dcs.task import Task, OrbitAction, ControlledTask + +from game.ato.flightstate.actionstate import ActionState +from game.provider import Provider +from game.utils import Distance, Speed +from .taskcontext import TaskContext +from .waypointaction import WaypointAction + + +class Hold(WaypointAction): + """Loiter at a location until a push time to synchronize with other flights. + + Taxi behavior is extremely unpredictable, so we cannot reliably predict ETAs for + waypoints without first fixing a time for one waypoint by holding until a sync time. + This is typically done with a dedicated hold point. If the flight reaches the hold + point before their push time, they will loiter at that location rather than fly to + their next waypoint as a speed that's often dangerously slow. + """ + + def __init__( + self, push_time_provider: Provider[datetime], altitude: Distance, speed: Speed + ) -> None: + self._push_time_provider = push_time_provider + self._altitude = altitude + self._speed = speed + + def describe(self) -> str: + return self._push_time_provider().strftime("Holding until %H:%M:%S") + + def update_state( + self, state: ActionState, time: datetime, duration: timedelta + ) -> None: + if self._push_time_provider() <= time: + state.finish() + + def iter_tasks(self, ctx: TaskContext) -> Iterator[Task]: + remaining_time = self._push_time_provider() - ctx.mission_start_time + if remaining_time <= timedelta(): + return + + loiter = ControlledTask( + OrbitAction( + altitude=int(self._altitude.meters), + pattern=OrbitAction.OrbitPattern.Circle, + speed=self._speed.kph, + ) + ) + # The DCS task is serialized using the time from mission start, not the actual + # time. + loiter.stop_after_time(int(remaining_time.total_seconds())) + yield loiter diff --git a/game/flightplan/waypointactions/taskcontext.py b/game/flightplan/waypointactions/taskcontext.py new file mode 100644 index 00000000..8a3a3884 --- /dev/null +++ b/game/flightplan/waypointactions/taskcontext.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass +from datetime import datetime + + +@dataclass(frozen=True) +class TaskContext: + mission_start_time: datetime diff --git a/game/flightplan/waypointactions/waypointaction.py b/game/flightplan/waypointactions/waypointaction.py new file mode 100644 index 00000000..52095bf5 --- /dev/null +++ b/game/flightplan/waypointactions/waypointaction.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Iterator +from datetime import datetime, timedelta +from typing import TYPE_CHECKING + +from dcs.task import Task + +from .taskcontext import TaskContext + +if TYPE_CHECKING: + from game.ato.flightstate.actionstate import ActionState + + +class WaypointAction(ABC): + @abstractmethod + def describe(self) -> str: + ... + + @abstractmethod + def update_state( + self, state: ActionState, time: datetime, duration: timedelta + ) -> None: + ... + + @abstractmethod + def iter_tasks(self, ctx: TaskContext) -> Iterator[Task]: + ... diff --git a/game/missiongenerator/aircraft/flightgroupspawner.py b/game/missiongenerator/aircraft/flightgroupspawner.py index 33855bf0..b8a51b82 100644 --- a/game/missiongenerator/aircraft/flightgroupspawner.py +++ b/game/missiongenerator/aircraft/flightgroupspawner.py @@ -161,6 +161,8 @@ class FlightGroupSpawner: ) group.points[0].alt_type = alt_type + # We don't need to add waypoint tasks for the starting waypoint here because we + # do that in WaypointGenerator. return group def _generate_at_airport(self, name: str, airport: Airport) -> FlyingGroup[Any]: diff --git a/game/missiongenerator/aircraft/waypoints/holdpoint.py b/game/missiongenerator/aircraft/waypoints/holdpoint.py deleted file mode 100644 index a5d46cd1..00000000 --- a/game/missiongenerator/aircraft/waypoints/holdpoint.py +++ /dev/null @@ -1,27 +0,0 @@ -import logging - -from dcs.point import MovingPoint -from dcs.task import ControlledTask, OptFormation, OrbitAction - -from game.ato.flightplans.loiter import LoiterFlightPlan -from .pydcswaypointbuilder import PydcsWaypointBuilder - - -class HoldPointBuilder(PydcsWaypointBuilder): - def add_tasks(self, waypoint: MovingPoint) -> None: - loiter = ControlledTask( - OrbitAction(altitude=waypoint.alt, pattern=OrbitAction.OrbitPattern.Circle) - ) - if not isinstance(self.flight.flight_plan, LoiterFlightPlan): - flight_plan_type = self.flight.flight_plan.__class__.__name__ - logging.error( - f"Cannot configure hold for for {self.flight} because " - f"{flight_plan_type} does not define a push time. AI will push " - "immediately and may flight unsuitable speeds." - ) - return - push_time = self.flight.flight_plan.push_time - self.waypoint.departure_time = push_time - loiter.stop_after_time(int((push_time - self.now).total_seconds())) - waypoint.add_task(loiter) - waypoint.add_task(OptFormation.finger_four_close()) diff --git a/game/missiongenerator/aircraft/waypoints/pydcswaypointbuilder.py b/game/missiongenerator/aircraft/waypoints/pydcswaypointbuilder.py index d946c588..eb51a235 100644 --- a/game/missiongenerator/aircraft/waypoints/pydcswaypointbuilder.py +++ b/game/missiongenerator/aircraft/waypoints/pydcswaypointbuilder.py @@ -11,6 +11,7 @@ from dcs.unitgroup import FlyingGroup from game.ato import Flight, FlightWaypoint from game.ato.flightwaypointtype import FlightWaypointType from game.ato.traveltime import GroundSpeed +from game.flightplan.waypointactions.taskcontext import TaskContext from game.missiongenerator.missiondata import MissionData from game.theater import MissionTarget, TheaterUnit from game.unitmap import UnitMap @@ -86,7 +87,10 @@ class PydcsWaypointBuilder: return waypoint def add_tasks(self, waypoint: MovingPoint) -> None: - pass + ctx = TaskContext(self.now) + for action in self.waypoint.actions: + for task in action.iter_tasks(ctx): + waypoint.add_task(task) def set_waypoint_tot(self, waypoint: MovingPoint, tot: datetime) -> None: self.waypoint.tot = tot diff --git a/game/missiongenerator/aircraft/waypoints/waypointgenerator.py b/game/missiongenerator/aircraft/waypoints/waypointgenerator.py index d0180850..0a688c4f 100644 --- a/game/missiongenerator/aircraft/waypoints/waypointgenerator.py +++ b/game/missiongenerator/aircraft/waypoints/waypointgenerator.py @@ -29,7 +29,6 @@ from .baiingress import BaiIngressBuilder from .casingress import CasIngressBuilder from .deadingress import DeadIngressBuilder from .default import DefaultWaypointBuilder -from .holdpoint import HoldPointBuilder from .joinpoint import JoinPointBuilder from .landingpoint import LandingPointBuilder from .landingzone import LandingZoneBuilder @@ -142,7 +141,6 @@ class WaypointGenerator: FlightWaypointType.INGRESS_SWEEP: SweepIngressBuilder, FlightWaypointType.JOIN: JoinPointBuilder, FlightWaypointType.LANDING_POINT: LandingPointBuilder, - FlightWaypointType.LOITER: HoldPointBuilder, FlightWaypointType.PATROL: RaceTrackEndBuilder, FlightWaypointType.PATROL_TRACK: RaceTrackBuilder, FlightWaypointType.PICKUP_ZONE: LandingZoneBuilder, diff --git a/game/provider.py b/game/provider.py new file mode 100644 index 00000000..a2cc3cf0 --- /dev/null +++ b/game/provider.py @@ -0,0 +1,4 @@ +from typing import TypeAlias, TypeVar, Callable + +T = TypeVar("T") +Provider: TypeAlias = Callable[[], T] diff --git a/tests/ato/__init__.py b/tests/ato/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ato/flightstate/__init__.py b/tests/ato/flightstate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ato/flightstate/test_actionstate.py b/tests/ato/flightstate/test_actionstate.py new file mode 100644 index 00000000..aeb4348b --- /dev/null +++ b/tests/ato/flightstate/test_actionstate.py @@ -0,0 +1,29 @@ +from collections.abc import Iterator +from datetime import timedelta, datetime + +from dcs.task import Task + +from game.ato.flightstate.actionstate import ActionState +from game.flightplan.waypointactions.taskcontext import TaskContext +from game.flightplan.waypointactions.waypointaction import WaypointAction + + +class TestAction(WaypointAction): + def describe(self) -> str: + return "" + + def update_state( + self, state: ActionState, time: datetime, duration: timedelta + ) -> None: + pass + + def iter_tasks(self, ctx: TaskContext) -> Iterator[Task]: + yield from [] + + +def test_actionstate() -> None: + action = TestAction() + state = ActionState(action) + assert not state.is_finished() + state.finish() + assert state.is_finished() diff --git a/tests/flightplan/waypointactions/test_hold.py b/tests/flightplan/waypointactions/test_hold.py new file mode 100644 index 00000000..3831569b --- /dev/null +++ b/tests/flightplan/waypointactions/test_hold.py @@ -0,0 +1,64 @@ +from datetime import datetime, timedelta + +from game.ato.flightstate.actionstate import ActionState +from game.flightplan.waypointactions.hold import Hold +from game.flightplan.waypointactions.taskcontext import TaskContext +from game.utils import meters, kph + + +def test_hold_tasks() -> None: + t0 = datetime(1999, 3, 28) + tasks = list( + Hold(lambda: t0 + timedelta(minutes=5), meters(8000), kph(400)).iter_tasks( + TaskContext(t0 + timedelta(minutes=1)) + ) + ) + assert len(tasks) == 1 + task = tasks[0] + assert task.id == "ControlledTask" + assert task.params["stopCondition"]["time"] == 4 * 60 + assert task.params["task"]["id"] == "Orbit" + assert task.params["task"]["params"]["altitude"] == 8000 + assert task.params["task"]["params"]["pattern"] == "Circle" + assert task.params["task"]["params"]["speed"] == kph(400).meters_per_second + + +def test_hold_task_at_or_after_push() -> None: + t0 = datetime(1999, 3, 28) + assert not list( + Hold(lambda: t0, meters(8000), kph(400)).iter_tasks(TaskContext(t0)) + ) + assert not list( + Hold(lambda: t0, meters(8000), kph(400)).iter_tasks( + TaskContext(t0 + timedelta(minutes=1)) + ) + ) + + +def test_hold_tick() -> None: + t0 = datetime(1999, 3, 28) + task = Hold(lambda: t0 + timedelta(minutes=5), meters(8000), kph(400)) + state = ActionState(task) + task.update_state(state, t0, timedelta()) + assert not state.is_finished() + task.update_state(state, t0 + timedelta(minutes=1), timedelta(minutes=1)) + assert not state.is_finished() + task.update_state(state, t0 + timedelta(minutes=2), timedelta(minutes=1)) + assert not state.is_finished() + task.update_state(state, t0 + timedelta(minutes=3), timedelta(minutes=1)) + assert not state.is_finished() + task.update_state(state, t0 + timedelta(minutes=4), timedelta(minutes=1)) + assert not state.is_finished() + task.update_state(state, t0 + timedelta(minutes=5), timedelta(minutes=1)) + assert state.is_finished() + task.update_state(state, t0 + timedelta(minutes=6), timedelta(minutes=1)) + assert state.is_finished() + + +def test_hold_description() -> None: + assert ( + Hold( + lambda: datetime(1999, 3, 28) + timedelta(minutes=5), meters(8000), kph(400) + ).describe() + == "Holding until 00:05:00" + )