Formalize waypoint actions.

Create a WaypointAction class that defines the actions taken at a
waypoint. These will often map one-to-one with DCS waypoint actions but
can also be higher level and generate multiple actions. Once everything
has migrated all waypoint-type-specific behaviors of
PydcsWaypointBuilder will be gone, and it'll be easier to keep the sim
behaviors in sync with the mission generator behaviors.

For now only hold has been migrated. This is actually probably the most
complicated action we have (starting with this may have been a mistake,
but it did find all the rough edges quickly) since it affects waypoint
timings and flight position during simulation. That part isn't handled
as neatly as I'd like because the FlightState still has to special case
LOITER points to avoid simulating the wrong waypoint position. At some
point we should probably start tracking real positions in FlightState,
and when we do that will be solved.
This commit is contained in:
Dan Albert 2023-08-11 00:02:40 -07:00
parent 2f385086fd
commit 87441b8939
22 changed files with 273 additions and 87 deletions

View File

@ -330,3 +330,6 @@ class FlightPlan(ABC, Generic[LayoutT]):
self, flight_plan: FlightPlan[Any]
) -> TypeGuard[FormationFlightPlan[Any]]:
return False
def add_waypoint_actions(self) -> None:
pass

View File

@ -36,6 +36,7 @@ class IBuilder(ABC, Generic[FlightPlanT, LayoutT]):
try:
self._generate_package_waypoints_if_needed(dump_debug_info)
self._flight_plan = self.build(dump_debug_info)
self._flight_plan.add_waypoint_actions()
except NavMeshError as ex:
color = "blue" if self.flight.squadron.player else "red"
raise PlanningError(

View File

@ -5,7 +5,9 @@ from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar
from game.flightplan.waypointactions.hold import Hold
from game.typeguard import self_type_guard
from game.utils import Speed
from .flightplan import FlightPlan
from .standard import StandardFlightPlan, StandardLayout
@ -49,3 +51,13 @@ class LoiterFlightPlan(StandardFlightPlan[LayoutT], ABC):
self, flight_plan: FlightPlan[Any]
) -> TypeGuard[LoiterFlightPlan[Any]]:
return True
def provide_push_time(self) -> datetime:
return self.push_time
def add_waypoint_actions(self) -> None:
hold = self.layout.hold
speed = self.flight.unit_type.patrol_speed
if speed is None:
speed = Speed.from_mach(0.6, hold.alt)
hold.add_action(Hold(self.provide_push_time, hold.alt, speed))

View File

@ -168,6 +168,9 @@ class WaypointBuilder:
"HOLD",
FlightWaypointType.LOITER,
position,
# Bug: DCS only accepts MSL altitudes for the orbit task and 500 meters is
# below the ground for most if not all of NTTR (and lots of places in other
# maps).
meters(500) if self.is_helo else self.doctrine.rendezvous_altitude,
alt_type,
description="Wait until push time",

View File

@ -0,0 +1,25 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from game.flightplan.waypointactions.waypointaction import WaypointAction
class ActionState:
def __init__(self, action: WaypointAction) -> None:
self.action = action
self._finished = False
def describe(self) -> str:
return self.action.describe()
def finish(self) -> None:
self._finished = True
def is_finished(self) -> bool:
return self._finished
def on_game_tick(self, time: datetime, duration: timedelta) -> None:
self.action.update_state(self, time, duration)

View File

