from typing import Protocol from abc import abstractmethod from . import SoloTool from pathlib import Path from glob import glob import json import requests from os import getenv class SessionManager(): def __init__(self, sessionPath: str): self._sessionPath = sessionPath from re import search match = search(r"^([a-z0-9]+://)", sessionPath) if not match or match.group(0) == "file://": self._backend = _FileSystemBackend(sessionPath) elif match.group(0) in ["http://", "https://"]: self._backend = _FileBrowserBackend(sessionPath) else: raise ValueError(f"Unsupported session path: {sessionPath}") def getSessions(self) -> list[str]: return self._backend.listIds() def loadSession(self, id: str, player=None) -> SoloTool: session = self._backend.read(id) st = SoloTool(player=player) for i, entry in enumerate(session): songPath = entry["path"] keyPoints = entry.get("key_points", []) volume = entry.get("vol", 1.0) st.addSong(songPath, keyPoints=keyPoints, volume=volume) return st def saveSession(self, soloTool: SoloTool, id: str) -> None: session = [] for i, song in enumerate(soloTool.songs): entry = { "path": song, "key_points" : soloTool._keyPoints[i], "vol" : soloTool._volumes[i] } session.append(entry) self._backend.write(session, id) class _Backend(Protocol): @abstractmethod def listIds(self) -> list[str]: raise NotImplementedError @abstractmethod def read(self, id: str) -> dict: raise NotImplementedError @abstractmethod def write(self, session: dict, id: str) -> None: raise NotImplementedError class _FileSystemBackend(_Backend): def __init__(self, sessionPath: str): self._sessionPath = Path(sessionPath) def listIds(self) -> list[str]: return [Path(f).stem for f in glob(f"{self._sessionPath}/*.json")] def read(self, id: str) -> dict: with open(self._sessionPath / f"{id}.json", "r") as f: session = json.load(f) return session def write(self, session: dict, id: str) -> None: with open(self._sessionPath / f"{id}.json", "w") as f: json.dump(session, f) class _FileBrowserBackend(_Backend): def __init__(self, serverUrl: str): self._baseUrl = serverUrl self._username = getenv("ST_USER") self._password = getenv("ST_PASS") self._apiKey = self._getApiKey() def listIds(self) -> list[str]: url = f"{self._baseUrl}/api/resources" response = self._request("GET", url) return [item["name"][0:-5] for item in response.json()["items"] if item["extension"] == ".json"] def read(self, id: str) -> dict: url = f"{self._baseUrl}/api/raw/{id}.json" response = self._request("GET", url) return json.loads(response.content) def write(self, session: dict, id: str) -> None: url = f"{self._baseUrl}/api/resources/{id}.json" self._request("PUT", url, json=session) def _getApiKey(self) -> str: response = requests.post(f"{self._baseUrl}/api/login", json={"username":self._username, "password":self._password}) return response.content def _request(self, verb: str, url: str, **kwargs): headers = {"X-Auth" : self._apiKey} response = requests.request(verb, url, headers=headers, **kwargs) if response.status_code == requests.codes.UNAUTHORIZED: # if unauthorized, the key might have expired self._apiKey = self._getApiKey() headers["X-Auth"] = self._apiKey response = requests.request(verb, url, headers=headers, **kwargs) response.raise_for_status() return response