Fix merged classvars in UnitType descendants.

```
>>> class Foo:
...     bar = 0
...     @classmethod
...     def set_bar(cls, v):
...             cls.bar = v
...
>>> class Bar(Foo):
...     ...
...
>>> Bar.set_bar(1)
>>> Bar.bar
1
>>> Foo.bar
0
>>> class Foo:
...     bar = {}
...     @classmethod
...     def add(cls, k, v):
...             cls.bar[k] = v
...
>>> class Bar(Foo):
...     pass
...
>>> Bar.add(0, 1)
>>> Bar.bar
{0: 1}
>>> Foo.bar
{0: 1}
```

The collections are copied by reference into the descendants, whereas
_loaded is copied by value, so that one can stay. Before this patch,
every subtype was loading because _loaded was set per subclass, but they
were all registering with a common collection defined by UnitType rather
than their own class.
This commit is contained in:
Dan Albert 2023-04-26 22:36:00 -07:00
parent 06b74c4ca6
commit 2f2ebff674
4 changed files with 50 additions and 35 deletions

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Type from typing import Any, ClassVar, Dict, Iterator, Optional, TYPE_CHECKING, Type
import yaml import yaml
from dcs.helicopters import helicopter_map from dcs.helicopters import helicopter_map
@ -196,6 +197,16 @@ class AircraftType(UnitType[Type[FlyingType]]):
# will be set to true for helos by default # will be set to true for helos by default
can_carry_crates: bool can_carry_crates: bool
_by_name: ClassVar[dict[str, AircraftType]] = {}
_by_unit_type: ClassVar[dict[type[FlyingType], list[AircraftType]]] = defaultdict(
list
)
@classmethod
def register(cls, unit_type: AircraftType) -> None:
cls._by_name[unit_type.name] = unit_type
cls._by_unit_type[unit_type.dcs_unit_type].append(unit_type)
@property @property
def flyable(self) -> bool: def flyable(self) -> bool:
return self.dcs_unit_type.flyable return self.dcs_unit_type.flyable
@ -316,17 +327,13 @@ class AircraftType(UnitType[Type[FlyingType]]):
def named(cls, name: str) -> AircraftType: def named(cls, name: str) -> AircraftType:
if not cls._loaded: if not cls._loaded:
cls._load_all() cls._load_all()
unit = cls._by_name[name] return cls._by_name[name]
assert isinstance(unit, AircraftType)
return unit
@classmethod @classmethod
def for_dcs_type(cls, dcs_unit_type: Type[FlyingType]) -> Iterator[AircraftType]: def for_dcs_type(cls, dcs_unit_type: Type[FlyingType]) -> Iterator[AircraftType]:
if not cls._loaded: if not cls._loaded:
cls._load_all() cls._load_all()
for unit in cls._by_unit_type[dcs_unit_type]: yield from cls._by_unit_type[dcs_unit_type]
assert isinstance(unit, AircraftType)
yield unit
@staticmethod @staticmethod
def each_dcs_type() -> Iterator[Type[FlyingType]]: def each_dcs_type() -> Iterator[Type[FlyingType]]:

View File

@ -1,9 +1,10 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Iterator, Optional, Type from typing import Any, ClassVar, Iterator, Optional, Type
import yaml import yaml
from dcs.unittype import VehicleType from dcs.unittype import VehicleType
@ -59,21 +60,27 @@ class GroundUnitType(UnitType[Type[VehicleType]]):
# Some units like few Launchers have to be placed backwards to be able to fire. # Some units like few Launchers have to be placed backwards to be able to fire.
reversed_heading: bool = False reversed_heading: bool = False
_by_name: ClassVar[dict[str, GroundUnitType]] = {}
_by_unit_type: ClassVar[
dict[type[VehicleType], list[GroundUnitType]]
] = defaultdict(list)
@classmethod
def register(cls, unit_type: GroundUnitType) -> None:
cls._by_name[unit_type.name] = unit_type
cls._by_unit_type[unit_type.dcs_unit_type].append(unit_type)
@classmethod @classmethod
def named(cls, name: str) -> GroundUnitType: def named(cls, name: str) -> GroundUnitType:
if not cls._loaded: if not cls._loaded:
cls._load_all() cls._load_all()
unit = cls._by_name[name] return cls._by_name[name]
assert isinstance(unit, GroundUnitType)
return unit
@classmethod @classmethod
def for_dcs_type(cls, dcs_unit_type: Type[VehicleType]) -> Iterator[GroundUnitType]: def for_dcs_type(cls, dcs_unit_type: Type[VehicleType]) -> Iterator[GroundUnitType]:
if not cls._loaded: if not cls._loaded:
cls._load_all() cls._load_all()
for unit in cls._by_unit_type[dcs_unit_type]: yield from cls._by_unit_type[dcs_unit_type]
assert isinstance(unit, GroundUnitType)
yield unit
@staticmethod @staticmethod
def each_dcs_type() -> Iterator[Type[VehicleType]]: def each_dcs_type() -> Iterator[Type[VehicleType]]:

