From 4631ee0d74657d824963a0843d7ad9908afda428 Mon Sep 17 00:00:00 2001 From: zhexu14 <64713351+zhexu14@users.noreply.github.com> Date: Mon, 18 Dec 2023 13:42:31 +1100 Subject: [PATCH] Doctrine load from YAML (#3291) This PR refactors the Doctrine class to load from YAML files in the resources folder instead of being hardcoded as a step towards making doctrines moddable (Issue #829). I haven't added anything to the changelog as a couple of things should get cleaned up first: - As far as I can tell, the flags in the Doctrine class (cap, cas, sead etc.) aren't used anywhere. Need to test further, and if they're truly not used, will remove them. - Probably need to update the Wiki --- game/data/doctrine.py | 202 ++++++++---------- game/factions/faction.py | 18 +- game/flightplan/waypointsolverloader.py | 7 +- resources/doctrines/coldwar.yaml | 32 +++ resources/doctrines/modern.yaml | 32 +++ resources/doctrines/ww2.yaml | 31 +++ tests/data/test_doctrine.py | 43 ++++ tests/flightplan/test_ipsolver.py | 6 +- .../waypointsolvertestcasereducer.py | 7 +- 9 files changed, 235 insertions(+), 143 deletions(-) create mode 100644 resources/doctrines/coldwar.yaml create mode 100644 resources/doctrines/modern.yaml create mode 100644 resources/doctrines/ww2.yaml create mode 100644 tests/data/test_doctrine.py diff --git a/game/data/doctrine.py b/game/data/doctrine.py index 4e74d6e0..8a260984 100644 --- a/game/data/doctrine.py +++ b/game/data/doctrine.py @@ -1,3 +1,9 @@ +from __future__ import annotations + +from pathlib import Path +import yaml +from typing import ClassVar + from dataclasses import dataclass from datetime import timedelta @@ -15,6 +21,16 @@ class GroundUnitProcurementRatios: except KeyError: return 0.0 + @staticmethod + def from_dict(data: dict[str, float]) -> GroundUnitProcurementRatios: + unit_class_enum_from_name = {unit.value: unit for unit in UnitClass} + r = {} + for unit_class in data: + if unit_class not in unit_class_enum_from_name: + raise ValueError(f"Could not find unit type {unit_class}") + r[unit_class_enum_from_name[unit_class]] = float(data[unit_class]) + return GroundUnitProcurementRatios(r) + @dataclass(frozen=True) class Doctrine: @@ -79,122 +95,78 @@ class Doctrine: ground_unit_procurement_ratios: GroundUnitProcurementRatios + _by_name: ClassVar[dict[str, Doctrine]] = {} + _loaded: ClassVar[bool] = False -MODERN_DOCTRINE = Doctrine( - "modern", - cap=True, - cas=True, - sead=True, - strike=True, - antiship=True, - rendezvous_altitude=feet(25000), - hold_distance=nautical_miles(25), - push_distance=nautical_miles(20), - join_distance=nautical_miles(20), - max_ingress_distance=nautical_miles(45), - min_ingress_distance=nautical_miles(10), - ingress_altitude=feet(20000), - min_patrol_altitude=feet(15000), - max_patrol_altitude=feet(33000), - pattern_altitude=feet(5000), - cap_duration=timedelta(minutes=30), - cap_min_track_length=nautical_miles(15), - cap_max_track_length=nautical_miles(40), - cap_min_distance_from_cp=nautical_miles(10), - cap_max_distance_from_cp=nautical_miles(40), - cap_engagement_range=nautical_miles(50), - cas_duration=timedelta(minutes=30), - sweep_distance=nautical_miles(60), - ground_unit_procurement_ratios=GroundUnitProcurementRatios( - { - UnitClass.TANK: 3, - UnitClass.ATGM: 2, - UnitClass.APC: 2, - UnitClass.IFV: 3, - UnitClass.ARTILLERY: 1, - UnitClass.SHORAD: 2, - UnitClass.RECON: 1, - } - ), -) + @classmethod + def register(cls, doctrine: Doctrine) -> None: + if doctrine.name in cls._by_name: + duplicate = cls._by_name[doctrine.name] + raise ValueError(f"Doctrine {doctrine.name} is already loaded") + cls._by_name[doctrine.name] = doctrine -COLDWAR_DOCTRINE = Doctrine( - name="coldwar", - cap=True, - cas=True, - sead=True, - strike=True, - antiship=True, - rendezvous_altitude=feet(22000), - hold_distance=nautical_miles(15), - push_distance=nautical_miles(10), - join_distance=nautical_miles(10), - max_ingress_distance=nautical_miles(30), - min_ingress_distance=nautical_miles(10), - ingress_altitude=feet(18000), - min_patrol_altitude=feet(10000), - max_patrol_altitude=feet(24000), - pattern_altitude=feet(5000), - cap_duration=timedelta(minutes=30), - cap_min_track_length=nautical_miles(12), - cap_max_track_length=nautical_miles(24), - cap_min_distance_from_cp=nautical_miles(8), - cap_max_distance_from_cp=nautical_miles(25), - cap_engagement_range=nautical_miles(35), - cas_duration=timedelta(minutes=30), - sweep_distance=nautical_miles(40), - ground_unit_procurement_ratios=GroundUnitProcurementRatios( - { - UnitClass.TANK: 4, - UnitClass.ATGM: 2, - UnitClass.APC: 3, - UnitClass.IFV: 2, - UnitClass.ARTILLERY: 1, - UnitClass.SHORAD: 2, - UnitClass.RECON: 1, - } - ), -) + @classmethod + def named(cls, name: str) -> Doctrine: + if not cls._loaded: + cls.load_all() + return cls._by_name[name] -WWII_DOCTRINE = Doctrine( - name="ww2", - cap=True, - cas=True, - sead=False, - strike=True, - antiship=True, - hold_distance=nautical_miles(10), - push_distance=nautical_miles(5), - join_distance=nautical_miles(5), - rendezvous_altitude=feet(10000), - max_ingress_distance=nautical_miles(7), - min_ingress_distance=nautical_miles(5), - ingress_altitude=feet(8000), - min_patrol_altitude=feet(4000), - max_patrol_altitude=feet(15000), - pattern_altitude=feet(5000), - cap_duration=timedelta(minutes=30), - cap_min_track_length=nautical_miles(8), - cap_max_track_length=nautical_miles(18), - cap_min_distance_from_cp=nautical_miles(0), - cap_max_distance_from_cp=nautical_miles(5), - cap_engagement_range=nautical_miles(20), - cas_duration=timedelta(minutes=30), - sweep_distance=nautical_miles(10), - ground_unit_procurement_ratios=GroundUnitProcurementRatios( - { - UnitClass.TANK: 3, - UnitClass.ATGM: 3, - UnitClass.APC: 3, - UnitClass.ARTILLERY: 1, - UnitClass.SHORAD: 3, - UnitClass.RECON: 1, - } - ), -) + @classmethod + def all_doctrines(cls) -> list[Doctrine]: + if not cls._loaded: + cls.load_all() + return list(cls._by_name.values()) -ALL_DOCTRINES = [ - COLDWAR_DOCTRINE, - MODERN_DOCTRINE, - WWII_DOCTRINE, -] + @classmethod + def load_all(cls) -> None: + if cls._loaded: + return + for doctrine_file_path in Path("resources/doctrines").glob("**/*.yaml"): + with doctrine_file_path.open(encoding="utf8") as doctrine_file: + data = yaml.safe_load(doctrine_file) + cls.register( + Doctrine( + name=data["name"], + cap=data["cap"], + cas=data["cas"], + sead=data["sead"], + strike=data["strike"], + antiship=data["antiship"], + rendezvous_altitude=feet(data["rendezvous_altitude_ft_msl"]), + hold_distance=nautical_miles(data["hold_distance_nm"]), + push_distance=nautical_miles(data["push_distance_nm"]), + join_distance=nautical_miles(data["join_distance_nm"]), + max_ingress_distance=nautical_miles( + data["max_ingress_distance_nm"] + ), + min_ingress_distance=nautical_miles( + data["min_ingress_distance_nm"] + ), + ingress_altitude=feet(data["ingress_altitude_ft_msl"]), + min_patrol_altitude=feet(data["min_patrol_altitude_ft_msl"]), + max_patrol_altitude=feet(data["max_patrol_altitude_ft_msl"]), + pattern_altitude=feet(data["pattern_altitude_ft_msl"]), + cap_duration=timedelta(minutes=data["cap_duration_minutes"]), + cap_min_track_length=nautical_miles( + data["cap_min_track_length_nm"] + ), + cap_max_track_length=nautical_miles( + data["cap_max_track_length_nm"] + ), + cap_min_distance_from_cp=nautical_miles( + data["cap_min_distance_from_cp_nm"] + ), + cap_max_distance_from_cp=nautical_miles( + data["cap_max_distance_from_cp_nm"] + ), + cap_engagement_range=nautical_miles( + data["cap_engagement_range_nm"] + ), + cas_duration=timedelta(minutes=data["cas_duration_minutes"]), + sweep_distance=nautical_miles(data["sweep_distance_nm"]), + ground_unit_procurement_ratios=GroundUnitProcurementRatios.from_dict( + data["ground_unit_procurement_ratios"] + ), + ) + ) + cls._loaded = True diff --git a/game/factions/faction.py b/game/factions/faction.py index 3a408d26..a7079595 100644 --- a/game/factions/faction.py +++ b/game/factions/faction.py @@ -19,12 +19,7 @@ from game.data.building_data import ( WW2_FREE, WW2_GERMANY_BUILDINGS, ) -from game.data.doctrine import ( - COLDWAR_DOCTRINE, - Doctrine, - MODERN_DOCTRINE, - WWII_DOCTRINE, -) +from game.data.doctrine import Doctrine from game.data.groups import GroupRole from game.data.units import UnitClass from game.dcs.aircrafttype import AircraftType @@ -106,7 +101,7 @@ class Faction: jtac_unit: Optional[AircraftType] = field(default=None) # doctrine - doctrine: Doctrine = field(default=MODERN_DOCTRINE) + doctrine: Doctrine = field(default=Doctrine.named("modern")) # List of available building layouts for this faction building_set: List[str] = field(default_factory=list) @@ -238,14 +233,7 @@ class Faction: # Load doctrine doctrine = json.get("doctrine", "modern") - if doctrine == "modern": - faction.doctrine = MODERN_DOCTRINE - elif doctrine == "coldwar": - faction.doctrine = COLDWAR_DOCTRINE - elif doctrine == "ww2": - faction.doctrine = WWII_DOCTRINE - else: - faction.doctrine = MODERN_DOCTRINE + faction.doctrine = Doctrine.named(doctrine) # Load the building set faction.building_set = [] diff --git a/game/flightplan/waypointsolverloader.py b/game/flightplan/waypointsolverloader.py index 83852b65..530724f7 100644 --- a/game/flightplan/waypointsolverloader.py +++ b/game/flightplan/waypointsolverloader.py @@ -11,17 +11,14 @@ from shapely import transform from shapely.geometry import shape from shapely.geometry.base import BaseGeometry -from game.data.doctrine import Doctrine, ALL_DOCTRINES +from game.data.doctrine import Doctrine from .ipsolver import IpSolver from .waypointsolver import WaypointSolver from ..theater.theaterloader import TERRAINS_BY_NAME def doctrine_from_name(name: str) -> Doctrine: - for doctrine in ALL_DOCTRINES: - if doctrine.name == name: - return doctrine - raise KeyError + return Doctrine.named(name) def geometry_ll_to_xy(geometry: BaseGeometry, terrain: Terrain) -> BaseGeometry: diff --git a/resources/doctrines/coldwar.yaml b/resources/doctrines/coldwar.yaml new file mode 100644 index 00000000..be1a9fce --- /dev/null +++ b/resources/doctrines/coldwar.yaml @@ -0,0 +1,32 @@ +name: coldwar +cap: true +cas: true +sead: true +strike: true +antiship: true +rendezvous_altitude_ft_msl: 22000 +hold_distance_nm: 15 +push_distance_nm: 10 +join_distance_nm: 10 +max_ingress_distance_nm: 30 +min_ingress_distance_nm: 10 +ingress_altitude_ft_msl: 18000 +min_patrol_altitude_ft_msl: 10000 +max_patrol_altitude_ft_msl: 24000 +pattern_altitude_ft_msl: 5000 +cap_duration_minutes: 30 +cap_min_track_length_nm: 12 +cap_max_track_length_nm: 24 +cap_min_distance_from_cp_nm: 8 +cap_max_distance_from_cp_nm: 25 +cap_engagement_range_nm: 35 +cas_duration_minutes: 30 +sweep_distance_nm: 40 +ground_unit_procurement_ratios: + Tank: 4 + ATGM: 2 + APC: 3 + IFV: 2 + Artillery: 1 + SHORAD: 2 + Recon: 1 diff --git a/resources/doctrines/modern.yaml b/resources/doctrines/modern.yaml new file mode 100644 index 00000000..c2b4a314 --- /dev/null +++ b/resources/doctrines/modern.yaml @@ -0,0 +1,32 @@ +name: modern +cap: true +cas: true +sead: true +strike: true +antiship: true +rendezvous_altitude_ft_msl: 25000 +hold_distance_nm: 25 +push_distance_nm: 20 +join_distance_nm: 20 +max_ingress_distance_nm: 45 +min_ingress_distance_nm: 10 +ingress_altitude_ft_msl: 20000 +min_patrol_altitude_ft_msl: 15000 +max_patrol_altitude_ft_msl: 33000 +pattern_altitude_ft_msl: 5000 +cap_duration_minutes: 30 +cap_min_track_length_nm: 15 +cap_max_track_length_nm: 40 +cap_min_distance_from_cp_nm: 10 +cap_max_distance_from_cp_nm: 40 +cap_engagement_range_nm: 50 +cas_duration_minutes: 30 +sweep_distance_nm: 60 +ground_unit_procurement_ratios: + Tank: 3 + ATGM: 2 + APC: 2 + IFV: 3 + Artillery: 1 + SHORAD: 2 + Recon: 1 diff --git a/resources/doctrines/ww2.yaml b/resources/doctrines/ww2.yaml new file mode 100644 index 00000000..20549444 --- /dev/null +++ b/resources/doctrines/ww2.yaml @@ -0,0 +1,31 @@ +name: ww2 +cap: true +cas: true +sead: false +strike: true +antiship: true +hold_distance_nm: 10 +push_distance_nm: 5 +join_distance_nm: 5 +rendezvous_altitude_ft_msl: 10000 +max_ingress_distance_nm: 7 +min_ingress_distance_nm: 5 +ingress_altitude_ft_msl: 8000 +min_patrol_altitude_ft_msl: 4000 +max_patrol_altitude_ft_msl: 15000 +pattern_altitude_ft_msl: 5000 +cap_duration_minutes: 30 +cap_min_track_length_nm: 8 +cap_max_track_length_nm: 18 +cap_min_distance_from_cp_nm: 0 +cap_max_distance_from_cp_nm: 5 +cap_engagement_range_nm: 20 +cas_duration_minutes: 30 +sweep_distance_nm: 10 +ground_unit_procurement_ratios: + Tank: 3 + ATGM: 3 + APC: 3 + Artillery: 1 + SHORAD: 3 + Recon: 1 diff --git a/tests/data/test_doctrine.py b/tests/data/test_doctrine.py new file mode 100644 index 00000000..bd5c0c74 --- /dev/null +++ b/tests/data/test_doctrine.py @@ -0,0 +1,43 @@ +import pytest + +from game.data.doctrine import Doctrine, GroundUnitProcurementRatios +from game.data.units import UnitClass + + +def test_ground_unit_procurement_ratios_empty() -> None: + r = GroundUnitProcurementRatios({}) + for unit_class in UnitClass: + assert r.for_unit_class(unit_class) == 0.0 + + +def test_ground_unit_procurement_ratios_single_item() -> None: + r = GroundUnitProcurementRatios({UnitClass.TANK: 1}) + for unit_class in UnitClass: + if unit_class == UnitClass.TANK: + assert r.for_unit_class(unit_class) == 1.0 + else: + assert r.for_unit_class(unit_class) == 0.0 + + +def test_ground_unit_procurement_ratios_multiple_items() -> None: + r = GroundUnitProcurementRatios({UnitClass.TANK: 1, UnitClass.ATGM: 1}) + for unit_class in UnitClass: + if unit_class in [UnitClass.TANK, UnitClass.ATGM]: + assert r.for_unit_class(unit_class) == 0.5 + else: + assert r.for_unit_class(unit_class) == 0.0 + + +def test_ground_unit_procurement_ratios_from_dict() -> None: + r = GroundUnitProcurementRatios.from_dict({"Tank": 1, "ATGM": 1}) + for unit_class in UnitClass: + if unit_class in [UnitClass.TANK, UnitClass.ATGM]: + assert r.for_unit_class(unit_class) == 0.5 + else: + assert r.for_unit_class(unit_class) == 0.0 + + +def test_doctrine() -> None: + # This test checks for the presence of a doctrine named "modern" as this doctrine is used as a default + modern_doctrine = Doctrine.named("modern") + assert modern_doctrine.name == "modern" diff --git a/tests/flightplan/test_ipsolver.py b/tests/flightplan/test_ipsolver.py index ec03c9a4..d1acc91a 100644 --- a/tests/flightplan/test_ipsolver.py +++ b/tests/flightplan/test_ipsolver.py @@ -8,7 +8,7 @@ import pytest from dcs.terrain import Caucasus from shapely import Point, MultiPolygon, Polygon, unary_union -from game.data.doctrine import ALL_DOCTRINES +from game.data.doctrine import Doctrine from game.flightplan.ipsolver import IpSolver from game.flightplan.waypointsolver import NoSolutionsError from game.flightplan.waypointstrategy import point_at_heading @@ -70,7 +70,7 @@ def fuzzed_solver_fixture( departure = Point(0, 0) target = point_at_heading(departure, target_heading, fuzzed_target_distance) solver = IpSolver( - departure, target, random.choice(ALL_DOCTRINES), fuzzed_threat_poly + departure, target, random.choice(Doctrine.all_doctrines()), fuzzed_threat_poly ) solver.set_debug_properties(tmp_path, Caucasus()) return solver @@ -98,4 +98,4 @@ def test_fuzz_ipsolver(fuzzed_solver: IpSolver, run_number: int) -> None: def test_can_construct_solver_with_empty_threat() -> None: - IpSolver(Point(0, 0), Point(0, 0), ALL_DOCTRINES[0], MultiPolygon([])) + IpSolver(Point(0, 0), Point(0, 0), Doctrine.named("coldwar"), MultiPolygon([])) diff --git a/tests/flightplan/waypointsolvertestcasereducer.py b/tests/flightplan/waypointsolvertestcasereducer.py index 366ee86c..ab0c5ad5 100644 --- a/tests/flightplan/waypointsolvertestcasereducer.py +++ b/tests/flightplan/waypointsolvertestcasereducer.py @@ -8,7 +8,7 @@ from typing import Any, TypeVar, Generic from shapely import Point, MultiPolygon from shapely.geometry import shape -from game.data.doctrine import Doctrine, ALL_DOCTRINES +from game.data.doctrine import Doctrine from game.flightplan.ipsolver import IpSolver from game.flightplan.waypointsolver import WaypointSolver, NoSolutionsError from game.flightplan.waypointsolverloader import WaypointSolverLoader @@ -18,10 +18,7 @@ ReducerT = TypeVar("ReducerT") def doctrine_from_name(name: str) -> Doctrine: - for doctrine in ALL_DOCTRINES: - if doctrine.name == name: - return doctrine - raise KeyError + return Doctrine.named(name) class Reducer(Generic[ReducerT], Iterator[ReducerT], ABC):