|
from typing import Literal, Optional |
|
|
|
import requests |
|
from pydantic import BaseModel |
|
|
|
import docker |
|
from inference.core.devices.utils import GLOBAL_DEVICE_ID |
|
from inference.core.env import API_BASE_URL, API_KEY |
|
from inference.core.logger import logger |
|
from inference.core.utils.url_utils import wrap_url |
|
from inference.enterprise.device_manager.container_service import get_container_by_id |
|
|
|
|
|
class Command(BaseModel): |
|
id: str |
|
containerId: str |
|
command: Literal["restart", "stop", "ping", "snapshot", "update_version"] |
|
deviceId: str |
|
requested_on: Optional[int] = None |
|
|
|
|
|
def fetch_commands(): |
|
url = wrap_url( |
|
f"{API_BASE_URL}/devices/{GLOBAL_DEVICE_ID}/commands?api_key={API_KEY}" |
|
) |
|
resp = requests.get(url).json() |
|
for cmd in resp.get("data", []): |
|
handle_command(cmd) |
|
|
|
|
|
def handle_command(cmd_payload: dict): |
|
was_processed = False |
|
container_id = cmd_payload.get("containerId") |
|
container = get_container_by_id(container_id) |
|
if not container: |
|
logger.warn(f"Container with id {container_id} not found") |
|
ack_command(cmd_payload.get("id"), was_processed) |
|
return |
|
cmd = cmd_payload.get("command") |
|
data = None |
|
match cmd: |
|
case "restart": |
|
was_processed, data = container.restart() |
|
case "stop": |
|
was_processed, data = container.stop() |
|
case "ping": |
|
was_processed, data = container.ping() |
|
case "snapshot": |
|
was_processed, data = container.snapshot() |
|
case "start": |
|
was_processed, data = container.start() |
|
case "update_version": |
|
was_processed, data = handle_version_update(container) |
|
case _: |
|
logger.error("Unknown command: {}".format(cmd)) |
|
return ack_command(cmd_payload.get("id"), was_processed, data=data) |
|
|
|
|
|
def ack_command(command_id, was_processed, data=None): |
|
post_body = dict() |
|
post_body["api_key"] = API_KEY |
|
post_body["commandId"] = command_id |
|
post_body["wasProcessed"] = was_processed |
|
if data: |
|
post_body["data"] = data |
|
url = wrap_url(f"{API_BASE_URL}/devices/{GLOBAL_DEVICE_ID}/commands/ack") |
|
requests.post(url, json=post_body) |
|
|
|
|
|
def handle_version_update(container): |
|
try: |
|
config = container.get_startup_config() |
|
image_name = config["image"].split(":")[0] |
|
container.kill() |
|
client = docker.from_env() |
|
new_container = client.containers.run( |
|
image=f"{image_name}:latest", |
|
detach=config["detach"], |
|
privileged=config["privileged"], |
|
labels=config["labels"], |
|
ports=config["port_bindings"], |
|
environment=config["env"], |
|
network="host", |
|
) |
|
logger.info(f"New container started {new_container}") |
|
return True, None |
|
except Exception as e: |
|
logger.error(e) |
|
return False, None |
|
|