diff --git a/backend/core/include/commands.h b/backend/core/include/commands.h index 95142b11..5119a9eb 100644 --- a/backend/core/include/commands.h +++ b/backend/core/include/commands.h @@ -5,6 +5,11 @@ #include "logger.h" #include "datatypes.h" +struct CommandResult { + string hash; + string result; +}; + namespace CommandPriority { enum CommandPriorities { LOW, MEDIUM, HIGH, IMMEDIATE }; }; diff --git a/backend/core/include/scheduler.h b/backend/core/include/scheduler.h index d84f7901..304a9da1 100644 --- a/backend/core/include/scheduler.h +++ b/backend/core/include/scheduler.h @@ -13,7 +13,13 @@ public: void execute(lua_State* L); void handleRequest(string key, json::value value, string username, json::value& answer); bool checkSpawnPoints(int spawnPoints, string coalition); - bool isCommandExecuted(string commandHash) { return (find(executedCommandsHashes.begin(), executedCommandsHashes.end(), commandHash) != executedCommandsHashes.end()); } + bool isCommandExecuted(string commandHash) { + for (auto& commandResult : executedCommandResults) { + if (commandResult.hash == commandHash) { + return true; + } + } + } void setFrameRate(double newFrameRate) { frameRate = newFrameRate; } void setRestrictSpawns(bool newRestrictSpawns) { restrictSpawns = newRestrictSpawns; } @@ -36,7 +42,7 @@ public: private: list commands; - list executedCommandsHashes; + list executedCommandResults; unsigned int load = 0; double frameRate = 0; diff --git a/backend/core/src/commands.cpp b/backend/core/src/commands.cpp index 240bfb87..b0b0c93d 100644 --- a/backend/core/src/commands.cpp +++ b/backend/core/src/commands.cpp @@ -49,7 +49,6 @@ string SpawnGroundUnits::getString() << "heading = " << spawnOptions[i].heading << ", " << "liveryID = " << "\"" << spawnOptions[i].liveryID << "\"" << ", " << "skill = \"" << spawnOptions[i].skill << "\"" << "}, "; - } std::ostringstream commandSS; @@ -59,6 +58,7 @@ string SpawnGroundUnits::getString() << "coalition = " << "\"" << coalition << "\"" << ", " << "country = \"" << country << "\", " << "units = " << "{" << unitsSS.str() << "}" << "}"; + commandSS << ", \"" << this->getHash() << "\""; return commandSS.str(); } @@ -85,6 +85,7 @@ string SpawnNavyUnits::getString() << "coalition = " << "\"" << coalition << "\"" << ", " << "country = \"" << country << "\", " << "units = " << "{" << unitsSS.str() << "}" << "}"; + commandSS << ", \"" << this->getHash() << "\""; return commandSS.str(); } @@ -113,6 +114,7 @@ string SpawnAircrafts::getString() << "airbaseName = \"" << airbaseName << "\", " << "country = \"" << country << "\", " << "units = " << "{" << unitsSS.str() << "}" << "}"; + commandSS << ", \"" << this->getHash() << "\""; return commandSS.str(); } @@ -142,6 +144,7 @@ string SpawnHelicopters::getString() << "airbaseName = \"" << airbaseName << "\", " << "country = \"" << country << "\", " << "units = " << "{" << unitsSS.str() << "}" << "}"; + commandSS << ", \"" << this->getHash() << "\""; return commandSS.str(); } diff --git a/backend/core/src/core.cpp b/backend/core/src/core.cpp index c5376e56..a7ddf538 100644 --- a/backend/core/src/core.cpp +++ b/backend/core/src/core.cpp @@ -23,6 +23,7 @@ Scheduler* scheduler = nullptr; /* Data jsons */ json::value missionData = json::value::object(); json::value drawingsByLayer = json::value::object(); +json::value executionResults = json::value::object(); mutex mutexLock; string sessionHash; @@ -174,5 +175,16 @@ extern "C" DllExport int coreDrawingsData(lua_State* L) lua_getfield(L, -1, "drawingsByLayer"); luaTableToJSON(L, -1, drawingsByLayer); + return(0); +} + +extern "C" DllExport int coreSetExecutionResults(lua_State* L) +{ + /* Lock for thread safety */ + lock_guard guard(mutexLock); + + lua_getglobal(L, "Olympus"); + lua_getfield(L, -1, "executionResults"); + luaTableToJSON(L, -1, executionResults, true); return(0); } \ No newline at end of file diff --git a/backend/core/src/scheduler.cpp b/backend/core/src/scheduler.cpp index ddb0af38..2bdb69cc 100644 --- a/backend/core/src/scheduler.cpp +++ b/backend/core/src/scheduler.cpp @@ -52,10 +52,11 @@ void Scheduler::execute(lua_State* L) if (command->getPriority() == priority) { string commandString = "Olympus.protectedCall(" + command->getString() + ")"; - if (dostring_in(L, "server", (commandString))) + string resultString = ""; + if (dostring_in(L, "server", (commandString), resultString)) log("Error executing command " + commandString); else - log("Command '" + commandString + "' executed correctly, current load " + to_string(getLoad())); + log("Command '" + commandString + "' executed correctly, current load " + to_string(getLoad()) + ", result string: " + resultString); /* Adjust the load depending on the fps */ double fpsMultiplier = 20; @@ -64,7 +65,10 @@ void Scheduler::execute(lua_State* L) load = static_cast(command->getLoad() * fpsMultiplier); commands.remove(command); - executedCommandsHashes.push_back(command->getHash()); + CommandResult commandResult = { + command->getHash(), resultString + }; + executedCommandResults.push_back(commandResult); command->executeCallback(); /* Execute the command callback (this is a lambda function that can be used to execute a function when the command is run) */ delete command; return; @@ -192,7 +196,6 @@ void Scheduler::handleRequest(string key, json::value value, string username, js string airbaseName = to_string(value[L"airbaseName"]); string country = to_string(value[L"country"]); - int spawnPoints = value[L"spawnPoints"].as_number().to_int32(); if (!checkSpawnPoints(spawnPoints, coalition)) { log(username + " insufficient spawn points ", true); diff --git a/backend/core/src/server.cpp b/backend/core/src/server.cpp index 12f5cc5e..e00a1fe0 100644 --- a/backend/core/src/server.cpp +++ b/backend/core/src/server.cpp @@ -18,6 +18,7 @@ extern WeaponsManager* weaponsManager; extern Scheduler* scheduler; extern json::value missionData; extern json::value drawingsByLayer; +extern json::value executionResults; extern mutex mutexLock; extern string sessionHash; extern string instancePath; @@ -149,6 +150,10 @@ void Server::handle_get(http_request request) } else if (URI.compare(COMMANDS_URI) == 0 && query.find(L"commandHash") != query.end()) { answer[L"commandExecuted"] = json::value(scheduler->isCommandExecuted(to_string(query[L"commandHash"]))); + if (executionResults.has_field(query[L"commandHash"])) + answer[L"commandResult"] = executionResults[query[L"commandHash"]]; + else + answer[L"commandResult"] = json::value::null(); } /* Drawings data*/ else if (URI.compare(DRAWINGS_URI) == 0 && drawingsByLayer.has_object_field(L"drawings")) { diff --git a/backend/dcstools/include/dcstools.h b/backend/dcstools/include/dcstools.h index d70795b3..0f4f1bed 100644 --- a/backend/dcstools/include/dcstools.h +++ b/backend/dcstools/include/dcstools.h @@ -7,6 +7,7 @@ void DllExport LogWarning(lua_State* L, string message); void DllExport LogError(lua_State* L, string message); void DllExport Log(lua_State* L, string message, unsigned int level); int DllExport dostring_in(lua_State* L, string target, string command); +int DllExport dostring_in(lua_State* L, string target, string command, string& result); void DllExport getAllUnits(lua_State* L, map& unitJSONs); unsigned int DllExport TACANChannelToFrequency(unsigned int channel, char XY); diff --git a/backend/dcstools/src/dcstools.cpp b/backend/dcstools/src/dcstools.cpp index ae655e68..ff179ddf 100644 --- a/backend/dcstools/src/dcstools.cpp +++ b/backend/dcstools/src/dcstools.cpp @@ -101,7 +101,24 @@ int dostring_in(lua_State* L, string target, string command) lua_getfield(L, -1, "dostring_in"); lua_pushstring(L, target.c_str()); lua_pushstring(L, command.c_str()); - return lua_pcall(L, 2, 0, 0); + int res = lua_pcall(L, 2, 0, 0); + return res; +} + +int dostring_in(lua_State* L, string target, string command, string &result) +{ + lua_getglobal(L, "net"); + lua_getfield(L, -1, "dostring_in"); + lua_pushstring(L, target.c_str()); + lua_pushstring(L, command.c_str()); + int res = lua_pcall(L, 2, 0, 0); + + // Get the first result in the stack + if (lua_isstring(L, -1)) { + result = lua_tostring(L, -1); + } + + return res; } unsigned int TACANChannelToFrequency(unsigned int channel, char XY) diff --git a/backend/olympus/src/olympus.cpp b/backend/olympus/src/olympus.cpp index 456014f7..5467a189 100644 --- a/backend/olympus/src/olympus.cpp +++ b/backend/olympus/src/olympus.cpp @@ -12,6 +12,7 @@ typedef int(__stdcall* f_coreUnitsData)(lua_State* L); typedef int(__stdcall* f_coreWeaponsData)(lua_State* L); typedef int(__stdcall* f_coreMissionData)(lua_State* L); typedef int(__stdcall* f_coreDrawingsData)(lua_State* L); +typedef int(__stdcall* f_coreSetExecutionResults)(lua_State* L); f_coreInit coreInit = nullptr; f_coreDeinit coreDeinit = nullptr; f_coreFrame coreFrame = nullptr; @@ -19,6 +20,7 @@ f_coreUnitsData coreUnitsData = nullptr; f_coreWeaponsData coreWeaponsData = nullptr; f_coreMissionData coreMissionData = nullptr; f_coreDrawingsData coreDrawingsData = nullptr; +f_coreSetExecutionResults coreExecutionResults = nullptr; string modPath; @@ -117,6 +119,13 @@ static int onSimulationStart(lua_State* L) goto error; } + coreExecutionResults = (f_coreSetExecutionResults)GetProcAddress(hGetProcIDDLL, "coreSetExecutionResults"); + if (!coreExecutionResults) + { + LogError(L, "Error getting coreSetExecutionResults ProcAddress from DLL"); + goto error; + } + coreInit(L, modPath.c_str()); LogInfo(L, "Module loaded and started successfully."); @@ -213,6 +222,15 @@ static int setDrawingsData(lua_State* L) return 0; } +static int setExecutionResults(lua_State* L) +{ + if (coreExecutionResults) + { + coreExecutionResults(L); + } + return 0; +} + static const luaL_Reg Map[] = { {"onSimulationStart", onSimulationStart}, {"onSimulationFrame", onSimulationFrame}, @@ -221,6 +239,7 @@ static const luaL_Reg Map[] = { {"setWeaponsData", setWeaponsData }, {"setMissionData", setMissionData }, {"setDrawingsData", setDrawingsData }, + {"setExecutionResults", setExecutionResults }, {NULL, NULL} }; diff --git a/backend/utils/src/utils.cpp b/backend/utils/src/utils.cpp index a959cdfe..a0ebfdf7 100644 --- a/backend/utils/src/utils.cpp +++ b/backend/utils/src/utils.cpp @@ -1,5 +1,6 @@ #include "framework.h" #include "utils.h" +#include // Get current date/time, format is YYYY-MM-DD.HH:mm:ss const std::string CurrentDateTime() @@ -44,7 +45,11 @@ std::string to_string(const std::wstring& wstr) std::string random_string(size_t length) { - srand(static_cast(time(NULL))); + // Use nanoseconds since epoch as a seed for random number generation + auto now = std::chrono::high_resolution_clock::now(); + auto nanos = std::chrono::duration_cast(now.time_since_epoch()).count(); + srand(static_cast(nanos)); + auto randchar = []() -> char { const char charset[] = diff --git a/frontend/server/src/audio/srshandler.ts b/frontend/server/src/audio/srshandler.ts index d42dd8e6..b0a87e7b 100644 --- a/frontend/server/src/audio/srshandler.ts +++ b/frontend/server/src/audio/srshandler.ts @@ -1,5 +1,6 @@ import { MessageType } from "./audiopacket"; import { defaultSRSData } from "./defaultdata"; +import { AudioPacket } from "./audiopacket"; /* TCP/IP socket */ var net = require("net"); @@ -113,6 +114,16 @@ export class SRSHandler { switch (data[0]) { case MessageType.audio: const encodedData = new Uint8Array(data.slice(1)); + + // Decoded the data for sanity check + if (encodedData.length < 22) { + console.log("Received audio data is too short, ignoring."); + return; + } + + let packet = new AudioPacket(); + packet.fromByteArray(encodedData); + this.udp.send(encodedData, this.SRSPort, "127.0.0.1", (error) => { if (error) console.log(`Error sending data to SRS server: ${error}`); }); diff --git a/scripts/python/API/.vscode/launch.json b/scripts/python/API/.vscode/launch.json index 2d2e73eb..96695e79 100644 --- a/scripts/python/API/.vscode/launch.json +++ b/scripts/python/API/.vscode/launch.json @@ -4,13 +4,21 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ - { - "name": "App", + "name": "Voice control", "type": "debugpy", "request": "launch", - "program": "app.py", - "console": "integratedTerminal" + "program": "voice_control.py", + "console": "integratedTerminal", + "justMyCode": false, + }, + { + "name": "Test bed", + "type": "debugpy", + "request": "launch", + "program": "testbed.py", + "console": "integratedTerminal", + "justMyCode": false, } ] } \ No newline at end of file diff --git a/scripts/python/API/api.py b/scripts/python/API/api.py new file mode 100644 index 00000000..545fc23e --- /dev/null +++ b/scripts/python/API/api.py @@ -0,0 +1,506 @@ +import json +import time +import requests +import base64 +import signal +import sys +import logging +import os +import tempfile +import asyncio +from google.cloud import speech, texttospeech + +# Custom imports +from data.data_extractor import DataExtractor +from unit.unit import Unit +from data.unit_spawn_table import UnitSpawnTable +from data.data_types import LatLng + +class API: + def __init__(self, username: str = "API", databases_location: str = "databases"): + self.base_url = None + self.config = None + self.logs = {} + self.units: dict[str, Unit] = {} + self.username = username + self.databases_location = databases_location + self.interval = 1 # Default update interval in seconds + self.on_update_callback = None + self.on_startup_callback = None + self.should_stop = False + self.running = False + + self.units_update_timestamp = 0 + + # Setup logging + self.logger = logging.getLogger(f"DCSOlympus.API") + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter('[%(asctime)s] %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.INFO) + + # Read the config file olympus.json + try: + with open("olympus.json", "r") as file: + # Load the JSON configuration + self.config = json.load(file) + except FileNotFoundError: + self.logger.error("Configuration file olympus.json not found.") + + self.password = self.config.get("authentication").get("gameMasterPassword") + address = self.config.get("backend").get("address") + port = self.config.get("backend").get("port", None) + + if port: + self.base_url = f"http://{address}:{port}/olympus" + else: + self.base_url = f"https://{address}/olympus" + + # Read the aircraft, helicopter, groundunit and navyunit databases as json files + try: + with open(f"{self.databases_location}/aircraftdatabase.json", "r", -1, 'utf-8') as file: + self.aircraft_database = json.load(file) + except FileNotFoundError: + self.logger.error("Aircraft database file not found.") + + try: + with open(f"{self.databases_location}/helicopterdatabase.json", "r", -1, 'utf-8') as file: + self.helicopter_database = json.load(file) + except FileNotFoundError: + self.logger.error("Helicopter database file not found.") + + try: + with open(f"{self.databases_location}/groundunitdatabase.json", "r", -1, 'utf-8') as file: + self.groundunit_database = json.load(file) + except FileNotFoundError: + self.logger.error("Ground unit database file not found.") + + try: + with open(f"{self.databases_location}/navyunitdatabase.json", "r", -1, 'utf-8') as file: + self.navyunit_database = json.load(file) + except FileNotFoundError: + self.logger.error("Navy unit database file not found.") + + def _get(self, endpoint): + credentials = f"{self.username}:{self.password}" + base64_encoded_credentials = base64.b64encode(credentials.encode()).decode() + + headers = { + "Authorization": f"Basic {base64_encoded_credentials}" + } + response = requests.get(f"{self.base_url}/{endpoint}", headers=headers) + if response.status_code == 200: + return response + else: + response.raise_for_status() + + def _put(self, data): + credentials = f"{self.username}:{self.password}" + base64_encoded_credentials = base64.b64encode(credentials.encode()).decode() + + headers = { + "Authorization": f"Basic {base64_encoded_credentials}", + "Content-Type": "application/json" + } + response = requests.put(f"{self.base_url}", headers=headers, json=data) + if response.status_code == 200: + return response + else: + response.raise_for_status() + + def _setup_signal_handlers(self): + def signal_handler(signum, frame): + self.logger.info(f"Received signal {signum}, initiating graceful shutdown...") + self.stop() + + # Register signal handlers + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + if hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, signal_handler) # Termination signal + + def get_units(self): + """ + Get all units from the API. Notice that if the API is not running, update_units() must be manually called first. + Returns: + dict: A dictionary of Unit objects indexed by their unit ID. + """ + return self.units + + def get_logs(self): + """ + Get the logs from the API. Notice that if the API is not running, update_logs() must be manually called first. + Returns: + dict: A dictionary of log entries indexed by their log ID. + """ + return self.logs + + def update_units(self, time=0): + """ + Fetch the list of units from the API. + Args: + time (int): The time in milliseconds from Unix epoch to fetch units from. Default is 0, which fetches all units. + If time is greater than 0, it fetches units updated after that time. + Returns: + dict: A dictionary of Unit objects indexed by their unit ID. + """ + response = self._get("units") + if response.status_code == 200 and len(response.content) > 0: + try: + data_extractor = DataExtractor(response.content) + + # Extract the update timestamp + self.units_update_timestamp = data_extractor.extract_uint64() + self.logger.debug(f"Update Timestamp: {self.units_update_timestamp}") + + while data_extractor.get_seek_position() < len(response.content): + # Extract the unit ID + unit_id = data_extractor.extract_uint32() + + if unit_id not in self.units: + # Create a new Unit instance if it doesn't exist + self.units[unit_id] = Unit(unit_id, self) + + self.units[unit_id].update_from_data_extractor(data_extractor) + + return self.units + + except ValueError: + self.logger.error("Failed to parse JSON response") + else: + self.logger.error(f"Failed to fetch units: {response.status_code} - {response.text}") + + + def update_logs(self, time = 0): + """ + Fetch the logs from the API. + Args: + time (int): The time in milliseconds from Unix epoch to fetch logs from. Default is 0, which fetches all logs. + Returns: + list: A list of log entries. + """ + endpoint = "/logs" + endpoint += f"?time={time}" + response = self._get(endpoint) + if response.status_code == 200: + try: + self.logs = json.loads(response.content.decode('utf-8')) + return self.logs + except ValueError: + self.logger.error("Failed to parse JSON response") + else: + self.logger.error(f"Failed to fetch logs: {response.status_code} - {response.text}") + + def spawn_aircrafts(self, units: list[UnitSpawnTable], coalition: str, airbaseName: str, country: str, immediate: bool, spawnPoints: int = 0): + """ + Spawn aircraft units at the specified location or airbase. + Args: + units (list[UnitSpawnTable]): List of UnitSpawnTable objects representing the aircraft to spawn. + coalition (str): The coalition to which the units belong. ("blue", "red", "neutral") + airbaseName (str): The name of the airbase where the units will be spawned. Leave "" for air spawn. + country (str): The country of the units. + immediate (bool): Whether to spawn the units immediately or not, overriding the scheduler. + spawnPoints (int): Amount of spawn points to use, default is 0. + """ + command = { + "units": [unit.toJSON() for unit in units], + "coalition": coalition, + "airbaseName": airbaseName, + "country": country, + "immediate": immediate, + "spawnPoints": spawnPoints, + } + data = { "spawnAircrafts": command } + response = self._put(data) + + def spawn_helicopters(self, units: list[UnitSpawnTable], coalition: str, airbaseName: str, country: str, immediate: bool, spawnPoints: int = 0): + """ + Spawn helicopter units at the specified location or airbase. + Args: + units (list[UnitSpawnTable]): List of UnitSpawnTable objects representing the helicopters to spawn. + coalition (str): The coalition to which the units belong. ("blue", "red", "neutral") + airbaseName (str): The name of the airbase where the units will be spawned. Leave "" for air spawn. + country (str): The country of the units. + immediate (bool): Whether to spawn the units immediately or not, overriding the scheduler. + spawnPoints (int): Amount of spawn points to use, default is 0. + """ + command = { + "units": [unit.toJSON() for unit in units], + "coalition": coalition, + "airbaseName": airbaseName, + "country": country, + "immediate": immediate, + "spawnPoints": spawnPoints, + } + data = { "spawnHelicopters": command } + response = self._put(data) + + def spawn_ground_units(self, units: list[UnitSpawnTable], coalition: str, country: str, immediate: bool, spawnPoints: int, execution_callback): + """ + Spawn ground units at the specified location. + Args: + units (list[UnitSpawnTable]): List of UnitSpawnTable objects representing the ground units to spawn. + coalition (str): The coalition to which the units belong. ("blue", "red", "neutral") + country (str): The country of the units. + immediate (bool): Whether to spawn the units immediately or not, overriding the scheduler. + spawnPoints (int): Amount of spawn points to use. + execution_callback (function): An async callback function to execute after the command is processed. + """ + command = { + "units": [unit.toJSON() for unit in units], + "coalition": coalition, + "country": country, + "immediate": immediate, + "spawnPoints": spawnPoints, + } + data = { "spawnGroundUnits": command } + response = self._put(data) + + # Parse the response as JSON + try: + response_data = response.json() + command_hash = response_data.get("commandHash", None) + if command_hash: + self.logger.info(f"Ground units spawned successfully. Command Hash: {command_hash}") + # Start a background task to check if the command was executed + asyncio.create_task(self._check_command_executed(command_hash, execution_callback, wait_for_result=True,)) + else: + self.logger.error("Command hash not found in response") + + + except ValueError: + self.logger.error("Failed to parse JSON response") + + async def _check_command_executed(self, command_hash: str, execution_callback, wait_for_result: bool, max_wait_time: int = 60): + """ + Check if a command has been executed by polling the API. + """ + start_time = time.time() + while True: + response = self._get(f"commands?commandHash={command_hash}") + if response.status_code == 200: + try: + data = response.json() + if data.get("commandExecuted") == True and (data.get("commandResult") is not None or (not wait_for_result)): + self.logger.info(f"Command {command_hash} executed successfully, command result: {data.get('commandResult')}") + if execution_callback: + await execution_callback(data.get("commandResult")) + break + elif data.get("status") == "failed": + self.logger.error(f"Command {command_hash} failed to execute.") + break + except ValueError: + self.logger.error("Failed to parse JSON response") + if time.time() - start_time > max_wait_time: + self.logger.warning(f"Timeout: Command {command_hash} did not complete within {max_wait_time} seconds.") + break + await asyncio.sleep(1) + + def spawn_navy_units(self, units: list[UnitSpawnTable], coalition: str, country: str, immediate: bool, spawnPoints: int = 0): + """ + Spawn navy units at the specified location. + Args: + units (list[UnitSpawnTable]): List of UnitSpawnTable objects representing the navy units to spawn. + coalition (str): The coalition to which the units belong. ("blue", "red", "neutral") + country (str): The country of the units. + immediate (bool): Whether to spawn the units immediately or not, overriding the scheduler. + spawnPoints (int): Amount of spawn points to use, default is 0. + """ + command = { + "units": [unit.toJSON() for unit in units], + "coalition": coalition, + "country": country, + "immediate": immediate, + "spawnPoints": spawnPoints, + } + data = { "spawnNavyUnits": command } + response = self._put(data) + + def create_radio_listener(self): + """ + Create an audio listener instance. + + Returns: + AudioListener: An instance of the AudioListener class. + """ + from radio.radio_listener import RadioListener + return RadioListener(self, "localhost", self.config.get("audio").get("WSPort")) + + def register_on_update_callback(self, callback): + """ + Register a callback function to be called on each update. + + Args: + callback (function): The function to call on update. Can be sync or async. + The function should accept a single argument, which is the API instance. + """ + self.on_update_callback = callback + + def register_on_startup_callback(self, callback): + """ + Register a callback function to be called on startup. + Args: + callback (function): The function to call on startup. Can be sync or async. + The function should accept a single argument, which is the API instance. + """ + self.on_startup_callback = callback + + def set_log_level(self, level): + """ + Set the logging level for the API. + + Args: + level: Logging level (e.g., logging.DEBUG, logging.INFO, logging.WARNING, self.logger.error) + """ + self.logger.setLevel(level) + self.logger.info(f"Log level set to {logging.getLevelName(level)}") + + def stop(self): + """ + Stop the API service gracefully. + """ + self.logger.info("Stopping API service...") + self.should_stop = True + + async def _run_callback_async(self, callback, *args): + """ + Run a callback asynchronously, handling both sync and async callbacks. + """ + try: + if asyncio.iscoroutinefunction(callback): + await callback(*args) + else: + callback(*args) + except Exception as e: + # Log the error but don't crash the update process + self.logger.error(f"Error in callback: {e}") + + def generate_audio_message(text: str, gender: str = "male", code: str = "en-US") -> str: + """ + Generate a WAV file from text using Google Text-to-Speech API. + Remember to manually delete the generated file after use! + + Args: + text (str): The text to synthesize. + gender (str): The gender of the voice (male or female). + code (str): The language code (e.g., en-US). + + Returns: + str: The filename of the generated WAV file. + """ + client = texttospeech.TextToSpeechClient() + input_text = texttospeech.SynthesisInput(text=text) + voice = texttospeech.VoiceSelectionParams( + language_code=code, + ssml_gender=texttospeech.SsmlVoiceGender.MALE if gender == "male" else texttospeech.SsmlVoiceGender.FEMALE + ) + audio_config = texttospeech.AudioConfig( + audio_encoding=texttospeech.AudioEncoding.LINEAR16, + sample_rate_hertz=16000 + ) + response = client.synthesize_speech( + input=input_text, + voice=voice, + audio_config=audio_config + ) + # Save the response audio to a WAV file + temp_dir = tempfile.gettempdir() + file_name = os.path.join(temp_dir, next(tempfile._get_candidate_names()) + ".wav") + with open(file_name, "wb") as out: + out.write(response.audio_content) + + return file_name + + def send_command(self, command: str): + """ + Send a command to the API. + + Args: + command (str): The command to send. + """ + response = self._put(command) + if response.status_code == 200: + self.logger.info(f"Command sent successfully: {command}") + else: + self.logger.error(f"Failed to send command: {response.status_code} - {response.text}") + + def run(self): + """ + Start the API service. + + This method initializes the API and starts the necessary components. + Sets up signal handlers for graceful shutdown. + """ + asyncio.run(self._run_async()) + + def get_closest_units(self, coalitions: list[str], categories: list[str], position: LatLng, operate_as: str | None = None, max_number: int = 1, max_distance: float = 10000) -> list[Unit]: + """ + Get the closest units of a specific coalition and category to a given position. + Units are filtered by coalition, category, and optionally by operating role. + + + Args: + coalitions (list[str]): List of coalitions to filter by (e.g., ["blue", "red"]). + categories (list[str]): List of categories to filter by (e.g., ["aircraft", "groundunit"]). + position (LatLng): The position to measure distance from. + operate_as (str | None): Optional list of operating roles to filter by (either "red" or "blue"). Default is None. + max_number (int): Maximum number of closest units to return. Default is 1. + max_distance (float): Maximum distance to consider for the closest unit. Default is 10000 meters. + """ + closest_units = [] + closest_distance = max_distance + + # Iterate through all units and find the closest ones that match the criteria + for unit in self.units.values(): + if unit.alive and unit.coalition in coalitions and unit.category.lower() in categories and (operate_as is None or unit.operate_as == operate_as or unit.coalition is not "neutral"): + distance = position.distance_to(unit.position) + if distance < closest_distance: + closest_distance = distance + closest_units = [unit] + elif distance == closest_distance: + closest_units.append(unit) + + # Sort the closest units by distance + closest_units.sort(key=lambda u: position.distance_to(u.position)) + + # Limit the number of closest units returned + closest_units = closest_units[:max_number] + + return closest_units + + async def _run_async(self): + """ + Async implementation of the API service loop. + """ + # Setup signal handlers for graceful shutdown + self._setup_signal_handlers() + + # Here you can add any initialization logic if needed + self.logger.info("API started") + self.logger.info("Press Ctrl+C to stop gracefully") + + self.running = True + self.should_stop = False + + # Call the startup callback if registered + if self.on_startup_callback: + try: + await self._run_callback_async(self.on_startup_callback, self) + except Exception as e: + self.logger.error(f"Error in startup callback: {e}") + + try: + while not self.should_stop: + # Update units from the last update timestamp + self.update_units(self.units_update_timestamp) + + if self.on_update_callback: + await self._run_callback_async(self.on_update_callback, self) + await asyncio.sleep(self.interval) + except KeyboardInterrupt: + self.logger.info("Keyboard interrupt received") + self.stop() + finally: + self.logger.info("API stopped") + self.running = False diff --git a/scripts/python/API/app.py b/scripts/python/API/app.py deleted file mode 100644 index 65e4d1eb..00000000 --- a/scripts/python/API/app.py +++ /dev/null @@ -1,248 +0,0 @@ -import hashlib -import json -from math import pi -from random import random -import requests -import base64 -from data_extractor import DataExtractor -from data_types import LatLng -from unit import Unit -from unit_spawn_table import UnitSpawnTable -from datetime import datetime - -class API: - def __init__(self, username: str = "API", databases_location: str = "databases"): - self.base_url = None - self.config = None - self.units: dict[str, Unit] = {} - self.username = username - self.databases_location = databases_location - - # Read the config file olympus.json - try: - with open("olympus.json", "r") as file: - # Load the JSON configuration - self.config = json.load(file) - except FileNotFoundError: - raise Exception("Configuration file olympus.json not found.") - - self.password = self.config.get("authentication").get("gameMasterPassword") - address = self.config.get("backend").get("address") - port = self.config.get("backend").get("port", None) - - if port: - self.base_url = f"http://{address}:{port}/olympus" - else: - self.base_url = f"https://{address}/olympus" - - # Read the aircraft, helicopter, groundunit and navyunit databases as json files - try: - with open(f"{self.databases_location}/aircraftdatabase.json", "r", -1, 'utf-8') as file: - self.aircraft_database = json.load(file) - except FileNotFoundError: - raise Exception("Aircraft database file not found.") - - try: - with open(f"{self.databases_location}/helicopterdatabase.json", "r", -1, 'utf-8') as file: - self.helicopter_database = json.load(file) - except FileNotFoundError: - raise Exception("Helicopter database file not found.") - - try: - with open(f"{self.databases_location}/groundunitdatabase.json", "r", -1, 'utf-8') as file: - self.groundunit_database = json.load(file) - except FileNotFoundError: - raise Exception("Ground unit database file not found.") - - try: - with open(f"{self.databases_location}/navyunitdatabase.json", "r", -1, 'utf-8') as file: - self.navyunit_database = json.load(file) - except FileNotFoundError: - raise Exception("Navy unit database file not found.") - - def get(self, endpoint): - credentials = f"{self.username}:{self.password}" - base64_encoded_credentials = base64.b64encode(credentials.encode()).decode() - - headers = { - "Authorization": f"Basic {base64_encoded_credentials}" - } - response = requests.get(f"{self.base_url}/{endpoint}", headers=headers) - if response.status_code == 200: - return response - else: - response.raise_for_status() - - def put(self, endpoint, data): - credentials = f"{self.username}:{self.password}" - base64_encoded_credentials = base64.b64encode(credentials.encode()).decode() - - headers = { - "Authorization": f"Basic {base64_encoded_credentials}", - "Content-Type": "application/json" - } - response = requests.put(f"{self.base_url}/{endpoint}", headers=headers, json=data) - if response.status_code == 200: - return response - else: - response.raise_for_status() - - def get_units(self): - response = self.get("/units") - if response.status_code == 200: - try: - data_extractor = DataExtractor(response.content) - - # Extract the update timestamp - update_timestamp = data_extractor.extract_uint64() - print(f"Update Timestamp: {update_timestamp}") - - while data_extractor.get_seek_position() < len(response.content): - # Extract the unit ID - unit_id = data_extractor.extract_uint32() - - if unit_id not in self.units: - # Create a new Unit instance if it doesn't exist - self.units[unit_id] = Unit(unit_id) - - self.units[unit_id].update_from_data_extractor(data_extractor) - - return self.units - - except ValueError: - raise Exception("Failed to parse JSON response") - else: - raise Exception(f"Failed to fetch units: {response.status_code} - {response.text}") - - - def get_logs(self, time = 0): - endpoint = "/logs" - endpoint += f"?time={time}" - response = self.get(endpoint) - if response.status_code == 200: - try: - logs = json.loads(response.content.decode('utf-8')) - return logs - except ValueError: - raise Exception("Failed to parse JSON response") - else: - raise Exception(f"Failed to fetch logs: {response.status_code} - {response.text}") - - def spawn_aircrafts(self, units: list[UnitSpawnTable], coalition: str, airbaseName: str, country: str, immediate: bool, spawnPoints: list[LatLng]): - command = { - "units": [unit.toJSON() for unit in units], - "coalition": coalition, - "airbaseName": airbaseName, - "country": country, - "immediate": immediate, - "spawnPoints": spawnPoints, - } - data = { "spawnAircrafts": command } - self.put("", data) - - def spanw_helicopters(self, units: list[UnitSpawnTable], coalition: str, airbaseName: str, country: str, immediate: bool, spawnPoints: list[LatLng]): - command = { - "units": [unit.toJSON() for unit in units], - "coalition": coalition, - "airbaseName": airbaseName, - "country": country, - "immediate": immediate, - "spawnPoints": spawnPoints, - } - data = { "spawnHelicopters": command } - self.put("", data) - - def spawn_ground_units(self, units: list[UnitSpawnTable], coalition: str, country: str, immediate: bool, spawnPoints: list[LatLng]): - command = { - "units": [unit.toJSON() for unit in units], - "coalition": coalition, - "country": country, - "immediate": immediate, - "spawnPoints": spawnPoints, - } - data = { "spawnGroundUnits": command } - self.put("", data) - - def spawn_navy_units(self, units: list[UnitSpawnTable], coalition: str, country: str, immediate: bool, spawnPoints: list[LatLng]): - command = { - "units": [unit.toJSON() for unit in units], - "coalition": coalition, - "country": country, - "immediate": immediate, - "spawnPoints": spawnPoints, - } - data = { "spawnNavyUnits": command } - self.put("", data) - - def delete_unit(self, ID: int, explosion: bool, explosionType: str, immediate: bool): - command = { - "ID": ID, - "explosion": explosion, - "explosionType": explosionType, - "immediate": immediate, - } - data = { "deleteUnit": command } - self.put("", data) - -if __name__ == "__main__": - api = API() - try: - # Example usage - # Get the units from the API - print("Fetching units...") - units = api.get_units() - print("Units:", units) - - # Example of spawning aircrafts - print("Spawning aircrafts...") - spawn_units = [ - UnitSpawnTable( - unit_type="A-10C_2", - location=LatLng(lat=35.0, lng=35.0, alt=1000), - skill="High", - livery_id="Default", - altitude=1000, - loadout="Default", - heading=pi/2 - ) - ] - api.spawn_aircrafts(spawn_units, "blue", "", "", False, 0) - - # Spawn a random navy unit - print("Spawning navy units...") - random_navy_unit = list(api.navyunit_database.keys())[int(random() * len(api.navyunit_database))] - - spawn_navy_units = [ - UnitSpawnTable( - unit_type=random_navy_unit, - location=LatLng(lat=35.0, lng=35.0, alt=0), - skill="High", - livery_id="Default", - altitude=None, - loadout=None, - heading=0 - ) - ] - api.spawn_navy_units(spawn_navy_units, "blue", "", False, 0) - - # Example of deleting a unit - # Get all the unit of type A-10C_2 - a10_units = [unit for unit in units.values() if unit.name == "A-10C_2"] - for unit in a10_units: - api.delete_unit(unit.ID, explosion=False, explosionType="", immediate=False) - - # Fetch logs from the API - print("Fetching logs...") - logs = api.get_logs()["logs"] - - # Pretty print the logs - print("Logs:") - # The log is a dictionary. The key is the timestamp and the value is the log message - for timestamp, log_message in logs.items(): - # The timestamp is in milliseconds from unix epoch - timestamp = int(timestamp) / 1000 # Convert to seconds - iso_time = datetime.fromtimestamp(timestamp).isoformat() - print(f"{iso_time}: {log_message}") - - except Exception as e: - print("An error occurred:", e) \ No newline at end of file diff --git a/scripts/python/API/audio/audio_packet.py b/scripts/python/API/audio/audio_packet.py new file mode 100644 index 00000000..70945e9b --- /dev/null +++ b/scripts/python/API/audio/audio_packet.py @@ -0,0 +1,205 @@ +from enum import Enum +from typing import List, Dict, Optional +import struct + +packet_id = 0 + +class MessageType(Enum): + AUDIO = 0 + SETTINGS = 1 + CLIENTS_DATA = 2 + +class AudioPacket: + def __init__(self): + # Mandatory data + self._frequencies: List[Dict[str, int]] = [] + self._audio_data: Optional[bytes] = None + self._transmission_guid: Optional[str] = None + self._client_guid: Optional[str] = None + + # Default data + self._unit_id: int = 0 + self._hops: int = 0 + + # Usually internally set only + self._packet_id: Optional[int] = None + + def from_byte_array(self, byte_array: bytes): + total_length = self._byte_array_to_integer(byte_array[0:2]) + audio_length = self._byte_array_to_integer(byte_array[2:4]) + frequencies_length = self._byte_array_to_integer(byte_array[4:6]) + + # Perform some sanity checks + if total_length != len(byte_array): + print(f"Warning, audio packet expected length is {total_length} but received length is {len(byte_array)}, aborting...") + return + + if frequencies_length % 10 != 0: + print(f"Warning, audio packet frequencies data length is {frequencies_length} which is not a multiple of 10, aborting...") + return + + # Extract the audio data + self._audio_data = byte_array[6:6 + audio_length] + + # Extract the frequencies + offset = 6 + audio_length + for idx in range(frequencies_length // 10): + self._frequencies.append({ + 'frequency': self._byte_array_to_double(byte_array[offset:offset + 8]), + 'modulation': byte_array[offset + 8], + 'encryption': byte_array[offset + 9] + }) + offset += 10 + + # Extract the remaining data + self._unit_id = self._byte_array_to_integer(byte_array[offset:offset + 4]) + offset += 4 + self._packet_id = self._byte_array_to_integer(byte_array[offset:offset + 8]) + offset += 8 + self._hops = self._byte_array_to_integer(byte_array[offset:offset + 1]) + offset += 1 + self._transmission_guid = byte_array[offset:offset + 22].decode('utf-8', errors='ignore') + offset += 22 + self._client_guid = byte_array[offset:offset + 22].decode('utf-8', errors='ignore') + offset += 22 + + + def to_byte_array(self) -> Optional[bytes]: + global packet_id + + # Perform some sanity checks + if len(self._frequencies) == 0: + print("Warning, could not encode audio packet, no frequencies data provided, aborting...") + return None + + if self._audio_data is None: + print("Warning, could not encode audio packet, no audio data provided, aborting...") + return None + + if self._transmission_guid is None: + print("Warning, could not encode audio packet, no transmission GUID provided, aborting...") + return None + + if self._client_guid is None: + print("Warning, could not encode audio packet, no client GUID provided, aborting...") + return None + + # Prepare the array for the header + header = [0, 0, 0, 0, 0, 0] + + # Encode the frequencies data + frequencies_data = [] + for data in self._frequencies: + frequencies_data.extend(self._double_to_byte_array(data['frequency'])) + frequencies_data.append(data['modulation']) + frequencies_data.append(data['encryption']) + + # If necessary increase the packet_id + if self._packet_id is None: + self._packet_id = packet_id + packet_id += 1 + + # Encode unitID, packetID, hops + enc_unit_id = self._integer_to_byte_array(self._unit_id, 4) + enc_packet_id = self._integer_to_byte_array(self._packet_id, 8) + enc_hops = [self._hops] + + # Assemble packet + encoded_data = [] + encoded_data.extend(header) + encoded_data.extend(list(self._audio_data)) + encoded_data.extend(frequencies_data) + encoded_data.extend(enc_unit_id) + encoded_data.extend(enc_packet_id) + encoded_data.extend(enc_hops) + encoded_data.extend(list(self._transmission_guid.encode('utf-8'))) + encoded_data.extend(list(self._client_guid.encode('utf-8'))) + + # Set the lengths of the parts + enc_packet_len = self._integer_to_byte_array(len(encoded_data), 2) + encoded_data[0] = enc_packet_len[0] + encoded_data[1] = enc_packet_len[1] + + enc_audio_len = self._integer_to_byte_array(len(self._audio_data), 2) + encoded_data[2] = enc_audio_len[0] + encoded_data[3] = enc_audio_len[1] + + frequency_audio_len = self._integer_to_byte_array(len(frequencies_data), 2) + encoded_data[4] = frequency_audio_len[0] + encoded_data[5] = frequency_audio_len[1] + + return bytes([0] + encoded_data) + + # Utility methods for byte array conversion + def _byte_array_to_integer(self, byte_array: bytes) -> int: + if len(byte_array) == 1: + return struct.unpack(' float: + return struct.unpack(' List[int]: + if length == 1: + return list(struct.pack(' List[int]: + return list(struct.pack(' List[Dict[str, int]]: + return self._frequencies + + def set_audio_data(self, audio_data: bytes): + self._audio_data = audio_data + + def get_audio_data(self) -> Optional[bytes]: + return self._audio_data + + def set_transmission_guid(self, transmission_guid: str): + self._transmission_guid = transmission_guid + + def get_transmission_guid(self) -> Optional[str]: + return self._transmission_guid + + def set_client_guid(self, client_guid: str): + self._client_guid = client_guid + + def get_client_guid(self) -> Optional[str]: + return self._client_guid + + def set_unit_id(self, unit_id: int): + self._unit_id = unit_id + + def get_unit_id(self) -> int: + return self._unit_id + + def set_packet_id(self, packet_id: int): + self._packet_id = packet_id + + def get_packet_id(self) -> Optional[int]: + return self._packet_id + + def set_hops(self, hops: int): + self._hops = hops + + def get_hops(self) -> int: + return self._hops \ No newline at end of file diff --git a/scripts/python/API/audio/audio_recorder.py b/scripts/python/API/audio/audio_recorder.py new file mode 100644 index 00000000..075e0b4b --- /dev/null +++ b/scripts/python/API/audio/audio_recorder.py @@ -0,0 +1,75 @@ +import threading +import opuslib # TODO: important, setup dll recognition +import wave +from typing import Callable + +from audio.audio_packet import AudioPacket +import tempfile +import os + +class AudioRecorder: + def __init__(self, api): + self.packets: list[AudioPacket] = [] + self.silence_timer = None + self.recording_callback = None + self.api = api + + def register_recording_callback(self, callback: Callable[[AudioPacket], None]): + """Set the callback function for handling recorded audio packets.""" + self.recording_callback = callback + + def add_packet(self, packet: AudioPacket): + self.packets.append(packet) + + # Start a countdown timer to stop recording after 2 seconds of silence + self.start_silence_timer() + + def stop_recording(self): + if self.silence_timer: + self.silence_timer.cancel() + self.silence_timer = None + + # Extract the client GUID from the first packet if available + unit_ID = self.packets[0].get_unit_id() if self.packets else None + + # Process the recorded packets + if self.packets: + print(f"Stopping recording, total packets: {len(self.packets)}") + + # Reorder the packets according to their packet ID + self.packets.sort(key=lambda p: p.get_packet_id()) + + # Decode to audio data using the opus codec + opus_decoder = opuslib.Decoder(16000, 1) + audio_data = bytearray() + for packet in self.packets: + decoded_data = opus_decoder.decode(packet.get_audio_data(), frame_size=6400) + audio_data.extend(decoded_data) + + # Save the audio into a temporary wav file with a random name in the tempo folder + temp_dir = tempfile.gettempdir() + file_name = os.path.join(temp_dir, next(tempfile._get_candidate_names()) + ".wav") + with wave.open(file_name, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(16000) + wav_file.writeframes(audio_data) + + if self.recording_callback: + self.recording_callback(file_name, unit_ID) + + # Clear the packets after saving and delete the temporary file + os.remove(file_name) + self.packets.clear() + else: + print("No packets recorded.") + + def start_silence_timer(self): + if self.silence_timer: + self.silence_timer.cancel() + + # Set a timer for 2 seconds + self.silence_timer = threading.Timer(2.0, self.stop_recording) + self.silence_timer.start() + + \ No newline at end of file diff --git a/scripts/python/API/data_extractor.py b/scripts/python/API/data/data_extractor.py similarity index 98% rename from scripts/python/API/data_extractor.py rename to scripts/python/API/data/data_extractor.py index e65f61d3..de1608dd 100644 --- a/scripts/python/API/data_extractor.py +++ b/scripts/python/API/data/data_extractor.py @@ -1,6 +1,6 @@ import struct from typing import List -from data_types import LatLng, TACAN, Radio, GeneralSettings, Ammo, Contact, Offset +from data.data_types import LatLng, TACAN, Radio, GeneralSettings, Ammo, Contact, Offset class DataExtractor: def __init__(self, buffer: bytes): diff --git a/scripts/python/API/data_indexes.py b/scripts/python/API/data/data_indexes.py similarity index 100% rename from scripts/python/API/data_indexes.py rename to scripts/python/API/data/data_indexes.py diff --git a/scripts/python/API/data/data_types.py b/scripts/python/API/data/data_types.py new file mode 100644 index 00000000..c4017a31 --- /dev/null +++ b/scripts/python/API/data/data_types.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from typing import List, Optional + +from utils.utils import bearing_to, distance, project_with_bearing_and_distance + +@dataclass +class LatLng: + lat: float + lng: float + alt: float + + def toJSON(self): + """Convert LatLng to a JSON serializable dictionary.""" + return { + "lat": self.lat, + "lng": self.lng, + "alt": self.alt + } + + def project_with_bearing_and_distance(self, d, bearing): + """ + Project this LatLng point with a bearing and distance. + Args: + d: Distance in meters to project. + bearing: Bearing in radians. + Returns: + A new LatLng point projected from this point. + + """ + (new_lat, new_lng) = project_with_bearing_and_distance(self.lat, self.lng, d, bearing) + return LatLng(new_lat, new_lng, self.alt) + + def distance_to(self, other): + """ + Calculate the distance to another LatLng point. + Args: + other: Another LatLng point. + Returns: + Distance in meters to the other point. + """ + return distance(self.lat, self.lng, other.lat, other.lng) + + def bearing_to(self, other): + """ + Calculate the bearing to another LatLng point. + Args: + other: Another LatLng point. + Returns: + Bearing in radians to the other point. + """ + return bearing_to(self.lat, self.lng, other.lat, other.lng) + +@dataclass +class TACAN: + is_on: bool + channel: int + xy: str + callsign: str + +@dataclass +class Radio: + frequency: int + callsign: int + callsign_number: int + +@dataclass +class GeneralSettings: + prohibit_jettison: bool + prohibit_aa: bool + prohibit_ag: bool + prohibit_afterburner: bool + prohibit_air_wpn: bool + +@dataclass +class Ammo: + quantity: int + name: str + guidance: int + category: int + missile_category: int + +@dataclass +class Contact: + id: int + detection_method: int + +@dataclass +class Offset: + x: float + y: float + z: float \ No newline at end of file diff --git a/scripts/python/API/roes.py b/scripts/python/API/data/roes.py similarity index 100% rename from scripts/python/API/roes.py rename to scripts/python/API/data/roes.py diff --git a/scripts/python/API/states.py b/scripts/python/API/data/states.py similarity index 100% rename from scripts/python/API/states.py rename to scripts/python/API/data/states.py diff --git a/scripts/python/API/unit_spawn_table.py b/scripts/python/API/data/unit_spawn_table.py similarity index 96% rename from scripts/python/API/unit_spawn_table.py rename to scripts/python/API/data/unit_spawn_table.py index 29a69980..c615e598 100644 --- a/scripts/python/API/unit_spawn_table.py +++ b/scripts/python/API/data/unit_spawn_table.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from typing import Optional -from data_types import LatLng +from data.data_types import LatLng @dataclass class UnitSpawnTable: diff --git a/scripts/python/API/data_types.py b/scripts/python/API/data_types.py deleted file mode 100644 index d5b22cb9..00000000 --- a/scripts/python/API/data_types.py +++ /dev/null @@ -1,56 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -@dataclass -class LatLng: - lat: float - lng: float - alt: float - - def toJSON(self): - """Convert LatLng to a JSON serializable dictionary.""" - return { - "lat": self.lat, - "lng": self.lng, - "alt": self.alt - } - -@dataclass -class TACAN: - is_on: bool - channel: int - xy: str - callsign: str - -@dataclass -class Radio: - frequency: int - callsign: int - callsign_number: int - -@dataclass -class GeneralSettings: - prohibit_jettison: bool - prohibit_aa: bool - prohibit_ag: bool - prohibit_afterburner: bool - prohibit_air_wpn: bool - -@dataclass -class Ammo: - quantity: int - name: str - guidance: int - category: int - missile_category: int - -@dataclass -class Contact: - id: int - detection_method: int - -@dataclass -class Offset: - x: float - y: float - z: float \ No newline at end of file diff --git a/scripts/python/API/olympus.json b/scripts/python/API/olympus.json index 2196d9e0..7159c6ee 100644 --- a/scripts/python/API/olympus.json +++ b/scripts/python/API/olympus.json @@ -1,9 +1,10 @@ { "backend": { - "address": "refugees.dcsolympus.com/direct" + "address": "localhost", + "port": 4512 }, "authentication": { - "gameMasterPassword": "a00a5973aacb17e4659125fbe10f4160d096dd84b2f586d2d75669462a30106d", + "gameMasterPassword": "a474219e5e9503c84d59500bb1bda3d9ade81e52d9fa1c234278770892a6dd74", "blueCommanderPassword": "7d2e1ef898b21db7411f725a945b76ec8dcad340ed705eaf801bc82be6fe8a4a", "redCommanderPassword": "abc5de7abdb8ed98f6d11d22c9d17593e339fde9cf4b9e170541b4f41af937e3" }, diff --git a/scripts/python/API/radio/radio_listener.py b/scripts/python/API/radio/radio_listener.py new file mode 100644 index 00000000..6b5765e8 --- /dev/null +++ b/scripts/python/API/radio/radio_listener.py @@ -0,0 +1,414 @@ +""" +Audio Listener Module + +WebSocket-based audio listener for real-time audio communication. +""" + +import asyncio +import random +import websockets +import logging +import threading +from typing import Dict, Optional, Callable, Any +import json +from google.cloud import speech +from google.cloud.speech import SpeechContext + +from audio.audio_packet import AudioPacket, MessageType +from audio.audio_recorder import AudioRecorder +from utils.utils import coalition_to_enum + +import wave +import opuslib +import time + +class RadioListener: + """ + WebSocket audio listener that connects to a specified address and port + to receive audio messages with graceful shutdown handling. + """ + + def __init__(self, api, address: str = "localhost", port: int = 5000): + """ + Initialize the RadioListener. + + Args: + address (str): WebSocket server address + port (int): WebSocket server port + message_callback: Optional callback function for handling received messages + """ + self.api = api + + self.address = address + self.port = port + self.websocket_url = f"ws://{address}:{port}" + self.message_callback = None + self.clients_callback = None + + self.frequency = 0 + self.modulation = 0 + self.encryption = 0 + self.coalition = "blue" + self.speech_contexts = [] + + self.audio_recorders: Dict[str, AudioRecorder] = {} + + # The guid is a random 22 char string, used to identify the radio + self._guid = ''.join(random.choice('abcdefghijklmnopqrstuvwxyz0123456789') for _ in range(22)) + + # Connection and control + self._websocket: Optional[websockets.WebSocketServerProtocol] = None + self._running = False + self._should_stop = False + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._thread: Optional[threading.Thread] = None + + # Clients data + self.clients_data: dict = {} + + # Setup logging + self.logger = logging.getLogger(f"RadioListener-{address}:{port}") + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter('[%(asctime)s] %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.INFO) + + async def _handle_message(self, message: bytes) -> None: + """ + Handle received WebSocket message. + + Args: + message: Raw message from WebSocket + """ + try: + # Extract the first byte to determine message type + message_type = message[0] + + if message_type == MessageType.AUDIO.value: + audio_packet = AudioPacket() + audio_packet.from_byte_array(message[1:]) + + if audio_packet.get_transmission_guid() != self._guid: + if audio_packet.get_transmission_guid() not in self.audio_recorders: + recorder = AudioRecorder(self.api) + self.audio_recorders[audio_packet.get_transmission_guid()] = recorder + recorder.register_recording_callback(self._recording_callback) + + self.audio_recorders[audio_packet.get_transmission_guid()].add_packet(audio_packet) + elif message_type == MessageType.CLIENTS_DATA.value: + clients_data = json.loads(message[1:]) + self.clients_data = clients_data + if self.clients_callback: + self.clients_callback(clients_data) + + except Exception as e: + self.logger.error(f"Error handling message: {e}") + + def _recording_callback(self, wav_filename: str, unit_id: str) -> None: + """ + Callback for when audio data is recorded. + + Args: + recorder: The AudioRecorder instance + audio_data: The recorded audio data + """ + if self.message_callback: + with open(wav_filename, 'rb') as audio_file: + audio_content = audio_file.read() + + client = speech.SpeechClient() + config = speech.RecognitionConfig( + language_code="en", + encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=16000, + speech_contexts=[self.speech_contexts] + ) + audio = speech.RecognitionAudio(content=audio_content) + + # Synchronous speech recognition request + response = client.recognize(config=config, audio=audio) + + # Extract recognized text + recognized_text = " ".join([result.alternatives[0].transcript for result in response.results]) + + self.message_callback(recognized_text, unit_id) + else: + self.logger.warning("No message callback registered to handle recorded audio") + + async def _listen(self) -> None: + """Main WebSocket listening loop.""" + retry_count = 0 + max_retries = 5 + retry_delay = 2.0 + + while not self._should_stop and retry_count < max_retries: + try: + self.logger.info(f"Connecting to WebSocket at {self.websocket_url}") + + async with websockets.connect( + self.websocket_url, + ping_interval=20, + ping_timeout=10, + close_timeout=10 + ) as websocket: + self._websocket = websocket + self._running = True + retry_count = 0 # Reset retry count on successful connection + + self.logger.info("WebSocket connection established") + + # Send the sync radio settings message + await self._sync_radio_settings() + + # Listen for messages + async for message in websocket: + if self._should_stop: + break + await self._handle_message(message) + + except websockets.exceptions.ConnectionClosed: + self.logger.warning("WebSocket connection closed") + if not self._should_stop: + retry_count += 1 + if retry_count < max_retries: + self.logger.info(f"Retrying connection in {retry_delay} seconds... (attempt {retry_count}/{max_retries})") + await asyncio.sleep(retry_delay) + retry_delay = min(retry_delay * 1.5, 30.0) # Exponential backoff, max 30 seconds + else: + self.logger.error("Max retries reached, giving up") + break + except websockets.exceptions.InvalidURI: + self.logger.error(f"Invalid WebSocket URI: {self.websocket_url}") + break + except OSError as e: + self.logger.error(f"Connection error: {e}") + if not self._should_stop: + retry_count += 1 + if retry_count < max_retries: + self.logger.info(f"Retrying connection in {retry_delay} seconds... (attempt {retry_count}/{max_retries})") + await asyncio.sleep(retry_delay) + retry_delay = min(retry_delay * 1.5, 30.0) + else: + self.logger.error("Max retries reached, giving up") + break + except Exception as e: + self.logger.error(f"Unexpected error in WebSocket listener: {e}") + break + + self._running = False + self._websocket = None + self.logger.info("Audio listener stopped") + + def _run_event_loop(self) -> None: + """Run the asyncio event loop in a separate thread.""" + try: + # Create new event loop for this thread + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + # Run the listener + self._loop.run_until_complete(self._listen()) + + except Exception as e: + self.logger.error(f"Error in event loop: {e}") + finally: + # Clean up + if self._loop and not self._loop.is_closed(): + self._loop.close() + self._loop = None + + async def _sync_radio_settings(self): + """Send the radio settings of each radio to the SRS backend""" + message = { + "type": "Settings update", + "guid": self._guid, + "coalition": coalition_to_enum(self.coalition), + "settings": [ + { + "frequency": self.frequency, + "modulation": self.modulation, + "ptt": False, + } + ] + } + + if self._websocket: + message_bytes = json.dumps(message).encode('utf-8') + data = bytes([MessageType.AUDIO.SETTINGS.value]) + message_bytes + await self._websocket.send(data) + + async def _send_message(self, message: Any) -> bool: + """ + Send a message through the WebSocket connection. + + Args: + message: Message to send (will be JSON-encoded if not a string) + + Returns: + bool: True if message was sent successfully, False otherwise + """ + if not self.is_connected(): + self.logger.warning("Cannot send message: WebSocket not connected") + return False + + try: + # Convert message to string if needed + if isinstance(message, str): + data = message + else: + data = json.dumps(message) + + await self._websocket.send(data) + self.logger.debug(f"Sent message: {data}") + return True + + except Exception as e: + self.logger.error(f"Error sending message: {e}") + return False + + def register_message_callback(self, callback: Callable[[str, str], None]) -> None: + """Set the callback function for handling received messages. + Args: + callback (Callable[[str, str], None]): Function to call with recognized text and unit ID""" + self.message_callback = callback + + def register_clients_callback(self, callback: Callable[[dict], None]) -> None: + """Set the callback function for handling clients data.""" + self.clients_callback = callback + + def set_speech_contexts(self, contexts: list[SpeechContext]) -> None: + """ + Set the speech contexts for speech recognition. + + Args: + contexts (list[SpeechContext]): List of SpeechContext objects + """ + self.speech_contexts = contexts + + def start(self, frequency: int, modulation: int, encryption: int) -> None: + """Start the audio listener in a separate thread. + + Args: + frequency (int): Transmission frequency in Hz + modulation (int): Modulation type (0 for AM, 1 for FM, etc.) + encryption (int): Encryption type (0 for none, 1 for simple, etc., TODO) + """ + if self._running or self._thread is not None: + self.logger.warning("RadioListener is already running") + return + + self._should_stop = False + self._thread = threading.Thread(target=self._run_event_loop, daemon=True) + self._thread.start() + + self.logger.info(f"RadioListener started, connecting to {self.websocket_url}") + self.frequency = frequency + self.modulation = modulation + self.encryption = encryption + + def transmit_on_frequency(self, file_name: str, frequency: float, modulation: int, encryption: int) -> bool: + """ + Transmit a WAV file as OPUS frames over the websocket. + Args: + file_name (str): Path to the input WAV file (linear16, mono, 16kHz) + frequency (float): Transmission frequency + modulation (int): Modulation type + encryption (int): Encryption type + Returns: + bool: True if transmission succeeded, False otherwise + """ + + try: + # Open WAV file + with wave.open(file_name, 'rb') as wf: + if wf.getnchannels() != 1 or wf.getframerate() != 16000 or wf.getsampwidth() != 2: + self.logger.error("Input WAV must be mono, 16kHz, 16-bit (linear16)") + return False + frame_size = int(16000 * 0.04) # 40ms frames = 640 samples + encoder = opuslib.Encoder(16000, 1, opuslib.APPLICATION_AUDIO) + packet_id = 0 + while True: + pcm_bytes = wf.readframes(frame_size) + if not pcm_bytes or len(pcm_bytes) < frame_size * 2: + break + # Encode PCM to OPUS + try: + opus_data = encoder.encode(pcm_bytes, frame_size) + except Exception as e: + self.logger.error(f"Opus encoding failed: {e}") + return False + # Create AudioPacket + packet = AudioPacket() + packet.set_packet_id(packet_id) + packet.set_audio_data(opus_data) + packet.set_frequencies([{ + 'frequency': frequency, + 'modulation': modulation, + 'encryption': encryption + }]) + packet.set_transmission_guid(self._guid) + packet.set_client_guid(self._guid) + # Serialize and send over websocket + if self._websocket and self._loop and not self._loop.is_closed(): + data = packet.to_byte_array() + fut = asyncio.run_coroutine_threadsafe(self._websocket.send(data), self._loop) + try: + fut.result(timeout=2.0) + except Exception as send_err: + self.logger.error(f"Failed to send packet {packet_id}: {send_err}") + return False + else: + self.logger.error("WebSocket not connected") + return False + packet_id += 1 + time.sleep(0.04) # Simulate real-time transmission + self.logger.info(f"Transmitted {packet_id} packets from {file_name}") + return True + except Exception as e: + self.logger.error(f"Transmit failed: {e}") + return False + + def stop(self) -> None: + """Stop the audio listener gracefully.""" + if not self._running and self._thread is None: + self.logger.info("RadioListener is not running") + return + + self.logger.info("Stopping RadioListener...") + self._should_stop = True + + # Close WebSocket connection if active + if self._websocket and self._loop: + # Schedule the close in the event loop + if not self._loop.is_closed(): + asyncio.run_coroutine_threadsafe(self._websocket.close(), self._loop) + + # Wait for thread to finish + if self._thread: + self._thread.join(timeout=5.0) + if self._thread.is_alive(): + self.logger.warning("Thread did not stop gracefully within timeout") + self._thread = None + + self._running = False + self.logger.info("RadioListener stopped") + + def is_running(self) -> bool: + """Check if the audio listener is currently running.""" + return self._running + + def is_connected(self) -> bool: + """Check if WebSocket is currently connected.""" + return self._websocket is not None and not self._websocket.closed + + def __enter__(self): + """Context manager entry.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit with graceful shutdown.""" + self.stop() + diff --git a/scripts/python/API/testbed.py b/scripts/python/API/testbed.py new file mode 100644 index 00000000..298eb96b --- /dev/null +++ b/scripts/python/API/testbed.py @@ -0,0 +1,186 @@ +import asyncio +from random import randrange +from api import API, Unit, UnitSpawnTable +from math import pi + +# Setup a logger for the module +import logging +logger = logging.getLogger("TestBed") +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +formatter = logging.Formatter('[%(asctime)s] %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) +logger.addHandler(handler) + +units_to_delete = None + +############################################################################################# +# This class represents a disembarked infantry unit that will engage in combat +# after disembarking from a vehicle. It will move forward and engage the closest enemy. +############################################################################################# +class DisembarkedInfantry(Unit): + def __str__(self): + return f"DisembarkedInfrantry(unit_id={self.unit_id}, group_id={self.group_id}, position={self.position}, heading={self.heading})" + + def start_fighting(self, random_bearing: bool = False): + """ + Start the fighting process for the unit. The unit will go forward 30 meters in the direction of the closest enemy and then start a firefight + with the closest enemy unit. + """ + logger.info(f"Unit {self.unit_id} is now fighting.") + + # Pick a random target + target = self.pick_random_target() + + if random_bearing: + # If random_bearing is True or no target is found, use a random bearing + bearing = randrange(0, 100) / 100 * pi * 2 + elif target is None: + # If no target is found, use the unit's current heading + bearing = self.heading + else: + bearing = self.position.bearing_to(target.position) + + # Project the unit's position 30 meters in front of its current heading + destination = self.position.project_with_bearing_and_distance(30, bearing) + + # Set the destination for the unit + self.set_path([destination]) + + # Register a callback for when the unit reaches its destination + self.register_on_destination_reached_callback( + self.on_destination_reached, + destination, + threshold=15.0, + timeout=30.0 # Timeout after 30 seconds if the destination is not reached + ) + + def pick_random_target(self): + # Find the closest enemy unit + targets = self.api.get_closest_units( + ["neutral", "red" if self.coalition == "blue" else "blue"], + ["groundunit"], + self.position, + "red" if self.coalition == "blue" else "blue", + 10 + ) + # Pick a random enemy from the list + target = targets[randrange(len(targets))] if targets else None + return target + + async def on_destination_reached(self, _, reached: bool): + if not reached: + logger.info(f"Unit {self} did not reach its destination.") + else: + logger.info(f"Unit {self} has reached its destination.") + + target = self.pick_random_target() + + if target is None: + logger.info("No enemies found nearby. Resuming patrol.") + await asyncio.sleep(1) + self.start_fighting(not reached) # Restart the fighting process, randomizing the bearing if not reached + else: + # Compute the bearing to the target + bearing_to_enemy = self.position.bearing_to(target.position) + + # Simulate a firefight in the direction of the enemy + firefight_destination = self.position.project_with_bearing_and_distance(30, bearing_to_enemy) + self.simulate_fire_fight(firefight_destination.lat, firefight_destination.lng, firefight_destination.alt + 1) + + await asyncio.sleep(10) # Simulate some time spent in firefight + self.start_fighting() # Restart the fighting process + +############################################################################################# +# This function is called when the API starts up. It will delete all blue units that are not human and alive. +############################################################################################# +def on_api_startup(api: API): + global units_to_delete + logger.info("API started") + + # Get all the units from the API + units = api.update_units() + + # Initialize the list to hold units to delete + units_to_delete = [] + + # Delete the units + for unit in units.values(): + if unit.alive and not unit.human and unit.coalition == "blue": + units_to_delete.append(unit) + try: + unit.delete_unit(False, "", True) + unit.register_on_property_change_callback("alive", on_unit_alive_change) + + logger.info(f"Deleted unit: {unit}") + except Exception as e: + logger.error(f"Failed to delete unit {unit}: {e}") + +################################################################################################# +# This function is called when a unit's alive property changes. If the unit is deleted, +# it will be removed from the units_to_delete list. If all units are deleted, it will spawn a new unit. +################################################################################################# +def on_unit_alive_change(unit: Unit, value: bool): + global units_to_delete + if units_to_delete is None: + logger.error("units_to_delete is not initialized.") + return + + # Check if the unit has been deleted + if value is False: + if unit in units_to_delete: + units_to_delete.remove(unit) + logger.info(f"Unit {unit} has been deleted and removed from the list.") + else: + logger.warning(f"Unit {unit} is not in the deletion list, but it is marked as dead.") + +############################################################################################## +# This function is called when the API updates. It checks if all units have been deleted and +# if so, it spawns new units near a human unit that is alive and on the ground. +############################################################################################## +def on_api_update(api: API): + global units_to_delete + if units_to_delete is not None and len(units_to_delete) == 0: + logger.info("All units have been deleted successfully.") + units_to_delete = None + + logger.info("Spawning a disembarked infantry units.") + units = api.get_units() + + # Find the first human unit that is alive and on the ground + for unit in units.values(): + if unit.human and unit.alive and not unit.airborne: + for i in range(50): + # Spawn a new unit 10 meters in from of the human + spawn_position = unit.position.project_with_bearing_and_distance(10, unit.heading + pi / 2 + 0.2 * i) + spawn_table: UnitSpawnTable = UnitSpawnTable( + unit_type="Soldier M4", + location=spawn_position, + heading=unit.heading + pi / 2 + 0.2 * i, + skill="High", + livery_id="" + ) + + async def execution_callback(new_group_ID: int): + logger.info(f"New units spawned, groupID: {new_group_ID}") + + units = api.get_units() + for new_unit in units.values(): + if new_unit.group_id == new_group_ID: + logger.info(f"New unit spawned: {new_unit}") + + new_unit.__class__ = DisembarkedInfantry + new_unit.start_fighting() + + api.spawn_ground_units([spawn_table], unit.coalition, "", True, 0, lambda new_group_ID: execution_callback(new_group_ID)) + logger.info(f"Spawned new unit succesfully at {spawn_position} with heading {unit.heading}") + +if __name__ == "__main__": + api = API() + + api.register_on_update_callback(on_api_update) + api.register_on_startup_callback(on_api_startup) + + api.run() + + \ No newline at end of file diff --git a/scripts/python/API/unit.py b/scripts/python/API/unit.py deleted file mode 100644 index 148e284c..00000000 --- a/scripts/python/API/unit.py +++ /dev/null @@ -1,214 +0,0 @@ -from data_extractor import DataExtractor -from data_indexes import DataIndexes -from data_types import LatLng, TACAN, Radio, GeneralSettings, Ammo, Contact, Offset -from typing import List - -from roes import ROES -from states import states -from utils import enum_to_coalition - -class Unit: - def __init__(self, id: int): - self.ID = id - - # Data controlled directly by the backend - self.alive = False - self.alarm_state = "AUTO" - self.human = False - self.controlled = False - self.coalition = "neutral" - self.country = 0 - self.name = "" - self.unit_name = "" - self.callsign = "" - self.group_id = 0 - self.unit_id = 0 - self.group_name = "" - self.state = "" - self.task = "" - self.has_task = False - self.position = LatLng(0, 0, 0) - self.speed = 0.0 - self.horizontal_velocity = 0.0 - self.vertical_velocity = 0.0 - self.heading = 0.0 - self.track = 0.0 - self.is_active_tanker = False - self.is_active_awacs = False - self.on_off = True - self.follow_roads = False - self.fuel = 0 - self.desired_speed = 0.0 - self.desired_speed_type = "CAS" - self.desired_altitude = 0.0 - self.desired_altitude_type = "ASL" - self.leader_id = 0 - self.formation_offset = Offset(0, 0, 0) - self.target_id = 0 - self.target_position = LatLng(0, 0, 0) - self.roe = "" - self.reaction_to_threat = "" - self.emissions_countermeasures = "" - self.tacan = TACAN(False, 0, "X", "TKR") - self.radio = Radio(124000000, 1, 1) - self.general_settings = GeneralSettings(False, False, False, False, False) - self.ammo: List[Ammo] = [] - self.contacts: List[Contact] = [] - self.active_path: List[LatLng] = [] - self.is_leader = False - self.operate_as = "blue" - self.shots_scatter = 2 - self.shots_intensity = 2 - self.health = 100 - self.racetrack_length = 0.0 - self.racetrack_anchor = LatLng(0, 0, 0) - self.racetrack_bearing = 0.0 - self.airborne = False - - self.previous_total_ammo = 0 - self.total_ammo = 0 - - def __repr__(self): - return f"Unit(id={self.ID}, name={self.name}, coalition={self.coalition}, position={self.position})" - - def update_from_data_extractor(self, data_extractor: DataExtractor): - datum_index = 0 - - while datum_index != DataIndexes.END_OF_DATA.value: - datum_index = data_extractor.extract_uint8() - - if datum_index == DataIndexes.CATEGORY.value: - data_extractor.extract_string() - elif datum_index == DataIndexes.ALIVE.value: - self.alive = data_extractor.extract_bool() - elif datum_index == DataIndexes.RADAR_STATE.value: - self.radar_state = data_extractor.extract_bool() - elif datum_index == DataIndexes.HUMAN.value: - self.human = data_extractor.extract_bool() - elif datum_index == DataIndexes.CONTROLLED.value: - self.controlled = data_extractor.extract_bool() - elif datum_index == DataIndexes.COALITION.value: - new_coalition = enum_to_coalition(data_extractor.extract_uint8()) - self.coalition = new_coalition - elif datum_index == DataIndexes.COUNTRY.value: - self.country = data_extractor.extract_uint8() - elif datum_index == DataIndexes.NAME.value: - self.name = data_extractor.extract_string() - elif datum_index == DataIndexes.UNIT_NAME.value: - self.unit_name = data_extractor.extract_string() - elif datum_index == DataIndexes.CALLSIGN.value: - self.callsign = data_extractor.extract_string() - elif datum_index == DataIndexes.UNIT_ID.value: - self.unit_id = data_extractor.extract_uint32() - elif datum_index == DataIndexes.GROUP_ID.value: - self.group_id = data_extractor.extract_uint32() - elif datum_index == DataIndexes.GROUP_NAME.value: - self.group_name = data_extractor.extract_string() - elif datum_index == DataIndexes.STATE.value: - self.state = states[data_extractor.extract_uint8()] - elif datum_index == DataIndexes.TASK.value: - self.task = data_extractor.extract_string() - elif datum_index == DataIndexes.HAS_TASK.value: - self.has_task = data_extractor.extract_bool() - elif datum_index == DataIndexes.POSITION.value: - self.position = data_extractor.extract_lat_lng() - elif datum_index == DataIndexes.SPEED.value: - self.speed = data_extractor.extract_float64() - elif datum_index == DataIndexes.HORIZONTAL_VELOCITY.value: - self.horizontal_velocity = data_extractor.extract_float64() - elif datum_index == DataIndexes.VERTICAL_VELOCITY.value: - self.vertical_velocity = data_extractor.extract_float64() - elif datum_index == DataIndexes.HEADING.value: - self.heading = data_extractor.extract_float64() - elif datum_index == DataIndexes.TRACK.value: - self.track = data_extractor.extract_float64() - elif datum_index == DataIndexes.IS_ACTIVE_TANKER.value: - self.is_active_tanker = data_extractor.extract_bool() - elif datum_index == DataIndexes.IS_ACTIVE_AWACS.value: - self.is_active_awacs = data_extractor.extract_bool() - elif datum_index == DataIndexes.ON_OFF.value: - self.on_off = data_extractor.extract_bool() - elif datum_index == DataIndexes.FOLLOW_ROADS.value: - self.follow_roads = data_extractor.extract_bool() - elif datum_index == DataIndexes.FUEL.value: - self.fuel = data_extractor.extract_uint16() - elif datum_index == DataIndexes.DESIRED_SPEED.value: - self.desired_speed = data_extractor.extract_float64() - elif datum_index == DataIndexes.DESIRED_SPEED_TYPE.value: - self.desired_speed_type = "GS" if data_extractor.extract_bool() else "CAS" - elif datum_index == DataIndexes.DESIRED_ALTITUDE.value: - self.desired_altitude = data_extractor.extract_float64() - elif datum_index == DataIndexes.DESIRED_ALTITUDE_TYPE.value: - self.desired_altitude_type = "AGL" if data_extractor.extract_bool() else "ASL" - elif datum_index == DataIndexes.LEADER_ID.value: - self.leader_id = data_extractor.extract_uint32() - elif datum_index == DataIndexes.FORMATION_OFFSET.value: - self.formation_offset = data_extractor.extract_offset() - elif datum_index == DataIndexes.TARGET_ID.value: - self.target_id = data_extractor.extract_uint32() - elif datum_index == DataIndexes.TARGET_POSITION.value: - self.target_position = data_extractor.extract_lat_lng() - elif datum_index == DataIndexes.ROE.value: - self.roe = ROES[data_extractor.extract_uint8()] - elif datum_index == DataIndexes.ALARM_STATE.value: - self.alarm_state = self.enum_to_alarm_state(data_extractor.extract_uint8()) - elif datum_index == DataIndexes.REACTION_TO_THREAT.value: - self.reaction_to_threat = self.enum_to_reaction_to_threat(data_extractor.extract_uint8()) - elif datum_index == DataIndexes.EMISSIONS_COUNTERMEASURES.value: - self.emissions_countermeasures = self.enum_to_emission_countermeasure(data_extractor.extract_uint8()) - elif datum_index == DataIndexes.TACAN.value: - self.tacan = data_extractor.extract_tacan() - elif datum_index == DataIndexes.RADIO.value: - self.radio = data_extractor.extract_radio() - elif datum_index == DataIndexes.GENERAL_SETTINGS.value: - self.general_settings = data_extractor.extract_general_settings() - elif datum_index == DataIndexes.AMMO.value: - self.ammo = data_extractor.extract_ammo() - self.previous_total_ammo = self.total_ammo - self.total_ammo = sum(ammo.quantity for ammo in self.ammo) - elif datum_index == DataIndexes.CONTACTS.value: - self.contacts = data_extractor.extract_contacts() - elif datum_index == DataIndexes.ACTIVE_PATH.value: - self.active_path = data_extractor.extract_active_path() - elif datum_index == DataIndexes.IS_LEADER.value: - self.is_leader = data_extractor.extract_bool() - elif datum_index == DataIndexes.OPERATE_AS.value: - self.operate_as = self.enum_to_coalition(data_extractor.extract_uint8()) - elif datum_index == DataIndexes.SHOTS_SCATTER.value: - self.shots_scatter = data_extractor.extract_uint8() - elif datum_index == DataIndexes.SHOTS_INTENSITY.value: - self.shots_intensity = data_extractor.extract_uint8() - elif datum_index == DataIndexes.HEALTH.value: - self.health = data_extractor.extract_uint8() - elif datum_index == DataIndexes.RACETRACK_LENGTH.value: - self.racetrack_length = data_extractor.extract_float64() - elif datum_index == DataIndexes.RACETRACK_ANCHOR.value: - self.racetrack_anchor = data_extractor.extract_lat_lng() - elif datum_index == DataIndexes.RACETRACK_BEARING.value: - self.racetrack_bearing = data_extractor.extract_float64() - elif datum_index == DataIndexes.TIME_TO_NEXT_TASKING.value: - self.time_to_next_tasking = data_extractor.extract_float64() - elif datum_index == DataIndexes.BARREL_HEIGHT.value: - self.barrel_height = data_extractor.extract_float64() - elif datum_index == DataIndexes.MUZZLE_VELOCITY.value: - self.muzzle_velocity = data_extractor.extract_float64() - elif datum_index == DataIndexes.AIM_TIME.value: - self.aim_time = data_extractor.extract_float64() - elif datum_index == DataIndexes.SHOTS_TO_FIRE.value: - self.shots_to_fire = data_extractor.extract_uint32() - elif datum_index == DataIndexes.SHOTS_BASE_INTERVAL.value: - self.shots_base_interval = data_extractor.extract_float64() - elif datum_index == DataIndexes.SHOTS_BASE_SCATTER.value: - self.shots_base_scatter = data_extractor.extract_float64() - elif datum_index == DataIndexes.ENGAGEMENT_RANGE.value: - self.engagement_range = data_extractor.extract_float64() - elif datum_index == DataIndexes.TARGETING_RANGE.value: - self.targeting_range = data_extractor.extract_float64() - elif datum_index == DataIndexes.AIM_METHOD_RANGE.value: - self.aim_method_range = data_extractor.extract_float64() - elif datum_index == DataIndexes.ACQUISITION_RANGE.value: - self.acquisition_range = data_extractor.extract_float64() - elif datum_index == DataIndexes.AIRBORNE.value: - self.airborne = data_extractor.extract_bool() - - \ No newline at end of file diff --git a/scripts/python/API/unit/temp_replace.py b/scripts/python/API/unit/temp_replace.py new file mode 100644 index 00000000..8fe34a54 --- /dev/null +++ b/scripts/python/API/unit/temp_replace.py @@ -0,0 +1,18 @@ +import re + +# Read the file +with open('unit.py', 'r', encoding='utf-8') as f: + content = f.read() + +# Pattern to match callback invocations +pattern = r'self\.on_property_change_callbacks\[\"(\w+)\"\]\(self, self\.(\w+)\)' +replacement = r'self._trigger_callback("\1", self.\2)' + +# Replace all matches +new_content = re.sub(pattern, replacement, content) + +# Write back to file +with open('unit.py', 'w', encoding='utf-8') as f: + f.write(new_content) + +print('Updated all callback invocations') diff --git a/scripts/python/API/unit/unit.py b/scripts/python/API/unit/unit.py new file mode 100644 index 00000000..7ce5a759 --- /dev/null +++ b/scripts/python/API/unit/unit.py @@ -0,0 +1,763 @@ +from typing import List +import asyncio + +from data.data_extractor import DataExtractor +from data.data_indexes import DataIndexes +from data.data_types import LatLng, TACAN, Radio, GeneralSettings, Ammo, Contact, Offset +from data.roes import ROES +from data.states import states +from utils.utils import enum_to_coalition + +class Unit: + def __init__(self, id: int, api): + from api import API + + self.ID = id + self.api: API = api + + # Data controlled directly by the backend + self.category = "" + self.alive = False + self.alarm_state = "AUTO" + self.human = False + self.controlled = False + self.coalition = "neutral" + self.country = 0 + self.name = "" + self.unit_name = "" + self.callsign = "" + self.group_id = 0 + self.unit_id = 0 + self.group_name = "" + self.state = "" + self.task = "" + self.has_task = False + self.position = LatLng(0, 0, 0) + self.speed = 0.0 + self.horizontal_velocity = 0.0 + self.vertical_velocity = 0.0 + self.heading = 0.0 + self.track = 0.0 + self.is_active_tanker = False + self.is_active_awacs = False + self.on_off = True + self.follow_roads = False + self.fuel = 0 + self.desired_speed = 0.0 + self.desired_speed_type = "CAS" + self.desired_altitude = 0.0 + self.desired_altitude_type = "ASL" + self.leader_id = 0 + self.formation_offset = Offset(0, 0, 0) + self.target_id = 0 + self.target_position = LatLng(0, 0, 0) + self.roe = "" + self.reaction_to_threat = "" + self.emissions_countermeasures = "" + self.tacan = TACAN(False, 0, "X", "TKR") + self.radio = Radio(124000000, 1, 1) + self.general_settings = GeneralSettings(False, False, False, False, False) + self.ammo: List[Ammo] = [] + self.contacts: List[Contact] = [] + self.active_path: List[LatLng] = [] + self.is_leader = False + self.operate_as = "blue" + self.shots_scatter = 2 + self.shots_intensity = 2 + self.health = 100 + self.racetrack_length = 0.0 + self.racetrack_anchor = LatLng(0, 0, 0) + self.racetrack_bearing = 0.0 + self.airborne = False + self.radar_state = False + self.time_to_next_tasking = 0.0 + self.barrel_height = 0.0 + self.muzzle_velocity = 0.0 + self.aim_time = 0.0 + self.shots_to_fire = 0 + self.shots_base_interval = 0.0 + self.shots_base_scatter = 0.0 + self.engagement_range = 0.0 + self.targeting_range = 0.0 + self.aim_method_range = 0.0 + self.acquisition_range = 0.0 + + self.previous_total_ammo = 0 + self.total_ammo = 0 + + self.on_property_change_callbacks = {} + self.on_destination_reached_callback = None + self.destination = None + self.destination_reached_threshold = 10 + self.destination_reached_timeout = None + self.destination_reached_start_time = None + + def __repr__(self): + return f"Unit(id={self.ID}, name={self.name}, coalition={self.coalition}, position={self.position})" + + def register_on_property_change_callback(self, property_name: str, callback): + """ + Register a callback function that will be called when a property changes. + Args: + property_name (str): The name of the property to watch. + callback (function): The function to call when the property changes. The callback should accept two parameters: the unit and the new value of the property. + """ + if property_name not in self.on_property_change_callbacks: + self.on_property_change_callbacks[property_name] = callback + + def unregister_on_property_change_callback(self, property_name: str): + """ + Unregister a callback function for a property. + Args: + property_name (str): The name of the property to stop watching. + """ + if property_name in self.on_property_change_callbacks: + del self.on_property_change_callbacks[property_name] + + def register_on_destination_reached_callback(self, callback, destination: LatLng, threshold: float = 10, timeout: float = None): + """ + Register a callback function that will be called when the unit reaches its destination. + If the destination is not reached within the specified timeout, the callback will also be called with `False`. + + Args: + callback (function): The function to call when the destination is reached. The callback should accept two parameters: the unit and a boolean indicating whether the destination was reached. + destination (LatLng): The destination that the unit is expected to reach. + threshold (float): The distance threshold in meters to consider the destination reached. Default is 10 meters. + """ + self.on_destination_reached_callback = callback + self.destination = destination + self.destination_reached_threshold = threshold + self.destination_reached_timeout = timeout + self.destination_reached_start_time = asyncio.get_event_loop().time() if timeout else None + + def unregister_on_destination_reached_callback(self): + """ + Unregister the callback function for destination reached. + """ + self.on_destination_reached_callback = None + self.destination = None + + def _trigger_callback(self, property_name: str, value): + """ + Trigger a property change callback, executing it in the asyncio event loop if available. + Args: + property_name (str): The name of the property that changed. + value: The new value of the property. + """ + if property_name in self.on_property_change_callbacks: + callback = self.on_property_change_callbacks[property_name] + try: + # Try to get the current event loop and schedule the callback + loop = asyncio.get_running_loop() + loop.create_task(self._run_callback_async(callback, self, value)) + except RuntimeError: + # No event loop running, execute synchronously + callback(self, value) + + async def _run_callback_async(self, callback, *args): + """ + Run a callback asynchronously, handling both sync and async callbacks. + """ + try: + if asyncio.iscoroutinefunction(callback): + await callback(*args) + else: + callback(*args) + except Exception as e: + # Log the error but don't crash the update process + print(f"Error in property change callback: {e}") + + def _trigger_destination_reached_callback(self, reached: bool): + """ + Trigger the destination reached callback, executing it in the asyncio event loop if available. + Args: + reached (bool): Whether the destination was reached or not. + """ + if self.on_destination_reached_callback: + try: + # Try to get the current event loop and schedule the callback + loop = asyncio.get_running_loop() + loop.create_task(self._run_callback_async(self.on_destination_reached_callback, self, reached)) + except RuntimeError: + # No event loop running, execute synchronously + self.on_destination_reached_callback(self, reached) + + def update_from_data_extractor(self, data_extractor: DataExtractor): + datum_index = 0 + + while datum_index != DataIndexes.END_OF_DATA.value: + datum_index = data_extractor.extract_uint8() + + if datum_index == DataIndexes.CATEGORY.value: + category = data_extractor.extract_string() + if category != self.category: + self.category = category + # Trigger callbacks for property change + if "category" in self.on_property_change_callbacks: + self._trigger_callback("category", self.category) + elif datum_index == DataIndexes.ALIVE.value: + alive = data_extractor.extract_bool() + if alive != self.alive: + self.alive = alive + # Trigger callbacks for property change + if "alive" in self.on_property_change_callbacks: + self._trigger_callback("alive", self.alive) + elif datum_index == DataIndexes.RADAR_STATE.value: + radar_state = data_extractor.extract_bool() + if radar_state != self.radar_state: + self.radar_state = radar_state + # Trigger callbacks for property change + if "radar_state" in self.on_property_change_callbacks: + self._trigger_callback("radar_state", self.radar_state) + elif datum_index == DataIndexes.HUMAN.value: + human = data_extractor.extract_bool() + if human != self.human: + self.human = human + # Trigger callbacks for property change + if "human" in self.on_property_change_callbacks: + self._trigger_callback("human", self.human) + elif datum_index == DataIndexes.CONTROLLED.value: + controlled = data_extractor.extract_bool() + if controlled != self.controlled: + self.controlled = controlled + # Trigger callbacks for property change + if "controlled" in self.on_property_change_callbacks: + self._trigger_callback("controlled", self.controlled) + elif datum_index == DataIndexes.COALITION.value: + coalition = enum_to_coalition(data_extractor.extract_uint8()) + if coalition != self.coalition: + self.coalition = coalition + # Trigger callbacks for property change + if "coalition" in self.on_property_change_callbacks: + self._trigger_callback("coalition", self.coalition) + elif datum_index == DataIndexes.COUNTRY.value: + country = data_extractor.extract_uint8() + if country != self.country: + self.country = country + # Trigger callbacks for property change + if "country" in self.on_property_change_callbacks: + self._trigger_callback("country", self.country) + elif datum_index == DataIndexes.NAME.value: + name = data_extractor.extract_string() + if name != self.name: + self.name = name + # Trigger callbacks for property change + if "name" in self.on_property_change_callbacks: + self._trigger_callback("name", self.name) + elif datum_index == DataIndexes.UNIT_NAME.value: + unit_name = data_extractor.extract_string() + if unit_name != self.unit_name: + self.unit_name = unit_name + # Trigger callbacks for property change + if "unit_name" in self.on_property_change_callbacks: + self._trigger_callback("unit_name", self.unit_name) + elif datum_index == DataIndexes.CALLSIGN.value: + callsign = data_extractor.extract_string() + if callsign != self.callsign: + self.callsign = callsign + # Trigger callbacks for property change + if "callsign" in self.on_property_change_callbacks: + self._trigger_callback("callsign", self.callsign) + elif datum_index == DataIndexes.UNIT_ID.value: + unit_id = data_extractor.extract_uint32() + if unit_id != self.unit_id: + self.unit_id = unit_id + # Trigger callbacks for property change + if "unit_id" in self.on_property_change_callbacks: + self._trigger_callback("unit_id", self.unit_id) + elif datum_index == DataIndexes.GROUP_ID.value: + group_id = data_extractor.extract_uint32() + if group_id != self.group_id: + self.group_id = group_id + # Trigger callbacks for property change + if "group_id" in self.on_property_change_callbacks: + self._trigger_callback("group_id", self.group_id) + elif datum_index == DataIndexes.GROUP_NAME.value: + group_name = data_extractor.extract_string() + if group_name != self.group_name: + self.group_name = group_name + # Trigger callbacks for property change + if "group_name" in self.on_property_change_callbacks: + self._trigger_callback("group_name", self.group_name) + elif datum_index == DataIndexes.STATE.value: + state = states[data_extractor.extract_uint8()] + if state != self.state: + self.state = state + # Trigger callbacks for property change + if "state" in self.on_property_change_callbacks: + self._trigger_callback("state", self.state) + elif datum_index == DataIndexes.TASK.value: + task = data_extractor.extract_string() + if task != self.task: + self.task = task + # Trigger callbacks for property change + if "task" in self.on_property_change_callbacks: + self._trigger_callback("task", self.task) + elif datum_index == DataIndexes.HAS_TASK.value: + has_task = data_extractor.extract_bool() + if has_task != self.has_task: + self.has_task = has_task + # Trigger callbacks for property change + if "has_task" in self.on_property_change_callbacks: + self._trigger_callback("has_task", self.has_task) + elif datum_index == DataIndexes.POSITION.value: + position = data_extractor.extract_lat_lng() + if position != self.position: + self.position = position + # Trigger callbacks for property change + if "position" in self.on_property_change_callbacks: + self._trigger_callback("position", self.position) + + if self.on_destination_reached_callback and self.destination: + reached = self.position.distance_to(self.destination) < self.destination_reached_threshold + if reached or ( + self.destination_reached_timeout and + (asyncio.get_event_loop().time() - self.destination_reached_start_time) > self.destination_reached_timeout + ): + self._trigger_destination_reached_callback(reached) + self.unregister_on_destination_reached_callback() + elif datum_index == DataIndexes.SPEED.value: + speed = data_extractor.extract_float64() + if speed != self.speed: + self.speed = speed + # Trigger callbacks for property change + if "speed" in self.on_property_change_callbacks: + self._trigger_callback("speed", self.speed) + elif datum_index == DataIndexes.HORIZONTAL_VELOCITY.value: + horizontal_velocity = data_extractor.extract_float64() + if horizontal_velocity != self.horizontal_velocity: + self.horizontal_velocity = horizontal_velocity + # Trigger callbacks for property change + if "horizontal_velocity" in self.on_property_change_callbacks: + self._trigger_callback("horizontal_velocity", self.horizontal_velocity) + elif datum_index == DataIndexes.VERTICAL_VELOCITY.value: + vertical_velocity = data_extractor.extract_float64() + if vertical_velocity != self.vertical_velocity: + self.vertical_velocity = vertical_velocity + # Trigger callbacks for property change + if "vertical_velocity" in self.on_property_change_callbacks: + self._trigger_callback("vertical_velocity", self.vertical_velocity) + elif datum_index == DataIndexes.HEADING.value: + heading = data_extractor.extract_float64() + if heading != self.heading: + self.heading = heading + # Trigger callbacks for property change + if "heading" in self.on_property_change_callbacks: + self._trigger_callback("heading", self.heading) + elif datum_index == DataIndexes.TRACK.value: + track = data_extractor.extract_float64() + if track != self.track: + self.track = track + # Trigger callbacks for property change + if "track" in self.on_property_change_callbacks: + self._trigger_callback("track", self.track) + elif datum_index == DataIndexes.IS_ACTIVE_TANKER.value: + is_active_tanker = data_extractor.extract_bool() + if is_active_tanker != self.is_active_tanker: + self.is_active_tanker = is_active_tanker + # Trigger callbacks for property change + if "is_active_tanker" in self.on_property_change_callbacks: + self._trigger_callback("is_active_tanker", self.is_active_tanker) + elif datum_index == DataIndexes.IS_ACTIVE_AWACS.value: + is_active_awacs = data_extractor.extract_bool() + if is_active_awacs != self.is_active_awacs: + self.is_active_awacs = is_active_awacs + # Trigger callbacks for property change + if "is_active_awacs" in self.on_property_change_callbacks: + self._trigger_callback("is_active_awacs", self.is_active_awacs) + elif datum_index == DataIndexes.ON_OFF.value: + on_off = data_extractor.extract_bool() + if on_off != self.on_off: + self.on_off = on_off + # Trigger callbacks for property change + if "on_off" in self.on_property_change_callbacks: + self._trigger_callback("on_off", self.on_off) + elif datum_index == DataIndexes.FOLLOW_ROADS.value: + follow_roads = data_extractor.extract_bool() + if follow_roads != self.follow_roads: + self.follow_roads = follow_roads + # Trigger callbacks for property change + if "follow_roads" in self.on_property_change_callbacks: + self._trigger_callback("follow_roads", self.follow_roads) + elif datum_index == DataIndexes.FUEL.value: + fuel = data_extractor.extract_uint16() + if fuel != self.fuel: + self.fuel = fuel + # Trigger callbacks for property change + if "fuel" in self.on_property_change_callbacks: + self._trigger_callback("fuel", self.fuel) + elif datum_index == DataIndexes.DESIRED_SPEED.value: + desired_speed = data_extractor.extract_float64() + if desired_speed != self.desired_speed: + self.desired_speed = desired_speed + # Trigger callbacks for property change + if "desired_speed" in self.on_property_change_callbacks: + self._trigger_callback("desired_speed", self.desired_speed) + elif datum_index == DataIndexes.DESIRED_SPEED_TYPE.value: + desired_speed_type = "GS" if data_extractor.extract_bool() else "CAS" + if desired_speed_type != self.desired_speed_type: + self.desired_speed_type = desired_speed_type + # Trigger callbacks for property change + if "desired_speed_type" in self.on_property_change_callbacks: + self._trigger_callback("desired_speed_type", self.desired_speed_type) + elif datum_index == DataIndexes.DESIRED_ALTITUDE.value: + desired_altitude = data_extractor.extract_float64() + if desired_altitude != self.desired_altitude: + self.desired_altitude = desired_altitude + # Trigger callbacks for property change + if "desired_altitude" in self.on_property_change_callbacks: + self._trigger_callback("desired_altitude", self.desired_altitude) + elif datum_index == DataIndexes.DESIRED_ALTITUDE_TYPE.value: + desired_altitude_type = "AGL" if data_extractor.extract_bool() else "ASL" + if desired_altitude_type != self.desired_altitude_type: + self.desired_altitude_type = desired_altitude_type + # Trigger callbacks for property change + if "desired_altitude_type" in self.on_property_change_callbacks: + self._trigger_callback("desired_altitude_type", self.desired_altitude_type) + elif datum_index == DataIndexes.LEADER_ID.value: + leader_id = data_extractor.extract_uint32() + if leader_id != self.leader_id: + self.leader_id = leader_id + # Trigger callbacks for property change + if "leader_id" in self.on_property_change_callbacks: + self._trigger_callback("leader_id", self.leader_id) + elif datum_index == DataIndexes.FORMATION_OFFSET.value: + formation_offset = data_extractor.extract_offset() + if formation_offset != self.formation_offset: + self.formation_offset = formation_offset + # Trigger callbacks for property change + if "formation_offset" in self.on_property_change_callbacks: + self._trigger_callback("formation_offset", self.formation_offset) + elif datum_index == DataIndexes.TARGET_ID.value: + target_id = data_extractor.extract_uint32() + if target_id != self.target_id: + self.target_id = target_id + # Trigger callbacks for property change + if "target_id" in self.on_property_change_callbacks: + self._trigger_callback("target_id", self.target_id) + elif datum_index == DataIndexes.TARGET_POSITION.value: + target_position = data_extractor.extract_lat_lng() + if target_position != self.target_position: + self.target_position = target_position + # Trigger callbacks for property change + if "target_position" in self.on_property_change_callbacks: + self._trigger_callback("target_position", self.target_position) + elif datum_index == DataIndexes.ROE.value: + roe = ROES[data_extractor.extract_uint8()] + if roe != self.roe: + self.roe = roe + # Trigger callbacks for property change + if "roe" in self.on_property_change_callbacks: + self._trigger_callback("roe", self.roe) + elif datum_index == DataIndexes.ALARM_STATE.value: + alarm_state = self.enum_to_alarm_state(data_extractor.extract_uint8()) + if alarm_state != self.alarm_state: + self.alarm_state = alarm_state + # Trigger callbacks for property change + if "alarm_state" in self.on_property_change_callbacks: + self._trigger_callback("alarm_state", self.alarm_state) + elif datum_index == DataIndexes.REACTION_TO_THREAT.value: + reaction_to_threat = self.enum_to_reaction_to_threat(data_extractor.extract_uint8()) + if reaction_to_threat != self.reaction_to_threat: + self.reaction_to_threat = reaction_to_threat + # Trigger callbacks for property change + if "reaction_to_threat" in self.on_property_change_callbacks: + self._trigger_callback("reaction_to_threat", self.reaction_to_threat) + elif datum_index == DataIndexes.EMISSIONS_COUNTERMEASURES.value: + emissions_countermeasures = self.enum_to_emission_countermeasure(data_extractor.extract_uint8()) + if emissions_countermeasures != self.emissions_countermeasures: + self.emissions_countermeasures = emissions_countermeasures + # Trigger callbacks for property change + if "emissions_countermeasures" in self.on_property_change_callbacks: + self._trigger_callback("emissions_countermeasures", self.emissions_countermeasures) + elif datum_index == DataIndexes.TACAN.value: + tacan = data_extractor.extract_tacan() + if tacan != self.tacan: + self.tacan = tacan + # Trigger callbacks for property change + if "tacan" in self.on_property_change_callbacks: + self._trigger_callback("tacan", self.tacan) + elif datum_index == DataIndexes.RADIO.value: + radio = data_extractor.extract_radio() + if radio != self.radio: + self.radio = radio + # Trigger callbacks for property change + if "radio" in self.on_property_change_callbacks: + self._trigger_callback("radio", self.radio) + elif datum_index == DataIndexes.GENERAL_SETTINGS.value: + general_settings = data_extractor.extract_general_settings() + if general_settings != self.general_settings: + self.general_settings = general_settings + # Trigger callbacks for property change + if "general_settings" in self.on_property_change_callbacks: + self._trigger_callback("general_settings", self.general_settings) + elif datum_index == DataIndexes.AMMO.value: + ammo = data_extractor.extract_ammo() + if ammo != self.ammo: + self.ammo = ammo + self.previous_total_ammo = self.total_ammo + self.total_ammo = sum(ammo.quantity for ammo in self.ammo) + # Trigger callbacks for property change + if "ammo" in self.on_property_change_callbacks: + self._trigger_callback("ammo", self.ammo) + elif datum_index == DataIndexes.CONTACTS.value: + contacts = data_extractor.extract_contacts() + if contacts != self.contacts: + self.contacts = contacts + # Trigger callbacks for property change + if "contacts" in self.on_property_change_callbacks: + self._trigger_callback("contacts", self.contacts) + elif datum_index == DataIndexes.ACTIVE_PATH.value: + active_path = data_extractor.extract_active_path() + if active_path != self.active_path: + self.active_path = active_path + # Trigger callbacks for property change + if "active_path" in self.on_property_change_callbacks: + self._trigger_callback("active_path", self.active_path) + elif datum_index == DataIndexes.IS_LEADER.value: + is_leader = data_extractor.extract_bool() + if is_leader != self.is_leader: + self.is_leader = is_leader + # Trigger callbacks for property change + if "is_leader" in self.on_property_change_callbacks: + self._trigger_callback("is_leader", self.is_leader) + elif datum_index == DataIndexes.OPERATE_AS.value: + operate_as = enum_to_coalition(data_extractor.extract_uint8()) + if operate_as != self.operate_as: + self.operate_as = operate_as + # Trigger callbacks for property change + if "operate_as" in self.on_property_change_callbacks: + self._trigger_callback("operate_as", self.operate_as) + elif datum_index == DataIndexes.SHOTS_SCATTER.value: + shots_scatter = data_extractor.extract_uint8() + if shots_scatter != self.shots_scatter: + self.shots_scatter = shots_scatter + # Trigger callbacks for property change + if "shots_scatter" in self.on_property_change_callbacks: + self._trigger_callback("shots_scatter", self.shots_scatter) + elif datum_index == DataIndexes.SHOTS_INTENSITY.value: + shots_intensity = data_extractor.extract_uint8() + if shots_intensity != self.shots_intensity: + self.shots_intensity = shots_intensity + # Trigger callbacks for property change + if "shots_intensity" in self.on_property_change_callbacks: + self._trigger_callback("shots_intensity", self.shots_intensity) + elif datum_index == DataIndexes.HEALTH.value: + health = data_extractor.extract_uint8() + if health != self.health: + self.health = health + # Trigger callbacks for property change + if "health" in self.on_property_change_callbacks: + self._trigger_callback("health", self.health) + elif datum_index == DataIndexes.RACETRACK_LENGTH.value: + racetrack_length = data_extractor.extract_float64() + if racetrack_length != self.racetrack_length: + self.racetrack_length = racetrack_length + # Trigger callbacks for property change + if "racetrack_length" in self.on_property_change_callbacks: + self._trigger_callback("racetrack_length", self.racetrack_length) + elif datum_index == DataIndexes.RACETRACK_ANCHOR.value: + racetrack_anchor = data_extractor.extract_lat_lng() + if racetrack_anchor != self.racetrack_anchor: + self.racetrack_anchor = racetrack_anchor + # Trigger callbacks for property change + if "racetrack_anchor" in self.on_property_change_callbacks: + self._trigger_callback("racetrack_anchor", self.racetrack_anchor) + elif datum_index == DataIndexes.RACETRACK_BEARING.value: + racetrack_bearing = data_extractor.extract_float64() + if racetrack_bearing != self.racetrack_bearing: + self.racetrack_bearing = racetrack_bearing + # Trigger callbacks for property change + if "racetrack_bearing" in self.on_property_change_callbacks: + self._trigger_callback("racetrack_bearing", self.racetrack_bearing) + elif datum_index == DataIndexes.TIME_TO_NEXT_TASKING.value: + time_to_next_tasking = data_extractor.extract_float64() + if time_to_next_tasking != self.time_to_next_tasking: + self.time_to_next_tasking = time_to_next_tasking + # Trigger callbacks for property change + if "time_to_next_tasking" in self.on_property_change_callbacks: + self._trigger_callback("time_to_next_tasking", self.time_to_next_tasking) + elif datum_index == DataIndexes.BARREL_HEIGHT.value: + barrel_height = data_extractor.extract_float64() + if barrel_height != self.barrel_height: + self.barrel_height = barrel_height + # Trigger callbacks for property change + if "barrel_height" in self.on_property_change_callbacks: + self._trigger_callback("barrel_height", self.barrel_height) + elif datum_index == DataIndexes.MUZZLE_VELOCITY.value: + muzzle_velocity = data_extractor.extract_float64() + if muzzle_velocity != self.muzzle_velocity: + self.muzzle_velocity = muzzle_velocity + # Trigger callbacks for property change + if "muzzle_velocity" in self.on_property_change_callbacks: + self._trigger_callback("muzzle_velocity", self.muzzle_velocity) + elif datum_index == DataIndexes.AIM_TIME.value: + aim_time = data_extractor.extract_float64() + if aim_time != self.aim_time: + self.aim_time = aim_time + # Trigger callbacks for property change + if "aim_time" in self.on_property_change_callbacks: + self._trigger_callback("aim_time", self.aim_time) + elif datum_index == DataIndexes.SHOTS_TO_FIRE.value: + shots_to_fire = data_extractor.extract_uint32() + if shots_to_fire != self.shots_to_fire: + self.shots_to_fire = shots_to_fire + # Trigger callbacks for property change + if "shots_to_fire" in self.on_property_change_callbacks: + self._trigger_callback("shots_to_fire", self.shots_to_fire) + elif datum_index == DataIndexes.SHOTS_BASE_INTERVAL.value: + shots_base_interval = data_extractor.extract_float64() + if shots_base_interval != self.shots_base_interval: + self.shots_base_interval = shots_base_interval + # Trigger callbacks for property change + if "shots_base_interval" in self.on_property_change_callbacks: + self._trigger_callback("shots_base_interval", self.shots_base_interval) + elif datum_index == DataIndexes.SHOTS_BASE_SCATTER.value: + shots_base_scatter = data_extractor.extract_float64() + if shots_base_scatter != self.shots_base_scatter: + self.shots_base_scatter = shots_base_scatter + # Trigger callbacks for property change + if "shots_base_scatter" in self.on_property_change_callbacks: + self._trigger_callback("shots_base_scatter", self.shots_base_scatter) + elif datum_index == DataIndexes.ENGAGEMENT_RANGE.value: + engagement_range = data_extractor.extract_float64() + if engagement_range != self.engagement_range: + self.engagement_range = engagement_range + # Trigger callbacks for property change + if "engagement_range" in self.on_property_change_callbacks: + self._trigger_callback("engagement_range", self.engagement_range) + elif datum_index == DataIndexes.TARGETING_RANGE.value: + targeting_range = data_extractor.extract_float64() + if targeting_range != self.targeting_range: + self.targeting_range = targeting_range + # Trigger callbacks for property change + if "targeting_range" in self.on_property_change_callbacks: + self._trigger_callback("targeting_range", self.targeting_range) + elif datum_index == DataIndexes.AIM_METHOD_RANGE.value: + aim_method_range = data_extractor.extract_float64() + if aim_method_range != self.aim_method_range: + self.aim_method_range = aim_method_range + # Trigger callbacks for property change + if "aim_method_range" in self.on_property_change_callbacks: + self._trigger_callback("aim_method_range", self.aim_method_range) + elif datum_index == DataIndexes.ACQUISITION_RANGE.value: + acquisition_range = data_extractor.extract_float64() + if acquisition_range != self.acquisition_range: + self.acquisition_range = acquisition_range + # Trigger callbacks for property change + if "acquisition_range" in self.on_property_change_callbacks: + self._trigger_callback("acquisition_range", self.acquisition_range) + elif datum_index == DataIndexes.AIRBORNE.value: + airborne = data_extractor.extract_bool() + if airborne != self.airborne: + self.airborne = airborne + # Trigger callbacks for property change + if "airborne" in self.on_property_change_callbacks: + self._trigger_callback("airborne", self.airborne) + + # --- API functions requiring ID --- + def set_path(self, path: List[LatLng]): + return self.api.send_command({"setPath": {"ID": self.ID, "path": [latlng.toJSON() for latlng in path]}}) + + def attack_unit(self, target_id: int): + return self.api.send_command({"attackUnit": {"ID": self.ID, "targetID": target_id}}) + + def follow_unit(self, target_id: int, offset_x=0, offset_y=0, offset_z=0): + return self.api.send_command({"followUnit": {"ID": self.ID, "targetID": target_id, "offsetX": offset_x, "offsetY": offset_y, "offsetZ": offset_z}}) + + def delete_unit(self, explosion=False, explosion_type="", immediate=True): + return self.api.send_command({"deleteUnit": {"ID": self.ID, "explosion": explosion, "explosionType": explosion_type, "immediate": immediate}}) + + def land_at(self, lat: float, lng: float): + return self.api.send_command({"landAt": {"ID": self.ID, "location": {"lat": lat, "lng": lng}}}) + + def change_speed(self, change: str): + return self.api.send_command({"changeSpeed": {"ID": self.ID, "change": change}}) + + def set_speed(self, speed: float): + return self.api.send_command({"setSpeed": {"ID": self.ID, "speed": speed}}) + + def set_speed_type(self, speed_type: str): + return self.api.send_command({"setSpeedType": {"ID": self.ID, "speedType": speed_type}}) + + def change_altitude(self, change: str): + return self.api.send_command({"changeAltitude": {"ID": self.ID, "change": change}}) + + def set_altitude_type(self, altitude_type: str): + return self.api.send_command({"setAltitudeType": {"ID": self.ID, "altitudeType": altitude_type}}) + + def set_altitude(self, altitude: float): + return self.api.send_command({"setAltitude": {"ID": self.ID, "altitude": altitude}}) + + def set_roe(self, roe: int): + return self.api.send_command({"setROE": {"ID": self.ID, "ROE": roe}}) + + def set_alarm_state(self, alarm_state: int): + return self.api.send_command({"setAlarmState": {"ID": self.ID, "alarmState": alarm_state}}) + + def set_reaction_to_threat(self, reaction_to_threat: int): + return self.api.send_command({"setReactionToThreat": {"ID": self.ID, "reactionToThreat": reaction_to_threat}}) + + def set_emissions_countermeasures(self, emissions_countermeasures: int): + return self.api.send_command({"setEmissionsCountermeasures": {"ID": self.ID, "emissionsCountermeasures": emissions_countermeasures}}) + + def set_on_off(self, on_off: bool): + return self.api.send_command({"setOnOff": {"ID": self.ID, "onOff": on_off}}) + + def set_follow_roads(self, follow_roads: bool): + return self.api.send_command({"setFollowRoads": {"ID": self.ID, "followRoads": follow_roads}}) + + def set_operate_as(self, operate_as: int): + return self.api.send_command({"setOperateAs": {"ID": self.ID, "operateAs": operate_as}}) + + def refuel(self): + return self.api.send_command({"refuel": {"ID": self.ID}}) + + def bomb_point(self, lat: float, lng: float): + return self.api.send_command({"bombPoint": {"ID": self.ID, "location": {"lat": lat, "lng": lng}}}) + + def carpet_bomb(self, lat: float, lng: float): + return self.api.send_command({"carpetBomb": {"ID": self.ID, "location": {"lat": lat, "lng": lng}}}) + + def bomb_building(self, lat: float, lng: float): + return self.api.send_command({"bombBuilding": {"ID": self.ID, "location": {"lat": lat, "lng": lng}}}) + + def fire_at_area(self, lat: float, lng: float): + return self.api.send_command({"fireAtArea": {"ID": self.ID, "location": {"lat": lat, "lng": lng}}}) + + def fire_laser(self, lat: float, lng: float, code: int): + return self.api.send_command({"fireLaser": {"ID": self.ID, "location": {"lat": lat, "lng": lng}, "code": code}}) + + def fire_infrared(self, lat: float, lng: float): + return self.api.send_command({"fireInfrared": {"ID": self.ID, "location": {"lat": lat, "lng": lng}}}) + + def simulate_fire_fight(self, lat: float, lng: float, altitude: float): + return self.api.send_command({"simulateFireFight": {"ID": self.ID, "location": {"lat": lat, "lng": lng}, "altitude": altitude}}) + + def scenic_aaa(self, coalition: str): + return self.api.send_command({"scenicAAA": {"ID": self.ID, "coalition": coalition}}) + + def miss_on_purpose(self, coalition: str): + return self.api.send_command({"missOnPurpose": {"ID": self.ID, "coalition": coalition}}) + + def land_at_point(self, lat: float, lng: float): + return self.api.send_command({"landAtPoint": {"ID": self.ID, "location": {"lat": lat, "lng": lng}}}) + + def set_shots_scatter(self, shots_scatter: int): + return self.api.send_command({"setShotsScatter": {"ID": self.ID, "shotsScatter": shots_scatter}}) + + def set_shots_intensity(self, shots_intensity: int): + return self.api.send_command({"setShotsIntensity": {"ID": self.ID, "shotsIntensity": shots_intensity}}) + + def set_racetrack(self, lat: float, lng: float, bearing: float, length: float): + return self.api.send_command({"setRacetrack": {"ID": self.ID, "location": {"lat": lat, "lng": lng}, "bearing": bearing, "length": length}}) + + def set_advanced_options(self, is_active_tanker: bool, is_active_awacs: bool, tacan: dict, radio: dict, general_settings: dict): + return self.api.send_command({"setAdvancedOptions": {"ID": self.ID, "isActiveTanker": is_active_tanker, "isActiveAWACS": is_active_awacs, "TACAN": tacan, "radio": radio, "generalSettings": general_settings}}) + + def set_engagement_properties(self, barrel_height, muzzle_velocity, aim_time, shots_to_fire, shots_base_interval, shots_base_scatter, engagement_range, targeting_range, aim_method_range, acquisition_range): + return self.api.send_command({"setEngagementProperties": {"ID": self.ID, "barrelHeight": barrel_height, "muzzleVelocity": muzzle_velocity, "aimTime": aim_time, "shotsToFire": shots_to_fire, "shotsBaseInterval": shots_base_interval, "shotsBaseScatter": shots_base_scatter, "engagementRange": engagement_range, "targetingRange": targeting_range, "aimMethodRange": aim_method_range, "acquisitionRange": acquisition_range}}) + + + + \ No newline at end of file diff --git a/scripts/python/API/utils.py b/scripts/python/API/utils.py deleted file mode 100644 index 4db06c53..00000000 --- a/scripts/python/API/utils.py +++ /dev/null @@ -1,19 +0,0 @@ -def enum_to_coalition(coalition_id: int) -> str: - if coalition_id == 0: - return "neutral" - elif coalition_id == 1: - return "red" - elif coalition_id == 2: - return "blue" - return "" - - -def coalition_to_enum(coalition: str) -> int: - if coalition == "neutral": - return 0 - elif coalition == "red": - return 1 - elif coalition == "blue": - return 2 - return 0 -1 \ No newline at end of file diff --git a/scripts/python/API/utils/utils.py b/scripts/python/API/utils/utils.py new file mode 100644 index 00000000..8528f7a7 --- /dev/null +++ b/scripts/python/API/utils/utils.py @@ -0,0 +1,83 @@ +from math import asin, atan2, cos, degrees, radians, sin, sqrt + +def enum_to_coalition(coalition_id: int) -> str: + if coalition_id == 0: + return "neutral" + elif coalition_id == 1: + return "red" + elif coalition_id == 2: + return "blue" + return "" + + +def coalition_to_enum(coalition: str) -> int: + if coalition == "neutral": + return 0 + elif coalition == "red": + return 1 + elif coalition == "blue": + return 2 + return 0 + +def project_with_bearing_and_distance(lat1, lon1, d, bearing, R=6371000): + """ + lat: initial latitude, in degrees + lon: initial longitude, in degrees + d: target distance from initial in meters + bearing: (true) heading in radians + R: optional radius of sphere, defaults to mean radius of earth + + Returns new lat/lon coordinate {d}m from initial, in degrees + """ + lat1 = radians(lat1) + lon1 = radians(lon1) + a = bearing + lat2 = asin(sin(lat1) * cos(d/R) + cos(lat1) * sin(d/R) * cos(a)) + lon2 = lon1 + atan2( + sin(a) * sin(d/R) * cos(lat1), + cos(d/R) - sin(lat1) * sin(lat2) + ) + return (degrees(lat2), degrees(lon2),) + +def distance(lat1, lng1, lat2, lng2): + """ + Calculate the Haversine distance. + Args: + lat1: Latitude of the first point + lng1: Longitude of the first point + lat2: Latitude of the second point + lng2: Longitude of the second point + Returns: + Distance in meters between the two points. + """ + radius = 6371000 + + dlat = radians(lat2 - lat1) + dlon = radians(lng2 - lng1) + a = (sin(dlat / 2) * sin(dlat / 2) + + cos(radians(lat1)) * cos(radians(lat2)) * + sin(dlon / 2) * sin(dlon / 2)) + c = 2 * atan2(sqrt(a), sqrt(1 - a)) + d = radius * c + + return d + +def bearing_to(lat1, lng1, lat2, lng2): + """ + Calculate the bearing from one point to another. + Args: + lat1: Latitude of the first point + lng1: Longitude of the first point + lat2: Latitude of the second point + lng2: Longitude of the second point + Returns: + Bearing in radians from the first point to the second. + """ + dLon = (lng2 - lng1) + x = cos(radians(lat2)) * sin(radians(dLon)) + y = cos(radians(lat1)) * sin(radians(lat2)) - sin(radians(lat1)) * cos(radians(lat2)) * cos(radians(dLon)) + brng = atan2(x,y) + brng = brng + + return brng + \ No newline at end of file diff --git a/scripts/python/API/voice_control.py b/scripts/python/API/voice_control.py new file mode 100644 index 00000000..bc1b1a43 --- /dev/null +++ b/scripts/python/API/voice_control.py @@ -0,0 +1,81 @@ + +from math import pi + +from api import API, UnitSpawnTable +from radio.radio_listener import RadioListener + +# Setup a logger for the module +import logging +logger = logging.getLogger("OlympusVoiceControl") +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +formatter = logging.Formatter('[%(asctime)s] %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) +logger.addHandler(handler) + +# Function to handle received messages +# This function will be called when a message is received on the radio frequency +def on_message_received(recognized_text: str, unit_id: str, api: API, listener: RadioListener): + logger.info(f"Received message from {unit_id}: {recognized_text}") + + units = api.update_units() + + # Extract the unit that sent the message + if not units: + logger.warning("No units available in API, unable to process audio.") + return + + if unit_id not in units: + logger.warning(f"Unit ID {unit_id} not found in API units, unable to process audio.") + return + + unit = units[unit_id] + + # Check for troop disembarkment request (expanded) + keywords = [ + "disembark troops", + "deploy troops", + "unload troops", + "drop off troops", + "let troops out", + "troops disembark", + "troops out", + "extract infantry", + "release soldiers", + "disembark infantry", + "release troops" + ] + is_disembarkment = any(kw in recognized_text.lower() for kw in keywords) + + # Check if "olympus" is mentioned + is_olympus = "olympus" in recognized_text.lower() + + if is_olympus and is_disembarkment: + logger.info("Troop disembarkment requested!") + + # Use the API to spawn an infrantry unit 10 meters away from the unit + spawn_location = unit.position.project_with_bearing_and_distance(bearing=unit.heading+pi/2, d=10) + spawn_table: UnitSpawnTable = UnitSpawnTable( + unit_type="Soldier M4", + location=spawn_location, + heading=unit.heading+pi/2, + skill="High", + livery_id="" + ) + api.spawn_ground_units([spawn_table], unit.coalition, "", True, 0) + message_filename = api.generate_audio_message("Roger, disembarking") + listener.transmit_on_frequency(message_filename, listener.frequency, listener.modulation, listener.encryption) + else: + logger.info("Did not understand the message or no disembarkment request found.") + message_filename = api.generate_audio_message("I did not understand") + listener.transmit_on_frequency(message_filename, listener.frequency, listener.modulation, listener.encryption) + +if __name__ == "__main__": + api = API() + logger.info("API initialized") + + listener = api.create_radio_listener() + listener.start(frequency=251.000e6, modulation=0, encryption=0) + listener.register_message_callback(lambda wav_filename, unit_id, api=api, listener=listener: on_message_received(wav_filename, unit_id, api, listener)) + + api.run() \ No newline at end of file