summaryrefslogtreecommitdiffstats
path: root/lab_control/test/mock_jds6600_device.py
blob: c027573f3ddd6777a245cabe73d9d32fc1b86b0f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import pty
import termios
import threading
import re

class MockJDS6600Device():
    class ChannelState:
        def __init__(self):
            self.on = False
            self.frequency = None
            self.amplitude = None
            self.function = None

    def __init__(self):
        self._master, self._slave = pty.openpty()
        self._masterFile = os.fdopen(self._master, mode="r+b", closefd=False, buffering=0)

        self._portName = os.ttyname(self._slave)
        self._channels = [MockJDS6600Device.ChannelState() for i in [1, 2]]

        self._injectedFailureCounter = 0
        
        self._mainThread = threading.Thread(target=self._mainLoop)
        self._mainThread.start()

    def _mainLoop(self) -> None:
        while True:
            try:
                request = self._masterFile.readline().decode().strip()
                response = self._handleRequest(request)
                
                if response is not None:
                    self._masterFile.write(response.encode())
            except OSError as e:
                break

    def _handleRequest(self, request: str) -> str:
        pattern = r":(?P<opcode>[wrab])(?P<function>\d+)=(?P<args>.*)\."
        m = re.search(pattern, request)
        
        if not m:
            return None

        opcode = m.group("opcode")
        function = int(m.group("function"))
        args = m.group("args").split(",")

        # channel on/off
        if function == 20:
            if opcode == "w":
                self._channels[0].on = args[0] == "1"
                self._channels[1].on = args[1] == "1"
                return ":ok\r\n"
            elif opcode == "r":
                return f":r20={int(self._channels[0].on)},{int(self._channels[1].on)}.\r\n"

        # channel frequency
        elif function == 23 or function == 24:
            ch = function - 23
            if opcode == "w":
                # Actual device takes a second argument for scaling, here we ignore it and always use 0 (Hz)
                frequency = float(args[0]) / 100.0

                if self._injectedFailureCounter > 0:
                    self._channels[ch].frequency = 0.0
                    self._injectedFailureCounter -= 1
                else:
                    self._channels[ch].frequency = frequency
                return ":ok\r\n"
            elif opcode == "r":
                frequency = self._channels[ch].frequency
                return f":r{function}={int(frequency)},0.\r\n"

        # channel amplitude
        elif function == 25 or function == 26:
            if opcode == "w":
                ch = function - 25
                amplitude = float(args[0]) / 1000.0
                self._channels[ch].amplitude = amplitude
                return ":ok\r\n"

        # channel function shape
        elif function == 21 or function == 22:
            if opcode == "w":
                ch = function - 21
                shape = int(args[0])
                self._channels[ch].function = shape
                return ":ok\r\n"

        # Unknown request format, no response
        return None

    def stop(self) -> None:
        self._masterFile.close()
        os.close(self._master)
        os.close(self._slave)
        self._mainThread.join()

    def checkPortConfiguration(self) -> None:
        iflag, oflag, cflag, lflag, ispeed, ospeed, cc = termios.tcgetattr(self._slave)

        # JDS6600 configuration taken from manual
        assert ispeed == termios.B115200
        assert ospeed == termios.B115200
        assert (cflag & termios.CSIZE) == termios.CS8
        assert (cflag & termios.CSTOPB) == 0
        assert (cflag & (termios.PARENB | termios.PARODD)) == 0

    def getPortName(self) -> str:
        return self._portName

    def isOn(self, ch: int) -> bool:
        return self._channels[ch - 1].on

    def getFrequency(self, ch: int) -> float:
        return self._channels[ch - 1].frequency

    def getAmplitude(self, ch: int) -> float:
        return self._channels[ch - 1].amplitude

    def getFunction(self, ch: int) -> int:
        return self._channels[ch - 1].function

    def injectFailures(self, count: int) -> None:
        self._injectedFailureCounter += count