diff --git a/game/dcs/aircrafttype.py b/game/dcs/aircrafttype.py index 511c9394..92e39f1c 100644 --- a/game/dcs/aircrafttype.py +++ b/game/dcs/aircrafttype.py @@ -1,10 +1,11 @@ from __future__ import annotations import logging +from collections import defaultdict from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Type +from typing import Any, ClassVar, Dict, Iterator, Optional, TYPE_CHECKING, Type import yaml from dcs.helicopters import helicopter_map @@ -197,6 +198,16 @@ class AircraftType(UnitType[Type[FlyingType]]): # when no TGP is mounted on any station. has_built_in_target_pod: bool + _by_name: ClassVar[dict[str, AircraftType]] = {} + _by_unit_type: ClassVar[dict[type[FlyingType], list[AircraftType]]] = defaultdict( + list + ) + + @classmethod + def register(cls, unit_type: AircraftType) -> None: + cls._by_name[unit_type.name] = unit_type + cls._by_unit_type[unit_type.dcs_unit_type].append(unit_type) + @property def flyable(self) -> bool: return self.dcs_unit_type.flyable @@ -321,17 +332,13 @@ class AircraftType(UnitType[Type[FlyingType]]): def named(cls, name: str) -> AircraftType: if not cls._loaded: cls._load_all() - unit = cls._by_name[name] - assert isinstance(unit, AircraftType) - return unit + return cls._by_name[name] @classmethod def for_dcs_type(cls, dcs_unit_type: Type[FlyingType]) -> Iterator[AircraftType]: if not cls._loaded: cls._load_all() - for unit in cls._by_unit_type[dcs_unit_type]: - assert isinstance(unit, AircraftType) - yield unit + yield from cls._by_unit_type[dcs_unit_type] @staticmethod def each_dcs_type() -> Iterator[Type[FlyingType]]: diff --git a/game/dcs/groundunittype.py b/game/dcs/groundunittype.py index 100c4546..1030ba14 100644 --- a/game/dcs/groundunittype.py +++ b/game/dcs/groundunittype.py @@ -1,9 +1,10 @@ from __future__ import annotations import logging +from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterator, Optional, Type +from typing import Any, ClassVar, Iterator, Optional, Type import yaml from dcs.unittype import VehicleType @@ -59,21 +60,27 @@ class GroundUnitType(UnitType[Type[VehicleType]]): # Some units like few Launchers have to be placed backwards to be able to fire. reversed_heading: bool = False + _by_name: ClassVar[dict[str, GroundUnitType]] = {} + _by_unit_type: ClassVar[ + dict[type[VehicleType], list[GroundUnitType]] + ] = defaultdict(list) + + @classmethod + def register(cls, unit_type: GroundUnitType) -> None: + cls._by_name[unit_type.name] = unit_type + cls._by_unit_type[unit_type.dcs_unit_type].append(unit_type) + @classmethod def named(cls, name: str) -> GroundUnitType: if not cls._loaded: cls._load_all() - unit = cls._by_name[name] - assert isinstance(unit, GroundUnitType) - return unit + return cls._by_name[name] @classmethod def for_dcs_type(cls, dcs_unit_type: Type[VehicleType]) -> Iterator[GroundUnitType]: if not cls._loaded: cls._load_all() - for unit in cls._by_unit_type[dcs_unit_type]: - assert isinstance(unit, GroundUnitType) - yield unit + yield from cls._by_unit_type[dcs_unit_type] @staticmethod def each_dcs_type() -> Iterator[Type[VehicleType]]: diff --git a/game/dcs/shipunittype.py b/game/dcs/shipunittype.py index 0b34f75b..cf36d1cd 100644 --- a/game/dcs/shipunittype.py +++ b/game/dcs/shipunittype.py @@ -1,9 +1,10 @@ from __future__ import annotations import logging +from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import Iterator, Type +from typing import ClassVar, Iterator, Type import yaml from dcs.ships import ship_map @@ -15,21 +16,27 @@ from game.dcs.unittype import UnitType @dataclass(frozen=True) class ShipUnitType(UnitType[Type[ShipType]]): + _by_name: ClassVar[dict[str, ShipUnitType]] = {} + _by_unit_type: ClassVar[dict[type[ShipType], list[ShipUnitType]]] = defaultdict( + list + ) + + @classmethod + def register(cls, unit_type: ShipUnitType) -> None: + cls._by_name[unit_type.name] = unit_type + cls._by_unit_type[unit_type.dcs_unit_type].append(unit_type) + @classmethod def named(cls, name: str) -> ShipUnitType: if not cls._loaded: cls._load_all() - unit = cls._by_name[name] - assert isinstance(unit, ShipUnitType) - return unit + return cls._by_name[name] @classmethod def for_dcs_type(cls, dcs_unit_type: Type[ShipType]) -> Iterator[ShipUnitType]: if not cls._loaded: cls._load_all() - for unit in cls._by_unit_type[dcs_unit_type]: - assert isinstance(unit, ShipUnitType) - yield unit + yield from cls._by_unit_type[dcs_unit_type] @staticmethod def each_dcs_type() -> Iterator[Type[ShipType]]: diff --git a/game/dcs/unittype.py b/game/dcs/unittype.py index 8660d871..c0678278 100644 --- a/game/dcs/unittype.py +++ b/game/dcs/unittype.py @@ -1,10 +1,9 @@ from __future__ import annotations from abc import ABC -from collections import defaultdict from dataclasses import dataclass from functools import cached_property -from typing import Any, ClassVar, Generic, Iterator, Type, TypeVar +from typing import ClassVar, Generic, Iterator, Self, Type, TypeVar from dcs.unittype import UnitType as DcsUnitType @@ -25,10 +24,6 @@ class UnitType(ABC, Generic[DcsUnitTypeT]): price: int unit_class: UnitClass - _by_name: ClassVar[dict[str, UnitType[Any]]] = {} - _by_unit_type: ClassVar[dict[Type[DcsUnitType], list[UnitType[Any]]]] = defaultdict( - list - ) _loaded: ClassVar[bool] = False def __str__(self) -> str: @@ -39,16 +34,15 @@ class UnitType(ABC, Generic[DcsUnitTypeT]): return self.dcs_unit_type.id @classmethod - def register(cls, unit_type: UnitType[Any]) -> None: - cls._by_name[unit_type.name] = unit_type - cls._by_unit_type[unit_type.dcs_unit_type].append(unit_type) - - @classmethod - def named(cls, name: str) -> UnitType[Any]: + def register(cls, unit_type: Self) -> None: raise NotImplementedError @classmethod - def for_dcs_type(cls, dcs_unit_type: DcsUnitTypeT) -> Iterator[UnitType[Any]]: + def named(cls, name: str) -> Self: + raise NotImplementedError + + @classmethod + def for_dcs_type(cls, dcs_unit_type: DcsUnitTypeT) -> Iterator[Self]: raise NotImplementedError @staticmethod @@ -56,7 +50,7 @@ class UnitType(ABC, Generic[DcsUnitTypeT]): raise NotImplementedError @classmethod - def _each_variant_of(cls, unit: DcsUnitTypeT) -> Iterator[UnitType[Any]]: + def _each_variant_of(cls, unit: DcsUnitTypeT) -> Iterator[Self]: raise NotImplementedError @classmethod