@ -1,12 +1,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections import deque
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from dcs import Point
from dcs.task import Task
from game.ato.flightstate import Completed
from game.ato.flightstate.actionstate import ActionState
from game.ato.flightstate.flightstate import FlightState
from game.ato.flightwaypoint import FlightWaypoint
from game.ato.flightwaypointtype import FlightWaypointType
@ -37,6 +40,15 @@ class InFlight(FlightState, ABC):
self.total_time_to_next_waypoint = self.travel_time_between_waypoints()
self.elapsed_time = timedelta()
self.current_waypoint_elapsed = False
self.pending_actions: deque[ActionState] = deque(
ActionState(a) for a in self.current_waypoint.actions
)
@property
def current_action(self) -> ActionState | None:
if self.pending_actions:
return self.pending_actions[0]
return None
@property
def cancelable(self) -> bool:
@ -80,7 +92,6 @@ class InFlight(FlightState, ABC):
return initial_fuel
def next_waypoint_state(self) -> FlightState:
from .loiter import Loiter
from .racetrack import RaceTrack
from .navigating import Navigating
@ -89,8 +100,6 @@ class InFlight(FlightState, ABC):
return Completed(self.flight, self.settings)
if self.next_waypoint.waypoint_type is FlightWaypointType.PATROL_TRACK:
return RaceTrack(self.flight, self.settings, new_index)
if self.next_waypoint.waypoint_type is FlightWaypointType.LOITER:
return Loiter(self.flight, self.settings, new_index)
return Navigating(self.flight, self.settings, new_index)
def advance_to_next_waypoint(self) -> FlightState:
@ -102,6 +111,12 @@ class InFlight(FlightState, ABC):
def on_game_tick(
self, events: GameUpdateEvents, time: datetime, duration: timedelta
) -> None:
if (action := self.current_action) is not None:
action.on_game_tick(time, duration)
if action.is_finished():
self.pending_actions.popleft()
return
self.elapsed_time += duration
if self.elapsed_time > self.total_time_to_next_waypoint:
new_state = self.advance_to_next_waypoint()
@ -152,11 +167,3 @@ class InFlight(FlightState, ABC):
@property
def spawn_type(self) -> StartType:
return StartType.IN_FLIGHT
@property
def description(self) -> str:
if self.has_aborted:
abort = "(Aborted) "
else:
abort = ""
return f"{abort}Flying to {self.next_waypoint.name}"

View File

@ -1,46 +0,0 @@
from __future__ import annotations
from datetime import timedelta
from typing import TYPE_CHECKING
from dcs import Point
from game.ato.flightstate import FlightState, InFlight
from game.ato.flightstate.navigating import Navigating
from game.utils import Distance, Speed
if TYPE_CHECKING:
from game.ato.flight import Flight
from game.settings import Settings
class Loiter(InFlight):
def __init__(self, flight: Flight, settings: Settings, waypoint_index: int) -> None:
assert flight.flight_plan.is_loiter(flight.flight_plan)
self.hold_duration = flight.flight_plan.hold_duration
super().__init__(flight, settings, waypoint_index)
def estimate_position(self) -> Point:
return self.current_waypoint.position
def estimate_altitude(self) -> tuple[Distance, str]:
return self.current_waypoint.alt, self.current_waypoint.alt_type
def estimate_speed(self) -> Speed:
return self.flight.unit_type.preferred_patrol_speed(self.estimate_altitude()[0])
def estimate_fuel(self) -> float:
# TODO: Estimate loiter consumption per minute?
return self.estimate_fuel_at_current_waypoint()
def next_waypoint_state(self) -> FlightState:
# Do not automatically advance to the next waypoint. Just proceed from the
# current one with the normal flying state.
return Navigating(self.flight, self.settings, self.waypoint_index)
def travel_time_between_waypoints(self) -> timedelta:
return self.hold_duration
@property
def description(self) -> str:
return f"Loitering for {self.hold_duration - self.elapsed_time}"

View File

@ -80,3 +80,14 @@ class Navigating(InFlight):
@property
def spawn_type(self) -> StartType:
return StartType.IN_FLIGHT
@property
def description(self) -> str:
if (action := self.current_action) is not None:
return action.describe()
if self.has_aborted:
abort = "(Aborted) "
else:
abort = ""
return f"{abort}Flying to {self.next_waypoint.name}"

View File

@ -7,6 +7,7 @@ from typing import Literal, TYPE_CHECKING
from dcs import Point
from game.ato.flightwaypointtype import FlightWaypointType
from game.flightplan.waypointactions.waypointaction import WaypointAction
from game.theater.theatergroup import TheaterUnit
from game.utils import Distance, meters
@ -39,6 +40,8 @@ class FlightWaypoint:
# The minimum amount of fuel remaining at this waypoint in pounds.
min_fuel: float | None = None
actions: list[WaypointAction] = field(default_factory=list)
# These are set very late by the air conflict generator (part of mission
# generation). We do it late so that we don't need to propagate changes
# to waypoint times whenever the player alters the package TOT or the
@ -46,6 +49,9 @@ class FlightWaypoint:
tot: datetime | None = None
departure_time: datetime | None = None
def add_action(self, action: WaypointAction) -> None:
self.actions.append(action)
@property
def x(self) -> float:
return self.position.x

View File

