From 27b5f24a0f5f9598d94c826a3a0c44a7489c512f Mon Sep 17 00:00:00 2001 From: Dan Albert Date: Mon, 7 Jun 2021 18:54:21 -0700 Subject: [PATCH] Move unit purchase off find_unittype. --- game/db.py | 2 +- game/factions/faction.py | 10 ++++-- .../windows/basemenu/NewUnitTransferDialog.py | 4 +-- .../ground_forces/QArmorRecruitmentMenu.py | 35 +++++++------------ 4 files changed, 23 insertions(+), 28 deletions(-) diff --git a/game/db.py b/game/db.py index 22ab6079..fdc5dc50 100644 --- a/game/db.py +++ b/game/db.py @@ -1332,7 +1332,7 @@ def upgrade_to_supercarrier(unit, name: str): def find_unittype(for_task: Type[MainTask], country_name: str) -> List[Type[UnitType]]: - return [x for x in UNIT_BY_TASK[for_task] if x in FACTIONS[country_name].units] + return [x for x in UNIT_BY_TASK[for_task] if x in FACTIONS[country_name].all_units] MANPADS: List[Type[VehicleType]] = [ diff --git a/game/factions/faction.py b/game/factions/faction.py index ef06ef0d..d17da280 100644 --- a/game/factions/faction.py +++ b/game/factions/faction.py @@ -3,7 +3,7 @@ from game.data.groundunitclass import GroundUnitClass import logging from dataclasses import dataclass, field -from typing import Optional, Dict, Type, List, Any, cast +from typing import Optional, Dict, Type, List, Any, cast, Iterator import dcs from dcs.countries import country_dict @@ -240,7 +240,7 @@ class Faction: return faction @property - def units(self) -> List[Type[UnitType]]: + def all_units(self) -> List[Type[UnitType]]: return ( self.infantry_units + self.aircrafts @@ -251,6 +251,12 @@ class Faction: + self.logistics_units ) + @property + def ground_units(self) -> Iterator[Type[VehicleType]]: + yield from self.artillery_units + yield from self.frontline_units + yield from self.logistics_units + def unit_loader(unit: str, class_repository: List[Any]) -> Optional[Type[UnitType]]: """ diff --git a/qt_ui/windows/basemenu/NewUnitTransferDialog.py b/qt_ui/windows/basemenu/NewUnitTransferDialog.py index 5b58cc4f..57288d5f 100644 --- a/qt_ui/windows/basemenu/NewUnitTransferDialog.py +++ b/qt_ui/windows/basemenu/NewUnitTransferDialog.py @@ -166,9 +166,7 @@ class ScrollingUnitTransferGrid(QFrame): scroll_content = QWidget() task_box_layout = QGridLayout() - unit_types = set( - db.find_unittype(PinpointStrike, self.game_model.game.player_name) - ) + unit_types = set(self.game_model.game.faction_for(player=True).ground_units) sorted_units = sorted( {u for u in unit_types if self.cp.base.total_units_of_type(u)}, key=lambda u: db.unit_get_expanded_info( diff --git a/qt_ui/windows/basemenu/ground_forces/QArmorRecruitmentMenu.py b/qt_ui/windows/basemenu/ground_forces/QArmorRecruitmentMenu.py index 0a421a6a..3f104e7d 100644 --- a/qt_ui/windows/basemenu/ground_forces/QArmorRecruitmentMenu.py +++ b/qt_ui/windows/basemenu/ground_forces/QArmorRecruitmentMenu.py @@ -7,10 +7,8 @@ from PySide2.QtWidgets import ( QScrollArea, QVBoxLayout, QWidget, - QMessageBox, ) -from dcs.task import PinpointStrike -from dcs.unittype import FlyingType, UnitType +from dcs.unittype import UnitType from game import db from game.theater import ControlPoint @@ -32,31 +30,24 @@ class QArmorRecruitmentMenu(QFrame, QRecruitBehaviour): def init_ui(self): main_layout = QVBoxLayout() - units = { - PinpointStrike: db.find_unittype( - PinpointStrike, self.game_model.game.player_name - ), - } - scroll_content = QWidget() task_box_layout = QGridLayout() scroll_content.setLayout(task_box_layout) row = 0 - for task_type in units.keys(): - units_column = list(set(units[task_type])) - if len(units_column) == 0: - continue - units_column.sort( - key=lambda u: db.unit_get_expanded_info( - self.game_model.game.player_country, u, "name" - ) + unit_types = list( + set(self.game_model.game.faction_for(player=True).ground_units) + ) + unit_types.sort( + key=lambda u: db.unit_get_expanded_info( + self.game_model.game.player_country, u, "name" ) - for unit_type in units_column: - row = self.add_purchase_row(unit_type, task_box_layout, row) - stretch = QVBoxLayout() - stretch.addStretch() - task_box_layout.addLayout(stretch, row, 0) + ) + for unit_type in unit_types: + row = self.add_purchase_row(unit_type, task_box_layout, row) + stretch = QVBoxLayout() + stretch.addStretch() + task_box_layout.addLayout(stretch, row, 0) scroll_content.setLayout(task_box_layout) scroll = QScrollArea()