aboutsummaryrefslogtreecommitdiffstats
path: root/solo-tool-project/src/solo_tool/storage.py
blob: 0c5577ffb2cb414bb611e9a1651b8e18ecd7739f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Protocol
from abc import abstractmethod

from pathlib import Path
from glob import glob
import json
import requests
from os import getenv

class StorageBackend(Protocol):
    @abstractmethod
    def listSessions(self) -> list[str]:
        raise NotImplementedError

    @abstractmethod
    def readSession(self, id: str) -> dict:
        raise NotImplementedError

    @abstractmethod
    def writeSession(self, session: dict, id: str) -> None:
        raise NotImplementedError

    @abstractmethod
    def writeRecording(self, recording: Path, destination: str) -> None:
        raise NotImplementedError

class FileSystemStorageBackend(StorageBackend):
    def __init__(self, storagePath: str):
        self._storagePath = Path(storagePath)

    def listSessions(self) -> list[str]:
        #return [Path(f).stem for f in glob(f"{self._storagePath / "sessions"}/*.json")]
        return [Path(f).stem for f in glob(str(self._storagePath / "sessions" / "*.json"))]

    def readSession(self, id: str) -> dict:
        with open(self._storagePath / "sessions" / f"{id}.json", "r") as f:
            session = json.load(f)
        return session

    def writeSession(self, session: dict, id: str) -> None:
        with open(self._storagePath / "sessions" / f"{id}.json", "w") as f:
            json.dump(session, f)

    def writeRecording(self, recording: Path, destination: str) -> None:
        pass

class FileBrowserStorageBackend(StorageBackend):
    def __init__(self, serverUrl: str):
        self._baseUrl = serverUrl
        self._username = getenv("ST_USER")
        self._password = getenv("ST_PASS")
        self._apiKey = self._getApiKey()

    def listSessions(self) -> list[str]:
        url = f"{self._baseUrl}/api/resources/sessions"
        response = self._request("GET", url)
        return [item["name"][0:-5] for item in response.json()["items"] if item["extension"] == ".json"]

    def readSession(self, id: str) -> dict:
        url = f"{self._baseUrl}/api/raw/sessions/{id}.json"
        response = self._request("GET", url)
        return json.loads(response.content)

    def writeSession(self, session: dict, id: str) -> None:
        url = f"{self._baseUrl}/api/resources/sessions/{id}.json"
        self._request("PUT", url, json=session)

    def writeRecording(self, recording: Path, destination: str) -> None:
        url = f"{self._baseUrl}/api/resources/recordings/{destination}"
        with open(recording, "rb") as file:
            self._request("POST", url, {"Content-Type" : "audio/mpeg"}, data=file)

    def _getApiKey(self) -> str:
        response = requests.post(f"{self._baseUrl}/api/login", json={"username":self._username, "password":self._password})
        return response.content

    def _request(self, verb: str, url: str, moreHeaders: dict={}, **kwargs):
        headers = moreHeaders | {"X-Auth" : self._apiKey}
        response = requests.request(verb, url, headers=headers, **kwargs)
        if response.status_code == requests.codes.UNAUTHORIZED:
            # if unauthorized, the key might have expired
            self._apiKey = self._getApiKey()
            headers["X-Auth"] = self._apiKey
            response = requests.request(verb, url, headers=headers, **kwargs)
        response.raise_for_status()
        return response