@ -0,0 +1,54 @@
from collections.abc import Iterator
from datetime import datetime, timedelta
from dcs.task import Task, OrbitAction, ControlledTask
from game.ato.flightstate.actionstate import ActionState
from game.provider import Provider
from game.utils import Distance, Speed
from .taskcontext import TaskContext
from .waypointaction import WaypointAction
class Hold(WaypointAction):
"""Loiter at a location until a push time to synchronize with other flights.
Taxi behavior is extremely unpredictable, so we cannot reliably predict ETAs for
waypoints without first fixing a time for one waypoint by holding until a sync time.
This is typically done with a dedicated hold point. If the flight reaches the hold
point before their push time, they will loiter at that location rather than fly to
their next waypoint as a speed that's often dangerously slow.
"""
def __init__(
self, push_time_provider: Provider[datetime], altitude: Distance, speed: Speed
) -> None:
self._push_time_provider = push_time_provider
self._altitude = altitude
self._speed = speed
def describe(self) -> str:
return self._push_time_provider().strftime("Holding until %H:%M:%S")
def update_state(
self, state: ActionState, time: datetime, duration: timedelta
) -> None:
if self._push_time_provider() <= time:
state.finish()
def iter_tasks(self, ctx: TaskContext) -> Iterator[Task]:
remaining_time = self._push_time_provider() - ctx.mission_start_time
if remaining_time <= timedelta():
return
loiter = ControlledTask(
OrbitAction(
altitude=int(self._altitude.meters),
pattern=OrbitAction.OrbitPattern.Circle,
speed=self._speed.kph,
)
)
# The DCS task is serialized using the time from mission start, not the actual
# time.
loiter.stop_after_time(int(remaining_time.total_seconds()))
yield loiter

View File

@ -0,0 +1,7 @@
from dataclasses import dataclass
from datetime import datetime
@dataclass(frozen=True)
class TaskContext:
mission_start_time: datetime

View File

@ -0,0 +1,29 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterator
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from dcs.task import Task
from .taskcontext import TaskContext
if TYPE_CHECKING:
from game.ato.flightstate.actionstate import ActionState
class WaypointAction(ABC):
@abstractmethod
def describe(self) -> str:
...
@abstractmethod
def update_state(
self, state: ActionState, time: datetime, duration: timedelta
) -> None:
...
@abstractmethod
def iter_tasks(self, ctx: TaskContext) -> Iterator[Task]:
...

View File

@ -161,6 +161,8 @@ class FlightGroupSpawner:
)
group.points[0].alt_type = alt_type
# We don't need to add waypoint tasks for the starting waypoint here because we
# do that in WaypointGenerator.
return group
def _generate_at_airport(self, name: str, airport: Airport) -> FlyingGroup[Any]:

View File

@ -1,27 +0,0 @@
import logging
from dcs.point import MovingPoint
from dcs.task import ControlledTask, OptFormation, OrbitAction
from game.ato.flightplans.loiter import LoiterFlightPlan
from .pydcswaypointbuilder import PydcsWaypointBuilder
class HoldPointBuilder(PydcsWaypointBuilder):
def add_tasks(self, waypoint: MovingPoint) -> None:
loiter = ControlledTask(
OrbitAction(altitude=waypoint.alt, pattern=OrbitAction.OrbitPattern.Circle)
)
if not isinstance(self.flight.flight_plan, LoiterFlightPlan):
flight_plan_type = self.flight.flight_plan.__class__.__name__
logging.error(
f"Cannot configure hold for for {self.flight} because "
f"{flight_plan_type} does not define a push time. AI will push "
"immediately and may flight unsuitable speeds."
)
return
push_time = self.flight.flight_plan.push_time
self.waypoint.departure_time = push_time
loiter.stop_after_time(int((push_time - self.now).total_seconds()))
waypoint.add_task(loiter)
waypoint.add_task(OptFormation.finger_four_close())

View File

@ -11,6 +11,7 @@ from dcs.unitgroup import FlyingGroup
from game.ato import Flight, FlightWaypoint
from game.ato.flightwaypointtype import FlightWaypointType
from game.ato.traveltime import GroundSpeed
from game.flightplan.waypointactions.taskcontext import TaskContext
from game.missiongenerator.missiondata import MissionData
from game.theater import MissionTarget, TheaterUnit
from game.unitmap import UnitMap
@ -86,7 +87,10 @@ class PydcsWaypointBuilder:
return waypoint
def add_tasks(self, waypoint: MovingPoint) -> None:
pass
ctx = TaskContext(self.now)
for action in self.waypoint.actions:
for task in action.iter_tasks(ctx):
waypoint.add_task(task)
def set_waypoint_tot(self, waypoint: MovingPoint, tot: datetime) -> None:
self.waypoint.tot = tot

View File

@ -29,7 +29,6 @@ from .baiingress import BaiIngressBuilder
from .casingress import CasIngressBuilder
from .deadingress import DeadIngressBuilder
from .default import DefaultWaypointBuilder
from .holdpoint import HoldPointBuilder
from .joinpoint import JoinPointBuilder
from .landingpoint import LandingPointBuilder
from .landingzone import LandingZoneBuilder
@ -142,7 +141,6 @@ class WaypointGenerator:
FlightWaypointType.INGRESS_SWEEP: SweepIngressBuilder,
FlightWaypointType.JOIN: JoinPointBuilder,
FlightWaypointType.LANDING_POINT: LandingPointBuilder,
FlightWaypointType.LOITER: HoldPointBuilder,
FlightWaypointType.PATROL: RaceTrackEndBuilder,
FlightWaypointType.PATROL_TRACK: RaceTrackBuilder,
FlightWaypointType.PICKUP_ZONE: LandingZoneBuilder,

4
game/provider.py Normal file
View File

@ -0,0 +1,4 @@
from typing import TypeAlias, TypeVar, Callable
T = TypeVar("T")
Provider: TypeAlias = Callable[[], T]

0
tests/ato/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,29 @@
from collections.abc import Iterator
from datetime import timedelta, datetime
from dcs.task import Task
from game.ato.flightstate.actionstate import ActionState
from game.flightplan.waypointactions.taskcontext import TaskContext
from game.flightplan.waypointactions.waypointaction import WaypointAction
class TestAction(WaypointAction):
def describe(self) -> str:
return ""
def update_state(
self, state: ActionState, time: datetime, duration: timedelta
) -> None:
pass
def iter_tasks(self, ctx: TaskContext) -> Iterator[Task]:
yield from []
def test_actionstate() -> None:
action = TestAction()
state = ActionState(action)
assert not state.is_finished()
state.finish()
assert state.is_finished()

View File

@ -0,0 +1,64 @@
from datetime import datetime, timedelta
from game.ato.flightstate.actionstate import ActionState
from game.flightplan.waypointactions.hold import Hold
from game.flightplan.waypointactions.taskcontext import TaskContext
from game.utils import meters, kph
def test_hold_tasks() -> None:
t0 = datetime(1999, 3, 28)
tasks = list(
Hold(lambda: t0 + timedelta(minutes=5), meters(8000), kph(400)).iter_tasks(
TaskContext(t0 + timedelta(minutes=1))
)
)
assert len(tasks) == 1
task = tasks[0]
assert task.id == "ControlledTask"
assert task.params["stopCondition"]["time"] == 4 * 60
assert task.params["task"]["id"] == "Orbit"
assert task.params["task"]["params"]["altitude"] == 8000
assert task.params["task"]["params"]["pattern"] == "Circle"
assert task.params["task"]["params"]["speed"] == kph(400).meters_per_second
def test_hold_task_at_or_after_push() -> None:
t0 = datetime(1999, 3, 28)
assert not list(
Hold(lambda: t0, meters(8000), kph(400)).iter_tasks(TaskContext(t0))
)
assert not list(
Hold(lambda: t0, meters(8000), kph(400)).iter_tasks(
TaskContext(t0 + timedelta(minutes=1))
)
)
def test_hold_tick() -> None:
t0 = datetime(1999, 3, 28)
task = Hold(lambda: t0 + timedelta(minutes=5), meters(8000), kph(400))
state = ActionState(task)
task.update_state(state, t0, timedelta())
assert not state.is_finished()
task.update_state(state, t0 + timedelta(minutes=1), timedelta(minutes=1))
assert not state.is_finished()
task.update_state(state, t0 + timedelta(minutes=2), timedelta(minutes=1))
assert not state.is_finished()
task.update_state(state, t0 + timedelta(minutes=3), timedelta(minutes=1))
assert not state.is_finished()
task.update_state(state, t0 + timedelta(minutes=4), timedelta(minutes=1))
assert not state.is_finished()
task.update_state(state, t0 + timedelta(minutes=5), timedelta(minutes=1))
assert state.is_finished()
task.update_state(state, t0 + timedelta(minutes=6), timedelta(minutes=1))
assert state.is_finished()
def test_hold_description() -> None:
assert (
Hold(
lambda: datetime(1999, 3, 28) + timedelta(minutes=5), meters(8000), kph(400)
).describe()
== "Holding until 00:05:00"
)