diff --git a/game/ato/flightstate/actionstate.py b/game/ato/flightstate/actionstate.py index 15884c01..323ab709 100644 --- a/game/ato/flightstate/actionstate.py +++ b/game/ato/flightstate/actionstate.py @@ -21,5 +21,5 @@ class ActionState: def is_finished(self) -> bool: return self._finished - def on_game_tick(self, time: datetime, duration: timedelta) -> None: - self.action.update_state(self, time, duration) + def on_game_tick(self, time: datetime, duration: timedelta) -> timedelta: + return self.action.update_state(self, time, duration) diff --git a/game/ato/flightstate/inflight.py b/game/ato/flightstate/inflight.py index 3654255e..fc0be5e3 100644 --- a/game/ato/flightstate/inflight.py +++ b/game/ato/flightstate/inflight.py @@ -6,7 +6,6 @@ from datetime import datetime, timedelta from typing import TYPE_CHECKING from dcs import Point -from dcs.task import Task from game.ato.flightstate import Completed from game.ato.flightstate.actionstate import ActionState @@ -111,11 +110,12 @@ class InFlight(FlightState, ABC): def on_game_tick( self, events: GameUpdateEvents, time: datetime, duration: timedelta ) -> None: - if (action := self.current_action) is not None: - action.on_game_tick(time, duration) + while (action := self.current_action) is not None: + duration = action.on_game_tick(time, duration) if action.is_finished(): self.pending_actions.popleft() - return + if duration <= timedelta(): + return self.elapsed_time += duration if self.elapsed_time > self.total_time_to_next_waypoint: diff --git a/game/flightplan/waypointactions/hold.py b/game/flightplan/waypointactions/hold.py index 9b2775e8..71da9227 100644 --- a/game/flightplan/waypointactions/hold.py +++ b/game/flightplan/waypointactions/hold.py @@ -32,9 +32,12 @@ class Hold(WaypointAction): def update_state( self, state: ActionState, time: datetime, duration: timedelta - ) -> None: - if self._push_time_provider() <= time: + ) -> timedelta: + push_time = self._push_time_provider() + if push_time <= time: state.finish() + return time - push_time + return timedelta() def iter_tasks(self, ctx: TaskContext) -> Iterator[Task]: remaining_time = self._push_time_provider() - ctx.mission_start_time diff --git a/game/flightplan/waypointactions/waypointaction.py b/game/flightplan/waypointactions/waypointaction.py index 52095bf5..a82a8678 100644 --- a/game/flightplan/waypointactions/waypointaction.py +++ b/game/flightplan/waypointactions/waypointaction.py @@ -21,7 +21,7 @@ class WaypointAction(ABC): @abstractmethod def update_state( self, state: ActionState, time: datetime, duration: timedelta - ) -> None: + ) -> timedelta: ... @abstractmethod diff --git a/tests/ato/flightstate/test_actionstate.py b/tests/ato/flightstate/test_actionstate.py index aeb4348b..b62864e7 100644 --- a/tests/ato/flightstate/test_actionstate.py +++ b/tests/ato/flightstate/test_actionstate.py @@ -14,8 +14,8 @@ class TestAction(WaypointAction): def update_state( self, state: ActionState, time: datetime, duration: timedelta - ) -> None: - pass + ) -> timedelta: + return timedelta() def iter_tasks(self, ctx: TaskContext) -> Iterator[Task]: yield from [] diff --git a/tests/flightplan/waypointactions/test_hold.py b/tests/flightplan/waypointactions/test_hold.py index 3831569b..e5f86a1b 100644 --- a/tests/flightplan/waypointactions/test_hold.py +++ b/tests/flightplan/waypointactions/test_hold.py @@ -39,19 +39,21 @@ def test_hold_tick() -> None: t0 = datetime(1999, 3, 28) task = Hold(lambda: t0 + timedelta(minutes=5), meters(8000), kph(400)) state = ActionState(task) - task.update_state(state, t0, timedelta()) + assert not task.update_state(state, t0, timedelta()) assert not state.is_finished() - task.update_state(state, t0 + timedelta(minutes=1), timedelta(minutes=1)) + assert not task.update_state(state, t0 + timedelta(minutes=1), timedelta(minutes=1)) assert not state.is_finished() - task.update_state(state, t0 + timedelta(minutes=2), timedelta(minutes=1)) + assert not task.update_state(state, t0 + timedelta(minutes=2), timedelta(minutes=1)) assert not state.is_finished() - task.update_state(state, t0 + timedelta(minutes=3), timedelta(minutes=1)) + assert not task.update_state(state, t0 + timedelta(minutes=3), timedelta(minutes=1)) assert not state.is_finished() - task.update_state(state, t0 + timedelta(minutes=4), timedelta(minutes=1)) + assert not task.update_state(state, t0 + timedelta(minutes=4), timedelta(minutes=1)) assert not state.is_finished() - task.update_state(state, t0 + timedelta(minutes=5), timedelta(minutes=1)) + assert not task.update_state(state, t0 + timedelta(minutes=5), timedelta(minutes=1)) assert state.is_finished() - task.update_state(state, t0 + timedelta(minutes=6), timedelta(minutes=1)) + assert task.update_state( + state, t0 + timedelta(minutes=6), timedelta(minutes=1) + ) == timedelta(minutes=1) assert state.is_finished()