View File

@ -1,9 +1,10 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Iterator, Type from typing import ClassVar, Iterator, Type
import yaml import yaml
from dcs.ships import ship_map from dcs.ships import ship_map
@ -15,21 +16,27 @@ from game.dcs.unittype import UnitType
@dataclass(frozen=True) @dataclass(frozen=True)
class ShipUnitType(UnitType[Type[ShipType]]): class ShipUnitType(UnitType[Type[ShipType]]):
_by_name: ClassVar[dict[str, ShipUnitType]] = {}
_by_unit_type: ClassVar[dict[type[ShipType], list[ShipUnitType]]] = defaultdict(
list
)
@classmethod
def register(cls, unit_type: ShipUnitType) -> None:
cls._by_name[unit_type.name] = unit_type
cls._by_unit_type[unit_type.dcs_unit_type].append(unit_type)
@classmethod @classmethod
def named(cls, name: str) -> ShipUnitType: def named(cls, name: str) -> ShipUnitType:
if not cls._loaded: if not cls._loaded:
cls._load_all() cls._load_all()
unit = cls._by_name[name] return cls._by_name[name]
assert isinstance(unit, ShipUnitType)
return unit
@classmethod @classmethod
def for_dcs_type(cls, dcs_unit_type: Type[ShipType]) -> Iterator[ShipUnitType]: def for_dcs_type(cls, dcs_unit_type: Type[ShipType]) -> Iterator[ShipUnitType]:
if not cls._loaded: if not cls._loaded:
cls._load_all() cls._load_all()
for unit in cls._by_unit_type[dcs_unit_type]: yield from cls._by_unit_type[dcs_unit_type]
assert isinstance(unit, ShipUnitType)
yield unit
@staticmethod @staticmethod
def each_dcs_type() -> Iterator[Type[ShipType]]: def each_dcs_type() -> Iterator[Type[ShipType]]:

View File

@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC from abc import ABC
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import Any, ClassVar, Generic, Iterator, Type, TypeVar from typing import ClassVar, Generic, Iterator, Self, Type, TypeVar
from dcs.unittype import UnitType as DcsUnitType from dcs.unittype import UnitType as DcsUnitType
@ -25,10 +24,6 @@ class UnitType(ABC, Generic[DcsUnitTypeT]):
price: int price: int
unit_class: UnitClass unit_class: UnitClass
_by_name: ClassVar[dict[str, UnitType[Any]]] = {}
_by_unit_type: ClassVar[dict[Type[DcsUnitType], list[UnitType[Any]]]] = defaultdict(
list
)
_loaded: ClassVar[bool] = False _loaded: ClassVar[bool] = False
def __str__(self) -> str: def __str__(self) -> str:
@ -39,16 +34,15 @@ class UnitType(ABC, Generic[DcsUnitTypeT]):
return self.dcs_unit_type.id return self.dcs_unit_type.id
@classmethod @classmethod
def register(cls, unit_type: UnitType[Any]) -> None: def register(cls, unit_type: Self) -> None:
cls._by_name[unit_type.name] = unit_type
cls._by_unit_type[unit_type.dcs_unit_type].append(unit_type)
@classmethod
def named(cls, name: str) -> UnitType[Any]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def for_dcs_type(cls, dcs_unit_type: DcsUnitTypeT) -> Iterator[UnitType[Any]]: def named(cls, name: str) -> Self:
raise NotImplementedError
@classmethod
def for_dcs_type(cls, dcs_unit_type: DcsUnitTypeT) -> Iterator[Self]:
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
@ -56,7 +50,7 @@ class UnitType(ABC, Generic[DcsUnitTypeT]):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def _each_variant_of(cls, unit: DcsUnitTypeT) -> Iterator[UnitType[Any]]: def _each_variant_of(cls, unit: DcsUnitTypeT) -> Iterator[Self]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod